[docs]classClassificationEvaluator(Evaluator):"""Multi-class classification evaluator."""METRIC_CLASSIFICATION="Cls"KEY_ACCURACY="Acc@1"KEY_ACCURACY_TOP5="Acc@5"def__init__(self)->None:"""Initialize the classification evaluator."""super().__init__()self._metrics_list:list[dict[str,float]]=[]@propertydefmetrics(self)->list[str]:"""Supported metrics."""return[self.KEY_ACCURACY,self.KEY_ACCURACY_TOP5,]
[docs]defreset(self)->None:"""Reset evaluator for new round of evaluation."""self._metrics_list=[]
def_is_correct(self,pred:NDArrayNumber,target:NDArrayI64,top_k:int=1)->bool:"""Check if the prediction is correct for top-k. Args: pred (NDArrayNumber): Prediction logits, in shape (C, ). target (NDArrayI64): Target logits, in shape (1, ). top_k (int, optional): Top-k to check. Defaults to 1. Returns: bool: Whether the prediction is correct. """top_k=min(top_k,pred.shape[0])top_k_idx=np.argsort(pred)[-top_k:]returnbool(np.any(top_k_idx==target))
[docs]defprocess_batch(# type: ignore # pylint: disable=arguments-differself,prediction:ArrayLike,groundtruth:ArrayLike):"""Process a batch of predictions and groundtruths. Args: prediction (ArrayLike): Prediction, in shape (N, C). groundtruth (ArrayLike): Groundtruth, in shape (N, ). """pred=array_to_numpy(prediction,n_dims=None,dtype=np.float32)gt=array_to_numpy(groundtruth,n_dims=None,dtype=np.int64)foriinrange(pred.shape[0]):self._metrics_list.append({"top1_correct":accuracy(pred[i],gt[i],top_k=1),"top5_correct":accuracy(pred[i],gt[i],top_k=5),})
[docs]defgather(self,gather_func:GenericFunc)->None:"""Accumulate predictions across processes."""all_metrics=gather_func(self._metrics_list)ifall_metricsisnotNone:self._metrics_list=list(itertools.chain(*all_metrics))
[docs]defevaluate(self,metric:str)->tuple[MetricLogs,str]:"""Evaluate predictions. Returns a dict containing the raw data and a short description string containing a readable result. Args: metric (str): Metric to use. See @property metric Returns: metric_data, description tuple containing the metric data (dict with metric name and value) as well as a short string with shortened information. Raises: RuntimeError: if no data has been registered to be evaluated. ValueError: if the metric is not supported. """iflen(self._metrics_list)==0:raiseRuntimeError("""No data registered to calculate metric. Register data using .process() first!""")metric_data:MetricLogs={}short_description=""ifmetric==self.METRIC_CLASSIFICATION:# Top1 accuracytop1_correct=np.array([metric["top1_correct"]formetricinself._metrics_list])top1_acc=np.mean(top1_correct)metric_data[self.KEY_ACCURACY]=top1_accshort_description+=f"Top1 Accuracy: {top1_acc:.4f}\n"# Top5 accuracytop5_correct=np.array([metric["top5_correct"]formetricinself._metrics_list])top5_acc=np.mean(top5_correct)metric_data[self.KEY_ACCURACY_TOP5]=top5_accshort_description+=f"Top5 Accuracy: {top5_acc:.4f}\n"else:raiseValueError(f"Unsupported metric: {metric}")# pragma: no coverreturnmetric_data,short_description