"""Memory for CC-3DT inference."""
from __future__ import annotations
from typing import TypedDict
import torch
from torch import Tensor, nn
from vis4d.common.typing import DictStrAny
from vis4d.op.box.box2d import bbox_iou
from vis4d.op.track3d.cc_3dt import CC3DTrackAssociation, get_track_3d_out
from vis4d.op.track3d.common import Track3DOut
from vis4d.op.track.assignment import TrackIDCounter
from .motion import BaseMotionModel, KF3DMotionModel, LSTM3DMotionModel
[docs]
class Track(TypedDict):
"""CC-3DT Track state.
Attributes:
box_2d (Tensor): In shape (4,) and contains x1, y1, x2, y2.
score_2d (Tensor): In shape (1,).
box_3d (Tensor): In shape (12,) contains x,y,z,h,w,l,rx,ry,rz,vx,vy,vz.
score_3d (Tensor): In shape (1,).
class_id (Tensor): In shape (1,).
embed (Tensor): In shape (E,). E is the embedding dimension.
motion_model (BaseMotionModel): The motion model.
velocity (Tensor): In shape (motion_dims,).
last_frame (int): The last frame the track was updated.
acc_frame (int): The number of frames the track was updated.
"""
box_2d: Tensor
score_2d: Tensor
box_3d: Tensor
score_3d: Tensor
class_id: Tensor
embed: Tensor
motion_model: BaseMotionModel
velocity: Tensor
last_frame: int
acc_frame: int
[docs]
class CC3DTrackGraph:
"""CC-3DT tracking graph."""
def __init__(
self,
track: CC3DTrackAssociation | None = None,
memory_size: int = 10,
memory_momentum: float = 0.8,
backdrop_memory_size: int = 1,
nms_backdrop_iou_thr: float = 0.3,
motion_model: str = "KF3D",
lstm_model: nn.Module | None = None,
motion_dims: int = 7,
num_frames: int = 5,
fps: int = 2,
update_3d_score: bool = True,
add_backdrops: bool = True,
) -> None:
"""Creates an instance of the class."""
assert memory_size >= 0
self.memory_size = memory_size
assert 0 <= memory_momentum <= 1.0
self.memory_momentum = memory_momentum
assert backdrop_memory_size >= 0
self.backdrop_memory_size = backdrop_memory_size
self.nms_backdrop_iou_thr = nms_backdrop_iou_thr
self.tracker = CC3DTrackAssociation() if track is None else track
self.tracklets: dict[int, Track] = {}
self.backdrops: list[DictStrAny] = []
if motion_model == "VeloLSTM":
assert (
lstm_model is not None
), "lstm_model must be provided for VeloLSTM"
self.lstm_model = lstm_model
self.motion_model = motion_model
self.motion_dims = motion_dims
self.num_frames = num_frames
self.fps = fps
self.update_3d_score = update_3d_score
self.add_backdrops = add_backdrops
[docs]
def reset(self) -> None:
"""Empty the memory."""
self.tracklets.clear()
self.backdrops.clear()
[docs]
def is_empty(self) -> bool:
"""Check if the memory is empty."""
return len(self.tracklets) == 0
[docs]
def get_tracks(
self,
device: torch.device,
frame_id: int | None = None,
add_backdrops: bool = False,
) -> tuple[
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
Tensor,
list[BaseMotionModel],
Tensor,
]:
"""Get tracklests.
If the frame_id is not provided, will return the latest state of all
tracklets. Otherwise, will return the state of all tracklets at the
given frame_id. If add_backdrops is True, will also return the
backdrops.
Args:
device (torch.device): Device to put the tensors on.
frame_id (int, optional): Frame id to query. Defaults to None.
add_backdrops (bool, optional): Whether to add backdrops to the
output. Defaults to False.
Returns:
boxes_2d (Tensor): 2D boxes in shape (N, 4).
scores_2d (Tensor): 2D scores in shape (N,).
boxes_3d (Tensor): 3D boxes in shape (N, 12).
scores_3d (Tensor): 3D scores in shape (N,).
class_ids (Tensor): Class ids in shape (N,).
track_ids (Tensor): Track ids in shape (N,).
embeds (Tensor): Embeddings in shape (N, E).
motion_models (list[BaseMotionModel]): Motion models.
velocities (Tensor): Velocities in shape (N, 3).
"""
(
boxes_2d_list,
scores_2d_list,
boxes_3d_list,
scores_3d_list,
class_ids_list,
embeds_list,
motion_models,
velocities_list,
track_ids_list,
) = ([], [], [], [], [], [], [], [], [])
for track_id, track in self.tracklets.items():
if frame_id is None or track["last_frame"] == frame_id:
boxes_2d_list.append(track["box_2d"].unsqueeze(0))
scores_2d_list.append(track["score_2d"].unsqueeze(0))
boxes_3d_list.append(track["box_3d"].unsqueeze(0))
scores_3d_list.append(track["score_3d"].unsqueeze(0))
class_ids_list.append(track["class_id"].unsqueeze(0))
embeds_list.append(track["embed"].unsqueeze(0))
motion_models.append(track["motion_model"])
velocities_list.append(track["velocity"].unsqueeze(0))
track_ids_list.append(track_id)
boxes_2d = (
torch.cat(boxes_2d_list)
if len(boxes_2d_list) > 0
else torch.empty((0, 4), device=device)
)
scores_2d = (
torch.cat(scores_2d_list)
if len(scores_2d_list) > 0
else torch.empty((0,), device=device)
)
boxes_3d = (
torch.cat(boxes_3d_list)
if len(boxes_3d_list) > 0
else torch.empty((0, 12), device=device)
)
scores_3d = (
torch.cat(scores_3d_list)
if len(scores_3d_list) > 0
else torch.empty((0,), device=device)
)
class_ids = (
torch.cat(class_ids_list)
if len(class_ids_list) > 0
else torch.empty((0,), device=device)
)
embeds = (
torch.cat(embeds_list)
if len(embeds_list) > 0
else torch.empty((0,), device=device)
)
velocities = (
torch.cat(velocities_list)
if len(velocities_list) > 0
else torch.empty((0, self.motion_dims), device=device)
)
track_ids = torch.tensor(track_ids_list, device=device)
if add_backdrops:
for backdrop in self.backdrops:
backdrop_ids = torch.full(
(len(backdrop["embeddings"]),),
-1,
dtype=torch.long,
device=device,
)
track_ids = torch.cat([track_ids, backdrop_ids])
boxes_2d = torch.cat([boxes_2d, backdrop["boxes_2d"]])
scores_2d = torch.cat([scores_2d, backdrop["scores_2d"]])
boxes_3d = torch.cat([boxes_3d, backdrop["boxes_3d"]])
scores_3d = torch.cat([scores_3d, backdrop["scores_3d"]])
class_ids = torch.cat([class_ids, backdrop["class_ids"]])
embeds = torch.cat([embeds, backdrop["embeddings"]])
motion_models.extend(backdrop["motion_models"])
backdrop_vs = torch.zeros_like(
backdrop["boxes_3d"][:, : self.motion_dims]
)
velocities = torch.cat([velocities, backdrop_vs])
return (
boxes_2d,
scores_2d,
boxes_3d,
scores_3d,
class_ids,
track_ids,
embeds,
motion_models,
velocities,
)
[docs]
def __call__(
self,
boxes_2d: Tensor,
scores_2d: Tensor,
camera_ids: Tensor,
boxes_3d: Tensor,
scores_3d: Tensor,
class_ids: Tensor,
embeddings: Tensor,
frame_id: int,
) -> Track3DOut:
"""Update the tracker with new detections."""
if frame_id == 0:
self.reset()
TrackIDCounter.reset()
if not self.is_empty():
(
_,
_,
memo_boxes_3d,
_,
memo_class_ids,
memo_track_ids,
memo_embeds,
memo_motion_models,
memo_velocities,
) = self.get_tracks(
boxes_2d.device, add_backdrops=self.add_backdrops
)
memory_boxes_3d = torch.cat(
[memo_boxes_3d[:, :6], memo_boxes_3d[:, 8].unsqueeze(1)],
dim=1,
)
memory_track_ids = memo_track_ids
memory_class_ids = memo_class_ids
memory_embeddings = memo_embeds
memory_boxes_3d_predict = memory_boxes_3d.clone()
for i, memo_motion_model in enumerate(memo_motion_models):
pd_box_3d = memo_motion_model.predict(
update_state=memo_motion_model.age != 0
)
memory_boxes_3d_predict[i, :3] += pd_box_3d[self.motion_dims :]
memory_velocities = memo_velocities
else:
memory_boxes_3d = None
memory_track_ids = None
memory_class_ids = None
memory_embeddings = None
memory_boxes_3d_predict = None
memory_velocities = None
obs_boxes_3d = torch.cat(
[boxes_3d[:, :6], boxes_3d[:, 8].unsqueeze(1)], dim=1
)
track_ids, filter_indices = self.tracker(
boxes_2d,
camera_ids,
scores_2d,
obs_boxes_3d,
scores_3d,
class_ids,
embeddings,
memory_boxes_3d,
memory_track_ids,
memory_class_ids,
memory_embeddings,
memory_boxes_3d_predict,
memory_velocities,
self.update_3d_score,
)
self.update(
frame_id,
track_ids,
boxes_2d[filter_indices],
scores_2d[filter_indices],
camera_ids[filter_indices],
boxes_3d[filter_indices],
scores_3d[filter_indices],
class_ids[filter_indices],
embeddings[filter_indices],
obs_boxes_3d[filter_indices],
)
(
_,
scores_2d,
boxes_3d,
scores_3d,
class_ids,
track_ids,
_,
_,
_,
) = self.get_tracks(boxes_2d.device, frame_id=frame_id)
# update 3D score
if self.update_3d_score:
track_scores_3d = scores_2d * scores_3d
else:
track_scores_3d = scores_3d
return get_track_3d_out(
boxes_3d, class_ids, track_scores_3d, track_ids
)
[docs]
def update(
self,
frame_id: int,
track_ids: Tensor,
boxes_2d: Tensor,
scores_2d: Tensor,
camera_ids: Tensor,
boxes_3d: Tensor,
scores_3d: Tensor,
class_ids: Tensor,
embeddings: Tensor,
obs_boxes_3d: Tensor,
) -> None:
"""Update the track memory with a new state."""
valid_tracks = track_ids > -1
# update memo
for (
track_id,
box_2d,
score_2d,
box_3d,
score_3d,
class_id,
embed,
obs_box_3d,
) in zip(
track_ids[valid_tracks],
boxes_2d[valid_tracks],
scores_2d[valid_tracks],
boxes_3d[valid_tracks],
scores_3d[valid_tracks],
class_ids[valid_tracks],
embeddings[valid_tracks],
obs_boxes_3d[valid_tracks],
):
track_id = int(track_id)
if track_id in self.tracklets:
self.update_track(
track_id,
box_2d,
score_2d,
box_3d,
score_3d,
class_id,
embed,
obs_box_3d,
frame_id,
)
else:
self.create_track(
track_id,
box_2d,
score_2d,
box_3d,
score_3d,
class_id,
embed,
obs_box_3d,
frame_id,
)
# Handle vanished tracklets
for track_id, track in self.tracklets.items():
if frame_id > track["last_frame"] and track_id > -1:
pd_box_3d = track["motion_model"].predict()
track["box_3d"][:6] = pd_box_3d[:6]
track["box_3d"][8] = pd_box_3d[6]
# Backdrops
backdrop_inds = torch.nonzero(
torch.eq(track_ids, -1), as_tuple=False
).squeeze(1)
valid_ious = torch.eq(
camera_ids[backdrop_inds].unsqueeze(1),
camera_ids.unsqueeze(0),
).int()
ious = bbox_iou(boxes_2d[backdrop_inds], boxes_2d)
ious *= valid_ious
for i, ind in enumerate(backdrop_inds):
if (ious[i, :ind] > self.nms_backdrop_iou_thr).any():
backdrop_inds[i] = -1
backdrop_inds = backdrop_inds[backdrop_inds > -1]
backdrop_motion_model = []
for bd_ind in backdrop_inds:
backdrop_motion_model.append(
self.build_motion_model(obs_boxes_3d[bd_ind])
)
self.backdrops.insert(
0,
{
"boxes_2d": boxes_2d[backdrop_inds],
"scores_2d": scores_2d[backdrop_inds],
"boxes_3d": boxes_3d[backdrop_inds],
"scores_3d": scores_3d[backdrop_inds],
"class_ids": class_ids[backdrop_inds],
"embeddings": embeddings[backdrop_inds],
"motion_models": backdrop_motion_model,
},
)
# delete invalid tracks from memory
invalid_ids = []
for k, v in self.tracklets.items():
if frame_id - v["last_frame"] >= self.memory_size:
invalid_ids.append(k)
for invalid_id in invalid_ids:
self.tracklets.pop(invalid_id)
if len(self.backdrops) > self.backdrop_memory_size:
self.backdrops.pop()
[docs]
def update_track(
self,
track_id: int,
box_2d: Tensor,
score_2d: Tensor,
box_3d: Tensor,
score_3d: Tensor,
class_id: Tensor,
embed: Tensor,
obs_box_3d: Tensor,
frame_id: int,
) -> None:
"""Update a track."""
self.tracklets[track_id]["box_2d"] = box_2d
self.tracklets[track_id]["score_2d"] = score_2d
self.tracklets[track_id]["motion_model"].update(obs_box_3d, score_3d)
pd_box_3d = self.tracklets[track_id]["motion_model"].get_state()[
: self.motion_dims
]
prev_obs = torch.cat(
[
self.tracklets[track_id]["box_3d"][:6],
self.tracklets[track_id]["box_3d"][8].unsqueeze(0),
]
)
self.tracklets[track_id]["box_3d"] = box_3d
self.tracklets[track_id]["box_3d"][:6] = pd_box_3d[:6]
self.tracklets[track_id]["box_3d"][8] = pd_box_3d[6]
self.tracklets[track_id]["box_3d"][9:12] = self.tracklets[track_id][
"motion_model"
].predict_velocity()
self.tracklets[track_id]["score_3d"] = score_3d
self.tracklets[track_id]["class_id"] = class_id
self.tracklets[track_id]["embed"] = (
1 - self.memory_momentum
) * self.tracklets[track_id]["embed"] + self.memory_momentum * embed
velocity = (pd_box_3d - prev_obs) / (
frame_id - self.tracklets[track_id]["last_frame"]
)
self.tracklets[track_id]["velocity"] = (
self.tracklets[track_id]["velocity"]
* self.tracklets[track_id]["acc_frame"]
+ velocity
) / (self.tracklets[track_id]["acc_frame"] + 1)
self.tracklets[track_id]["last_frame"] = frame_id
self.tracklets[track_id]["acc_frame"] += 1
[docs]
def create_track(
self,
track_id: int,
box_2d: Tensor,
score_2d: Tensor,
box_3d: Tensor,
score_3d: Tensor,
class_id: Tensor,
embed: Tensor,
obs_box_3d: Tensor,
frame_id: int,
) -> None:
"""Create a new track."""
motion_model = self.build_motion_model(obs_box_3d)
self.tracklets[track_id] = Track(
box_2d=box_2d,
score_2d=score_2d,
box_3d=box_3d,
score_3d=score_3d,
class_id=class_id,
embed=embed,
motion_model=motion_model,
velocity=torch.zeros(self.motion_dims, device=box_3d.device),
last_frame=frame_id,
acc_frame=0,
)
[docs]
def build_motion_model(self, obs_3d: Tensor) -> BaseMotionModel:
"""Build motion model."""
if self.motion_model == "KF3D":
return KF3DMotionModel(
num_frames=self.num_frames,
obs_3d=obs_3d,
motion_dims=self.motion_dims,
fps=self.fps,
)
if self.motion_model == "VeloLSTM":
return LSTM3DMotionModel(
num_frames=self.num_frames,
lstm_model=self.lstm_model,
obs_3d=obs_3d,
motion_dims=self.motion_dims,
fps=self.fps,
)
raise NotImplementedError(
f"Motion model: {self.motion_model} not known!"
)