vis4d.engine.callbacks.ema

Callback for updating EMA model.

Classes

EMACallback([epoch_based, train_connector, ...])

Callback for EMA.

class EMACallback(epoch_based=True, train_connector=None, test_connector=None)[source]

Callback for EMA.

Init callback.

Parameters:
  • epoch_based (bool, optional) – Whether the callback is epoch based. Defaults to False.

  • train_connector (None | CallbackConnector, optional) – Defines which kwargs to use during training for different callbacks. Defaults to None.

  • test_connector (None | CallbackConnector, optional) – Defines which kwargs to use during testing for different callbacks. Defaults to None.

on_train_batch_end(trainer_state, model, loss_module, outputs, batch, batch_idx)[source]

Hook to run at the end of a training batch.

Return type:

Optional[Dict[str, Union[float, int, Tensor]]]