[docs]classSubdividingIterableDataset(IterableDataset[DictData]):"""Subdivides a given dataset into smaller chunks. This also adds a field called 'index' (DataKeys.index) to the data struct in order to relate the data to the source index. Example: Given a dataset (ds) that outputs tensors of the shape (10, 3): sub_ds = SubdividingIterableDataset(ds, n_samples_per_batch = 5) next(iter(sub_ds))['key'].shape >> torch.Size([5, 3]) next(DataLoader(sub_ds, batch_size = 4))['key'].shape >> torch.size([4,5,3]) Assuming the dataset returns two entries with shape (10,3): [e['index'].item() for e in sub_ds] >> [0,0,1,1] """def__init__(self,dataset:Dataset[DictData],n_samples_per_batch:int,preprocess_fn:Callable[[list[DictData]],list[DictData]]=lambdax:x,)->None:"""Creates a new Dataset. Args: dataset (Dataset): The dataset which should be subdivided. n_samples_per_batch: How many samples each batch should contain. The first dimension of dataset[0].shape must be divisible by this number. preprocess_fn (Callable[[list[DictData]], list[DictData]): Preprocessing function. Defaults to identity. """super().__init__()self.dataset=datasetself.n_samples_per_batch=n_samples_per_batchself.preprocess_fn=preprocess_fn
[docs]def__getitem__(self,index:int)->DictData:"""Indexing is not supported for IterableDatasets."""raiseNotImplementedError("IterableDataset does not support indeing")
[docs]def__iter__(self)->Iterator[DictData]:"""Iterates over the dataset, supporting distributed sampling."""worker_info=get_worker_info()ifworker_infoisNone:# not distributednum_workers=1worker_id=0else:# pragma: no covernum_workers=worker_info.num_workersworker_id=worker_info.idasserthasattr(self.dataset,"__len__"),"Dataset must have __len__ in order to be subdivided."n_samples=len(self.dataset)foriinrange(math.ceil(n_samples/num_workers)):data_idx=i*num_workers+worker_idifdata_idx>=n_samples:continuedata_sample=self.dataset[data_idx]n_elements=list((data_sample.values()))[0].shape[0]foridxinrange(int(n_elements/self.n_samples_per_batch)):# This is kind of ugly# this field defines from which source the data was loaded# (first entry, second entry, ...)# this is required if we e.g. want to subdivide a room that is# too big into equal sized chunks and stick them back together# for visualizatonout_data:DictData={"source_index":np.ndarray([data_idx])}forkeyindata_sample:start_idx=idx*self.n_samples_per_batchend_idx=(idx+1)*self.n_samples_per_batchif(len(data_sample[key]))<self.n_samples_per_batch:out_data[key]=data_sample[key]else:out_data[key]=data_sample[key][start_idx:end_idx,...]yieldself.preprocess_fn([out_data])[0]