"""Callback for updating EMA model."""from__future__importannotationsfromtorchimportnnfromvis4d.common.distributedimportis_module_wrapperfromvis4d.common.typingimportMetricLogsfromvis4d.data.typingimportDictDatafromvis4d.engine.loss_moduleimportLossModulefromvis4d.model.adapterimportModelEMAAdapterfrom.baseimportCallbackfrom.trainer_stateimportTrainerState
[docs]classEMACallback(Callback):"""Callback for EMA."""
[docs]defon_train_batch_end(# pylint: disable=useless-returnself,trainer_state:TrainerState,model:nn.Module,loss_module:LossModule,outputs:DictData,batch:DictData,batch_idx:int,)->None|MetricLogs:"""Hook to run at the end of a training batch."""ifis_module_wrapper(model):module=model.moduleelse:module=modelassertisinstance(module,ModelEMAAdapter),("Model should be wrapped with ModelEMAAdapter when using ""EMACallback.")module.update(trainer_state["global_step"])returnNone