Source code for vis4d.data.transforms.post_process

"""Post process after transformation."""

from __future__ import annotations

import torch

from vis4d.common.typing import NDArrayF32, NDArrayI64
from vis4d.data.const import CommonKeys as K
from vis4d.op.box.box2d import bbox_area, bbox_clip

from .base import Transform


[docs] @Transform( in_keys=[ K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids, K.input_hw, K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids, ], out_keys=[ K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids, K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids, ], ) class PostProcessBoxes2D: """Post process after transformation.""" def __init__( self, min_area: float = 7.0 * 7.0, clip_bboxes_to_image: bool = True ) -> None: """Creates an instance of the class. Args: min_area (float): Minimum area of the bounding box. Defaults to 7.0 * 7.0. clip_bboxes_to_image (bool): Whether to clip the bounding boxes to the image size. Defaults to True. """ self.min_area = min_area self.clip_bboxes_to_image = clip_bboxes_to_image
[docs] def __call__( self, boxes_list: list[NDArrayF32], classes_list: list[NDArrayI64], track_ids_list: list[NDArrayI64] | None, input_hw_list: list[tuple[int, int]], boxes3d_list: list[NDArrayF32] | None, boxes3d_classes_list: list[NDArrayI64] | None, boxes3d_track_ids_list: list[NDArrayI64] | None, ) -> tuple[ list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None, list[NDArrayF32] | None, list[NDArrayI64] | None, list[NDArrayI64] | None, ]: """Post process according to boxes2D after transformation. Args: boxes_list (list[NDArrayF32]): The bounding boxes to be post processed. classes_list (list[NDArrayF32]): The classes of the bounding boxes. track_ids_list (list[NDArrayI64] | None): The track ids of the bounding boxes. input_hw_list (list[tuple[int, int]]): The height and width of the input image. boxes3d_list (list[NDArrayF32] | None): The 3D bounding boxes to be post processed. boxes3d_classes_list (list[NDArrayI64] | None): The classes of the 3D bounding boxes. boxes3d_track_ids_list (list[NDArrayI64] | None): The track ids of the 3D bounding boxes. Returns: tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None, list[NDArrayF32] | None, list[NDArrayI64] | None, list[NDArrayI64] | None]: The post processed results. """ new_track_ids: list[NDArrayI64] | None = ( [] if track_ids_list is not None else None ) new_boxes3d: list[NDArrayF32] | None = ( [] if boxes3d_list is not None else None ) new_boxes3d_classes: list[NDArrayI64] | None = ( [] if boxes3d_classes_list is not None else None ) new_boxes3d_track_ids: list[NDArrayI64] | None = ( [] if boxes3d_track_ids_list is not None else None ) for i, (boxes, classes) in enumerate(zip(boxes_list, classes_list)): boxes_ = torch.from_numpy(boxes) if self.clip_bboxes_to_image: boxes_ = bbox_clip(boxes_, input_hw_list[i]) keep = (bbox_area(boxes_) >= self.min_area).numpy() boxes_list[i] = boxes[keep] classes_list[i] = classes[keep] if track_ids_list is not None: assert new_track_ids is not None new_track_ids.append(track_ids_list[i][keep]) if boxes3d_list is not None: assert new_boxes3d is not None new_boxes3d.append(boxes3d_list[i][keep]) if boxes3d_classes_list is not None: assert new_boxes3d_classes is not None new_boxes3d_classes.append(boxes3d_classes_list[i][keep]) if boxes3d_track_ids_list is not None: assert new_boxes3d_track_ids is not None new_boxes3d_track_ids.append(boxes3d_track_ids_list[i][keep]) return ( boxes_list, classes_list, new_track_ids, new_boxes3d, new_boxes3d_classes, new_boxes3d_track_ids, )
[docs] @Transform(in_keys=[K.boxes2d_track_ids], out_keys=[K.boxes2d_track_ids]) class RescaleTrackIDs: """Rescale track ids."""
[docs] def __call__(self, track_ids_list: list[NDArrayI64]) -> list[NDArrayI64]: """Rescale the track ids. Args: track_ids_list (list[NDArrayI64]): The track ids to be rescaled. Returns: list[NDArrayI64]: The rescaled track ids. """ track_ids_all: dict[int, int] = {} for track_ids in track_ids_list: for track_id in track_ids: if track_id not in track_ids_all: track_ids_all[track_id] = len(track_ids_all) for track_ids in track_ids_list: for i, track_id in enumerate(track_ids): track_ids[i] = track_ids_all[track_id] return track_ids_list