"""Class-balanced Grouping and Sampling for 3D Object Detection.Implementation of `Class-balanced Grouping and Sampling for Point Cloud 3DObject Detection <https://arxiv.org/abs/1908.09492>`_."""from__future__importannotationsimportnumpyasnpfromtorch.utils.dataimportDatasetfromvis4d.common.distributedimportbroadcast,rank_zero_onlyfromvis4d.common.loggingimportrank_zero_infofromvis4d.common.timeimportTimerfrom.datasets.utilimportprint_class_histogramfrom.referenceimportMultiViewDatasetfrom.typingimportDictDataOrList# TODO: Support sensor selection.
[docs]classCBGSDataset(Dataset[DictDataOrList]):"""Balance the number of scenes under different classes."""def__init__(self,dataset:Dataset[DictDataOrList],class_map:dict[str,int],ignore:int=-1,)->None:"""Creates an instance of the class."""super().__init__()self.dataset=datasetself.has_reference=isinstance(dataset,MultiViewDataset)self.cat2id=dict(sorted(class_map.items(),key=lambdax:x[1]))self.ignore=ignorerank_zero_info("Wrapping dataset with CBGS...")sample_indices=self._get_sample_indices()self.sample_indices=broadcast(sample_indices)def_show_histogram(self,sample_indices:list[int],sample_frequencies:list[dict[str,int]],)->None:"""Show class histogram."""frequencies={cat:0forcatinself.cat2id.keys()}foridxinsample_indices:freq=sample_frequencies[idx]forbox3d_classinfreq:frequencies[box3d_class]+=freq[box3d_class]print_class_histogram(frequencies)def_get_class_sample_indices(self,)->tuple[dict[int,list[int]],list[dict[str,int]]]:"""Get sample indices."""class_sample_indices:dict[int,list[int]]={cat_id:[]forcat_idinself.cat2id.values()}sample_frequencies=[]inv_class_map={v:kfork,vinself.cat2id.items()}# Handle the case that dataset is already wrapped.ifhasattr(self.dataset,"dataset"):dataset=self.dataset.datasetelse:dataset=self.datasetforidxinrange(len(dataset)):asserthasattr(dataset,"get_cat_ids"),"The dataset must have a method `get_cat_ids` to get cat ids."cat_ids=dataset.get_cat_ids(idx)cur_cats={}frequencies={cat:0forcatinself.cat2id.keys()}forcat_idincat_ids:ifcat_id!=self.ignore:cur_cats[cat_id]=[idx]frequencies[inv_class_map[cat_id]]+=1sample_frequencies.append(frequencies)forcat_id,indexincur_cats.items():class_sample_indices[cat_id]+=indexreturnclass_sample_indices,sample_frequencies@rank_zero_onlydef_get_sample_indices(self)->list[int]:"""Load sample indices. Returns: list[int]: List of indices after class sampling. """t=Timer()(class_sample_indices,sample_frequencies,)=self._get_class_sample_indices()duplicated_samples=sum(len(v)for_,vinclass_sample_indices.items())class_distribution={k:len(v)/duplicated_samplesfork,vinclass_sample_indices.items()}sample_indices=[]frac=1.0/len(self.cat2id)ratios=[frac/vforvinclass_distribution.values()]forcls_inds,ratioinzip(list(class_sample_indices.values()),ratios):sample_indices+=np.random.choice(cls_inds,int(len(cls_inds)*ratio)).tolist()self._show_histogram(sample_indices,sample_frequencies)rank_zero_info(f"Generating {len(sample_indices)} CBGS samples takes "+f"{t.time():.2f} seconds.")returnsample_indices
[docs]def__len__(self)->int:"""Return the length of sample indices. Returns: int: Length of sample indices. """returnlen(self.sample_indices)
[docs]def__getitem__(self,idx:int)->DictDataOrList:"""Get original dataset idx according to the given index. Args: idx (int): The index of self.sample_indices. Returns: DictDataOrList: Data of the corresponding index. """ori_index=self.sample_indices[idx]returnself.dataset[ori_index]