[docs]defthreshold_and_flatten(prediction:NDArrayNumber,target:NDArrayNumber,threshold_value:float)->tuple[NDArrayBool,NDArrayBool]:"""Thresholds the predictions based on the provided treshold value. Applies the following actions: prediction -> prediction >= threshold_value pred, gt = pred.ravel().bool(), gt.ravel().bool() Args: prediction: Prediction array with continuous values target: Grondgtruth values {0,1} threshold_value: Value to use to convert the continuous prediction into binary. Returns: tuple of two boolean arrays, prediction and target """prediction_bin:NDArrayBool=prediction>=threshold_valuereturnprediction_bin.ravel().astype(bool),target.ravel().astype(bool)
[docs]classBinaryEvaluator(Evaluator):"""Creates a new Evaluater that evaluates binary predictions."""METRIC_BINARY="BinaryCls"KEY_IOU="IoU"KEY_ACCURACY="Accuracy"KEY_F1="F1"KEY_PRECISION="Precision"KEY_RECALL="Recall"def__init__(self,threshold:float=0.5,)->None:"""Creates a new binary evaluator. Args: threshold (float): Threshold for prediction to convert to binary. All prediction that are higher than this value will be assigned the 'True' label """super().__init__()self.threshold=thresholdself.reset()self.true_positives:list[float]=[]self.false_positives:list[float]=[]self.true_negatives:list[float]=[]self.false_negatives:list[float]=[]self.n_samples:list[float]=[]self.has_samples=Falsedef_calc_confusion_matrix(self,prediction:NDArrayBool,target:NDArrayBool)->None:"""Calculates the confusion matrix and stores them as attributes. Args: prediction: the prediction (binary) (N, Pts) target: the groundtruth (binary) (N, Pts) """tp=int(np.sum(np.logical_and(prediction==1,target==1)))fp=int(np.sum(np.logical_and(prediction==1,target==0)))tn=int(np.sum(np.logical_and(prediction==0,target==0)))fn=int(np.sum(np.logical_and(prediction==0,target==1)))self.true_positives.append(tp)self.false_positives.append(fp)self.true_negatives.append(tn)self.false_negatives.append(fn)self.n_samples.append(tp+fp+tn+fn)@propertydefmetrics(self)->list[str]:"""Supported metrics."""return[self.METRIC_BINARY]
[docs]defreset(self)->None:"""Reset the saved predictions to start new round of evaluation."""self.true_positives=[]self.false_positives=[]self.true_negatives=[]self.false_negatives=[]self.n_samples=[]
[docs]defprocess_batch(# type: ignore # pylint: disable=arguments-differself,prediction:ArrayLike,groundtruth:ArrayLike,)->None:"""Processes a new (batch) of predictions. Calculates the metrics and caches them internally. Args: prediction: the prediction(continuous values or bin) (Batch x Pts) groundtruth: the groundtruth (binary) (Batch x Pts) """pred,gt=threshold_and_flatten(array_to_numpy(prediction,n_dims=None,dtype=np.float32),array_to_numpy(groundtruth,n_dims=None,dtype=np.bool_),self.threshold,)# Confusion Matrixself._calc_confusion_matrix(pred,gt)self.has_samples=True
[docs]defevaluate(self,metric:str)->tuple[MetricLogs,str]:"""Evaluate predictions. Returns a dict containing the raw data and a short description string containing a readable result. Args: metric (str): Metric to use. See @property metric Returns: metric_data, description tuple containing the metric data (dict with metric name and value) as well as a short string with shortened information. Raises: RuntimeError: if no data has been registered to be evaluated. ValueError: if metric is not supported. """ifnotself.has_samples:raiseRuntimeError("""No data registered to calculate metric. Register data using .process() first!""")metric_data:MetricLogs={}short_description=""ifmetric==self.METRIC_BINARY:# IoUiou=sum(self.true_positives)/(sum(self.n_samples)-sum(self.true_negatives)+1e-6)metric_data[self.KEY_IOU]=ioushort_description+=f"IoU: {iou:.3f}\n"# Accuracyacc=(sum(self.true_positives)+sum(self.true_negatives))/sum(self.n_samples)metric_data[self.KEY_ACCURACY]=accshort_description+=f"Accuracy: {acc:.3f}\n"# Precisiontp_fp=sum(self.true_positives)+sum(self.false_positives)precision=sum(self.true_positives)/tp_fpiftp_fp!=0else1metric_data[self.KEY_PRECISION]=precisionshort_description+=f"Precision: {precision:.3f}\n"# Recalltp_fn=sum(self.true_positives)+sum(self.false_negatives)recall=sum(self.true_positives)/tp_fniftp_fn!=0else1metric_data[self.KEY_RECALL]=recallshort_description+=f"Recall: {acc:.3f}\n"# F1f1=2*precision*recall/(precision+recall+1e-8)metric_data[self.KEY_F1]=f1short_description+=f"F1: {f1:.3f}\n"else:raiseValueError(f"Unsupported metric: {metric}")# pragma: no coverreturnmetric_data,short_description