Source code for vis4d.zoo.cc_3dt.model

"""CC-3DT model config."""

from __future__ import annotations

from ml_collections import ConfigDict, FieldReference

from vis4d.config import class_config
from vis4d.data.const import CommonKeys as K
from vis4d.engine.connectors import LossConnector, pred_key, remap_pred_keys
from vis4d.engine.loss_module import LossModule
from vis4d.model.track3d.cc_3dt import FasterRCNNCC3DT
from vis4d.op.box.anchor import AnchorGenerator
from vis4d.op.detect3d.qd_3dt import Box3DUncertaintyLoss
from vis4d.op.detect.faster_rcnn import FasterRCNNHead
from vis4d.op.detect.rcnn import RCNNHead, RCNNLoss
from vis4d.op.detect.rpn import RPNLoss
from vis4d.op.loss.common import smooth_l1_loss
from vis4d.op.track.qdtrack import QDTrackInstanceSimilarityLoss
from vis4d.state.track3d.cc_3dt import CC3DTrackGraph
from vis4d.zoo.base import get_callable_cfg
from vis4d.zoo.base.models.faster_rcnn import (
    get_default_rcnn_box_codec_cfg,
    get_default_rpn_box_codec_cfg,
)
from vis4d.zoo.base.models.qdtrack import CONN_ROI_LOSS_2D as _CONN_ROI_LOSS_2D
from vis4d.zoo.base.models.qdtrack import (
    CONN_TRACK_LOSS_2D as _CONN_TRACK_LOSS_2D,
)

PRED_PREFIX = "qdtrack_out"

CONN_RPN_LOSS_2D = {
    "cls_outs": pred_key(f"{PRED_PREFIX}.detector_out.rpn.cls"),
    "reg_outs": pred_key(f"{PRED_PREFIX}.detector_out.rpn.box"),
    "target_boxes": pred_key(f"{PRED_PREFIX}.key_target_boxes"),
    "images_hw": pred_key(f"{PRED_PREFIX}.key_images_hw"),
}

CONN_ROI_LOSS_2D = remap_pred_keys(_CONN_ROI_LOSS_2D, PRED_PREFIX)

CONN_TRACK_LOSS_2D = remap_pred_keys(_CONN_TRACK_LOSS_2D, PRED_PREFIX)

CONN_DET_3D_LOSS = {
    "pred": pred_key("detector_3d_out"),
    "target": pred_key("detector_3d_target"),
    "labels": pred_key("detector_3d_labels"),
}

CONN_BBOX_3D_TRAIN = {
    "images": K.images,
    "images_hw": K.input_hw,
    "intrinsics": K.intrinsics,
    "boxes2d": K.boxes2d,
    "boxes3d": K.boxes3d,
    "boxes3d_classes": K.boxes3d_classes,
    "boxes3d_track_ids": K.boxes3d_track_ids,
    "keyframes": "keyframes",
}


[docs] def get_cc_3dt_cfg( num_classes: int | FieldReference, basemodel: ConfigDict, pure_det: bool | FieldReference = False, motion_model: str | FieldReference = "KF3D", lstm_model: ConfigDict | None = None, fps: int | FieldReference = 2, ) -> tuple[ConfigDict, ConfigDict]: """Get CC-3DT model config. Args: num_classes (int): Number of classes. basemodel (ConfigDict): Base model config. pure_det (bool, optional): Whether to use pure detection mode. Defaults to False. motion_model (str, optional): Motion model. Defaults to "KF3D". lstm_model (ConfigDict, optional): LSTM model config. Defaults to None. fps (int, optional): FPS. Defaults to 2. """ ###################################################### ## MODEL ## ###################################################### anchor_generator = class_config( AnchorGenerator, scales=[4, 8], ratios=[0.25, 0.5, 1.0, 2.0, 4.0], strides=[4, 8, 16, 32, 64], ) roi_head = class_config( RCNNHead, num_shared_convs=4, num_classes=num_classes, ) faster_rcnn_head = class_config( FasterRCNNHead, num_classes=num_classes, anchor_generator=anchor_generator, roi_head=roi_head, ) track_graph = class_config( CC3DTrackGraph, motion_model=motion_model, lstm_model=lstm_model, fps=fps, ) model = class_config( FasterRCNNCC3DT, num_classes=num_classes, basemodel=basemodel, faster_rcnn_head=faster_rcnn_head, track_graph=track_graph, pure_det=pure_det, ) ###################################################### ## LOSS ## ###################################################### rpn_box_encoder, _ = get_default_rpn_box_codec_cfg() rcnn_box_encoder, _ = get_default_rcnn_box_codec_cfg() rpn_loss = class_config( RPNLoss, anchor_generator=anchor_generator, box_encoder=rpn_box_encoder, loss_bbox=get_callable_cfg(smooth_l1_loss, beta=1.0 / 9.0), ) rcnn_loss = class_config( RCNNLoss, box_encoder=rcnn_box_encoder, num_classes=num_classes, loss_bbox=get_callable_cfg(smooth_l1_loss, beta=1.0 / 9.0), ) track_loss = class_config(QDTrackInstanceSimilarityLoss) loss = class_config( LossModule, losses=[ { "loss": rpn_loss, "connector": class_config( LossConnector, key_mapping=CONN_RPN_LOSS_2D ), }, { "loss": rcnn_loss, "connector": class_config( LossConnector, key_mapping=CONN_ROI_LOSS_2D ), "weight": 5.0, }, { "loss": track_loss, "connector": class_config( LossConnector, key_mapping=CONN_TRACK_LOSS_2D ), }, { "loss": class_config(Box3DUncertaintyLoss), "connector": class_config( LossConnector, key_mapping=CONN_DET_3D_LOSS ), }, ], ) return model, loss