Source code for vis4d.model.track3d.cc_3dt

"""CC-3DT model implementation.

This file composes the operations associated with CC-3DT
`https://arxiv.org/abs/2212.01247`_ into the full model implementation.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import NamedTuple

import torch
from torch import Tensor, nn

from vis4d.data.const import AxisMode
from vis4d.model.track.qdtrack import FasterRCNNQDTrackOut
from vis4d.op.base import BaseModel, ResNet
from vis4d.op.box.anchor import AnchorGenerator
from vis4d.op.box.box2d import bbox_area, bbox_clip
from vis4d.op.box.box3d import boxes3d_to_corners, transform_boxes3d
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
from vis4d.op.detect3d.qd_3dt import QD3DTBBox3DHead, RoI2Det3D
from vis4d.op.detect3d.util import bev_3d_nms
from vis4d.op.detect.faster_rcnn import FasterRCNNHead
from vis4d.op.detect.rcnn import RCNNHead, RoI2Det
from vis4d.op.fpp import FPN
from vis4d.op.geometry.projection import project_points
from vis4d.op.geometry.rotation import (
    quaternion_to_matrix,
    rotation_matrix_yaw,
)
from vis4d.op.geometry.transform import inverse_rigid_transform
from vis4d.op.track3d.cc_3dt import (
    CC3DTrackAssociation,
    cam_to_global,
    get_track_3d_out,
)
from vis4d.op.track3d.common import Track3DOut
from vis4d.op.track.qdtrack import QDTrackHead
from vis4d.state.track3d.cc_3dt import CC3DTrackGraph

from ..track.util import split_key_ref_indices


[docs] class FasterRCNNCC3DTOut(NamedTuple): """Output of CC-3DT model with Faster R-CNN detector.""" detector_3d_out: Tensor detector_3d_target: Tensor detector_3d_labels: Tensor qdtrack_out: FasterRCNNQDTrackOut
[docs] class FasterRCNNCC3DT(nn.Module): """CC-3DT with Faster-RCNN detector.""" def __init__( self, num_classes: int, basemodel: BaseModel | None = None, faster_rcnn_head: FasterRCNNHead | None = None, rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None, qdtrack_head: QDTrackHead | None = None, track_graph: CC3DTrackGraph | None = None, pure_det: bool = False, ) -> None: """Creates an instance of the class. Args: num_classes (int): Number of object categories. basemodel (BaseModel, optional): Base model network. Defaults to None. If None, will use ResNet50. faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head. Defaults to None. if None, will use default FasterRCNNHead. rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN bounding boxes. Defaults to None. qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None. If None, will use default QDTrackHead. track_graph (CC3DTrackGraph, optional): Track graph. Defaults to None. If None, will use default CC3DTrackGraph. pure_det (bool, optional): Whether to use pure detection. Defaults to False. """ super().__init__() self.basemodel = ( ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) if basemodel is None else basemodel ) self.fpn = FPN(self.basemodel.out_channels[2:], 256) if faster_rcnn_head is None: anchor_generator = AnchorGenerator( scales=[4, 8], ratios=[0.25, 0.5, 1.0, 2.0, 4.0], strides=[4, 8, 16, 32, 64], ) roi_head = RCNNHead(num_shared_convs=4, num_classes=num_classes) self.faster_rcnn_head = FasterRCNNHead( num_classes=num_classes, anchor_generator=anchor_generator, roi_head=roi_head, ) else: self.faster_rcnn_head = faster_rcnn_head self.roi2det = RoI2Det(rcnn_box_decoder) self.bbox_3d_head = QD3DTBBox3DHead(num_classes=num_classes) self.roi2det_3d = RoI2Det3D() self.qdtrack_head = ( QDTrackHead() if qdtrack_head is None else qdtrack_head ) self.track_graph = ( CC3DTrackGraph() if track_graph is None else track_graph ) self.pure_det = pure_det
[docs] def forward( self, images: list[Tensor], images_hw: list[list[tuple[int, int]]], intrinsics: list[Tensor], extrinsics: list[Tensor] | None = None, frame_ids: list[int] | None = None, boxes2d: list[list[Tensor]] | None = None, boxes3d: list[list[Tensor]] | None = None, boxes3d_classes: list[list[Tensor]] | None = None, boxes3d_track_ids: list[list[Tensor]] | None = None, keyframes: None | list[list[bool]] | None = None, ) -> FasterRCNNCC3DTOut | Track3DOut: """Forward.""" if self.training: assert ( boxes2d is not None and boxes3d is not None and boxes3d_classes is not None and boxes3d_track_ids is not None and keyframes is not None ) return self._forward_train( images, images_hw, intrinsics, boxes2d, boxes3d, boxes3d_classes, boxes3d_track_ids, keyframes, ) assert extrinsics is not None and frame_ids is not None return self._forward_test( images, images_hw, intrinsics, extrinsics, frame_ids )
def _forward_train( self, images: list[Tensor], images_hw: list[list[tuple[int, int]]], intrinsics: list[Tensor], target_boxes2d: list[list[Tensor]], target_boxes3d: list[list[Tensor]], target_classes: list[list[Tensor]], target_track_ids: list[list[Tensor]], keyframes: list[list[bool]], ) -> FasterRCNNCC3DTOut: """Foward training stage.""" key_index, ref_indices = split_key_ref_indices(keyframes) # feature extraction key_features = self.fpn(self.basemodel(images[key_index])) ref_features = [ self.fpn(self.basemodel(images[ref_index])) for ref_index in ref_indices ] key_detector_out = self.faster_rcnn_head( key_features, images_hw[key_index], target_boxes2d[key_index], target_classes[key_index], ) with torch.no_grad(): ref_detector_out = [ self.faster_rcnn_head( ref_features[i], images_hw[ref_index], target_boxes2d[ref_index], target_classes[ref_index], ) for i, ref_index in enumerate(ref_indices) ] key_proposals = key_detector_out.proposals.boxes ref_proposals = [ref.proposals.boxes for ref in ref_detector_out] key_target_boxes = target_boxes2d[key_index] ref_target_boxes = [ target_boxes2d[ref_index] for ref_index in ref_indices ] key_target_track_ids = target_track_ids[key_index] ref_target_track_ids = [ target_track_ids[ref_index] for ref_index in ref_indices ] ( key_embeddings, ref_embeddings, key_track_ids, ref_track_ids, ) = self.qdtrack_head( features=[key_features, *ref_features], det_boxes=[key_proposals, *ref_proposals], target_boxes=[key_target_boxes, *ref_target_boxes], target_track_ids=[key_target_track_ids, *ref_target_track_ids], ) assert ( ref_embeddings is not None and key_track_ids is not None and ref_track_ids is not None ) predictions, targets, labels = self.bbox_3d_head( features=key_features, det_boxes=key_proposals, intrinsics=intrinsics[key_index], target_boxes=key_target_boxes, target_boxes3d=target_boxes3d[key_index], target_class_ids=target_classes[key_index], ) detector_3d_out = torch.cat(predictions) assert targets is not None and labels is not None return FasterRCNNCC3DTOut( detector_3d_out=detector_3d_out, detector_3d_target=targets, detector_3d_labels=labels, qdtrack_out=FasterRCNNQDTrackOut( detector_out=key_detector_out, key_images_hw=images_hw[key_index], key_target_boxes=key_target_boxes, key_embeddings=key_embeddings, ref_embeddings=ref_embeddings, key_track_ids=key_track_ids, ref_track_ids=ref_track_ids, ), ) def _forward_test( self, images_list: list[Tensor], images_hw: list[list[tuple[int, int]]], intrinsics_list: list[Tensor], extrinsics_list: list[Tensor], frame_ids: list[int], ) -> Track3DOut: """Forward inference stage. Curretnly only work with single batch per gpu. """ # (N, 1, 3, H, W) -> (N, 3, H, W) images = torch.cat(images_list) # (N, 1, 3, 3) -> (N, 3, 3) intrinsics = torch.cat(intrinsics_list) # (N, 1, 4, 4) -> (N, 4, 4) extrinsics = torch.cat(extrinsics_list) # (N, 1) -> (N,) frame_id = frame_ids[0] images_hw_list: list[tuple[int, int]] = sum(images_hw, []) features = self.basemodel(images) features = self.fpn(features) _, roi, proposals, _, _, _ = self.faster_rcnn_head( features, images_hw_list ) boxes_2d_list, scores_2d_list, class_ids_list = self.roi2det( *roi, proposals.boxes, images_hw_list ) predictions, _, _ = self.bbox_3d_head( features, det_boxes=boxes_2d_list ) boxes_3d_list, scores_3d_list = self.roi2det_3d( predictions, boxes_2d_list, class_ids_list, intrinsics ) embeddings_list, _, _, _ = self.qdtrack_head(features, boxes_2d_list) # Assign camera id camera_ids_list = [] for i, boxes_2d in enumerate(boxes_2d_list): camera_ids_list.append( (torch.mul(torch.ones(len(boxes_2d)), i)).to(boxes_2d.device) ) # Move 3D boxes to world coordinate boxes_3d_list = cam_to_global(boxes_3d_list, extrinsics) # Merge boxes from all cameras boxes_2d = torch.cat(boxes_2d_list) scores_2d = torch.cat(scores_2d_list) camera_ids = torch.cat(camera_ids_list) boxes_3d = torch.cat(boxes_3d_list) scores_3d = torch.cat(scores_3d_list) class_ids = torch.cat(class_ids_list) embeddings = torch.cat(embeddings_list) if self.pure_det: return get_track_3d_out( boxes_3d, class_ids, scores_3d, torch.zeros_like(class_ids) ) # 3D NMS in world coordinate keep_indices = bev_3d_nms( center_x=boxes_3d[:, 0].unsqueeze(1), center_y=boxes_3d[:, 1].unsqueeze(1), width=boxes_3d[:, 4].unsqueeze(1), length=boxes_3d[:, 5].unsqueeze(1), angle=180.0 / torch.pi * boxes_3d[:, 8].unsqueeze(1), scores=scores_2d * scores_3d, ) boxes_2d = boxes_2d[keep_indices] scores_2d = scores_2d[keep_indices] camera_ids = camera_ids[keep_indices] boxes_3d = boxes_3d[keep_indices] scores_3d = scores_3d[keep_indices] class_ids = class_ids[keep_indices] embeddings = embeddings[keep_indices] outs = self.track_graph( boxes_2d, scores_2d, camera_ids, boxes_3d, scores_3d, class_ids, embeddings, frame_id, ) return outs
[docs] def __call__( self, images: list[Tensor] | Tensor, images_hw: list[list[tuple[int, int]]], intrinsics: list[Tensor] | Tensor, extrinsics: Tensor | None = None, frame_ids: list[list[int]] | None = None, boxes2d: list[list[Tensor]] | None = None, boxes3d: list[list[Tensor]] | None = None, boxes3d_classes: list[list[Tensor]] | None = None, boxes3d_track_ids: list[list[Tensor]] | None = None, keyframes: None | list[list[bool]] | None = None, ) -> FasterRCNNCC3DTOut | Track3DOut: """Type definition for call implementation.""" return self._call_impl( images, images_hw, intrinsics, extrinsics, frame_ids, boxes2d, boxes3d, boxes3d_classes, boxes3d_track_ids, keyframes, )
[docs] class CC3DT(nn.Module): """CC-3DT with custom detection results.""" def __init__( self, basemodel: BaseModel | None = None, qdtrack_head: QDTrackHead | None = None, track_graph: CC3DTrackGraph | None = None, detection_range: Sequence[float] | None = None, ) -> None: """Creates an instance of the class. Args: basemodel (BaseModel, optional): Base model network. Defaults to None. If None, will use ResNet50. qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None. If None, will use default QDTrackHead. track_graph (CC3DTrackGraph, optional): Track graph. Defaults to None. If None, will use default CC3DTrackGraph. detection_range (Sequence[float], optional): Detection range for each class. Defaults to None. """ super().__init__() self.basemodel = ( ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) if basemodel is None else basemodel ) self.fpn = FPN(self.basemodel.out_channels[2:], 256) self.qdtrack_head = ( QDTrackHead() if qdtrack_head is None else qdtrack_head ) self.track_graph = track_graph or CC3DTrackGraph( track=CC3DTrackAssociation(init_score_thr=0.2, obj_score_thr=0.1), update_3d_score=False, add_backdrops=False, ) self.detection_range = detection_range
[docs] def forward( self, images_list: list[Tensor], images_hw: list[list[tuple[int, int]]], intrinsics_list: list[Tensor], extrinsics_list: list[Tensor], frame_ids: list[int], pred_boxes3d: list[list[Tensor]], pred_boxes3d_classes: list[list[Tensor]], pred_boxes3d_scores: list[list[Tensor]], pred_boxes3d_velocities: list[list[Tensor]], ) -> Track3DOut: """Forward inference stage. Curretnly only work with single batch per gpu. """ # (N, 1, 3, H, W) -> (N, 3, H, W) images = torch.cat(images_list) # (N, 1, 3, 3) -> (N, 3, 3) intrinsics = torch.cat(intrinsics_list) # (N, 1, 4, 4) -> (N, 4, 4) extrinsics = torch.cat(extrinsics_list) # (N, 1) -> (N,) frame_id = frame_ids[0] images_hw_list: list[tuple[int, int]] = sum(images_hw, []) features = self.basemodel(images) features = self.fpn(features) # (1, 1, B,) -> (B,) boxes_3d = pred_boxes3d[0][0] class_ids = pred_boxes3d_classes[0][0] scores_3d = pred_boxes3d_scores[0][0] velocities = pred_boxes3d_velocities[0][0] # Get 2D boxes and assign camera id global_to_cams = inverse_rigid_transform(extrinsics) boxes_3d_list = [] boxes_2d_list = [] class_ids_list = [] scores_list = [] camera_ids_list = [] for i, global_to_cam in enumerate(global_to_cams): boxes3d_cam = transform_boxes3d( boxes_3d, global_to_cam, source_axis_mode=AxisMode.ROS, target_axis_mode=AxisMode.OPENCV, ) corners = boxes3d_to_corners( boxes3d_cam, axis_mode=AxisMode.OPENCV ) corners_2d = project_points(corners, intrinsics[i]) boxes_2d = self._to_boxes2d(corners_2d) boxes_2d = bbox_clip(boxes_2d, images_hw_list[i], 1) mask = ( (boxes3d_cam[:, 2] > 0) & (bbox_area(boxes_2d) > 0) & ( bbox_area(boxes_2d) < (images_hw_list[i][0] - 1) * (images_hw_list[i][1] - 1) ) & self._filter_distance(class_ids, boxes3d_cam) ) cc_3dt_boxes_3d = boxes_3d.new_zeros(len(boxes_2d[mask]), 12) cc_3dt_boxes_3d[:, :3] = boxes_3d[mask][:, :3] # WLH -> HWL cc_3dt_boxes_3d[:, 3:6] = boxes_3d[mask][:, [5, 3, 4]] cc_3dt_boxes_3d[:, 6:9] = rotation_matrix_yaw( quaternion_to_matrix(boxes_3d[mask][:, 6:]), AxisMode.ROS ) cc_3dt_boxes_3d[:, 9:] = velocities[mask] boxes_3d_list.append(cc_3dt_boxes_3d) boxes_2d_list.append(boxes_2d[mask]) class_ids_list.append(class_ids[mask]) scores_list.append(scores_3d[mask]) camera_ids_list.append( (torch.mul(torch.ones(len(cc_3dt_boxes_3d)), i)).to( boxes_2d.device ) ) embeddings_list, _, _, _ = self.qdtrack_head(features, boxes_2d_list) boxes_3d = torch.cat(boxes_3d_list) boxes_2d = torch.cat(boxes_2d_list) camera_ids = torch.cat(camera_ids_list) scores = torch.cat(scores_list) class_ids = torch.cat(class_ids_list) embeddings = torch.cat(embeddings_list) # Select project boxes2d according to bbox area keep_indices = embeddings.new_ones(len(boxes_3d)).bool() boxes_2d_area = bbox_area(boxes_2d) for i, box3d in enumerate(boxes_3d): for same_idx in ( (box3d[:3] == boxes_3d[:, :3]).all(dim=1).nonzero() ): if ( same_idx != i and boxes_2d_area[same_idx] > boxes_2d_area[i] ): keep_indices[i] = False break boxes_3d = boxes_3d[keep_indices] boxes_2d = boxes_2d[keep_indices] camera_ids = camera_ids[keep_indices] scores = scores[keep_indices] class_ids = class_ids[keep_indices] embeddings = embeddings[keep_indices] outs = self.track_graph( boxes_2d, scores, camera_ids, boxes_3d, scores, class_ids, embeddings, frame_id, ) return outs
def _to_boxes2d(self, corners_2d: Tensor) -> Tensor: """Project 3D boxes (Camera coordinates) to 2D boxes.""" min_x = torch.min(corners_2d[:, :, 0], 1).values.unsqueeze(-1) min_y = torch.min(corners_2d[:, :, 1], 1).values.unsqueeze(-1) max_x = torch.max(corners_2d[:, :, 0], 1).values.unsqueeze(-1) max_y = torch.max(corners_2d[:, :, 1], 1).values.unsqueeze(-1) return torch.cat([min_x, min_y, max_x, max_y], dim=1) def _filter_distance( self, class_ids: Tensor, boxes3d: Tensor, tolerance: float = 2.0 ) -> Tensor: """Filter boxes3d on distance.""" if self.detection_range is None: return torch.ones_like(class_ids, dtype=torch.bool) return torch.linalg.norm( # pylint: disable=not-callable boxes3d[:, [0, 2]], dim=1 ) <= torch.tensor( [ self.detection_range[class_id] + tolerance for class_id in class_ids ] ).to( class_ids.device )
[docs] def __call__( self, images_list: list[Tensor], images_hw: list[list[tuple[int, int]]], intrinsics_list: list[Tensor], extrinsics_list: list[Tensor], frame_ids: list[int], pred_boxes3d: list[list[Tensor]], pred_boxes3d_classes: list[list[Tensor]], pred_boxes3d_scores: list[list[Tensor]], pred_boxes3d_velocities: list[list[Tensor]], ) -> Track3DOut: """Type definition for call implementation.""" return self._call_impl( images_list, images_hw, intrinsics_list, extrinsics_list, frame_ids, pred_boxes3d, pred_boxes3d_classes, pred_boxes3d_scores, pred_boxes3d_velocities, )