"""Post process after transformation."""
from __future__ import annotations
import torch
from vis4d.common.typing import NDArrayF32, NDArrayI64
from vis4d.data.const import CommonKeys as K
from vis4d.op.box.box2d import bbox_area, bbox_clip
from .base import Transform
[docs]
@Transform(
in_keys=[
K.boxes2d,
K.boxes2d_classes,
K.boxes2d_track_ids,
K.input_hw,
K.boxes3d,
K.boxes3d_classes,
K.boxes3d_track_ids,
],
out_keys=[
K.boxes2d,
K.boxes2d_classes,
K.boxes2d_track_ids,
K.boxes3d,
K.boxes3d_classes,
K.boxes3d_track_ids,
],
)
class PostProcessBoxes2D:
"""Post process after transformation."""
def __init__(
self, min_area: float = 7.0 * 7.0, clip_bboxes_to_image: bool = True
) -> None:
"""Creates an instance of the class.
Args:
min_area (float): Minimum area of the bounding box. Defaults to
7.0 * 7.0.
clip_bboxes_to_image (bool): Whether to clip the bounding boxes to
the image size. Defaults to True.
"""
self.min_area = min_area
self.clip_bboxes_to_image = clip_bboxes_to_image
[docs]
def __call__(
self,
boxes_list: list[NDArrayF32],
classes_list: list[NDArrayI64],
track_ids_list: list[NDArrayI64] | None,
input_hw_list: list[tuple[int, int]],
boxes3d_list: list[NDArrayF32] | None,
boxes3d_classes_list: list[NDArrayI64] | None,
boxes3d_track_ids_list: list[NDArrayI64] | None,
) -> tuple[
list[NDArrayF32],
list[NDArrayI64],
list[NDArrayI64] | None,
list[NDArrayF32] | None,
list[NDArrayI64] | None,
list[NDArrayI64] | None,
]:
"""Post process according to boxes2D after transformation.
Args:
boxes_list (list[NDArrayF32]): The bounding boxes to be post
processed.
classes_list (list[NDArrayF32]): The classes of the bounding boxes.
track_ids_list (list[NDArrayI64] | None): The track ids of the
bounding boxes.
input_hw_list (list[tuple[int, int]]): The height and width of the
input image.
boxes3d_list (list[NDArrayF32] | None): The 3D bounding boxes to be
post processed.
boxes3d_classes_list (list[NDArrayI64] | None): The classes of the
3D bounding boxes.
boxes3d_track_ids_list (list[NDArrayI64] | None): The track ids of
the 3D bounding boxes.
Returns:
tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None,
list[NDArrayF32] | None, list[NDArrayI64] | None,
list[NDArrayI64] | None]: The post processed results.
"""
new_track_ids: list[NDArrayI64] | None = (
[] if track_ids_list is not None else None
)
new_boxes3d: list[NDArrayF32] | None = (
[] if boxes3d_list is not None else None
)
new_boxes3d_classes: list[NDArrayI64] | None = (
[] if boxes3d_classes_list is not None else None
)
new_boxes3d_track_ids: list[NDArrayI64] | None = (
[] if boxes3d_track_ids_list is not None else None
)
for i, (boxes, classes) in enumerate(zip(boxes_list, classes_list)):
boxes_ = torch.from_numpy(boxes)
if self.clip_bboxes_to_image:
boxes_ = bbox_clip(boxes_, input_hw_list[i])
keep = (bbox_area(boxes_) >= self.min_area).numpy()
boxes_list[i] = boxes[keep]
classes_list[i] = classes[keep]
if track_ids_list is not None:
assert new_track_ids is not None
new_track_ids.append(track_ids_list[i][keep])
if boxes3d_list is not None:
assert new_boxes3d is not None
new_boxes3d.append(boxes3d_list[i][keep])
if boxes3d_classes_list is not None:
assert new_boxes3d_classes is not None
new_boxes3d_classes.append(boxes3d_classes_list[i][keep])
if boxes3d_track_ids_list is not None:
assert new_boxes3d_track_ids is not None
new_boxes3d_track_ids.append(boxes3d_track_ids_list[i][keep])
return (
boxes_list,
classes_list,
new_track_ids,
new_boxes3d,
new_boxes3d_classes,
new_boxes3d_track_ids,
)
[docs]
@Transform(in_keys=[K.boxes2d_track_ids], out_keys=[K.boxes2d_track_ids])
class RescaleTrackIDs:
"""Rescale track ids."""
[docs]
def __call__(self, track_ids_list: list[NDArrayI64]) -> list[NDArrayI64]:
"""Rescale the track ids.
Args:
track_ids_list (list[NDArrayI64]): The track ids to be
rescaled.
Returns:
list[NDArrayI64]: The rescaled track ids.
"""
track_ids_all: dict[int, int] = {}
for track_ids in track_ids_list:
for track_id in track_ids:
if track_id not in track_ids_all:
track_ids_all[track_id] = len(track_ids_all)
for track_ids in track_ids_list:
for i, track_id in enumerate(track_ids):
track_ids[i] = track_ids_all[track_id]
return track_ids_list