"""BDD100K segmentation evaluator."""from__future__importannotationsimportitertoolsfromcollections.abcimportCallablefromtypingimportAnyimportnumpyasnpfromvis4d.common.arrayimportarray_to_numpyfromvis4d.common.importsimportBDD100K_AVAILABLE,SCALABEL_AVAILABLEfromvis4d.common.typingimportArrayLike,MetricLogsfromvis4d.data.datasets.bdd100kimportbdd100k_seg_mapfrom..baseimportEvaluatorifSCALABEL_AVAILABLEandBDD100K_AVAILABLE:frombdd100k.common.utilsimportload_bdd100k_configfrombdd100k.label.to_scalabelimportbdd100k_to_scalabelfromscalabel.eval.sem_segimportevaluate_sem_segfromscalabel.label.ioimportloadfromscalabel.label.transformsimportmask_to_rlefromscalabel.label.typingimportFrame,Labelelse:raiseImportError("scalabel or bdd100k is not installed.")
[docs]classBDD100KSegEvaluator(Evaluator):"""BDD100K segmentation evaluation class."""inverse_seg_map={v:kfork,vinbdd100k_seg_map.items()}def__init__(self,annotation_path:str)->None:"""Initialize the evaluator."""super().__init__()self.annotation_path=annotation_pathself.frames:list[Frame]=[]bdd100k_anns=load(annotation_path)frames=bdd100k_anns.framesself.config=load_bdd100k_config("sem_seg")self.gt_frames=bdd100k_to_scalabel(frames,self.config)self.reset()
[docs]def__repr__(self)->str:"""Concise representation of the dataset evaluator."""return"BDD100K Segmentation Evaluator"
[docs]defgather(# type: ignore # pragma: no coverself,gather_func:Callable[[Any],Any])->None:"""Gather variables in case of distributed setting (if needed). Args: gather_func (Callable[[Any], Any]): Gather function. """all_preds=gather_func(self.frames)ifall_predsisnotNone:self.frames=list(itertools.chain(*all_preds))
[docs]defreset(self)->None:"""Reset the evaluator."""self.frames=[]
[docs]defprocess_batch(# type: ignore # pylint: disable=arguments-differself,data_names:list[str],masks_list:list[ArrayLike])->None:"""Process tracking results."""masks_numpy=[array_to_numpy(m,None)forminmasks_list]# to numpyfordata_name,masksinzip(data_names,masks_numpy):labels=[]fori,class_idinenumerate(np.unique(masks)):label=Label(rle=mask_to_rle((masks==class_id).astype(np.uint8)),category=self.inverse_seg_map[int(class_id)],id=str(i),)labels.append(label)frame=Frame(name=data_name,labels=labels)self.frames.append(frame)
[docs]defevaluate(self,metric:str)->tuple[MetricLogs,str]:"""Evaluate the dataset."""ifmetric=="sem_seg":results=evaluate_sem_seg(ann_frames=self.gt_frames,pred_frames=self.frames,config=self.config.scalabel,nproc=0,)else:raiseNotImplementedErrorreturn{},str(results)