"""DataPipe wraps datasets to share the prepossessing pipeline."""from__future__importannotationsimportrandomfromcollections.abcimportCallable,Iterablefromtorch.utils.dataimportConcatDataset,Datasetfrom.referenceimportMultiViewDatasetfrom.transforms.baseimportTFunctorfrom.typingimportDictData,DictDataOrList
[docs]classDataPipe(ConcatDataset[DictDataOrList]):"""DataPipe class. This class wraps one or multiple instances of a PyTorch Dataset so that the preprocessing steps can be shared across those datasets. Composes dataset and the preprocessing pipeline. """def__init__(self,datasets:Dataset[DictDataOrList]|Iterable[Dataset[DictDataOrList]],preprocess_fn:Callable[[list[DictData]],list[DictData]]=lambdax:x,):"""Creates an instance of the class. Args: datasets (Dataset | Iterable[Dataset]): Dataset(s) to be wrapped by this data pipeline. preprocess_fn (Callable[[list[DictData]], list[DictData]]): Preprocessing function of a single sample. It takes a list of samples and returns a list of samples. Defaults to identity function. """ifisinstance(datasets,Dataset):datasets=[datasets]super().__init__(datasets)self.preprocess_fn=preprocess_fnself.has_reference=any(_check_reference(dataset)fordatasetindatasets)ifself.has_referenceandnotall(_check_reference(dataset)fordatasetindatasets):raiseValueError("All datasets must be MultiViewDataset / has reference if "+"one of them is.")
[docs]def__getitem__(self,idx:int)->DictDataOrList:"""Wrap getitem to apply augmentations."""samples=super().__getitem__(idx)ifisinstance(samples,list):returnself.preprocess_fn(samples)returnself.preprocess_fn([samples])[0]
[docs]classMultiSampleDataPipe(DataPipe):"""MultiSampleDataPipe class. This class wraps DataPipe to support augmentations that require multiple images (e.g., Mosaic and Mixup) by sampling additional indices for each image. NUM_SAMPLES needs to be defined as a class attribute for transforms that require multi-sample augmentation. """def__init__(self,datasets:Dataset[DictDataOrList]|Iterable[Dataset[DictDataOrList]],preprocess_fn:list[list[TFunctor]],):"""Creates an instance of the class. Args: datasets (Dataset | Iterable[Dataset]): Dataset(s) to be wrapped by this data pipeline. preprocess_fn (list[list[TFunctor]]): Preprocessing functions of a single sample. Different than DataPipe, this is a list of lists of transformation functions. The inner list is for transforms that needs to share the same sampled indices (e.g., GenMosaicParameters and MosaicImages), and the outer list is for different transforms. """super().__init__(datasets)self.preprocess_fns=preprocess_fndef_sample_indices(self,idx:int,num_samples:int)->list[int]:"""Sample additional indices for multi-sample augmentation."""indices=[idx]for_inrange(1,num_samples):indices.append(random.randint(0,len(self)-1))returnindices
[docs]def__getitem__(self,idx:int)->DictDataOrList:"""Wrap getitem to apply augmentations."""samples=super(DataPipe,self).__getitem__(idx)ifnotisinstance(samples,list):samples=[samples]single_view=Trueelse:single_view=Falseforpreprocess_fninself.preprocess_fns:ifhasattr(preprocess_fn[0],"NUM_SAMPLES"):num_samples=preprocess_fn[0].NUM_SAMPLESaug_inds=self._sample_indices(idx,num_samples)add_samples=[super(DataPipe,self).__getitem__(ind)forindinaug_inds[1:]]prep_samples=[]fori,sampinenumerate(samples):prep_samples.append(samp)prep_samples+=[s[i]ifisinstance(s,list)elsesforsinadd_samples]else:num_samples=1prep_samples=samplesforprep_fninpreprocess_fn:prep_samples=prep_fn.apply_to_data(prep_samples)# type: ignore # pylint: disable=line-too-longsamples=prep_samples[::num_samples]returnsamples[0]ifsingle_viewelsesamples
def_check_reference(dataset:Dataset[DictDataOrList])->bool:"""Check if the datasets have reference."""has_reference=(dataset.has_referenceifhasattr(dataset,"has_reference")elseFalse)returnhas_referenceorisinstance(dataset,MultiViewDataset)