Source code for vis4d.zoo.shift.mask_rcnn.mask_rcnn_r50_12e_shift

# pylint: disable=duplicate-code
"""Mask RCNN SHIFT training example."""
from __future__ import annotations

import lightning.pytorch as pl
from torch.optim import SGD
from torch.optim.lr_scheduler import LinearLR, MultiStepLR

from vis4d.config import FieldConfigDict, class_config
from vis4d.data.io.hdf5 import HDF5Backend
from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
from vis4d.engine.connectors import CallbackConnector, DataConnector
from vis4d.eval.shift import SHIFTDetectEvaluator
from vis4d.op.base import ResNet
from vis4d.vis.image import SegMaskVisualizer
from vis4d.zoo.base import (
    get_default_callbacks_cfg,
    get_default_cfg,
    get_default_pl_trainer_cfg,
    get_lr_scheduler_cfg,
    get_optimizer_cfg,
)
from vis4d.zoo.base.data_connectors import (
    CONN_BBOX_2D_TEST,
    CONN_BBOX_2D_TRAIN,
    CONN_INS_MASK_2D_VIS,
)
from vis4d.zoo.base.datasets.shift import (
    CONN_SHIFT_INS_EVAL,
    get_shift_instance_seg_config,
)
from vis4d.zoo.base.models.mask_rcnn import get_mask_rcnn_cfg


[docs] def get_config() -> FieldConfigDict: """Returns the Faster-RCNN config dict for the SHIFT detection task. Returns: FieldConfigDict: The configuration """ ###################################################### ## General Config ## ###################################################### config = get_default_cfg(exp_name="mask_rcnn_r50_12e_shift") # High level hyper parameters params = FieldConfigDict() params.samples_per_gpu = 2 params.workers_per_gpu = 2 params.lr = 0.02 params.num_epochs = 12 params.num_classes = 6 config.params = params ###################################################### ## Datasets with augmentations ## ###################################################### data_root = "data/shift/" views_to_load = ["front"] train_split = "train" test_split = "val" domain_attr = [{"weather_coarse": "clear", "timeofday_coarse": "daytime"}] data_backend = class_config(HDF5Backend) config.data = get_shift_instance_seg_config( data_root=data_root, train_split=train_split, test_split=test_split, train_views_to_load=views_to_load, test_views_to_load=views_to_load, train_attributes_to_load=domain_attr, test_attributes_to_load=domain_attr, data_backend=data_backend, samples_per_gpu=params.samples_per_gpu, workers_per_gpu=params.workers_per_gpu, ) ###################################################### ## MODEL & LOSS ## ###################################################### basemodel = class_config( ResNet, resnet_name="resnet50", pretrained=True, trainable_layers=4 ) config.model, config.loss = get_mask_rcnn_cfg( num_classes=params.num_classes, basemodel=basemodel, no_overlap=True, ) ###################################################### ## OPTIMIZERS ## ###################################################### config.optimizers = [ get_optimizer_cfg( optimizer=class_config( SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 ), lr_schedulers=[ get_lr_scheduler_cfg( class_config( LinearLR, start_factor=0.001, total_iters=500 ), end=500, epoch_based=False, ), get_lr_scheduler_cfg( class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), ), ], ) ] ###################################################### ## DATA CONNECTOR ## ###################################################### config.train_data_connector = class_config( DataConnector, key_mapping=CONN_BBOX_2D_TRAIN ) config.test_data_connector = class_config( DataConnector, key_mapping=CONN_BBOX_2D_TEST ) ###################################################### ## CALLBACKS ## ###################################################### # Logger and Checkpoint callbacks = get_default_callbacks_cfg(config.output_dir) # Visualizer callbacks.append( class_config( VisualizerCallback, visualizer=class_config(SegMaskVisualizer, vis_freq=25), save_prefix=config.output_dir, test_connector=class_config( CallbackConnector, key_mapping=CONN_INS_MASK_2D_VIS ), ) ) # Evaluator callbacks.append( class_config( EvaluatorCallback, evaluator=class_config( SHIFTDetectEvaluator, annotation_path=( f"{data_root}/discrete/images/val/front/det_insseg_2d.json" ), attributes_to_load=domain_attr, ), test_connector=class_config( CallbackConnector, key_mapping=CONN_SHIFT_INS_EVAL ), metrics_to_eval=[ SHIFTDetectEvaluator.METRICS_DET, SHIFTDetectEvaluator.METRICS_INS_SEG, ], ) ) config.callbacks = callbacks ###################################################### ## PL CLI ## ###################################################### # PL Trainer args pl_trainer = get_default_pl_trainer_cfg(config) pl_trainer.max_epochs = params.num_epochs config.pl_trainer = pl_trainer # PL Callbacks pl_callbacks: list[pl.callbacks.Callback] = [] config.pl_callbacks = pl_callbacks return config.value_mode()