Source code for vis4d.state.track.qdtrack

"""Memory for QDTrack inference."""

from __future__ import annotations

from typing import TypedDict

import torch
from torch import Tensor

from vis4d.op.box.box2d import bbox_iou
from vis4d.op.track.assignment import TrackIDCounter
from vis4d.op.track.common import TrackOut
from vis4d.op.track.qdtrack import QDTrackAssociation


[docs] class Track(TypedDict): """QDTrack Track state. Attributes: box (Tensor): In shape (4,) and contains x1, y1, x2, y2. score (Tensor): In shape (1,). class_id (Tensor): In shape (1,). embedding (Tensor): In shape (E,). E is the embedding dimension. last_frame (int): Last frame id. """ box: Tensor score: Tensor class_id: Tensor embed: Tensor last_frame: int
[docs] class QDTrackGraph: """Quasi-dense embedding similarity based graph.""" def __init__( self, track: QDTrackAssociation | None = None, memory_size: int = 10, memory_momentum: float = 0.8, nms_backdrop_iou_thr: float = 0.3, backdrop_memory_size: int = 1, ) -> None: """Init.""" assert memory_size >= 0 self.memory_size = memory_size assert 0 <= memory_momentum <= 1.0 self.memory_momentum = memory_momentum assert backdrop_memory_size >= 0 self.backdrop_memory_size = backdrop_memory_size self.nms_backdrop_iou_thr = nms_backdrop_iou_thr self.tracker = QDTrackAssociation() if track is None else track self.tracklets: dict[int, Track] = {} self.backdrops: list[dict[str, Tensor]] = []
[docs] def reset(self) -> None: """Empty the memory.""" self.tracklets.clear() self.backdrops.clear()
[docs] def is_empty(self) -> bool: """Check if the memory is empty.""" return len(self.tracklets) == 0
[docs] def get_tracks( self, device: torch.device, frame_id: int | None = None, add_backdrops: bool = False, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Get tracklests. If the frame_id is not provided, will return the latest state of all tracklets. Otherwise, will return the state of all tracklets at the given frame_id. If add_backdrops is True, will also return the backdrops. Args: device (torch.device): Device to put the tensors on. frame_id (int, optional): Frame id to query. Defaults to None. add_backdrops (bool, optional): Whether to add backdrops to the output. Defaults to False. Returns: boxes (Tensor): 2D boxes in shape (N, 4). scores (Tensor): 2D scores in shape (N,). class_ids (Tensor): Class ids in shape (N,). track_ids (Tensor): Track ids in shape (N,). embeddings (Tensor): Embeddings in shape (N, E). """ ( boxes_list, scores_list, class_ids_list, embeddings_list, track_ids_list, ) = ([], [], [], [], []) for track_id, track in self.tracklets.items(): if frame_id is None or track["last_frame"] == frame_id: boxes_list.append(track["box"].unsqueeze(0)) scores_list.append(track["score"].unsqueeze(0)) class_ids_list.append(track["class_id"].unsqueeze(0)) embeddings_list.append(track["embed"].unsqueeze(0)) track_ids_list.append(track_id) boxes = ( torch.cat(boxes_list) if len(boxes_list) > 0 else torch.empty((0, 4), device=device) ) scores = ( torch.cat(scores_list) if len(scores_list) > 0 else torch.empty((0,), device=device) ) class_ids = ( torch.cat(class_ids_list) if len(class_ids_list) > 0 else torch.empty((0,), device=device) ) embeddings = ( torch.cat(embeddings_list) if len(embeddings_list) > 0 else torch.empty((0,), device=device) ) track_ids = torch.tensor(track_ids_list, device=device) if add_backdrops: for backdrop in self.backdrops: backdrop_ids = torch.full( (len(backdrop["embeddings"]),), -1, dtype=torch.long, device=device, ) track_ids = torch.cat([track_ids, backdrop_ids]) boxes = torch.cat([boxes, backdrop["boxes"]]) scores = torch.cat([scores, backdrop["scores"]]) class_ids = torch.cat([class_ids, backdrop["class_ids"]]) embeddings = torch.cat([embeddings, backdrop["embeddings"]]) return boxes, scores, class_ids, track_ids, embeddings
[docs] def __call__( self, embeddings_list: list[Tensor], det_boxes_list: list[Tensor], det_scores_list: list[Tensor], class_ids_list: list[Tensor], frame_id_list: list[int], ) -> TrackOut: """Forward during test.""" ( batched_boxes, batched_scores, batched_class_ids, batched_track_ids, ) = ([], [], [], []) for frame_id, det_boxes, det_scores, class_ids, embeddings in zip( frame_id_list, det_boxes_list, det_scores_list, class_ids_list, embeddings_list, ): # reset graph at begin of sequence if frame_id == 0: self.reset() TrackIDCounter.reset() if not self.is_empty(): ( _, _, memo_class_ids, memo_track_ids, memo_embeds, ) = self.get_tracks(det_boxes.device, add_backdrops=True) else: memo_class_ids = None memo_track_ids = None memo_embeds = None track_ids, filter_indices = self.tracker( det_boxes, det_scores, class_ids, embeddings, memo_track_ids, memo_class_ids, memo_embeds, ) self.update( frame_id, track_ids, det_boxes[filter_indices], det_scores[filter_indices], class_ids[filter_indices], embeddings[filter_indices], ) ( boxes, scores, class_ids, track_ids, _, ) = self.get_tracks(det_boxes.device, frame_id=frame_id) batched_boxes.append(boxes) batched_scores.append(scores) batched_class_ids.append(class_ids) batched_track_ids.append(track_ids) return TrackOut( boxes=batched_boxes, class_ids=batched_class_ids, scores=batched_scores, track_ids=batched_track_ids, )
[docs] def update( self, frame_id: int, track_ids: Tensor, boxes: Tensor, scores: Tensor, class_ids: Tensor, embeddings: Tensor, ) -> None: """Update the track memory with a new state.""" valid_tracks = track_ids > -1 # update memo for track_id, box, score, class_id, embed in zip( track_ids[valid_tracks], boxes[valid_tracks], scores[valid_tracks], class_ids[valid_tracks], embeddings[valid_tracks], ): track_id = int(track_id) if track_id in self.tracklets: self.update_track( track_id, box, score, class_id, embed, frame_id ) else: self.create_track( track_id, box, score, class_id, embed, frame_id ) # backdrops backdrop_inds = torch.nonzero( torch.eq(track_ids, -1), as_tuple=False ).squeeze(1) ious = bbox_iou(boxes[backdrop_inds], boxes) for i, ind in enumerate(backdrop_inds): if (ious[i, :ind] > self.nms_backdrop_iou_thr).any(): backdrop_inds[i] = -1 backdrop_inds = backdrop_inds[backdrop_inds > -1] self.backdrops.insert( 0, { "boxes": boxes[backdrop_inds], "scores": scores[backdrop_inds], "class_ids": class_ids[backdrop_inds], "embeddings": embeddings[backdrop_inds], }, ) # delete invalid tracks from memory invalid_ids = [] for k, v in self.tracklets.items(): if frame_id - v["last_frame"] >= self.memory_size: invalid_ids.append(k) for invalid_id in invalid_ids: self.tracklets.pop(invalid_id) if len(self.backdrops) > self.backdrop_memory_size: self.backdrops.pop()
[docs] def update_track( self, track_id: int, box: Tensor, score: Tensor, class_id: Tensor, embedding: Tensor, frame_id: int, ) -> None: """Update a specific track with a new models.""" self.tracklets[track_id]["box"] = box self.tracklets[track_id]["score"] = score self.tracklets[track_id]["class_id"] = class_id self.tracklets[track_id]["embed"] = ( 1 - self.memory_momentum ) * self.tracklets[track_id][ "embed" ] + self.memory_momentum * embedding self.tracklets[track_id]["last_frame"] = frame_id
[docs] def create_track( self, track_id: int, box: Tensor, score: Tensor, class_id: Tensor, embedding: Tensor, frame_id: int, ) -> None: """Create a new track from a models.""" self.tracklets[track_id] = Track( box=box, score=score, class_id=class_id, embed=embedding, last_frame=frame_id, )