Source code for vis4d.model.track.qdtrack

"""Quasi-dense instance similarity learning model."""

from __future__ import annotations

from typing import NamedTuple

import torch
from torch import Tensor, nn

from vis4d.common.ckpt import load_model_checkpoint
from vis4d.model.detect.yolox import REV_KEYS as YOLOX_REV_KEYS
from vis4d.op.base import BaseModel, CSPDarknet, ResNet
from vis4d.op.box.box2d import scale_and_clip_boxes
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
from vis4d.op.box.poolers import MultiScaleRoIAlign
from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut
from vis4d.op.detect.rcnn import RoI2Det
from vis4d.op.detect.yolox import YOLOXHead, YOLOXOut, YOLOXPostprocess
from vis4d.op.fpp import FPN, YOLOXPAFPN, FeaturePyramidProcessing
from vis4d.op.track.common import TrackOut
from vis4d.op.track.qdtrack import (
    QDSimilarityHead,
    QDTrackAssociation,
    QDTrackHead,
)
from vis4d.state.track.qdtrack import QDTrackGraph

from .util import split_key_ref_indices

REV_KEYS = [
    (r"^faster_rcnn_heads\.", "faster_rcnn_head."),
    (r"^backbone.body\.", "basemodel."),
    (r"^qdtrack\.", "qdtrack_head."),
]


