Source code for vis4d.engine.connectors.multi_sensor
"""Data connector for multi-sensor dataset."""from__future__importannotationsfromtypingimportNamedTuplefromvis4d.data.typingimportDictData,DictDataOrListfrom.baseimportCallbackConnector,DataConnector,LossConnectorfrom.utilimportSourceKeyDescription,get_field_from_prediction
[docs]classMultiSensorDataConnector(DataConnector):"""Data connector for multi-sensor data dict."""def__init__(self,key_mapping:dict[str,str|SourceKeyDescription]):"""Initializes the data connector with static remapping of the keys. Args: key_mapping (dict[str, | SourceKeyDescription]): Defines which kwargs to pass onto the module. TODO: Add Simple Example Configuration: """_key_mapping={}multi_sensor_key_mapping={}fork,vinkey_mapping.items():ifisinstance(v,dict):sensors=v.get("sensors")ifsensorsisnotNone:multi_sensor_key_mapping[k]=velse:_key_mapping[k]=v["key"]else:_key_mapping[k]=vsuper().__init__(_key_mapping)self.multi_sensor_key_mapping=multi_sensor_key_mapping
[docs]def__call__(self,data:DictDataOrList)->DictData:"""Returns the train input for the model."""input_dict=super().__call__(data)fortarget_key,source_keyinself.multi_sensor_key_mapping.items():key=source_key["key"]sensors=source_key["sensors"]ifisinstance(data,list):input_dict[target_key]=[[d[sensor][key]forsensorinsensors]fordindata]else:input_dict[target_key]=[data[sensor][key]forsensorinsensors]returninput_dict
[docs]classMultiSensorLossConnector(LossConnector):"""Multi-sensor Data connector for loss module of the training pipeline."""
[docs]def__call__(self,prediction:DictData|NamedTuple,data:DictData)->DictData:"""Returns the kwargs that are passed to the loss module. Args: prediction (DictData | NamedTuple): The output from model. data (DictData): The data dictionary from the dataloader which contains all data that was loaded. Returns: DictData: kwargs that are passed onto the loss. """returnget_multi_sensor_inputs(self.key_mapping,prediction,data)
[docs]classMultiSensorCallbackConnector(CallbackConnector):"""Multi-sensor data connector for the callback."""
[docs]def__call__(self,prediction:DictData|NamedTuple,data:DictData)->DictData:"""Returns the kwargs that are passed to the callback. Args: prediction (DictData | NamedTuple): The output from model. data (DictData): The data dictionary from the dataloader which contains all data that was loaded. Returns: DictData: kwargs that are passed onto the callback. """returnget_multi_sensor_inputs(self.key_mapping,prediction,data)
[docs]defget_multi_sensor_inputs(connection_dict:dict[str,SourceKeyDescription],prediction:DictData|NamedTuple,data:DictData,)->DictData:"""Extracts multi-sensor input data from the provided SourceKeyDescription. Args: connection_dict (dict[str, SourceKeyDescription]): Input Key description which is used to gather and remap data from the two data dicts. prediction (DictData): Dict containing the model prediction output. data (DictData): Dict containing the dataloader output. Raises: ValueError: If the datasource is invalid. Returns: out (DictData): Dict containing new kwargs consisting of new key name and data extracted from the data dicts. """out:DictData={}fornew_key_name,old_key_nameinconnection_dict.items():# Assign field from dataifold_key_name["source"]=="data":sensors=old_key_name.get("sensors")ifsensorsisNone:ifold_key_name["key"]notindata:raiseValueError(f"Key {old_key_name['key']} not found in data dict."f" Available keys: {data.keys()}")out[new_key_name]=data[old_key_name["key"]]else:out[new_key_name]=[data[sensor][old_key_name["key"]]forsensorinsensors]# Assign field from predictionelifold_key_name["source"]=="prediction":out[new_key_name]=get_field_from_prediction(prediction,old_key_name)else:raiseValueError(f"Unknown data source {old_key_name['source']}."f"Available: [prediction, data]")returnout