Source code for vis4d.zoo.bdd100k.qdtrack.qdtrack_frcnn_r50_fpn_1x_bdd100k
# pylint: disable=duplicate-code"""QDTrack with Faster R-CNN on BDD100K."""from__future__importannotationsimportlightning.pytorchasplfromtorch.optimimportSGDfromtorch.optim.lr_schedulerimportLinearLR,MultiStepLRfromvis4d.configimportclass_configfromvis4d.config.typingimportExperimentConfig,ExperimentParametersfromvis4d.data.datasets.bdd100kimportbdd100k_track_mapfromvis4d.data.io.hdf5importHDF5Backendfromvis4d.engine.callbacksimportEvaluatorCallbackfromvis4d.engine.connectorsimportCallbackConnector,DataConnectorfromvis4d.eval.bdd100kimportBDD100KTrackEvaluatorfromvis4d.op.baseimportResNetfromvis4d.zoo.baseimport(get_default_callbacks_cfg,get_default_cfg,get_default_pl_trainer_cfg,get_lr_scheduler_cfg,get_optimizer_cfg,)fromvis4d.zoo.base.datasets.bdd100kimport(CONN_BDD100K_TRACK_EVAL,get_bdd100k_track_cfg,)fromvis4d.zoo.base.models.qdtrackimport(CONN_BBOX_2D_TEST,CONN_BBOX_2D_TRAIN,get_qdtrack_cfg,)
[docs]defget_config()->ExperimentConfig:"""Returns the config dict for qdtrack on bdd100k. Returns: ExperimentConfig: The configuration """######################################################## General Config ########################################################config=get_default_cfg(exp_name="qdtrack_frcnn_r50_fpn_1x_bdd100k")# High level hyper parametersparams=ExperimentParameters()params.samples_per_gpu=4# batch size = 4 GPUs * 4 samples per GPU = 16params.workers_per_gpu=4params.lr=0.02params.num_epochs=12config.params=params######################################################## Datasets with augmentations ########################################################data_backend=class_config(HDF5Backend)config.data=get_bdd100k_track_cfg(data_backend=data_backend,samples_per_gpu=params.samples_per_gpu,workers_per_gpu=params.workers_per_gpu,)######################################################## MODEL ########################################################num_classes=len(bdd100k_track_map)basemodel=class_config(ResNet,resnet_name="resnet50",pretrained=True,trainable_layers=3)config.model,config.loss=get_qdtrack_cfg(num_classes=num_classes,basemodel=basemodel)######################################################## OPTIMIZERS ########################################################config.optimizers=[get_optimizer_cfg(optimizer=class_config(SGD,lr=params.lr,momentum=0.9,weight_decay=0.0001),lr_schedulers=[get_lr_scheduler_cfg(class_config(LinearLR,start_factor=0.1,total_iters=1000),end=1000,epoch_based=False,),get_lr_scheduler_cfg(class_config(MultiStepLR,milestones=[8,11],gamma=0.1),),],)]######################################################## DATA CONNECTOR ########################################################config.train_data_connector=class_config(DataConnector,key_mapping=CONN_BBOX_2D_TRAIN)config.test_data_connector=class_config(DataConnector,key_mapping=CONN_BBOX_2D_TEST)######################################################## CALLBACKS ######################################################### Logger and Checkpointcallbacks=get_default_callbacks_cfg(config.output_dir)# Evaluatorcallbacks.append(class_config(EvaluatorCallback,evaluator=class_config(BDD100KTrackEvaluator,annotation_path="data/bdd100k/labels/box_track_20/val/",),test_connector=class_config(CallbackConnector,key_mapping=CONN_BDD100K_TRACK_EVAL),))config.callbacks=callbacks######################################################## PL CLI ######################################################### PL Trainer argspl_trainer=get_default_pl_trainer_cfg(config)pl_trainer.max_epochs=params.num_epochsconfig.pl_trainer=pl_trainerpl_trainer.gradient_clip_val=35# PL Callbackspl_callbacks:list[pl.callbacks.Callback]=[]config.pl_callbacks=pl_callbacksreturnconfig.value_mode()