"""This module contains utilities for callbacks."""from__future__importannotationsfromcollectionsimportdefaultdictfromtorchimportnnfromvis4d.commonimportArgsType,MetricLogsfromvis4d.common.loggingimportrank_zero_infofromvis4d.common.progressimportcompose_log_strfromvis4d.common.timeimportTimerfromvis4d.data.typingimportDictDatafromvis4d.engine.loss_moduleimportLossModulefrom.baseimportCallbackfrom.trainer_stateimportTrainerState
[docs]classLoggingCallback(Callback):"""Callback for logging."""def__init__(self,*args:ArgsType,refresh_rate:int=50,**kwargs:ArgsType)->None:"""Init callback."""super().__init__(*args,**kwargs)self._refresh_rate=refresh_rateself._metrics:dict[str,list[float]]=defaultdict(list)self.train_timer=Timer()self.test_timer=Timer()self.last_step=0
[docs]defon_train_batch_start(self,trainer_state:TrainerState,model:nn.Module,loss_module:LossModule,batch:DictData,batch_idx:int,)->None:"""Hook to run at the start of a training batch."""ifnotself.epoch_basedandself.train_timer.paused:self.train_timer.resume()
[docs]defon_train_epoch_start(self,trainer_state:TrainerState,model:nn.Module,loss_module:LossModule,)->None:"""Hook to run at the start of a training epoch."""ifself.epoch_based:self.train_timer.reset()self.last_step=0self._metrics.clear()eliftrainer_state["global_step"]==0:self.train_timer.reset()
[docs]defon_train_batch_end(self,trainer_state:TrainerState,model:nn.Module,loss_module:LossModule,outputs:DictData,batch:DictData,batch_idx:int,)->None|MetricLogs:"""Hook to run at the end of a training batch."""if"metrics"intrainer_state:fork,vintrainer_state["metrics"].items():self._metrics[k].append(v)ifself.epoch_based:cur_iter=batch_idx+1total_iters=(trainer_state["num_train_batches"]iftrainer_state["num_train_batches"]isnotNoneelse-1)else:# After optimizer.step(), global_step is already incremented by 1.cur_iter=trainer_state["global_step"]total_iters=trainer_state["num_steps"]log_dict:None|MetricLogs=Noneifcur_iter%self._refresh_rate==0andcur_iter!=self.last_step:prefix=(f"Epoch {trainer_state['current_epoch']+1}"ifself.epoch_basedelse"Iter")log_dict={k:sum(v)/len(v)iflen(v)>0elsefloat("NaN")fork,vinself._metrics.items()}rank_zero_info(compose_log_str(prefix,cur_iter,total_iters,self.train_timer,log_dict))self._metrics.clear()self.last_step=cur_iterreturnlog_dict
[docs]defon_test_epoch_start(self,trainer_state:TrainerState,model:nn.Module)->None:"""Hook to run at the start of a testing epoch."""self.test_timer.reset()ifnotself.epoch_based:self.train_timer.pause()
[docs]defon_test_batch_end(self,trainer_state:TrainerState,model:nn.Module,outputs:DictData,batch:DictData,batch_idx:int,dataloader_idx:int=0,)->None:"""Hook to run at the end of a testing batch."""cur_iter=batch_idx+1total_iters=(trainer_state["num_test_batches"][dataloader_idx]iftrainer_state["num_test_batches"]isnotNoneelse-1)ifcur_iter%self._refresh_rate==0:rank_zero_info(compose_log_str("Testing",cur_iter,total_iters,self.test_timer))