"""Resample index to recover the original dataset length."""from__future__importannotationsimportnumpyasnpfromtorch.utils.dataimportDatasetfromvis4d.common.loggingimportrank_zero_infofrom.referenceimportMultiViewDatasetfrom.typingimportDictDataOrList
[docs]classResampleDataset(Dataset[DictDataOrList]):"""Dataset wrapper to recover the filtered samples through resampling. In MMEngine and Detectron2, the dataset might return None when the sample has no valid annotations. They will resample the index and try to get the valid training data. The length of dataset will be different depends on whether filtering the empty samples first. This dataset wrapper resamples the index to recover the original dataset length (before filter empty frames) to align with the other codebases' implementation. https://github.com/open-mmlab/mmengine/blob/main/mmengine/dataset/base_dataset.py#L411 https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/common.py#L96 """def__init__(self,dataset:Dataset[DictDataOrList])->None:"""Creates an instance of the class."""super().__init__()self.dataset=datasetself.has_reference=isinstance(dataset,MultiViewDataset)self.valid_len=len(dataset)# type: ignore# Handle the case that dataset is already wrapped.ifhasattr(self.dataset,"dataset"):_dataset=self.dataset.datasetelse:_dataset=self.datasetasserthasattr(_dataset,"original_len"),("The dataset must have the attribute `original_len` to resample "+"index to recover the original length.")self.original_len=_dataset.original_lenrank_zero_info(f"Recover {_dataset} to {self.original_len} samples by resampling "+"index.")
[docs]def__len__(self)->int:"""Return the length of dataset. Returns: int: Length of dataset. """returnself.original_len
[docs]def__getitem__(self,idx:int)->DictDataOrList:"""Get original dataset idx according to the given index. Resample index to recover the original dataset length. Args: idx (int): The index of original dataset length. Returns: DictDataOrList: Data of the corresponding index. """ifidx<self.valid_len:index=idxelse:index=np.random.randint(0,self.valid_len)returnself.dataset[index]