"""This module contains utilities for callbacks."""from__future__importannotationsimportosfromtorchimportnnfromvis4d.commonimportArgsTypefromvis4d.common.distributedimportbroadcastfromvis4d.data.typingimportDictDatafromvis4d.engine.loss_moduleimportLossModulefromvis4d.vis.baseimportVisualizerfrom.baseimportCallbackfrom.trainer_stateimportTrainerState
[docs]classVisualizerCallback(Callback):"""Callback for model visualization."""def__init__(self,*args:ArgsType,visualizer:Visualizer,visualize_train:bool=False,show:bool=False,save_to_disk:bool=True,save_prefix:str|None=None,**kwargs:ArgsType,)->None:"""Init callback. Args: visualizer (Visualizer): Visualizer. visualize_train (bool): If the training data should be visualized. Defaults to False. save_prefix (str): Output directory for saving the visualizations. show (bool): If the visualizations should be shown. Defaults to False. save_to_disk (bool): If the visualizations should be saved to disk. Defaults to True. """super().__init__(*args,**kwargs)self.visualizer=visualizerself.visualize_train=visualize_trainself.save_prefix=save_prefixself.show=showself.save_to_disk=save_to_diskifself.save_to_disk:assert(save_prefixisnotNone),"If save_to_disk is True, save_prefix must be provided."self.output_dir=f"{self.save_prefix}/vis"
[docs]defsetup(self)->None:# pragma: no cover"""Setup callback."""ifself.save_to_disk:self.output_dir=broadcast(self.output_dir)
[docs]defon_train_batch_end(self,trainer_state:TrainerState,model:nn.Module,loss_module:LossModule,outputs:DictData,batch:DictData,batch_idx:int,)->None:"""Hook to run at the end of a training batch."""cur_iter=batch_idx+1ifself.visualize_train:self.visualizer.process(cur_iter=cur_iter,**self.get_train_callback_inputs(outputs,batch),)ifself.show:self.visualizer.show(cur_iter=cur_iter)ifself.save_to_disk:os.makedirs(f"{self.output_dir}/train",exist_ok=True)self.visualizer.save_to_disk(cur_iter=cur_iter,output_folder=f"{self.output_dir}/train",)self.visualizer.reset()
[docs]defon_test_batch_end(self,trainer_state:TrainerState,model:nn.Module,outputs:DictData,batch:DictData,batch_idx:int,dataloader_idx:int=0,)->None:"""Hook to run at the end of a testing batch."""cur_iter=batch_idx+1self.visualizer.process(cur_iter=cur_iter,**self.get_test_callback_inputs(outputs,batch),)ifself.show:self.visualizer.show(cur_iter=cur_iter)ifself.save_to_disk:os.makedirs(f"{self.output_dir}/test",exist_ok=True)self.visualizer.save_to_disk(cur_iter=cur_iter,output_folder=f"{self.output_dir}/test",)self.visualizer.reset()