vis4d.engine.callbacks.base¶
Base module for callbacks.
Classes
|
Base class for Callbacks. |
- class Callback(epoch_based=True, train_connector=None, test_connector=None)[source]¶
Base class for Callbacks.
Init callback.
- Parameters:
epoch_based (bool, optional) – Whether the callback is epoch based. Defaults to False.
train_connector (None | CallbackConnector, optional) – Defines which kwargs to use during training for different callbacks. Defaults to None.
test_connector (None | CallbackConnector, optional) – Defines which kwargs to use during testing for different callbacks. Defaults to None.
- get_test_callback_inputs(outputs, batch)[source]¶
Returns the data connector results for inference.
It extracts the required data from prediction and datas and passes it to the next component with the provided new key.
- Parameters:
outputs (DictData) – Outputs of the model.
batch (DictData) – Batch data.
- Returns:
Data connector results.
- Return type:
dict[str, Tensor | DictStrArrNested]
- Raises:
AssertionError – If test connector is None.
- get_train_callback_inputs(outputs, batch)[source]¶
Returns the data connector results for training.
It extracts the required data from prediction and datas and passes it to the next component with the provided new key.
- Parameters:
outputs (DictData) – Outputs of the model.
batch (DictData) – Batch data.
- Returns:
Data connector results.
- Return type:
dict[str, Tensor | DictStrArrNested]
- Raises:
AssertionError – If train connector is None.
- on_test_batch_end(trainer_state, model, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Hook to run at the end of a testing batch.
- Parameters:
trainer_state (TrainerState) – Trainer state.
model (
Module
) – Model that is being trained.outputs (DictData) – Model prediction output.
batch (DictData) – Dataloader output data batch.
batch_idx (int) – Index of the batch.
dataloader_idx (int, optional) – Index of the dataloader. Defaults to 0.
- Return type:
None
- on_test_epoch_end(trainer_state, model)[source]¶
Hook to run at the end of a testing epoch.
- Parameters:
trainer_state (TrainerState) – Trainer state.
model (nn.Module) – Model that is being trained.
- Return type:
Optional
[Dict
[str
,Union
[float
,int
,Tensor
]]]
- on_test_epoch_start(trainer_state, model)[source]¶
Hook to run at the beginning of a testing epoch.
- Parameters:
trainer_state (TrainerState) – Trainer state.
model (nn.Module) – Model that is being trained.
- Return type:
None
- on_train_batch_end(trainer_state, model, loss_module, outputs, batch, batch_idx)[source]¶
Hook to run at the end of a training batch.
- Parameters:
trainer_state (TrainerState) – Trainer state.
model (
Module
) – Model that is being trained.loss_module (LossModule) – Loss module.
outputs (DictData) – Model prediction output.
batch (DictData) – Dataloader output data batch.
batch_idx (int) – Index of the batch.
- Return type:
Optional
[Dict
[str
,Union
[float
,int
,Tensor
]]]
- on_train_batch_start(trainer_state, model, loss_module, batch, batch_idx)[source]¶
Hook to run at the start of a training batch.
- Parameters:
trainer_state (TrainerState) – Trainer state.
model (
Module
) – Model that is being trained.loss_module (LossModule) – Loss module.
batch (DictData) – Dataloader output data batch.
batch_idx (int) – Index of the batch.
- Return type:
None
- on_train_epoch_end(trainer_state, model, loss_module)[source]¶
Hook to run at the end of a training epoch.
- Parameters:
trainer_state (TrainerState) – Trainer state.
model (nn.Module) – Model that is being trained.
loss_module (LossModule) – Loss module.
- Return type:
None
- on_train_epoch_start(trainer_state, model, loss_module)[source]¶
Hook to run at the beginning of a training epoch.
- Parameters:
trainer_state (TrainerState) – Trainer state.
model (nn.Module) – Model that is being trained.
loss_module (LossModule) – Loss module.
- Return type:
None