Source code for vis4d.op.detect3d.bevformer.encoder
"""BEVFormer Encoder."""
from __future__ import annotations
from collections.abc import Sequence
import torch
from torch import Tensor, nn
from vis4d.op.geometry.transform import inverse_rigid_transform
from vis4d.op.layer.transformer import FFN, get_clones
from .spatial_cross_attention import SpatialCrossAttention
from .temporal_self_attention import TemporalSelfAttention
[docs]
class BEVFormerEncoder(nn.Module):
"""Attention with both self and cross attention."""
def __init__(
self,
num_layers: int = 6,
layer: BEVFormerEncoderLayer | None = None,
embed_dims: int = 256,
num_points_in_pillar: int = 4,
point_cloud_range: Sequence[float] = (
-51.2,
-51.2,
-5.0,
51.2,
51.2,
3.0,
),
return_intermediate: bool = False,
) -> None:
"""Init.
Args:
num_layers (int): Number of layers in the encoder.
layer (BEVFormerEncoderLayer, optional): Encoder layer. Defaults to
None. If None, a default layer will be used.
embed_dims (int): Embedding dimension.
num_points_in_pillar (int): Number of points in each pillar.
point_cloud_range (Sequence[float]): Range of the point cloud.
Defaults to (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0).
return_intermediate (bool): Whether to return intermediate outputs.
"""
super().__init__()
self.num_layers = num_layers
self.embed_dims = embed_dims
self.num_points_in_pillar = num_points_in_pillar
self.pc_range = point_cloud_range
self.return_intermediate = return_intermediate
layer = layer or BEVFormerEncoderLayer(embed_dims=embed_dims)
self.layers = get_clones(layer, num=self.num_layers)
self.eps = 1e-5
[docs]
def get_reference_points(
self,
bev_h: int,
bev_w: int,
dim: int,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> Tensor:
"""Get the reference points used in SCA and TSA.
Args:
bev_h (int): Height of the BEV feature map.
bev_w (int): Width of the BEV feature map.
dim (int): Dimension of the reference points.
batch_size (int): Batch size.
device (torch.device): The device where reference_points should be.
dtype (torch.dtype): The dtype of reference_points.
Returns:
Tensor: reference points used in decoder, has shape (batch_size,
num_keys, num_levels, dim).
"""
assert dim in {2, 3}, f"Unknown dim {dim}."
# Reference points in 3D space for spatial cross-attention (SCA)
if dim == 3:
height_z = self.pc_range[5] - self.pc_range[2]
zs = (
torch.linspace(
0.5,
height_z - 0.5,
self.num_points_in_pillar,
dtype=dtype,
device=device,
)
.view(-1, 1, 1)
.expand(self.num_points_in_pillar, bev_h, bev_w)
/ height_z
)
xs = (
torch.linspace(
0.5, bev_w - 0.5, bev_w, dtype=dtype, device=device
)
.view(1, 1, bev_w)
.expand(self.num_points_in_pillar, bev_h, bev_w)
/ bev_w
)
ys = (
torch.linspace(
0.5, bev_h - 0.5, bev_h, dtype=dtype, device=device
)
.view(1, bev_h, 1)
.expand(self.num_points_in_pillar, bev_h, bev_w)
/ bev_h
)
ref_3d = torch.stack((xs, ys, zs), -1)
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
ref_3d = ref_3d[None].repeat(batch_size, 1, 1, 1)
return ref_3d
# Reference points on 2D bev plane for temporal self-attention (TSA)
ref_y, ref_x = torch.meshgrid(
torch.linspace(
0.5, bev_h - 0.5, bev_h, dtype=dtype, device=device
),
torch.linspace(
0.5, bev_w - 0.5, bev_w, dtype=dtype, device=device
),
indexing="ij",
)
ref_y = ref_y.reshape(-1)[None] / bev_h
ref_x = ref_x.reshape(-1)[None] / bev_w
ref_2d = torch.stack((ref_x, ref_y), -1)
ref_2d = ref_2d.repeat(batch_size, 1, 1).unsqueeze(2)
return ref_2d
[docs]
def point_sampling(
self,
reference_points: Tensor,
images_hw: tuple[int, int],
cam_intrinsics: list[Tensor],
cam_extrinsics: list[Tensor],
lidar_extrinsics: Tensor,
) -> tuple[Tensor, Tensor]:
"""Sample points from reference points."""
lidar2img_list = []
for i, _cam_intrinsics in enumerate(cam_intrinsics):
viewpad = torch.eye(4, device=_cam_intrinsics.device)
viewpad[:3, :3] = _cam_intrinsics
lidar2img = (
viewpad
[docs]
@ inverse_rigid_transform(cam_extrinsics[i])
@ lidar_extrinsics
)
lidar2img_list.append(lidar2img)
lidar2img = torch.stack(lidar2img_list, dim=1) # (B, N, 4, 4)
reference_points = reference_points.clone()
reference_points[..., 0:1] = (
reference_points[..., 0:1] * (self.pc_range[3] - self.pc_range[0])
+ self.pc_range[0]
)
reference_points[..., 1:2] = (
reference_points[..., 1:2] * (self.pc_range[4] - self.pc_range[1])
+ self.pc_range[1]
)
reference_points[..., 2:3] = (
reference_points[..., 2:3] * (self.pc_range[5] - self.pc_range[2])
+ self.pc_range[2]
)
reference_points = torch.cat(
(reference_points, torch.ones_like(reference_points[..., :1])), -1
)
reference_points = reference_points.permute(1, 0, 2, 3)
d, b, num_query, _ = reference_points.shape
num_cam = lidar2img.size(1)
reference_points = (
reference_points.view(d, b, 1, num_query, 4)
.repeat(1, 1, num_cam, 1, 1)
.unsqueeze(-1)
)
lidar2img = lidar2img.view(1, b, num_cam, 1, 4, 4).repeat(
d, 1, 1, num_query, 1, 1
)
reference_points_cam = torch.matmul(
lidar2img, reference_points
).squeeze(-1)
bev_mask = reference_points_cam[..., 2:3] > self.eps
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3],
torch.mul(
torch.ones_like(reference_points_cam[..., 2:3]), self.eps
),
)
reference_points_cam[..., 0] /= images_hw[1]
reference_points_cam[..., 1] /= images_hw[0]
bev_mask = (
bev_mask
& (reference_points_cam[..., 1:2] > 0.0)
& (reference_points_cam[..., 1:2] < 1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 0:1] > 0.0)
)
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
return reference_points_cam, bev_mask
def forward(
self,
bev_query: Tensor,
value: Tensor,
bev_h: int,
bev_w: int,
bev_pos: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
prev_bev: Tensor | None,
shift: Tensor,
images_hw: tuple[int, int],
cam_intrinsics: list[Tensor],
cam_extrinsics: list[Tensor],
lidar_extrinsics: Tensor,
) -> Tensor:
"""Forward.
Args:
bev_query (Tensor): Input BEV query with shape (num_query,
batch_size, embed_dims).
value (Tensor): Input multi-cameta features with shape (num_cam,
num_value, batch_size, embed_dims).
bev_h (int): BEV height.
bev_w (int): BEV width.
bev_pos (Tensor): BEV positional encoding with shape (batch_size,
embed_dims).
spatial_shapes (Tensor): Spatial shapes of multi-level
features with shape (num_levels, 2).
level_start_index (Tensor): Start index of each level with shape
(num_levels, ).
prev_bev (Tensor | None): Previous BEV features with shape
(batch_size, embed_dims).
shift (Tensor): Shift of each level with shape (num_levels, 2).
images_hw (tuple[int, int]): List of image height and width.
cam_intrinsics (list[Tensor]): List of camera intrinsics. In shape
(num_cam, batch_size, 3, 3)
cam_extrinsics (list[Tensor]): List of camera extrinsics. In shape
(num_cam, batch_size, 4, 4)
lidar_extrinsics (Tensor): LiDAR extrinsics. In shape (batch_size,
4, 4)
Returns:
Tensor: Results with shape [batch_size, num_query, embed_dims]
when return_intermediate is False, otherwise it has shape
[num_layers, batch_size, num_query, embed_dims].
"""
intermediate = []
ref_3d = self.get_reference_points(
bev_h,
bev_w,
dim=3,
batch_size=bev_query.size(1),
device=bev_query.device,
dtype=bev_query.dtype,
)
ref_2d = self.get_reference_points(
bev_h,
bev_w,
dim=2,
batch_size=bev_query.size(1),
device=bev_query.device,
dtype=bev_query.dtype,
)
reference_points_img, bev_mask = self.point_sampling(
ref_3d,
images_hw,
cam_intrinsics,
cam_extrinsics,
lidar_extrinsics,
)
shift_ref_2d = ref_2d.clone()
shift_ref_2d += shift[:, None, None, :]
bev_query = bev_query.permute(1, 0, 2)
bev_pos = bev_pos.permute(1, 0, 2)
batch_size, len_bev, num_bev_level, _ = ref_2d.shape
if prev_bev is not None:
prev_bev = prev_bev.permute(1, 0, 2)
prev_bev = torch.stack([prev_bev, bev_query], 1).reshape(
batch_size * 2, len_bev, -1
)
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
batch_size * 2, len_bev, num_bev_level, 2
)
else:
hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
batch_size * 2, len_bev, num_bev_level, 2
)
for _, layer in enumerate(self.layers):
output = layer(
bev_query,
value,
bev_pos=bev_pos,
ref_2d=hybird_ref_2d,
bev_h=bev_h,
bev_w=bev_w,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
reference_points_img=reference_points_img,
bev_mask=bev_mask,
prev_bev=prev_bev,
)
bev_query = output
if self.return_intermediate:
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output
[docs]
class BEVFormerEncoderLayer(nn.Module):
"""BEVFormer encoder layer."""
def __init__(
self,
embed_dims: int = 256,
self_attn: TemporalSelfAttention | None = None,
cross_attn: SpatialCrossAttention | None = None,
feedforward_channels: int = 512,
drop_out: float = 0.1,
) -> None:
"""Init."""
super().__init__()
self.attentions = nn.ModuleList()
self_attn = self_attn or TemporalSelfAttention(
embed_dims=embed_dims, num_levels=1
)
self.attentions.append(self_attn)
cross_attn = cross_attn or SpatialCrossAttention(embed_dims=embed_dims)
self.attentions.append(cross_attn)
self.embed_dims = embed_dims
self.ffns = nn.ModuleList()
self.ffns.append(
FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
dropout=drop_out,
)
)
self.norms = nn.ModuleList()
for _ in range(3):
self.norms.append(nn.LayerNorm(self.embed_dims))
[docs]
def forward(
self,
query: Tensor,
value: Tensor,
bev_pos: Tensor,
ref_2d: Tensor,
bev_h: int,
bev_w: int,
spatial_shapes: Tensor,
level_start_index: Tensor,
reference_points_img: Tensor,
bev_mask: Tensor,
prev_bev: Tensor | None = None,
) -> Tensor:
"""Forward function.
self_attn -> norm -> cross_attn -> norm -> ffn -> norm
Returns:
Tensor: forwarded results with shape [num_queries, batch_size,
embed_dims].
"""
# Temporal self attention
query = self.attentions[0](
query,
ref_2d,
prev_bev,
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
level_start_index=torch.tensor([0], device=query.device),
query_pos=bev_pos,
)
query = self.norms[0](query)
# Spaital cross attention
query = self.attentions[1](
query,
reference_points_img,
value,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
bev_mask=bev_mask,
)
query = self.norms[1](query)
# FFN
query = self.ffns[0](query)
query = self.norms[2](query)
return query