"""Memory for QDTrack inference."""
from __future__ import annotations
from typing import TypedDict
import torch
from torch import Tensor
from vis4d.op.box.box2d import bbox_iou
from vis4d.op.track.assignment import TrackIDCounter
from vis4d.op.track.common import TrackOut
from vis4d.op.track.qdtrack import QDTrackAssociation
[docs]
class Track(TypedDict):
"""QDTrack Track state.
Attributes:
box (Tensor): In shape (4,) and contains x1, y1, x2, y2.
score (Tensor): In shape (1,).
class_id (Tensor): In shape (1,).
embedding (Tensor): In shape (E,). E is the embedding dimension.
last_frame (int): Last frame id.
"""
box: Tensor
score: Tensor
class_id: Tensor
embed: Tensor
last_frame: int
[docs]
class QDTrackGraph:
"""Quasi-dense embedding similarity based graph."""
def __init__(
self,
track: QDTrackAssociation | None = None,
memory_size: int = 10,
memory_momentum: float = 0.8,
nms_backdrop_iou_thr: float = 0.3,
backdrop_memory_size: int = 1,
) -> None:
"""Init."""
assert memory_size >= 0
self.memory_size = memory_size
assert 0 <= memory_momentum <= 1.0
self.memory_momentum = memory_momentum
assert backdrop_memory_size >= 0
self.backdrop_memory_size = backdrop_memory_size
self.nms_backdrop_iou_thr = nms_backdrop_iou_thr
self.tracker = QDTrackAssociation() if track is None else track
self.tracklets: dict[int, Track] = {}
self.backdrops: list[dict[str, Tensor]] = []
[docs]
def reset(self) -> None:
"""Empty the memory."""
self.tracklets.clear()
self.backdrops.clear()
[docs]
def is_empty(self) -> bool:
"""Check if the memory is empty."""
return len(self.tracklets) == 0
[docs]
def get_tracks(
self,
device: torch.device,
frame_id: int | None = None,
add_backdrops: bool = False,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Get tracklests.
If the frame_id is not provided, will return the latest state of all
tracklets. Otherwise, will return the state of all tracklets at the
given frame_id. If add_backdrops is True, will also return the
backdrops.
Args:
device (torch.device): Device to put the tensors on.
frame_id (int, optional): Frame id to query. Defaults to None.
add_backdrops (bool, optional): Whether to add backdrops to the
output. Defaults to False.
Returns:
boxes (Tensor): 2D boxes in shape (N, 4).
scores (Tensor): 2D scores in shape (N,).
class_ids (Tensor): Class ids in shape (N,).
track_ids (Tensor): Track ids in shape (N,).
embeddings (Tensor): Embeddings in shape (N, E).
"""
(
boxes_list,
scores_list,
class_ids_list,
embeddings_list,
track_ids_list,
) = ([], [], [], [], [])
for track_id, track in self.tracklets.items():
if frame_id is None or track["last_frame"] == frame_id:
boxes_list.append(track["box"].unsqueeze(0))
scores_list.append(track["score"].unsqueeze(0))
class_ids_list.append(track["class_id"].unsqueeze(0))
embeddings_list.append(track["embed"].unsqueeze(0))
track_ids_list.append(track_id)
boxes = (
torch.cat(boxes_list)
if len(boxes_list) > 0
else torch.empty((0, 4), device=device)
)
scores = (
torch.cat(scores_list)
if len(scores_list) > 0
else torch.empty((0,), device=device)
)
class_ids = (
torch.cat(class_ids_list)
if len(class_ids_list) > 0
else torch.empty((0,), device=device)
)
embeddings = (
torch.cat(embeddings_list)
if len(embeddings_list) > 0
else torch.empty((0,), device=device)
)
track_ids = torch.tensor(track_ids_list, device=device)
if add_backdrops:
for backdrop in self.backdrops:
backdrop_ids = torch.full(
(len(backdrop["embeddings"]),),
-1,
dtype=torch.long,
device=device,
)
track_ids = torch.cat([track_ids, backdrop_ids])
boxes = torch.cat([boxes, backdrop["boxes"]])
scores = torch.cat([scores, backdrop["scores"]])
class_ids = torch.cat([class_ids, backdrop["class_ids"]])
embeddings = torch.cat([embeddings, backdrop["embeddings"]])
return boxes, scores, class_ids, track_ids, embeddings
[docs]
def __call__(
self,
embeddings_list: list[Tensor],
det_boxes_list: list[Tensor],
det_scores_list: list[Tensor],
class_ids_list: list[Tensor],
frame_id_list: list[int],
) -> TrackOut:
"""Forward during test."""
(
batched_boxes,
batched_scores,
batched_class_ids,
batched_track_ids,
) = ([], [], [], [])
for frame_id, det_boxes, det_scores, class_ids, embeddings in zip(
frame_id_list,
det_boxes_list,
det_scores_list,
class_ids_list,
embeddings_list,
):
# reset graph at begin of sequence
if frame_id == 0:
self.reset()
TrackIDCounter.reset()
if not self.is_empty():
(
_,
_,
memo_class_ids,
memo_track_ids,
memo_embeds,
) = self.get_tracks(det_boxes.device, add_backdrops=True)
else:
memo_class_ids = None
memo_track_ids = None
memo_embeds = None
track_ids, filter_indices = self.tracker(
det_boxes,
det_scores,
class_ids,
embeddings,
memo_track_ids,
memo_class_ids,
memo_embeds,
)
self.update(
frame_id,
track_ids,
det_boxes[filter_indices],
det_scores[filter_indices],
class_ids[filter_indices],
embeddings[filter_indices],
)
(
boxes,
scores,
class_ids,
track_ids,
_,
) = self.get_tracks(det_boxes.device, frame_id=frame_id)
batched_boxes.append(boxes)
batched_scores.append(scores)
batched_class_ids.append(class_ids)
batched_track_ids.append(track_ids)
return TrackOut(
boxes=batched_boxes,
class_ids=batched_class_ids,
scores=batched_scores,
track_ids=batched_track_ids,
)
[docs]
def update(
self,
frame_id: int,
track_ids: Tensor,
boxes: Tensor,
scores: Tensor,
class_ids: Tensor,
embeddings: Tensor,
) -> None:
"""Update the track memory with a new state."""
valid_tracks = track_ids > -1
# update memo
for track_id, box, score, class_id, embed in zip(
track_ids[valid_tracks],
boxes[valid_tracks],
scores[valid_tracks],
class_ids[valid_tracks],
embeddings[valid_tracks],
):
track_id = int(track_id)
if track_id in self.tracklets:
self.update_track(
track_id, box, score, class_id, embed, frame_id
)
else:
self.create_track(
track_id, box, score, class_id, embed, frame_id
)
# backdrops
backdrop_inds = torch.nonzero(
torch.eq(track_ids, -1), as_tuple=False
).squeeze(1)
ious = bbox_iou(boxes[backdrop_inds], boxes)
for i, ind in enumerate(backdrop_inds):
if (ious[i, :ind] > self.nms_backdrop_iou_thr).any():
backdrop_inds[i] = -1
backdrop_inds = backdrop_inds[backdrop_inds > -1]
self.backdrops.insert(
0,
{
"boxes": boxes[backdrop_inds],
"scores": scores[backdrop_inds],
"class_ids": class_ids[backdrop_inds],
"embeddings": embeddings[backdrop_inds],
},
)
# delete invalid tracks from memory
invalid_ids = []
for k, v in self.tracklets.items():
if frame_id - v["last_frame"] >= self.memory_size:
invalid_ids.append(k)
for invalid_id in invalid_ids:
self.tracklets.pop(invalid_id)
if len(self.backdrops) > self.backdrop_memory_size:
self.backdrops.pop()
[docs]
def update_track(
self,
track_id: int,
box: Tensor,
score: Tensor,
class_id: Tensor,
embedding: Tensor,
frame_id: int,
) -> None:
"""Update a specific track with a new models."""
self.tracklets[track_id]["box"] = box
self.tracklets[track_id]["score"] = score
self.tracklets[track_id]["class_id"] = class_id
self.tracklets[track_id]["embed"] = (
1 - self.memory_momentum
) * self.tracklets[track_id][
"embed"
] + self.memory_momentum * embedding
self.tracklets[track_id]["last_frame"] = frame_id
[docs]
def create_track(
self,
track_id: int,
box: Tensor,
score: Tensor,
class_id: Tensor,
embedding: Tensor,
frame_id: int,
) -> None:
"""Create a new track from a models."""
self.tracklets[track_id] = Track(
box=box,
score=score,
class_id=class_id,
embed=embedding,
last_frame=frame_id,
)