"""CLI interface using PyTorch Lightning."""from__future__importannotationsimportloggingimportos.pathasospimporttorchfromabslimportapp# pylint: disable=no-name-in-modulefromlightning.fabric.utilities.exceptionsimportMisconfigurationExceptionfromlightning.pytorchimportCallbackfromtorch.utils.collect_envimportget_pretty_env_infofromvis4d.commonimportArgsTypefromvis4d.common.loggingimportdump_config,rank_zero_info,setup_loggerfromvis4d.common.utilimportset_tf32fromvis4d.configimportinstantiate_classesfromvis4d.config.typingimportExperimentConfigfromvis4d.engine.callbacksimportCheckpointCallbackfromvis4d.engine.flagimport_CKPT,_CONFIG,_GPUS,_RESUME,_SHOW_CONFIGfromvis4d.engine.parserimportpprints_configfromvis4d.pl.callbacksimportCallbackWrapper,LRSchedulerCallbackfromvis4d.pl.data_moduleimportDataModulefromvis4d.pl.trainerimportPLTrainerfromvis4d.pl.training_moduleimportTrainingModule
[docs]defmain(argv:ArgsType)->None:"""Main entry point for the CLI. Example to run this script: >>> python -m vis4d.pl.run fit --config configs/faster_rcnn/faster_rcnn_coco.py """# Get configmode=argv[1]assertmodein{"fit","test"},f"Invalid mode: {mode}"config:ExperimentConfig=_CONFIG.valuenum_gpus=_GPUS.value# Setup logginglogger_vis4d=logging.getLogger("vis4d")logger_pl=logging.getLogger("pytorch_lightning")log_file=osp.join(config.output_dir,f"log_{config.timestamp}.txt")setup_logger(logger_vis4d,log_file)setup_logger(logger_pl,log_file)# Dump configconfig_file=osp.join(config.output_dir,f"config_{config.timestamp}.yaml")dump_config(config,config_file)rank_zero_info("Environment info: %s",get_pretty_env_info())# PyTorch Settingset_tf32(config.use_tf32,config.tf32_matmul_precision)torch.hub.set_dir(f"{config.work_dir}/.cache/torch/hub")# Setup deviceifnum_gpus>0:config.pl_trainer.accelerator="gpu"config.pl_trainer.devices=num_gpuselse:config.pl_trainer.accelerator="cpu"config.pl_trainer.devices=1trainer_args=instantiate_classes(config.pl_trainer).to_dict()if_SHOW_CONFIG.value:rank_zero_info(pprints_config(config))# Instantiate classesifmode=="fit":train_data_connector=instantiate_classes(config.train_data_connector)loss=instantiate_classes(config.loss)else:train_data_connector=Noneloss=Noneifconfig.test_data_connectorisnotNone:test_data_connector=instantiate_classes(config.test_data_connector)else:test_data_connector=None# Callbackscallbacks:list[Callback]=[]forcbinconfig.callbacks:callback=instantiate_classes(cb)# Skip checkpoint callback to use PL ModelCheckpointifnotisinstance(callback,CheckpointCallback):callbacks.append(CallbackWrapper(callback))if"pl_callbacks"inconfig:pl_callbacks=[instantiate_classes(cb)forcbinconfig.pl_callbacks]else:pl_callbacks=[]forcbinpl_callbacks:ifnotisinstance(cb,Callback):raiseMisconfigurationException("Callback must be a subclass of pytorch_lightning Callback. "f"Provided callback: {cb} is not!")callbacks.append(cb)# Add needed callbackscallbacks.append(LRSchedulerCallback())# Checkpoint pathckpt_path=_CKPT.value# Resume trainingresume=_RESUME.valueifresume:ifckpt_pathisNone:resume_ckpt_path=osp.join(config.output_dir,"checkpoints/last.ckpt")else:resume_ckpt_path=ckpt_pathelse:resume_ckpt_path=Nonetrainer=PLTrainer(callbacks=callbacks,**trainer_args)hyper_params=trainer_argsifconfig.get("params",None)isnotNone:hyper_params.update(config.params.to_dict())training_module=TrainingModule(config.model,config.optimizers,loss,train_data_connector,test_data_connector,hyper_params,config.seed,ckpt_pathifnotresumeelseNone,)data_module=DataModule(config.data)ifmode=="fit":trainer.fit(training_module,datamodule=data_module,ckpt_path=resume_ckpt_path)elifmode=="test":trainer.test(training_module,datamodule=data_module,verbose=False)
[docs]defentrypoint()->None:"""Entry point for the CLI."""app.run(main)