vis4d.engine.callbacks.base

Base module for callbacks.

Classes

Callback([epoch_based, train_connector, ...])

Base class for Callbacks.

class Callback(epoch_based=True, train_connector=None, test_connector=None)[source]

Base class for Callbacks.

Init callback.

Parameters:
  • epoch_based (bool, optional) – Whether the callback is epoch based. Defaults to False.

  • train_connector (None | CallbackConnector, optional) – Defines which kwargs to use during training for different callbacks. Defaults to None.

  • test_connector (None | CallbackConnector, optional) – Defines which kwargs to use during testing for different callbacks. Defaults to None.

get_test_callback_inputs(outputs, batch)[source]

Returns the data connector results for inference.

It extracts the required data from prediction and datas and passes it to the next component with the provided new key.

Parameters:
  • outputs (DictData) – Outputs of the model.

  • batch (DictData) – Batch data.

Returns:

Data connector results.

Return type:

dict[str, Tensor | DictStrArrNested]

Raises:

AssertionError – If test connector is None.

get_train_callback_inputs(outputs, batch)[source]

Returns the data connector results for training.

It extracts the required data from prediction and datas and passes it to the next component with the provided new key.

Parameters:
  • outputs (DictData) – Outputs of the model.

  • batch (DictData) – Batch data.

Returns:

Data connector results.

Return type:

dict[str, Tensor | DictStrArrNested]

Raises:

AssertionError – If train connector is None.

on_test_batch_end(trainer_state, model, outputs, batch, batch_idx, dataloader_idx=0)[source]

Hook to run at the end of a testing batch.

Parameters:
  • trainer_state (TrainerState) – Trainer state.

  • model (Module) – Model that is being trained.

  • outputs (DictData) – Model prediction output.

  • batch (DictData) – Dataloader output data batch.

  • batch_idx (int) – Index of the batch.

  • dataloader_idx (int, optional) – Index of the dataloader. Defaults to 0.

Return type:

None

on_test_epoch_end(trainer_state, model)[source]

Hook to run at the end of a testing epoch.

Parameters:
  • trainer_state (TrainerState) – Trainer state.

  • model (nn.Module) – Model that is being trained.

Return type:

Optional[Dict[str, Union[float, int, Tensor]]]

on_test_epoch_start(trainer_state, model)[source]

Hook to run at the beginning of a testing epoch.

Parameters:
  • trainer_state (TrainerState) – Trainer state.

  • model (nn.Module) – Model that is being trained.

Return type:

None

on_train_batch_end(trainer_state, model, loss_module, outputs, batch, batch_idx)[source]

Hook to run at the end of a training batch.

Parameters:
  • trainer_state (TrainerState) – Trainer state.

  • model (Module) – Model that is being trained.

  • loss_module (LossModule) – Loss module.

  • outputs (DictData) – Model prediction output.

  • batch (DictData) – Dataloader output data batch.

  • batch_idx (int) – Index of the batch.

Return type:

Optional[Dict[str, Union[float, int, Tensor]]]

on_train_batch_start(trainer_state, model, loss_module, batch, batch_idx)[source]

Hook to run at the start of a training batch.

Parameters:
  • trainer_state (TrainerState) – Trainer state.

  • model (Module) – Model that is being trained.

  • loss_module (LossModule) – Loss module.

  • batch (DictData) – Dataloader output data batch.

  • batch_idx (int) – Index of the batch.

Return type:

None

on_train_epoch_end(trainer_state, model, loss_module)[source]

Hook to run at the end of a training epoch.

Parameters:
  • trainer_state (TrainerState) – Trainer state.

  • model (nn.Module) – Model that is being trained.

  • loss_module (LossModule) – Loss module.

Return type:

None

on_train_epoch_start(trainer_state, model, loss_module)[source]

Hook to run at the beginning of a training epoch.

Parameters:
  • trainer_state (TrainerState) – Trainer state.

  • model (nn.Module) – Model that is being trained.

  • loss_module (LossModule) – Loss module.

Return type:

None

setup()[source]

Setup callback.

Return type:

None