vis4d.engine.connectors.base¶
Base data connector to define data structures for data connection.
Classes
|
Data connector for the callback. |
|
Defines which data to pass to which component. |
|
Defines which data to pass to loss module of the training pipeline. |
- class CallbackConnector(key_mapping)[source]¶
Data connector for the callback.
It extracts the required data from prediction and datas and passes it to the next component with the provided new key.
Initializes the data connector with static remapping of the keys.
- __call__(prediction, data)[source]¶
Returns the kwargs that are passed to the callback.
- Parameters:
prediction (DictData | NamedTuple) – The output from model.
data (DictData) – The data dictionary from the dataloader which contains all data that was loaded.
- Returns:
- kwargs that are passed
onto the callback.
- Return type:
dict[str, Tensor | DictStrArrNested]
- class DataConnector(key_mapping)[source]¶
Defines which data to pass to which component.
It extracts the required data from a ‘DictData’ objects and passes it to the next component with the provided new key.
Initializes the data connector with static remapping of the keys.
- Parameters:
key_mapping (dict[str, str]) – Defines which kwargs to pass onto the module.
Simple Example Configuration:
>>> train = dict(images = "images", gt = "gt_images) >>> train_data_connector = DataConnector(train) >>> test = dict(images = "images") >>> test_data_connector = DataConnector(test)
- class LossConnector(key_mapping)[source]¶
Defines which data to pass to loss module of the training pipeline.
It extracts the required data from prediction and data and passes it to the next component with the provided new key.
Initializes the data connector with static remapping of the keys.
- __call__(prediction, data)[source]¶
Returns the kwargs that are passed to the loss module.
- Parameters:
prediction (DictData | NamedTuple) – The output from model.
data (DictData) – The data dictionary from the dataloader which contains all data that was loaded.
- Returns:
- kwargs that are passed
onto the loss.
- Return type:
dict[str, Tensor | DictStrArrNested]