"""This module contains utilities for callbacks."""from__future__importannotationsimportosfromtorchimportnnfromvis4d.commonimportArgsType,MetricLogsfromvis4d.common.distributedimport(all_gather_object_cpu,broadcast,rank_zero_only,synchronize,)fromvis4d.common.loggingimportrank_zero_infofromvis4d.data.typingimportDictDatafromvis4d.eval.baseimportEvaluatorfrom.baseimportCallbackfrom.trainer_stateimportTrainerState
[docs]classEvaluatorCallback(Callback):"""Callback for model evaluation. Args: evaluator (Evaluator): Evaluator. metrics_to_eval (list[str], Optional): Metrics to evaluate. If None, all metrics in the evaluator will be evaluated. Defaults to None. save_predictions (bool): If the predictions should be saved. Defaults to False. save_prefix (str, Optional): Output directory for saving the evaluation results. Defaults to None. """def__init__(self,*args:ArgsType,evaluator:Evaluator,metrics_to_eval:list[str]|None=None,save_predictions:bool=False,save_prefix:None|str=None,**kwargs:ArgsType,)->None:"""Init callback."""super().__init__(*args,**kwargs)self.evaluator=evaluatorself.save_predictions=save_predictionsself.metrics_to_eval=metrics_to_evalorself.evaluator.metricsifself.save_predictions:assert(save_prefixisnotNone),"If save_predictions is True, save_prefix must be provided."self.output_dir=save_prefix
[docs]defsetup(self)->None:# pragma: no cover"""Setup callback."""ifself.save_predictions:self.output_dir=broadcast(self.output_dir)formetricinself.metrics_to_eval:output_dir=os.path.join(self.output_dir,metric)os.makedirs(output_dir,exist_ok=True)self.evaluator.reset()
[docs]defon_test_batch_end(self,trainer_state:TrainerState,model:nn.Module,outputs:DictData,batch:DictData,batch_idx:int,dataloader_idx:int=0,)->None:"""Hook to run at the end of a testing batch."""self.evaluator.process_batch(**self.get_test_callback_inputs(outputs,batch))formetricinself.metrics_to_eval:# Save output predictions in current batch.ifself.save_predictions:output_dir=os.path.join(self.output_dir,metric)self.evaluator.save_batch(metric,output_dir)
[docs]defon_test_epoch_end(self,trainer_state:TrainerState,model:nn.Module)->None|MetricLogs:"""Hook to run at the end of a testing epoch."""self.evaluator.gather(all_gather_object_cpu)synchronize()log_dict=self.evaluate()log_dict=broadcast(log_dict)self.evaluator.reset()returnlog_dict
[docs]@rank_zero_onlydefevaluate(self)->MetricLogs:"""Evaluate the performance after processing all input/output pairs. Returns: MetricLogs: A dictionary containing the evaluation results. The keys are formatted as {metric_name}/{key_name}, and the values are the corresponding evaluated values. """rank_zero_info("Running evaluator %s...",str(self.evaluator))self.evaluator.process()log_dict={}formetricinself.metrics_to_eval:# Save output predictions. This is done here instead of# on_test_batch_end because the evaluator may not have processed# all batches yet.ifself.save_predictions:output_dir=os.path.join(self.output_dir,metric)self.evaluator.save(metric,output_dir)# Evaluate metricmetric_dict,metric_str=self.evaluator.evaluate(metric)fork,vinmetric_dict.items():log_k=metric+"/"+krank_zero_info("%s: %.4f",log_k,v)log_dict[f"{metric}/{k}"]=vrank_zero_info("Showing results for metric: %s",metric)rank_zero_info(metric_str)returnlog_dict