Source code for vis4d.zoo.qdtrack.qdtrack_frcnn_r50_fpn_augs_1x_bdd100k

# pylint: disable=duplicate-code
"""QDTrack with Faster R-CNN on BDD100K."""
from __future__ import annotations

from lightning.pytorch.callbacks import ModelCheckpoint
from torch.optim import SGD
from torch.optim.lr_scheduler import LinearLR, MultiStepLR

from vis4d.config import class_config
from vis4d.config.typing import ExperimentConfig, ExperimentParameters
from vis4d.data.datasets.bdd100k import bdd100k_track_map
from vis4d.data.io.hdf5 import HDF5Backend
from vis4d.engine.callbacks import (
    EvaluatorCallback,
    VisualizerCallback,
    YOLOXModeSwitchCallback,
)
from vis4d.engine.connectors import CallbackConnector, DataConnector
from vis4d.eval.bdd100k import BDD100KTrackEvaluator
from vis4d.op.base import ResNet
from vis4d.vis.image import BoundingBoxVisualizer
from vis4d.zoo.base import (
    get_default_callbacks_cfg,
    get_default_cfg,
    get_default_pl_trainer_cfg,
    get_lr_scheduler_cfg,
    get_optimizer_cfg,
)
from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TRACK_VIS
from vis4d.zoo.base.datasets.bdd100k import CONN_BDD100K_TRACK_EVAL
from vis4d.zoo.base.models.qdtrack import (
    CONN_BBOX_2D_TEST,
    CONN_BBOX_2D_TRAIN,
    get_qdtrack_cfg,
)
from vis4d.zoo.qdtrack.data_yolox import get_bdd100k_track_cfg


[docs] def get_config() -> ExperimentConfig: """Returns the config dict for qdtrack on bdd100k. Returns: ExperimentConfig: The configuration """ ###################################################### ## General Config ## ###################################################### config = get_default_cfg(exp_name="qdtrack_frcnn_r50_fpn_augs_1x_bdd100k") # High level hyper parameters params = ExperimentParameters() params.samples_per_gpu = 4 # batch size = 4 GPUs * 4 samples per GPU = 16 params.workers_per_gpu = 8 params.lr = 0.02 params.num_epochs = 12 config.params = params ###################################################### ## Datasets with augmentations ## ###################################################### data_backend = class_config(HDF5Backend) config.data = get_bdd100k_track_cfg( data_backend=data_backend, image_size=(720, 1280), normalize_image=True, samples_per_gpu=params.samples_per_gpu, workers_per_gpu=params.workers_per_gpu, ) ###################################################### ## MODEL ## ###################################################### num_classes = len(bdd100k_track_map) basemodel = class_config( ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=3 ) config.model, config.loss = get_qdtrack_cfg( num_classes=num_classes, basemodel=basemodel ) ###################################################### ## OPTIMIZERS ## ###################################################### config.optimizers = [ get_optimizer_cfg( optimizer=class_config( SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 ), lr_schedulers=[ get_lr_scheduler_cfg( class_config(LinearLR, start_factor=0.1, total_iters=1000), end=1000, epoch_based=False, ), get_lr_scheduler_cfg( class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), ), ], ) ] ###################################################### ## DATA CONNECTOR ## ###################################################### config.train_data_connector = class_config( DataConnector, key_mapping=CONN_BBOX_2D_TRAIN ) config.test_data_connector = class_config( DataConnector, key_mapping=CONN_BBOX_2D_TEST ) ###################################################### ## CALLBACKS ## ###################################################### # Logger and Checkpoint callbacks = get_default_callbacks_cfg(config.output_dir) # Mode switch for strong augmentations callbacks += [class_config(YOLOXModeSwitchCallback, switch_epoch=9)] # Visualizer callbacks.append( class_config( VisualizerCallback, visualizer=class_config( BoundingBoxVisualizer, vis_freq=500, image_mode="BGR" ), save_prefix=config.output_dir, test_connector=class_config( CallbackConnector, key_mapping=CONN_BBOX_2D_TRACK_VIS ), ) ) # Evaluator callbacks.append( class_config( EvaluatorCallback, evaluator=class_config( BDD100KTrackEvaluator, annotation_path="data/bdd100k/labels/box_track_20/val/", ), test_connector=class_config( CallbackConnector, key_mapping=CONN_BDD100K_TRACK_EVAL ), ) ) config.callbacks = callbacks ###################################################### ## PL CLI ## ###################################################### # PL Trainer args pl_trainer = get_default_pl_trainer_cfg(config) pl_trainer.max_epochs = params.num_epochs pl_trainer.checkpoint_callback = class_config( ModelCheckpoint, dirpath=config.get_ref("output_dir") + "/checkpoints", verbose=True, save_last=True, save_on_train_epoch_end=True, every_n_epochs=1, save_top_k=4, mode="max", monitor="step", ) pl_trainer.wandb = True pl_trainer.gradient_clip_val = 35 config.pl_trainer = pl_trainer return config.value_mode()