"""Anchor generator for 2D bounding boxes.
Modified from:
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/anchor/anchor_generator.py
"""
from __future__ import annotations
import numpy as np
import torch
from torch import Tensor
from torch.nn.modules.utils import _pair
from .util import meshgrid
[docs]
def anchor_inside_image(
flat_anchors: Tensor, img_shape: tuple[int, int], allowed_border: int = 0
) -> Tensor:
"""Check whether the anchors are inside the border.
Args:
flat_anchors (Tensor): Flatten anchors, shape (n, 4).
img_shape (tuple(int)): Shape of current image.
allowed_border (int): The border to allow the valid anchor.
Defaults to 0.
Returns:
Tensor: Flags indicating whether the anchors are inside a valid range.
"""
img_h, img_w = img_shape
inside_flags = (
(flat_anchors[:, 0] >= -allowed_border)
& (flat_anchors[:, 1] >= -allowed_border)
& (flat_anchors[:, 2] < img_w + allowed_border)
& (flat_anchors[:, 3] < img_h + allowed_border)
)
return inside_flags
[docs]
class AnchorGenerator:
"""Standard anchor generator for 2D anchor-based detectors.
Examples:
>>> from vis4d.op.box.anchor import AnchorGenerator
>>> self = AnchorGenerator([16], [1.], [1.], [9])
>>> all_anchors = self.grid_priors([(2, 2)], device='cpu')
>>> print(all_anchors)
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
[11.5000, -4.5000, 20.5000, 4.5000],
[-4.5000, 11.5000, 4.5000, 20.5000],
[11.5000, 11.5000, 20.5000, 20.5000]])]
>>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
>>> all_anchors = self.grid_priors([(2, 2), (1, 1)], device='cpu')
>>> print(all_anchors)
[tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
[11.5000, -4.5000, 20.5000, 4.5000],
[-4.5000, 11.5000, 4.5000, 20.5000],
[11.5000, 11.5000, 20.5000, 20.5000]]), \
tensor([[-9., -9., 9., 9.]])]
"""
def __init__(
self,
strides: list[int] | list[tuple[int, int]],
ratios: list[float],
scales: list[int] | None = None,
base_sizes: list[int] | None = None,
scale_major: bool = True,
octave_base_scale: None | int = None,
scales_per_octave: None | int = None,
centers: list[tuple[float, float]] | None = None,
center_offset: float = 0.0,
) -> None:
"""Creates an instance of the class.
Args:
strides (list[int] | list[tuple[int, int]]): Strides of anchors
in multiple feature levels in order (w, h).
ratios (list[float]): The list of ratios between the height and
width of anchors in a single level.
scales (list[int] | None): Anchor scales for anchors in a single
level. It cannot be set at the same time if `octave_base_scale`
and `scales_per_octave` are set.
base_sizes (list[int] | None): The basic sizes
of anchors in multiple levels.
If None is given, strides will be used as base_sizes.
(If strides are non square, the shortest stride is taken.)
scale_major (bool): Whether to multiply scales first when
generating base anchors. If true, the anchors in the same row
will have the same scales. By default it is True in V2.0
octave_base_scale (int): The base scale of octave.
scales_per_octave (int): Number of scales for each octave.
`octave_base_scale` and `scales_per_octave` are usually used in
retinanet and the `scales` should be None when they are set.
centers (list[tuple[float, float]] | None): The centers of the
anchor relative to the feature grid center in multiple feature
levels. By default it is set to be None and not used. If a list
of tuple of float is given, they will be used to shift the
centers of anchors.
center_offset (float): The offset of center in proportion to
anchors' width and height. By default it is 0 in V2.0.
"""
# check center and center_offset
if center_offset != 0:
assert centers is None, (
"center cannot be set when center_offset"
f"!=0, {centers} is given."
)
if not 0 <= center_offset <= 1:
raise ValueError(
"center_offset should be in range [0, 1], "
f"{center_offset} is given."
)
if centers is not None:
assert len(centers) == len(strides), (
"The number of strides should be the same as centers, got "
f"{strides} and {centers}"
)
# calculate base sizes of anchors
self.strides = [_pair(stride) for stride in strides]
self.base_sizes = (
[min(stride) for stride in self.strides]
if base_sizes is None
else base_sizes
)
assert len(self.base_sizes) == len(self.strides), (
"The number of strides should be the same as base sizes, got "
f"{self.strides} and {self.base_sizes}"
)
# calculate scales of anchors
assert (
octave_base_scale is not None and scales_per_octave is not None
) ^ (scales is not None), (
"scales and octave_base_scale with scales_per_octave cannot"
" be set at the same time"
)
if scales is not None:
self.scales = torch.Tensor(scales)
elif octave_base_scale is not None and scales_per_octave is not None:
octave_scales = np.array(
[
2 ** (i / scales_per_octave)
for i in range(scales_per_octave)
]
)
scales = octave_scales * octave_base_scale # type: ignore
self.scales = torch.Tensor(scales)
else:
raise ValueError(
"Either scales or octave_base_scale with "
"scales_per_octave should be set"
)
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
self.ratios = torch.Tensor(ratios)
self.scale_major = scale_major
self.centers = centers
self.center_offset = center_offset
self.base_anchors = self.gen_base_anchors()
@property
def num_base_priors(self) -> list[int]:
"""list[int]: The number of priors at a point on the feature grid."""
return [base_anchors.size(0) for base_anchors in self.base_anchors]
@property
def num_levels(self) -> int:
"""int: number of feature levels that the generator will be applied."""
return len(self.strides)
[docs]
def gen_base_anchors(self) -> list[Tensor]:
"""Generate base anchors.
Returns:
list(torch.Tensor): Base anchors of a feature grid in multiple \
feature levels.
"""
multi_level_base_anchors = []
for i, base_size in enumerate(self.base_sizes):
center = None
if self.centers is not None:
center = self.centers[i]
multi_level_base_anchors.append(
self.gen_single_level_base_anchors(
base_size,
scales=self.scales,
ratios=self.ratios,
center=center,
)
)
return multi_level_base_anchors
[docs]
def gen_single_level_base_anchors(
self,
base_size: int,
scales: Tensor,
ratios: Tensor,
center: tuple[float, float] | None = None,
) -> Tensor:
"""Generate base anchors of a single level.
Args:
base_size (int): Basic size of an anchor.
scales (Tensor): Scales of the anchor.
ratios (Tensor): The ratio between between the height
and width of anchors in a single level.
center (tuple[float], optional): The center of the base anchor
related to a single feature grid. Defaults to None.
Returns:
Tensor: Anchors in a single-level feature maps.
"""
width, height = base_size, base_size
if center is None:
x_center = self.center_offset * width
y_center = self.center_offset * height
else:
x_center, y_center = center
h_ratios = torch.sqrt(ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (width * w_ratios[:, None] * scales[None, :]).view(-1)
hs = (height * h_ratios[:, None] * scales[None, :]).view(-1)
else:
ws = (width * scales[:, None] * w_ratios[None, :]).view(-1)
hs = (height * scales[:, None] * h_ratios[None, :]).view(-1)
# use float anchor and the anchor's center is aligned with the
# pixel center
base_anchors = [
x_center - 0.5 * ws,
y_center - 0.5 * hs,
x_center + 0.5 * ws,
y_center + 0.5 * hs,
]
return torch.stack(base_anchors, dim=-1)
[docs]
def grid_priors(
self,
featmap_sizes: list[tuple[int, int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> list[Tensor]:
"""Generate grid anchors in multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes in
multiple feature levels.
dtype (torch.dtype): Dtype of priors. Default: torch.float32.
device (torch.device): The device where the anchors will be put on.
Return:
list[Tensor]: Anchors in multiple feature levels. The sizes of each
tensor should be [N, 4], where
N = width * height * num_base_anchors, width and height
are the sizes of the corresponding feature level,
num_base_anchors is the number of anchors for that level.
"""
assert self.num_levels == len(featmap_sizes)
multi_level_anchors = []
for i in range(self.num_levels):
anchors = self.single_level_grid_priors(
featmap_sizes[i], level_idx=i, dtype=dtype, device=device
)
multi_level_anchors.append(anchors)
return multi_level_anchors
[docs]
def single_level_grid_priors(
self,
featmap_size: tuple[int, int],
level_idx: int,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Generate grid anchors of a single level.
Args:
featmap_size (tuple[int, int]): Size of the feature maps.
level_idx (int): The index of corresponding feature map level.
dtype (torch.dtype, optional): Data type of points. Defaults to
torch.float32.
device (torch.device): The device the tensor will be put on.
Returns:
Tensor: Anchors in the overall feature maps.
"""
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
# First create Range with the default dtype, than convert to
# target `dtype` for onnx exporting.
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
shift_xx, shift_yy = meshgrid(shift_x, shift_y)
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
return all_anchors
[docs]
def __repr__(self) -> str:
"""str: a string that describes the module."""
indent_str = " "
repr_str = self.__class__.__name__ + "(\n"
repr_str += f"{indent_str}strides={self.strides},\n"
repr_str += f"{indent_str}ratios={self.ratios},\n"
repr_str += f"{indent_str}scales={self.scales},\n"
repr_str += f"{indent_str}base_sizes={self.base_sizes},\n"
repr_str += f"{indent_str}scale_major={self.scale_major},\n"
repr_str += f"{indent_str}octave_base_scale="
repr_str += f"{self.octave_base_scale},\n"
repr_str += f"{indent_str}scales_per_octave="
repr_str += f"{self.scales_per_octave},\n"
repr_str += f"{indent_str}num_levels={self.num_levels}\n"
repr_str += f"{indent_str}centers={self.centers},\n"
repr_str += f"{indent_str}center_offset={self.center_offset})"
return repr_str