"""Utility functions for the connectors module."""from__future__importannotationsfromcollections.abcimportSequencefromcopyimportdeepcopyfromtypingimportNamedTuple,TypedDictfromtorchimportTensorfromtyping_extensionsimportNotRequiredfromvis4d.common.dictimportget_dict_nestedfromvis4d.common.named_tupleimportget_from_namedtuple,is_namedtuplefromvis4d.common.typingimportDictStrArrNestedfromvis4d.data.typingimportDictData
[docs]classSourceKeyDescription(TypedDict):"""Defines a data entry by providing the key and source of the data. Attributes: key (str): Key that is used to index data from the specified source source (str): Which datasource to choose from. Options are ['data', 'prediction'] where data referes to the output of the dataloader and prediction refers to the model output sensors (Sequence[str]): Which sensors to use for the data. """key:strsource:strsensors:NotRequired[Sequence[str]]
[docs]defremap_pred_keys(info:dict[str,SourceKeyDescription],parent_key:str)->dict[str,SourceKeyDescription]:"""Remaps the key of a connection mapping to a new parent key. Args: info (SourceKeyDescription): Description to remap. parent_key (str): New parent_key to use. Returns: SourceKeyDescription: Description with new key. """info=deepcopy(info)forvalueininfo.values():ifvalue["source"]=="prediction":value["key"]=parent_key+"."+value["key"]returninfo
[docs]defdata_key(key:str,sensors:Sequence[str]|None=None)->SourceKeyDescription:"""Returns a SourceKeyDescription with data as source. Args: key (str): Key to use for the data entry. sensors (Sequence[str] | None, optional): Which sensors to use for the data. Defaults to None. Returns: SourceKeyDescription: A SourceKeyDescription with data as source. """ifsensorsisNone:returnSourceKeyDescription(key=key,source="data")returnSourceKeyDescription(key=key,source="data",sensors=sensors)
[docs]defpred_key(key:str)->SourceKeyDescription:"""Returns a SourceKeyDescription with prediction as source. Args: key (str): Key to use for the data entry. Returns: SourceKeyDescription: A SourceKeyDescription with prediction as source. """returnSourceKeyDescription(key=key,source="prediction")
[docs]defget_field_from_prediction(prediction:DictData|NamedTuple,old_key_name:SourceKeyDescription,)->Tensor|DictStrArrNested:"""Extracts a field from the prediction dict. Args: prediction (DictData): Dict containing the model prediction output. old_key_name (SourceKeyDescription): Description of the data to extract. Returns: Tensor | DictStrArrNested: Data extracted from the prediction dict. """ifis_namedtuple(prediction):returnget_from_namedtuple(prediction,old_key_name["key"]# type: ignore)old_key=old_key_name["key"]returnget_dict_nested(prediction,old_key.split("."))# type: ignore
[docs]defget_inputs_for_pred_and_data(connection_dict:dict[str,SourceKeyDescription],prediction:DictData|NamedTuple,data:DictData,)->dict[str,Tensor|DictStrArrNested]:"""Extracts input data from the provided SourceKeyDescription. Args: connection_dict (dict[str, SourceKeyDescription]): Input Key description which is used to gather and remap data from the two data dicts. prediction (DictData): Dict containing the model prediction output. data (DictData): Dict containing the dataloader output. Raises: ValueError: If the datasource is invalid. Returns: out (dict[str, Tensor | DictStrArrNested]): Dict containing new kwargs consisting of new key name and data extracted from the data dicts. """out={}fornew_key_name,old_key_nameinconnection_dict.items():# Assign field from dataifold_key_name["source"]=="data":ifold_key_name["key"]notindata:raiseValueError(f"Key {old_key_name['key']} not found in data dict."f" Available keys: {data.keys()}")out[new_key_name]=data[old_key_name["key"]]# Assign field from model predictionelifold_key_name["source"]=="prediction":out[new_key_name]=get_field_from_prediction(prediction,old_key_name)else:raiseValueError(f"Unknown data source {old_key_name['source']}."f" Available: [prediction, data]")returnout