[docs] class FasterRCNNQDTrackOut(NamedTuple): """Output of QDtrack model.""" detector_out: FRCNNOut key_images_hw: list[tuple[int, int]] key_target_boxes: list[Tensor] key_embeddings: list[Tensor] ref_embeddings: list[list[Tensor]] key_track_ids: list[Tensor] ref_track_ids: list[list[Tensor]]
[docs] class FasterRCNNQDTrack(nn.Module): """Wrap QDTrack with Faster R-CNN detector.""" def __init__( self, num_classes: int, basemodel: BaseModel | None = None, faster_rcnn_head: FasterRCNNHead | None = None, rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None, qdtrack_head: QDTrackHead | None = None, track_graph: QDTrackGraph | None = None, weights: None | str = None, ) -> None: """Creates an instance of the class. Args: num_classes (int): Number of object categories. basemodel (BaseModel, optional): Base model network. Defaults to None. If None, will use ResNet50. faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head. Defaults to None. if None, will use default FasterRCNNHead. rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN bounding boxes. Defaults to None. qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None. If None, will use default QDTrackHead. track_graph (QDTrackGraph, optional): Track graph. Defaults to None. If None, will use default QDTrackGraph. weights (str, optional): Weights to load for model. """ super().__init__() self.basemodel = ( ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) if basemodel is None else basemodel ) self.fpn = FPN(self.basemodel.out_channels[2:], 256) if faster_rcnn_head is None: self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes) else: self.faster_rcnn_head = faster_rcnn_head self.roi2det = RoI2Det(rcnn_box_decoder) self.qdtrack_head = ( QDTrackHead() if qdtrack_head is None else qdtrack_head ) self.track_graph = ( QDTrackGraph() if track_graph is None else track_graph ) if weights is not None: load_model_checkpoint( self, weights, map_location="cpu", rev_keys=REV_KEYS )
[docs] def forward( self, images: list[Tensor] | Tensor, images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], frame_ids: list[list[int]] | list[int], boxes2d: None | list[list[Tensor]] = None, boxes2d_classes: None | list[list[Tensor]] = None, boxes2d_track_ids: None | list[list[Tensor]] = None, keyframes: None | list[list[bool]] = None, ) -> TrackOut | FasterRCNNQDTrackOut: """Forward.""" if self.training: assert ( isinstance(images, list) and boxes2d is not None and boxes2d_classes is not None and boxes2d_track_ids is not None and keyframes is not None ) return self._forward_train( images, images_hw, # type: ignore boxes2d, boxes2d_classes, boxes2d_track_ids, keyframes, ) return self._forward_test(images, images_hw, original_hw, frame_ids) # type: ignore # pylint: disable=line-too-long
def _forward_train( self, images: list[Tensor], images_hw: list[list[tuple[int, int]]], target_boxes: list[list[Tensor]], target_classes: list[list[Tensor]], target_track_ids: list[list[Tensor]], keyframes: list[list[bool]], ) -> FasterRCNNQDTrackOut: """Forward training stage. Args: images (list[Tensor]): Input images. images_hw (list[list[tuple[int, int]]]): Input image resolutions. target_boxes (list[list[Tensor]]): Bounding box labels. target_classes (list[list[Tensor]]): Class labels. target_track_ids (list[list[Tensor]]): Track IDs. keyframes (list[list[bool]]): Whether the frame is a keyframe. Returns: FasterRCNNQDTrackOut: Raw model outputs. """ key_index, ref_indices = split_key_ref_indices(keyframes) # feature extraction key_features = self.fpn(self.basemodel(images[key_index])) ref_features = [ self.fpn(self.basemodel(images[ref_index])) for ref_index in ref_indices ] key_detector_out = self.faster_rcnn_head( key_features, images_hw[key_index], target_boxes[key_index], target_classes[key_index], ) with torch.no_grad(): ref_detector_out = [ self.faster_rcnn_head( ref_features[i], images_hw[ref_index], target_boxes[ref_index], target_classes[ref_index], ) for i, ref_index in enumerate(ref_indices) ] key_proposals = key_detector_out.proposals.boxes ref_proposals = [ref.proposals.boxes for ref in ref_detector_out] key_target_boxes = target_boxes[key_index] ref_target_boxes = [ target_boxes[ref_index] for ref_index in ref_indices ] key_target_track_ids = target_track_ids[key_index] ref_target_track_ids = [ target_track_ids[ref_index] for ref_index in ref_indices ] ( key_embeddings, ref_embeddings, key_track_ids, ref_track_ids, ) = self.qdtrack_head( features=[key_features, *ref_features], det_boxes=[key_proposals, *ref_proposals], target_boxes=[key_target_boxes, *ref_target_boxes], target_track_ids=[key_target_track_ids, *ref_target_track_ids], ) assert ( ref_embeddings is not None and key_track_ids is not None and ref_track_ids is not None ) return FasterRCNNQDTrackOut( detector_out=key_detector_out, key_images_hw=images_hw[key_index], key_target_boxes=key_target_boxes, key_embeddings=key_embeddings, ref_embeddings=ref_embeddings, key_track_ids=key_track_ids, ref_track_ids=ref_track_ids, ) def _forward_test( self, images: Tensor, images_hw: list[tuple[int, int]], original_hw: list[tuple[int, int]], frame_ids: list[int], ) -> TrackOut: """Forward inference stage.""" features = self.basemodel(images) features = self.fpn(features) detector_out = self.faster_rcnn_head(features, images_hw) boxes, scores, class_ids = self.roi2det( *detector_out.roi, detector_out.proposals.boxes, images_hw ) embeddings, _, _, _ = self.qdtrack_head(features, boxes) tracks = self.track_graph( embeddings, boxes, scores, class_ids, frame_ids ) for i, boxs in enumerate(tracks.boxes): tracks.boxes[i] = scale_and_clip_boxes( boxs, original_hw[i], images_hw[i] ) return tracks
[docs] def __call__( self, images: list[Tensor] | Tensor, images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], original_hw: list[tuple[int, int]], frame_ids: list[list[int]] | list[int], boxes2d: None | list[list[Tensor]] = None, boxes2d_classes: None | list[list[Tensor]] = None, boxes2d_track_ids: None | list[list[Tensor]] = None, keyframes: None | list[list[bool]] = None, ) -> TrackOut | FasterRCNNQDTrackOut: """Type definition for call implementation.""" return self._call_impl( images, images_hw, original_hw, frame_ids, boxes2d, boxes2d_classes, boxes2d_track_ids, keyframes, )
[docs] class YOLOXQDTrackOut(NamedTuple): """Output of QDtrack YOLOX model.""" detector_out: YOLOXOut key_images_hw: list[tuple[int, int]] key_target_boxes: list[Tensor] key_target_classes: list[Tensor] key_embeddings: list[Tensor] ref_embeddings: list[list[Tensor]] key_track_ids: list[Tensor] ref_track_ids: list[list[Tensor]]
[docs] class YOLOXQDTrack(nn.Module): """Wrap QDTrack with YOLOX detector.""" def __init__( self, num_classes: int, basemodel: BaseModel | None = None, fpn: FeaturePyramidProcessing | None = None, yolox_head: YOLOXHead | None = None, train_postprocessor: YOLOXPostprocess | None = None, test_postprocessor: YOLOXPostprocess | None = None, qdtrack_head: QDTrackHead | None = None, track_graph: QDTrackGraph | None = None, weights: None | str = None, ) -> None: """Creates an instance of the class. Args: num_classes (int): Number of object categories. basemodel (BaseModel, optional): Base model. Defaults to None. If None, will use CSPDarknet. fpn (FeaturePyramidProcessing, optional): Feature Pyramid Processing. Defaults to None. If None, will use YOLOXPAFPN. yolox_head (YOLOXHead, optional): YOLOX head. Defaults to None. If None, will use YOLOXHead. train_postprocessor (YOLOXPostprocess, optional): Post processor for training. Defaults to None. If None, will use YOLOXPostprocess. test_postprocessor (YOLOXPostprocess, optional): Post processor for testing. Defaults to None. If None, will use YOLOXPostprocess. qdtrack_head (QDTrack, optional): QDTrack head. Defaults to None. If None, will use default QDTrackHead. track_graph (QDTrackGraph, optional): Track graph. Defaults to None. If None, will use default QDTrackGraph. weights (str, optional): Weights to load for model. """ super().__init__() self.basemodel = ( CSPDarknet(deepen_factor=1.33, widen_factor=1.25) if basemodel is None else basemodel ) self.fpn = ( YOLOXPAFPN([320, 640, 1280], 320, num_csp_blocks=4) if fpn is None else fpn ) self.yolox_head = ( YOLOXHead( num_classes=num_classes, in_channels=320, feat_channels=320 ) if yolox_head is None else yolox_head ) self.train_postprocessor = ( YOLOXPostprocess( self.yolox_head.point_generator, self.yolox_head.box_decoder, nms_threshold=0.7, score_thr=0.0, nms_pre=2000, max_per_img=1000, ) if train_postprocessor is None else train_postprocessor ) self.test_postprocessor = ( YOLOXPostprocess( self.yolox_head.point_generator, self.yolox_head.box_decoder, nms_threshold=0.65, score_thr=0.1, ) if test_postprocessor is None else test_postprocessor ) self.qdtrack_head = ( QDTrackHead( QDSimilarityHead( MultiScaleRoIAlign( resolution=[7, 7], strides=[8, 16, 32], sampling_ratio=0, ), in_dim=320, ) ) if qdtrack_head is None else qdtrack_head ) self.track_graph = ( QDTrackGraph( track=QDTrackAssociation( init_score_thr=0.5, obj_score_thr=0.35 ) ) if track_graph is None else track_graph ) if weights is not None: load_model_checkpoint( self, weights, map_location="cpu", rev_keys=YOLOX_REV_KEYS )
[docs] def forward( self, images: list[Tensor] | Tensor, images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], frame_ids: list[list[int]] | list[int], boxes2d: None | list[list[Tensor]] = None, boxes2d_classes: None | list[list[Tensor]] = None, boxes2d_track_ids: None | list[list[Tensor]] = None, keyframes: None | list[list[bool]] = None, ) -> TrackOut | YOLOXQDTrackOut: """Forward.""" if self.training: assert ( isinstance(images, list) and boxes2d is not None and boxes2d_classes is not None and boxes2d_track_ids is not None and keyframes is not None ) return self._forward_train( images, images_hw, # type: ignore boxes2d, boxes2d_classes, boxes2d_track_ids, keyframes, ) return self._forward_test(images, images_hw, original_hw, frame_ids) # type: ignore # pylint: disable=line-too-long
def _forward_train( self, images: list[Tensor], images_hw: list[list[tuple[int, int]]], target_boxes: list[list[Tensor]], target_classes: list[list[Tensor]], target_track_ids: list[list[Tensor]], keyframes: list[list[bool]], ) -> YOLOXQDTrackOut: """Forward training stage. Args: images (list[Tensor]): Input images. images_hw (list[list[tuple[int, int]]]): Input image resolutions. target_boxes (list[list[Tensor]]): Bounding box labels. target_classes (list[list[Tensor]]): Class labels. target_track_ids (list[list[Tensor]]): Track IDs. keyframes (list[list[bool]]): Whether the frame is a keyframe. Returns: YOLOXQDTrackOut: Raw model outputs. """ key_index, ref_indices = split_key_ref_indices(keyframes) # feature extraction key_features = self.fpn(self.basemodel(images[key_index].contiguous())) ref_features = [ self.fpn(self.basemodel(images[ref_index].contiguous())) for ref_index in ref_indices ] key_detector_out = self.yolox_head(key_features[-3:]) key_proposals, _, _ = self.train_postprocessor( cls_outs=key_detector_out.cls_score, reg_outs=key_detector_out.bbox_pred, obj_outs=key_detector_out.objectness, images_hw=images_hw[key_index], ) with torch.no_grad(): ref_detector_out = [ self.yolox_head(ref_feat[-3:]) for ref_feat in ref_features ] ref_proposals = [ self.train_postprocessor( cls_outs=ref_out.cls_score, reg_outs=ref_out.bbox_pred, obj_outs=ref_out.objectness, images_hw=images_hw[ref_index], )[0] for ref_index, ref_out in zip(ref_indices, ref_detector_out) ] key_target_boxes = target_boxes[key_index] ref_target_boxes = [ target_boxes[ref_index] for ref_index in ref_indices ] key_target_classes = target_classes[key_index] key_target_track_ids = target_track_ids[key_index] ref_target_track_ids = [ target_track_ids[ref_index] for ref_index in ref_indices ] ( key_embeddings, ref_embeddings, key_track_ids, ref_track_ids, ) = self.qdtrack_head( features=[key_features, *ref_features], det_boxes=[key_proposals, *ref_proposals], target_boxes=[key_target_boxes, *ref_target_boxes], target_track_ids=[key_target_track_ids, *ref_target_track_ids], ) assert ( ref_embeddings is not None and key_track_ids is not None and ref_track_ids is not None ) return YOLOXQDTrackOut( detector_out=key_detector_out, key_images_hw=images_hw[key_index], key_target_boxes=key_target_boxes, key_target_classes=key_target_classes, key_embeddings=key_embeddings, ref_embeddings=ref_embeddings, key_track_ids=key_track_ids, ref_track_ids=ref_track_ids, ) def _forward_test( self, images: torch.Tensor, images_hw: list[tuple[int, int]], original_hw: list[tuple[int, int]], frame_ids: list[int], ) -> TrackOut: """Forward inference stage.""" features = self.fpn(self.basemodel(images)) outs = self.yolox_head(features[-3:]) boxes, scores, class_ids = self.test_postprocessor( cls_outs=outs.cls_score, reg_outs=outs.bbox_pred, obj_outs=outs.objectness, images_hw=images_hw, ) embeddings, _, _, _ = self.qdtrack_head(features, boxes) tracks = self.track_graph( embeddings, boxes, scores, class_ids, frame_ids ) for i, boxs in enumerate(tracks.boxes): tracks.boxes[i] = scale_and_clip_boxes( boxs, original_hw[i], images_hw[i] ) return tracks
[docs] def __call__( self, images: list[Tensor] | Tensor, images_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], original_hw: list[list[tuple[int, int]]] | list[tuple[int, int]], frame_ids: list[list[int]] | list[int], boxes2d: None | list[list[Tensor]] = None, boxes2d_classes: None | list[list[Tensor]] = None, boxes2d_track_ids: None | list[list[Tensor]] = None, keyframes: None | list[list[bool]] = None, ) -> TrackOut | FasterRCNNQDTrackOut: """Type definition for call implementation.""" return self._call_impl( images, images_hw, original_hw, frame_ids, boxes2d, boxes2d_classes, boxes2d_track_ids, keyframes, )