"""Provides functionalities to wrap torchvision datasets."""from__future__importannotationsfromcollections.abcimportCallablefromtypingimportAnyimportnumpyasnpfromPIL.ImageimportImagefromtorchvision.datasetsimportVisionDatasetfromtorchvision.transformsimportToTensorfrom..constimportCommonKeysasKfrom..typingimportDictDatafrom.baseimportDataset
[docs]classTorchvisionDataset(Dataset):"""Wrapper for torchvision datasets. This class wraps torchvision datasets and converts them to the format that is expected by the vis4d framework. The return of the torchvisons dataset is passed to the data_converter, which needs to be provided by the user. The data_converter is expected to return a DictData object following the vis4d conventions. For well defined dataformats, such as classification, there are already implemented wrappers that can be used. See `TorchvisionClassificationDataset` for an example. """def__init__(# type: ignoreself,torchvision_ds:VisionDataset,data_converter:Callable[[Any],DictData],)->None:"""Creates a new instance of the class. Args: torchvision_ds (VisionDataset): Torchvision dataset that should be converted. data_converter (Callable[[Any], DictData]): Function that converts the output of the torchvision datasets __getitem__ to the format expected by the vis4d framework. """super().__init__()self.torchvision_ds=torchvision_dsself.data_converter=data_converter
[docs]def__getitem__(self,idx:int)->DictData:"""Returns a new sample from the dataset. Args: idx (int): Index of the sample. Returns: DictData: Data in vis4d format. """returnself.data_converter(self.torchvision_ds[idx])
[docs]def__len__(self)->int:"""Returns the number of samples in the dataset. Returns: int: Length of the dataset. """returnlen(self.torchvision_ds)
[docs]classTorchvisionClassificationDataset(TorchvisionDataset):"""Wrapper for torchvision classification datasets. This class wraps torchvision classification datasets and converts them to the format that is expected by the vis4d framework. It expects the torchvision dataset to return a tuple of (image, class_id) where the image is a PIL Image and the class_id is an integer. If you want to use a torchvision dataset that returns a different format, you can provide a custom data_converter function to the `TorchvisionDataset` class. The returned sample will have the following key, values: images: ndarray of dimension (1, H, W, C) categories: ndarray of dimension 1. Example: >>> from torchvision.datasets.mnist import MNIST >>> ds = TorchvisionClassificationDataset( >>> MNIST("data/mnist_ds", train=False) >>> ) >>> data = next(iter(ds)) >>> print(data.keys) dict_keys(['images', 'categories']) """def__init__(self,detection_ds:VisionDataset)->None:"""Creates a new instance of the class. Args: detection_ds (VisionDataset): Torchvision dataset that returns a tuple of (image, class_id) where the image is a PIL Image and the class_id is an integer. """img_to_tensor=ToTensor()def_data_converter(img_and_target:tuple[Image,int])->DictData:"""Converts the output of a torchvision dataset. The output is converted to the format expected by the vis4d framework. Args: img_and_target (tuple[Image, int]): Output of the datasets __getitem__ method. Returns: DictData: Sample in vis4d format. """img,class_id=img_and_targetdata:DictData={}data[K.images]=(img_to_tensor(img).unsqueeze(0).permute(0,2,3,1).numpy())data[K.categories]=np.array([class_id],dtype=np.int64)returndatasuper().__init__(detection_ds,_data_converter)