"""This module contains utilities for callbacks."""from__future__importannotationsimportosimporttorchfromtorchimportnnfromvis4d.commonimportArgsTypefromvis4d.common.distributedimportbroadcast,rank_zero_onlyfromvis4d.data.typingimportDictDatafromvis4d.engine.callbacks.trainer_stateimportTrainerStatefromvis4d.engine.loss_moduleimportLossModulefrom.baseimportCallbackfrom.trainer_stateimportTrainerState
[docs]classCheckpointCallback(Callback):"""Callback for model checkpointing."""def__init__(self,*args:ArgsType,save_prefix:str,checkpoint_period:int=1,**kwargs:ArgsType,)->None:"""Init callback. Args: save_prefix (str): Prefix of checkpoint path for saving. checkpoint_period (int, optional): Checkpoint period. Defaults to 1. """super().__init__(*args,**kwargs)self.output_dir=f"{save_prefix}/checkpoints"self.checkpoint_period=checkpoint_period
[docs]defsetup(self)->None:# pragma: no cover"""Setup callback."""self.output_dir=broadcast(self.output_dir)os.makedirs(self.output_dir,exist_ok=True)
[docs]defon_train_batch_end(self,trainer_state:TrainerState,model:nn.Module,loss_module:LossModule,outputs:DictData,batch:DictData,batch_idx:int,)->None:"""Hook to run at the end of a training batch."""if(notself.epoch_basedandtrainer_state["global_step"]%self.checkpoint_period==0):self._save_checkpoint(trainer_state,model)
[docs]defon_train_epoch_end(self,trainer_state:TrainerState,model:nn.Module,loss_module:LossModule,)->None:"""Hook to run at the end of a training epoch."""if(self.epoch_basedand(trainer_state["current_epoch"]+1)%self.checkpoint_period==0):self._save_checkpoint(trainer_state,model)