"""Utility functions for bounding boxes."""
from __future__ import annotations
import torch
from torch import Tensor
from torchvision.ops import batched_nms, nms
from vis4d.common.logging import rank_zero_warn
from vis4d.op.geometry.transform import transform_points
[docs]
def bbox_scale(
boxes: torch.Tensor, scale_factor_xy: tuple[float, float]
) -> torch.Tensor:
"""Scale bounding box tensor.
Args:
boxes (torch.Tensor): Bounding boxes with shape [N, 4]
scale_factor_xy (tuple[float, float]): Scaling factor for x and y
Returns:
torch.Tensor with bounding boxes scaled by the given factors in
x and y direction
"""
boxes[:, [0, 2]] *= scale_factor_xy[0]
boxes[:, [1, 3]] *= scale_factor_xy[1]
return boxes
[docs]
def bbox_clip(
boxes: torch.Tensor,
image_hw: tuple[float, float],
epsilon: int = 0,
) -> torch.Tensor:
"""Clip bounding boxes to image dims.
Args:
boxes (torch.Tensor): Bounding boxes with shape [N, 4]
image_hw (tuple[float, float]): Image dimensions.
epsilon (int): Epsilon for clipping.
Defaults to 0.
Returns:
torch.Tensor: Clipped bounding boxes.
"""
boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(0, image_hw[1] - epsilon)
boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(0, image_hw[0] - epsilon)
return boxes
[docs]
def scale_and_clip_boxes(
boxes: torch.Tensor,
original_hw: tuple[int, int],
current_hw: tuple[int, int],
clip: bool = True,
) -> torch.Tensor:
"""Postprocess boxes by scaling and clipping to given image dims.
Args:
boxes (torch.Tensor): Bounding boxes with shape [N, 4].
original_hw (tuple[int, int]): Original height / width of image.
current_hw (tuple[int, int]): Current height / width of image.
clip (bool): If true, clips box corners to image bounds.
Returns:
torch.Tensor: Rescaled and possibly clipped bounding boxes.
"""
scale_factor = (
original_hw[1] / current_hw[1],
original_hw[0] / current_hw[0],
)
boxes = bbox_scale(boxes, scale_factor)
if clip:
boxes = bbox_clip(boxes, original_hw)
return boxes
[docs]
def bbox_area(boxes: torch.Tensor) -> torch.Tensor:
"""Compute bounding box areas.
Args:
boxes (torch.Tensor): [N, 4] tensor of 2D boxes
in format (x1, y1, x2, y2).
Returns:
torch.Tensor: [N,] tensor of box areas.
"""
return (boxes[:, 2] - boxes[:, 0]).clamp(0) * (
boxes[:, 3] - boxes[:, 1]
).clamp(0)
[docs]
def bbox_intersection(boxes1: Tensor, boxes2: Tensor) -> torch.Tensor:
"""Given two lists of boxes of size N and M, compute N x M intersection.
Args:
boxes1: N 2D boxes in format (x1, y1, x2, y2)
boxes2: M 2D boxes in format (x1, y1, x2, y2)
Returns:
Tensor: intersection (N, M).
"""
width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
boxes1[:, None, :2], boxes2[:, :2]
)
width_height.clamp_(min=0)
intersection = width_height.prod(dim=2)
return intersection
[docs]
def bbox_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
"""Compute IoU between all pairs of boxes.
Args:
boxes1: N 2D boxes in format (x1, y1, x2, y2)
boxes2: M 2D boxes in format (x1, y1, x2, y2)
Returns:
Tensor: IoU (N, M).
"""
area1 = bbox_area(boxes1)
area2 = bbox_area(boxes2)
inter = bbox_intersection(boxes1, boxes2)
union = area1[:, None] + area2 - inter
inter = torch.where(
union > 0,
inter,
torch.zeros(1, dtype=inter.dtype, device=inter.device),
)
iou = torch.where(
inter > 0,
inter / (area1[:, None] + area2 - inter),
torch.zeros(1, dtype=inter.dtype, device=inter.device),
)
return iou
[docs]
def bbox_intersection_aligned(boxes1: Tensor, boxes2: Tensor) -> torch.Tensor:
"""Given two lists of boxes both of size N, compute N intersection.
Args:
boxes1: N 2D boxes in format (x1, y1, x2, y2)
boxes2: N 2D boxes in format (x1, y1, x2, y2)
Returns:
Tensor: intersection (N).
"""
width_height = torch.min(boxes1[:, 2:], boxes2[:, 2:]) - torch.max(
boxes1[:, :2], boxes2[:, :2]
)
width_height.clamp_(min=0)
intersection = width_height.prod(dim=1)
return intersection
[docs]
def bbox_iou_aligned(
boxes1: torch.Tensor, boxes2: torch.Tensor
) -> torch.Tensor:
"""Compute IoU between aligned pairs of boxes.
The number of boxes in both inputs must be the same.
Args:
boxes1: N 2D boxes in format (x1, y1, x2, y2)
boxes2: N 2D boxes in format (x1, y1, x2, y2)
Returns:
Tensor: IoU (N).
"""
area1 = bbox_area(boxes1)
area2 = bbox_area(boxes2)
inter = bbox_intersection_aligned(boxes1, boxes2)
iou = torch.where(
inter > 0,
inter / (area1 + area2 - inter),
torch.zeros(1, dtype=inter.dtype, device=inter.device),
)
return iou
# TODO, refactor? move to utils?
[docs]
def random_choice(tensor: torch.Tensor, sample_size: int) -> torch.Tensor:
"""Randomly choose elements from a tensor.
If sample_size < len(tensor) this function will sample without repetition
otherwise certain elements will be repeated.
Args:
tensor (torch.Tensor): Tensor to sample from
sample_size (int): Number of elements to sample
Returns:
torch.Tensor containing sample_size randomly sampled entries.
"""
perm = torch.randperm(len(tensor), device=tensor.device)[:sample_size]
# Additionally sample with repetition
if sample_size > len(tensor):
remaining_samples = sample_size - len(tensor)
perm = torch.concat(
[
torch.randint(
remaining_samples,
(remaining_samples,),
device=tensor.device,
),
perm,
]
)
return tensor[perm]
[docs]
def non_intersection(
tensor_a: torch.Tensor, tensor_b: torch.Tensor
) -> torch.Tensor:
"""Get the elements of tensor_a that are not present in tensor_b.
Args:
tensor_a (torch.Tensor): First tensor
tensor_b (torch.Tensor): Second tensor
Returns:
torch.Tensor containing all elements that occur in both tensors
"""
compareview = tensor_b.repeat(tensor_a.shape[0], 1).T
return tensor_a[(compareview != tensor_a).T.prod(1) == 1]
[docs]
def apply_mask(
masks: list[torch.Tensor], *args: list[torch.Tensor]
) -> tuple[list[torch.Tensor], ...]:
"""Apply given masks (either bool or indices) to given list of tensors.
Args:
masks (list[torch.Tensor]): Masks to apply on tensors.
*args (list[torch.Tensor]): List of tensors to apply the masks on.
Returns:
tuple[list[torch.Tensor], ...]: Masked tensor lists.
"""
return tuple(
[t[m] if len(t) > 0 else t for t, m in zip(t_list, masks)]
for t_list in args
)
[docs]
def filter_boxes_by_area(
boxes: torch.Tensor, min_area: float = 0.0
) -> tuple[torch.Tensor, torch.Tensor]:
"""Filter a set of 2D bounding boxes given a minimum area.
Args:
boxes (Tensor): 2D bounding boxes [N, 4].
min_area (float, optional): Minimum area. Defaults to 0.0.
Returns:
tuple[Tensor, Tensor]: filtered boxes, boolean mask
"""
if min_area > 0.0:
w = boxes[:, 2] - boxes[:, 0]
h = boxes[:, 3] - boxes[:, 1]
valid_mask = w * h >= min_area
if not valid_mask.all():
return boxes[valid_mask], valid_mask
return boxes, boxes.new_ones((len(boxes),), dtype=torch.bool)
[docs]
def hbox2corner(boxes: Tensor) -> Tensor:
"""Convert box coordinates from boxes to corners.
Boxes are represented as (x1, y1, x2, y2).
Corners are represented as ((x1, y1), (x2, y1), (x1, y2), (x2, y2)).
Args:
boxes (Tensor): Horizontal box tensor with shape of (..., 4).
Returns:
Tensor: Corner tensor with shape of (..., 4, 2).
"""
x1, y1, x2, y2 = torch.split(boxes, 1, dim=-1)
corners = torch.cat([x1, y1, x2, y1, x1, y2, x2, y2], dim=-1)
return corners.reshape(*corners.shape[:-1], 4, 2)
[docs]
def corner2hbox(corners: Tensor) -> Tensor:
"""Convert box coordinates from corners to boxes.
Boxes are represented as (x1, y1, x2, y2).
Corners are represented as ((x1, y1), (x2, y1), (x1, y2), (x2, y2)).
Args:
corners (Tensor): Corner tensor with shape of (..., 4, 2).
Returns:
Tensor: Horizontal box tensor with shape of (..., 4).
"""
if corners.numel() == 0:
return corners.new_zeros((0, 4))
min_xy = corners.min(dim=-2)[0]
max_xy = corners.max(dim=-2)[0]
return torch.cat([min_xy, max_xy], dim=-1)
[docs]
def bbox_project(boxes: Tensor, homography_matrix: Tensor) -> Tensor:
"""Apply geometric transform to boxes in-place.
Args:
boxes (Tensor): Horizontal box tensor with shape of (..., 4).
homography_matrix (Tensor): Shape (3, 3) for geometric transformation.
"""
corners = hbox2corner(boxes)
corners = torch.cat(
[corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1
)
corners_t = torch.transpose(corners, -1, -2)
corners_t = torch.matmul(homography_matrix, corners_t)
corners = torch.transpose(corners_t, -1, -2)
# Convert to homogeneous coordinates by normalization
corners = corners[..., :2] / corners[..., 2:3]
return corner2hbox(corners)
[docs]
def multiclass_nms(
multi_bboxes: Tensor,
multi_scores: Tensor,
score_thr: float,
iou_thr: float,
max_num: int = -1,
class_agnostic: bool = False,
split_thr: int = 100000,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Non-maximum suppression with multiple classes.
Args:
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
multi_scores (Tensor): shape (n, #class), where the last column
contains scores of the background class, but this will be ignored.
score_thr (float): bbox threshold, bboxes with scores lower than it
will not be considered.
iou_thr (float): NMS IoU threshold
max_num (int, optional): if there are more than max_num bboxes after
NMS, only top max_num will be kept. Defaults to -1.
class_agnostic (bool, optional): whether apply class_agnostic NMS.
Defaults to False.
split_thr (int, optional): If the number of bboxes is less than
split_thr, use class agnostic NMS with class_agnostic=True.
Defaults to 100000.
Returns:
tuple: (Tensor, Tensor, Tensor, Tensor): detections (k, 5), scores
(k), classes (k) and indices (k).
Raises:
RuntimeError: If there is a onnx error,
"""
num_classes = multi_scores.size(1) - 1
# exclude background category
if multi_bboxes.shape[1] > 4:
bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
else:
bboxes = multi_bboxes[:, None].expand(
multi_scores.size(0), num_classes, 4
)
scores = multi_scores[:, :-1]
labels = torch.arange(num_classes, dtype=torch.long, device=scores.device)
labels = labels.view(1, -1).expand_as(scores)
bboxes = bboxes.view(-1, 4)
scores = scores.reshape(-1)
labels = labels.reshape(-1)
if not torch.onnx.is_in_onnx_export():
# NonZero not supported in TensorRT
# remove low scoring boxes
valid_mask = scores > score_thr
if not torch.onnx.is_in_onnx_export():
# NonZero not supported in TensorRT
inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
else:
# TensorRT NMS plugin has invalid output filled with -1
# add dummy data to make detection output correct.
bboxes = torch.cat([bboxes, bboxes.new_zeros(1, 4)], dim=0)
scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
labels = torch.cat([labels, labels.new_zeros(1)], dim=0)
if bboxes.numel() == 0:
if torch.onnx.is_in_onnx_export():
raise RuntimeError(
"[ONNX Error] Can not record NMS "
"as it has not been executed this time"
)
return bboxes, scores, labels, inds
if class_agnostic and bboxes.shape[0] < split_thr:
keep = nms(bboxes, scores, iou_thr)
else:
if class_agnostic:
rank_zero_warn(
f"Number of bboxes is larger than {split_thr}, "
"using per-class NMS instead"
)
keep = batched_nms(bboxes, scores, labels, iou_thr)
if max_num > 0:
keep = keep[:max_num]
bboxes = bboxes[keep]
scores = scores[keep]
labels = labels[keep]
return bboxes, scores, labels, inds[keep]