Source code for vis4d.op.detect3d.bevformer.spatial_cross_attention
"""Spatial Cross Attention Module for BEVFormer."""
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from vis4d.op.layer.ms_deform_attn import (
MSDeformAttentionFunction,
is_power_of_2,
ms_deformable_attention_cpu,
)
from vis4d.op.layer.weight_init import constant_init, xavier_init
[docs]
class SpatialCrossAttention(nn.Module):
"""An attention module used in BEVFormer."""
def __init__(
self,
embed_dims: int = 256,
num_cams: int = 6,
dropout: float = 0.1,
deformable_attention: MSDeformableAttention3D | None = None,
) -> None:
"""Init.
Args:
embed_dims (int): The embedding dimension of Attention. Default:
256.
num_cams (int): The number of cameras. Default: 6.
dropout (float): A Dropout layer on `inp_residual`. Default: 0.1.
deformable_attention (MSDeformableAttention3D, optional):
The deformable attention module. Default: None. If None,
we will use `MSDeformableAttention3D` with default
parameters.
"""
super().__init__()
self.dropout = nn.Dropout(dropout)
self.deformable_attention = (
deformable_attention or MSDeformableAttention3D()
)
self.embed_dims = embed_dims
self.num_cams = num_cams
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weight()
[docs]
def init_weight(self) -> None:
"""Default initialization for Parameters of Module."""
xavier_init(self.output_proj, distribution="uniform", bias=0.0)
[docs]
def forward(
self,
query: Tensor,
reference_points: Tensor,
value: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
bev_mask: Tensor,
query_pos: Tensor | None = None,
) -> Tensor:
"""Forward Function of Detr3DCrossAtten.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
reference_points (Tensor): The normalized reference points with
shape (bs, num_query, 4), all elements is range in [0, 1],
top-left (0,0), bottom-right (1, 1), including padding area.
Or (N, Length_{query}, num_levels, 4), add additional two
dimensions is (w, h) to form reference boxes.
value (Tensor): The value tensor with shape `(num_key, bs,
embed_dims)`. (B, N, C, H, W)
spatial_shapes (Tensor): Spatial shape of features in different
level. With shape (num_levels, 2), last dimension represent
(h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
bev_mask (Tensor): The mask of BEV features with shape
(num_query, bs, num_levels, h, w).
query_pos (Tensor): The positional encoding for `query`. Default
None.
Returns:
Tensor: Forwarded results with shape [num_query, bs, embed_dims].
"""
inp_residual = query
slots = torch.zeros_like(query)
if query_pos is not None:
query = query + query_pos
bs = query.shape[0]
d = reference_points.shape[3]
indexes = []
for i, mask_per_img in enumerate(bev_mask):
index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
indexes.append(index_query_per_img)
max_len = max(len(each) for each in indexes)
# Each camera only interacts with its corresponding BEV queries.
# This step can greatly save GPU memory.
queries_rebatch = query.new_zeros(
[bs, self.num_cams, max_len, self.embed_dims]
)
reference_points_rebatch = reference_points.new_zeros(
[bs, self.num_cams, max_len, d, 2]
)
for j in range(bs):
for i, _reference_points in enumerate(reference_points):
index_query_per_img = indexes[i]
queries_rebatch[j, i, : len(index_query_per_img)] = query[
j, index_query_per_img
]
reference_points_rebatch[j, i, : len(index_query_per_img)] = (
_reference_points[j, index_query_per_img]
)
_, l, bs, _ = value.shape
value = value.permute(2, 0, 1, 3).reshape(
bs * self.num_cams, l, self.embed_dims
)
queries = self.deformable_attention(
query=queries_rebatch.view(
bs * self.num_cams, max_len, self.embed_dims
),
reference_points=reference_points_rebatch.view(
bs * self.num_cams, max_len, d, 2
),
value=value,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
).view(bs, self.num_cams, max_len, self.embed_dims)
for j in range(bs):
for i, index_query_per_img in enumerate(indexes):
slots[j, index_query_per_img] += queries[
j, i, : len(index_query_per_img)
]
count = bev_mask.sum(-1) > 0
count = count.permute(1, 2, 0).sum(-1)
count = torch.clamp(count, min=1.0)
slots = slots / count[..., None]
slots = self.output_proj(slots)
return self.dropout(slots) + inp_residual
[docs]
class MSDeformableAttention3D(nn.Module):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
"""
def __init__(
self,
embed_dims: int = 256,
num_heads: int = 8,
num_levels: int = 4,
num_points: int = 8,
im2col_step: int = 64,
batch_first: bool = True,
) -> None:
"""Init.
Args:
embed_dims (int): The embedding dimension of Attention. Default:
256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for each query in
each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
batch_first (bool): Key, Query and Value are shape of (batch, n,
embed_dim) or (n, batch, embed_dim). Default to True.
"""
super().__init__()
if embed_dims % num_heads != 0:
raise ValueError(
f"embed_dims must be divisible by num_heads, "
f"but got {embed_dims} and {num_heads}"
)
self.batch_first = batch_first
is_power_of_2(embed_dims // num_heads)
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2
)
self.attention_weights = nn.Linear(
embed_dims, num_heads * num_levels * num_points
)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
[docs]
def init_weights(self) -> None:
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.0)
thetas = torch.mul(
torch.arange(self.num_heads, dtype=torch.float32),
(2.0 * math.pi / self.num_heads),
)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.num_heads, 1, 1, 2)
.repeat(1, self.num_levels, self.num_points, 1)
)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0.0, bias=0.0)
xavier_init(self.value_proj, distribution="uniform", bias=0.0)
[docs]
def forward( # pylint: disable=duplicate-code
self,
query: Tensor,
reference_points: Tensor,
value: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
key_padding_mask: Tensor | None = None,
query_pos: Tensor | None = None,
) -> Tensor:
"""Forward.
Args:
query (Tensor): Query of Transformer with shape (bs, num_query,
embed_dims).
reference_points (Tensor): The normalized reference points with
shape (bs, num_query, num_levels, 2), all elements is range in
[0, 1], top-left (0,0), bottom-right (1, 1), including padding
area. Or (N, Length_{query}, num_levels, 4), add additional two
dimensions is (w, h) to form reference boxes.
value (Tensor): The value tensor with shape `(bs, num_key,
embed_dims)`.
spatial_shapes (Tensor): Spatial shape of features in different
levels. With shape (num_levels, 2), last dimension represents
(h, w).
level_start_index (Tensor): The start index of each level. A tensor
has shape ``(num_levels, )`` and can be represented as [0,
h_0*w_0, h_0*w_0+h_1*w_1, ...].
key_padding_mask (Tensor): ByteTensor for value, with shape [bs,
num_key].
query_pos (Tensor): The positional encoding for `query`.
Default: None.
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2
)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points
)
attention_weights = attention_weights.softmax(-1)
# bs, num_query, num_heads, num_levels, num_all_points
attention_weights = attention_weights.view(
bs, num_query, self.num_heads, self.num_levels, self.num_points
)
# For each BEV query, it owns `num_z_anchors` in 3D space that
# having different heights. After proejcting, each BEV query has
# `num_z_anchors` reference points in each 2D image. For each
# referent point, we sample `num_points` sampling points.
# For `num_z_anchors` reference points, it has overall `num_points
# * num_z_anchors` sampling points.
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1
)
bs, num_query, num_z_anchors, xy = reference_points.shape
reference_points = reference_points[:, :, None, None, None, :, :]
sampling_offsets = (
sampling_offsets
/ offset_normalizer[None, None, None, :, None, :]
)
(
bs,
num_query,
num_heads,
num_levels,
num_all_points,
xy,
) = sampling_offsets.shape
sampling_offsets = sampling_offsets.view(
bs,
num_query,
num_heads,
num_levels,
num_all_points // num_z_anchors,
num_z_anchors,
xy,
)
sampling_locations = reference_points + sampling_offsets
(
bs,
num_query,
num_heads,
num_levels,
num_points,
num_z_anchors,
xy,
) = sampling_locations.shape
assert num_all_points == num_points * num_z_anchors
# bs, num_query, num_heads, num_levels, num_all_points, 2
sampling_locations = sampling_locations.view(
bs, num_query, num_heads, num_levels, num_all_points, xy
)
else:
raise ValueError(
"Last dim of reference_points must be 2 , but get "
+ f"{reference_points.shape[-1]} instead."
)
if torch.cuda.is_available() and value.is_cuda:
output = MSDeformAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
else:
output = ms_deformable_attention_cpu(
value, spatial_shapes, sampling_locations, attention_weights
)
if not self.batch_first:
output = output.permute(1, 0, 2)
return output