"""Point generator for 2D bounding boxes.
Modified from:
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/anchor/point_generator.py
"""
from __future__ import annotations
import numpy as np
import torch
from torch.nn.modules.utils import _pair
from .util import meshgrid
[docs]
class MlvlPointGenerator:
"""Standard points generator for multi-level feature maps.
Used for 2D points-based detectors.
Args:
strides (list[int] | list[tuple[int, int]]): Strides of anchors
in multiple feature levels in order (w, h).
offset (float): The offset of points, the value is normalized with
corresponding stride. Defaults to 0.5.
"""
def __init__(
self, strides: list[int] | list[tuple[int, int]], offset: float = 0.5
):
"""Init."""
self.strides = [_pair(stride) for stride in strides]
self.offset = offset
@property
def num_levels(self) -> int:
"""Number of feature levels."""
return len(self.strides)
@property
def num_base_priors(self) -> list[int]:
"""Number of points at a point on the feature grid."""
return [1 for _ in range(len(self.strides))]
[docs]
def grid_priors(
self,
featmap_sizes: list[tuple[int, int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cuda"),
with_stride: bool = False,
) -> list[torch.Tensor]:
"""Generate grid points of multiple feature levels.
Args:
featmap_sizes (list[tuple[int, int]]): List of feature map sizes in
multiple feature levels, each (H, W).
dtype (torch.dtype): Dtype of priors. Defaults to torch.float32.
device (torch.device): The device where the anchors will be put on.
Defaults to torch.device("cuda").
with_stride (bool): Whether to concatenate the stride to the last
dimension of points. Defaults to False,
Return:
list[torch.Tensor]: Points of multiple feature levels.
The sizes of each tensor should be (N, 2) when with stride is
``False``, where N = width * height, width and height
are the sizes of the corresponding feature level,
and the last dimension 2 represent (coord_x, coord_y),
otherwise the shape should be (N, 4),
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""
assert self.num_levels == len(featmap_sizes)
multi_level_priors = []
for i in range(self.num_levels):
priors = self.single_level_grid_priors(
featmap_sizes[i],
level_idx=i,
dtype=dtype,
device=device,
with_stride=with_stride,
)
multi_level_priors.append(priors)
return multi_level_priors
[docs]
def single_level_grid_priors(
self,
featmap_size: tuple[int, int],
level_idx: int,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cuda"),
with_stride: bool = False,
) -> torch.Tensor:
"""Generate grid Points of a single level.
Note:
This function is usually called by method ``self.grid_priors``.
Args:
featmap_size (tuple[int, int]): Size of the feature maps, (H, W).
level_idx (int): The index of corresponding feature map level.
dtype (torch.dtype): Dtype of priors. Defaults to torch.float32.
device (torch.device): The device where the tensors will be put on.
Defaults to torch.device("cuda").
with_stride (bool): Concatenate the stride to the last dimension
of points. Defaults to False,
Return:
Tensor: Points of single feature levels.
The shape of tensor should be (N, 2) when with stride is
``False``, where N = width * height, width and height
are the sizes of the corresponding feature level,
and the last dimension 2 represent (coord_x, coord_y),
otherwise the shape should be (N, 4),
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
shift_x = (
torch.arange(0, feat_w, device=device) + self.offset
) * stride_w
# keep featmap_size as Tensor instead of int, so that we
# can convert to ONNX correctly
shift_x = shift_x.to(dtype)
shift_y = (
torch.arange(0, feat_h, device=device) + self.offset
) * stride_h
# keep featmap_size as Tensor instead of int, so that we
# can convert to ONNX correctly
shift_y = shift_y.to(dtype)
shift_xx, shift_yy = meshgrid(shift_x, shift_y)
if not with_stride:
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
else:
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(
dtype
)
stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(
dtype
)
shifts = torch.stack(
[shift_xx, shift_yy, stride_w, stride_h], dim=-1
)
all_points = shifts.to(device)
return all_points
[docs]
def valid_flags(
self,
featmap_sizes: list[tuple[int, int]],
pad_shape: tuple[int, int],
device: torch.device = torch.device("cuda"),
) -> list[torch.Tensor]:
"""Generate valid flags of points of multiple feature levels.
Args:
featmap_sizes (list[tuple[int, int]]): List of feature map sizes in
multiple feature levels, each (H, W).
pad_shape (tuple[int, int]): The padded shape of the image, (H, W).
device (torch.device): The device where the anchors will be put on.
Defaults to torch.device("cuda").
Return:
list(torch.Tensor): Valid flags of points of multiple levels.
"""
assert self.num_levels == len(featmap_sizes)
multi_level_flags = []
for i in range(self.num_levels):
point_stride = self.strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = pad_shape[:2]
valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
flags = self.single_level_valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device
)
multi_level_flags.append(flags)
return multi_level_flags
[docs]
def single_level_valid_flags(
self,
featmap_size: tuple[int, int],
valid_size: tuple[int, int],
device: torch.device = torch.device("cuda"),
) -> torch.Tensor:
"""Generate the valid flags of points of a single feature map.
Args:
featmap_size (tuple[int, int]): The size of feature maps, (H, W).
valid_size (tuple[int, int]): The valid size of the feature maps,
(H, W).
device (torch.device, optional): The device where the flags will
be put on. Defaults to torch.device("cuda").
Returns:
torch.Tensor: The valid flags of each points in a single level
feature map.
"""
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
return valid