"""Base dataset classes.We implement a typed version of the PyTorch dataset class here. In addition, weprovide a number of Mixin classes which a dataset can inherit from to implementadditional functionality."""from__future__importannotationsfromcollections.abcimportSequencefromtypingimportTypedDictfromtorch.utils.dataimportDatasetasTorchDatasetfromvis4d.commonimportArgsTypefromvis4d.data.io.baseimportDataBackendfromvis4d.data.io.fileimportFileBackendfromvis4d.data.typingimportDictData
[docs]classDataset(TorchDataset[DictData]):"""Basic pytorch dataset with defined return type."""# Dataset metadata.DESCRIPTION=""HOMEPAGE=""PAPER=""LICENSE=""# List of all keys supported by this dataset.KEYS:Sequence[str]=[]def__init__(self,image_channel_mode:str="RGB",data_backend:None|DataBackend=None,)->None:"""Initialize dataset. Args: image_channel_mode (str): Image channel mode to use. Default: RGB. data_backend (None | DataBackend): Data backend to use. Default: None. """self.image_channel_mode=image_channel_modeself.data_backend=(data_backendifdata_backendisnotNoneelseFileBackend())
[docs]def__len__(self)->int:"""Return length of dataset."""raiseNotImplementedError
[docs]def__getitem__(self,idx:int)->DictData:"""Convert single element at given index into Vis4D data format."""raiseNotImplementedError
[docs]defvalidate_keys(self,keys_to_load:Sequence[str])->None:"""Validate that all keys to load are supported. Args: keys_to_load (list[str]): List of keys to load. Raises: ValueError: Raise if any key is not defined in AVAILABLE_KEYS. """forkinkeys_to_load:ifknotinself.KEYS:raiseValueError(f"Key '{k}' is not supported!")
[docs]classVideoMapping(TypedDict):"""Grouped dataset sample indices and frame indices."""video_to_indices:dict[str,list[int]]video_to_frame_ids:dict[str,list[int]]
[docs]classVideoDataset(Dataset):"""Video datasets. Provides video_mapping attribute for video based interface and reference view samplers. """def__init__(self,*args:ArgsType,**kwargs:ArgsType)->None:"""Initialize dataset."""super().__init__(*args,**kwargs)self.video_mapping:VideoMapping={"video_to_indices":{},"video_to_frame_ids":{},}def_sort_video_mapping(self,video_mapping:VideoMapping)->VideoMapping:"""Sort video mapping by frame ids."""video_to_indices=video_mapping["video_to_indices"]video_to_frame_ids=video_mapping["video_to_frame_ids"]forseqinvideo_to_indices:sorted_zipped=sorted(list(zip(video_to_indices[seq],video_to_frame_ids[seq])),key=lambdax:x[1],)sorted_indices,sorted_frame_ids=zip(*sorted_zipped)video_mapping["video_to_indices"][seq]=list(sorted_indices)video_mapping["video_to_frame_ids"][seq]=list(sorted_frame_ids)returnvideo_mappingdef_generate_video_mapping(self)->VideoMapping:"""Group dataset sample by their associated video ID. The sample index is an integer while video IDs are string. Returns: VideoMapping: Mapping of video IDs to sample indices and frame IDs. """raiseNotImplementedError