"""NuScenes trajectory dataset."""
from __future__ import annotations
import json
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm
from vis4d.common.imports import NUSCENES_AVAILABLE
from vis4d.common.logging import rank_zero_info
from vis4d.common.typing import DictStrAny, NDArrayF32
from vis4d.data.typing import DictData
from .base import Dataset
from .util import CacheMappingMixin
if NUSCENES_AVAILABLE:
from nuscenes import NuScenes as NuScenesDevkit
from nuscenes.eval.detection.utils import category_to_detection_name
from nuscenes.utils.data_classes import Quaternion
from nuscenes.utils.splits import create_splits_scenes
else:
raise ImportError("nusenes-devkit is not available.")
[docs]
class NuScenesTrajectory(CacheMappingMixin, Dataset):
"""NuScenes Trajectory dataset with given detection results.
It will generate a trajectory data pair with minimum sequence length. The
detection results will be matched with the ground truth trajectory
according to the BEV distance.
"""
def __init__(
self,
detector: str,
pure_detection: str,
data_root: str,
version: str = "v1.0-trainval",
split: str = "train",
min_seq_len: int = 10,
cache_as_binary: bool = False,
cached_file_path: str | None = None,
) -> None:
"""Init dataset.
Args:
detector (str): The detector name.
pure_detection (str): The path to the pure detection results. It
should be the same format as nuScenes submission format.
data_root (str): The root path of the dataset.
version (str, optional): The version of the dataset. Defaults to
"v1.0-trainval".
split (str, optional): The split of the dataset. Defaults to
"train".
min_seq_len (int, optional): The minimum sequence length of the
trajectory. Defaults to 10.
cache_as_binary (bool, optional): Whether to cache the dataset as
binary. Defaults to False.
cached_file_path (str | None, optional): The path to the cached
file. Defaults to None.
"""
super().__init__()
self.data_root = data_root
self.version = version
self.split = split
self.detector = detector
self.min_seq_len = min_seq_len
self.pure_detection = pure_detection
# Load trajectories
self.samples, _ = self._load_mapping(
self._generate_data_mapping,
cache_as_binary=cache_as_binary,
cached_file_path=cached_file_path,
)
rank_zero_info(f"Generated {len(self.samples)} trajectories.")
[docs]
def __repr__(self) -> str:
"""Concise representation of the dataset."""
return f"NuScenes Trajectory Data with {self.detector} detection"
def _match_gt_pred(
self,
gt_world: NDArrayF32,
gt_class: str,
predictions: list[DictStrAny],
) -> tuple[NDArrayF32, bool]:
"""Match gt and pred according to BEV center distance.
If the distance is less than 2 meters, the prediction will be used
instead of the ground truth.
"""
if len(predictions) > 0:
same_class_preds = [
pred
for pred in predictions
if pred["detection_name"] == gt_class
]
if len(same_class_preds) > 0:
preds_center = [
pred["translation"][:2] for pred in same_class_preds
]
distance_matrix = (
cdist( # pylint: disable=unsubscriptable-object
gt_world[:, :2],
np.array(preds_center).reshape(-1, 2),
)[0]
)
if distance_matrix[distance_matrix.argmin()] <= 2:
match_pred = same_class_preds[distance_matrix.argmin()]
# WLH -> HWL
w, l, h = match_pred["size"]
dimensions = [h, w, l]
yaw = Quaternion(match_pred["rotation"]).yaw_pitch_roll[0]
pred_world = np.array(
[
[
*match_pred["translation"],
*dimensions,
yaw,
match_pred["detection_score"],
]
],
dtype=np.float32,
)
return pred_world, False
return gt_world, True
def _generate_data_mapping(self) -> list[dict[str, NDArrayF32]]:
"""Generate trajectories predction and groundtruth.
Trajectories will be generated for each scene. Each trajectory consists
of [x, y, z, h, w, l, yaw, score] in world coordinate.
Returns:
list[dict[str, NDArrayF32]]: The list of trajectories.
"""
data = NuScenesDevkit(
version=self.version, dataroot=self.data_root, verbose=False
)
scene_names_per_split = create_splits_scenes()
scenes = [
scene
for scene in data.scene
if scene["name"] in scene_names_per_split[self.split]
]
instance_tokens = []
with open(self.pure_detection, "r", encoding="utf-8") as f:
predictions = json.load(f)
num_gt_boxes = 0
num_pred_boxes = 0
total_traj = []
for scene in tqdm(scenes):
local_traj: dict[int, dict[str, list[NDArrayF32]]] = {}
sample_token = scene["first_sample_token"]
while sample_token:
sample = data.get("sample", sample_token)
preds = predictions["results"][sample_token]
for ann_token in sample["anns"]:
ann_info = data.get("sample_annotation", ann_token)
box3d_class = category_to_detection_name(
ann_info["category_name"]
)
if box3d_class is None:
continue
box3d = data.get_box(ann_info["token"])
instance_token = data.get(
"sample_annotation", box3d.token
)["instance_token"]
if not instance_token in instance_tokens:
instance_tokens.append(instance_token)
track_id = instance_tokens.index(instance_token)
if track_id not in local_traj:
local_traj[track_id] = {"gt": [], "pred": []}
# WLH -> HWL
w, l, h = box3d.wlh
dimensions = [h, w, l]
yaw = box3d.orientation.yaw_pitch_roll[0]
gt_world = np.array(
[[*box3d.center, *dimensions, yaw, 1.0]],
dtype=np.float32,
)
local_traj[track_id]["gt"].append(gt_world)
matched_pred, is_gt = self._match_gt_pred(
gt_world, box3d_class, preds
)
local_traj[track_id]["pred"].append(matched_pred)
if is_gt:
num_gt_boxes += 1
else:
num_pred_boxes += 1
sample_token = sample["next"]
for _, traj in local_traj.items():
if len(traj["gt"]) >= self.min_seq_len:
trajectory = {
"gt": np.concatenate(traj["gt"]),
"pred": np.concatenate(traj["pred"]),
}
total_traj.append(trajectory)
rank_zero_info(f"Use {num_gt_boxes} gt boxes.")
rank_zero_info(f"Use {num_pred_boxes} pred boxes.")
return total_traj
[docs]
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.samples)
[docs]
def __getitem__(self, idx: int) -> DictData:
"""Return the item at the given index.
The trajectory will be randomly cropped to the minimum sequence length.
"""
trajectory = self.samples[idx]
data_dict: DictData = {}
traj_len = len(trajectory["gt"])
if traj_len > self.min_seq_len:
first_frame = np.random.randint(traj_len - self.min_seq_len)
else:
first_frame = 0
data_dict["gt_traj"] = trajectory["gt"][
first_frame : first_frame + self.min_seq_len
]
data_dict["pred_traj"] = trajectory["pred"][
first_frame : first_frame + self.min_seq_len
]
return data_dict