"""Scalabel base evaluator."""from__future__importannotationsimportitertoolsfromcollections.abcimportCallablefromtypingimportAnyfromvis4d.common.importsimportSCALABEL_AVAILABLEfromvis4d.common.typingimportMetricLogsfromvis4d.eval.baseimportEvaluatorifSCALABEL_AVAILABLE:fromscalabel.label.ioimportloadfromscalabel.label.typingimportConfig,Framefromscalabel.label.utilsimportget_leaf_categorieselse:raiseImportError("scalabel is not installed.")
[docs]classScalabelEvaluator(Evaluator):"""Scalabel base evaluation class."""def__init__(self,annotation_path:str,config:Config|None=None)->None:"""Initialize the evaluator."""super().__init__()self.annotation_path=annotation_pathself.frames:list[Frame]=[]dataset=load(self.annotation_path,validate_frames=False)self.gt_frames=dataset.framesifconfigisnotNone:self.config:Config|None=configelse:self.config=dataset.configifself.configisnotNoneandself.config.categoriesisnotNone:categories=get_leaf_categories(self.config.categories)self.inverse_cat_map={cat_id:cat.nameforcat_id,catinenumerate(categories)}else:self.inverse_cat_map={}self.reset()
[docs]defgather(# type: ignore # pragma: no coverself,gather_func:Callable[[Any],Any])->None:"""Gather variables in case of distributed setting (if needed). Args: gather_func (Callable[[Any], Any]): Gather function. """all_preds=gather_func(self.frames)ifall_predsisnotNone:self.frames=list(itertools.chain(*all_preds))
[docs]defreset(self)->None:"""Reset the evaluator."""self.frames=[]
[docs]defprocess_batch(# type: ignore # pragma: no coverself,*args:Any,**kwargs:Any)->None:"""Process sample and update confusion matrix."""raiseNotImplementedError
[docs]defevaluate(self,metric:str)->tuple[MetricLogs,str]:"""Evaluate the dataset."""raiseNotImplementedError