Source code for vis4d.op.detect3d.bevformer.transformer

"""BEVFormer transformer."""

from __future__ import annotations

import numpy as np
import torch
from torch import Tensor, nn
from torchvision.transforms.functional import rotate

from vis4d.op.layer.weight_init import xavier_init

from .decoder import BEVFormerDecoder
from .encoder import BEVFormerEncoder


[docs] class PerceptionTransformer(nn.Module): """Perception Transformer.""" def __init__( self, num_cams: int = 6, encoder: BEVFormerEncoder | None = None, decoder: BEVFormerDecoder | None = None, embed_dims: int = 256, num_feature_levels: int = 4, rotate_center: tuple[int, int] = (100, 100), ) -> None: """Init.""" super().__init__() self.num_cams = num_cams self.embed_dims = embed_dims self.num_feature_levels = num_feature_levels self.rotate_center = list(rotate_center) self.encoder = encoder or BEVFormerEncoder(embed_dims=self.embed_dims) self.decoder = decoder or BEVFormerDecoder(embed_dims=self.embed_dims) self._init_layers() self._init_weights() def _init_layers(self) -> None: """Initialize layers of the Detr3DTransformer.""" self.level_embeds = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims) ) self.cams_embeds = nn.Parameter( torch.Tensor(self.num_cams, self.embed_dims) ) self.reference_points = nn.Linear(self.embed_dims, 3) self.can_bus_mlp = nn.Sequential( nn.Linear(18, self.embed_dims // 2), nn.ReLU(inplace=True), nn.Linear(self.embed_dims // 2, self.embed_dims), nn.ReLU(inplace=True), ) self.can_bus_mlp.add_module("norm", nn.LayerNorm(self.embed_dims)) def _init_weights(self) -> None: """Initialize the transformer weights.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) nn.init.normal_(self.level_embeds) nn.init.normal_(self.cams_embeds) xavier_init(self.reference_points, distribution="uniform", bias=0.0) xavier_init(self.can_bus_mlp, distribution="uniform", bias=0.0)
[docs] def get_bev_features( self, mlvl_feats: list[Tensor], can_bus: Tensor, bev_queries: Tensor, bev_h: int, bev_w: int, images_hw: tuple[int, int], cam_intrinsics: list[Tensor], cam_extrinsics: list[Tensor], lidar_extrinsics: Tensor, grid_length: tuple[float, float], bev_pos: Tensor, prev_bev: Tensor | None = None, ) -> Tensor: """Obtain bev features.""" batch_size = mlvl_feats[0].shape[0] bev_queries = bev_queries.unsqueeze(1).repeat(1, batch_size, 1) bev_pos = bev_pos.flatten(2).permute(2, 0, 1) # obtain rotation angle and shift with ego motion delta_x = can_bus[:, 0].unsqueeze(1) delta_y = can_bus[:, 1].unsqueeze(1) ego_angle = can_bus[:, -2] / np.pi * 180 translation_length = torch.sqrt(delta_x**2 + delta_y**2) translation_angle = torch.arctan2(delta_y, delta_x) / np.pi * 180 bev_angle = ego_angle - translation_angle shift_y = ( translation_length * torch.cos(bev_angle / 180 * np.pi) / grid_length[0] / bev_h ) shift_x = ( translation_length * torch.sin(bev_angle / 180 * np.pi) / grid_length[1] / bev_w ) # B, xy shift = torch.cat([shift_x, shift_y], dim=1) if prev_bev is not None: if prev_bev.shape[1] == bev_h * bev_w: prev_bev = prev_bev.permute(1, 0, 2) # rotate prev_bev for i in range(batch_size): rotation_angle = float(can_bus[i][-1]) tmp_prev_bev = ( prev_bev[:, i].reshape(bev_h, bev_w, -1).permute(2, 0, 1) ) tmp_prev_bev = rotate( tmp_prev_bev, rotation_angle, center=self.rotate_center ) tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape( bev_h * bev_w, 1, -1 ) prev_bev[:, i] = tmp_prev_bev[:, 0] # add can bus signals bev_queries = bev_queries + self.can_bus_mlp(can_bus)[None, :, :] feat_flatten_list = [] spatial_shapes_list = [] for lvl, feat in enumerate(mlvl_feats): spatial_shape = feat.shape[-2:] feat = feat.flatten(3).permute(1, 0, 3, 2) # Add cams_embeds and level_embeds feat += self.cams_embeds[:, None, None, :].to(feat.dtype) feat += self.level_embeds[None, None, lvl : lvl + 1, :].to( feat.dtype ) spatial_shapes_list.append(spatial_shape) feat_flatten_list.append(feat) feat_flatten = torch.cat(feat_flatten_list, 2) spatial_shapes = torch.as_tensor( spatial_shapes_list, dtype=torch.long, device=bev_pos.device ) level_start_index = torch.cat( ( spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1], ) ) # (num_cam, H*W, bs, embed_dims) feat_flatten = feat_flatten.permute(0, 2, 1, 3) bev_embed = self.encoder( bev_queries, feat_flatten, bev_h=bev_h, bev_w=bev_w, bev_pos=bev_pos, spatial_shapes=spatial_shapes, level_start_index=level_start_index, prev_bev=prev_bev, shift=shift, images_hw=images_hw, cam_intrinsics=cam_intrinsics, cam_extrinsics=cam_extrinsics, lidar_extrinsics=lidar_extrinsics, ) return bev_embed
[docs] def forward( self, mlvl_feats: list[Tensor], can_bus: Tensor, bev_queries: Tensor, object_query_embed: Tensor, bev_h: int, bev_w: int, images_hw: tuple[int, int], cam_intrinsics: list[Tensor], cam_extrinsics: list[Tensor], lidar_extrinsics: Tensor, grid_length: tuple[float, float], bev_pos: Tensor, reg_branches: list[nn.Module], prev_bev: Tensor | None = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Forward function for BEVFormer transformer. Args: mlvl_feats (list(Tensor)): Input queries from different level. Each element has shape [bs, num_cams, embed_dims, h, w]. can_bus (Tensor): The can bus signals, has shape [bs, 18]. bev_queries (Tensor): (bev_h * bev_w, embed_dims). object_query_embed (Tensor): The query embedding for decoder, with shape [num_query, embed_dims * 2]. bev_h (int): The height of BEV feature map. bev_w (int): The width of BEV feature map. images_hw (tuple[int, int]): The height and width of images. cam_intrinsics (list[Tensor]): The camera intrinsics. cam_extrinsics (list[Tensor]): The camera extrinsics. lidar_extrinsics (Tensor): The lidar extrinsics. grid_length (tuple[float, float]): The length of grid in x and y direction. bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w) reg_branches (list[nn.Module]): Regression heads for feature maps from each decoder layer. prev_bev (Tensor, optional): The previous BEV feature map, has shape [bev_h * bev_w, bs, embed_dims]. Defaults to None. Returns: bev_embed (Tensor): BEV features has shape [bev_h *bev_w, bs, embed_dims]. inter_states: Outputs from decoder has shape [1, bs, num_query, embed_dims]. reference_points: As the initial reference has shape [bs, num_queries, 4]. inter_references: The internal value of reference points in the decoder, has shape [num_dec_layers, bs,num_query, embed_dims]. """ # bs, bev_h*bev_w, embed_dims bev_embed = self.get_bev_features( mlvl_feats, can_bus, bev_queries, bev_h, bev_w, images_hw=images_hw, cam_intrinsics=cam_intrinsics, cam_extrinsics=cam_extrinsics, lidar_extrinsics=lidar_extrinsics, grid_length=grid_length, bev_pos=bev_pos, prev_bev=prev_bev, ) bs = mlvl_feats[0].shape[0] query_pos, query = torch.split( object_query_embed, self.embed_dims, dim=1 ) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_pos) reference_points = reference_points.sigmoid() query = query.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) bev_embed = bev_embed.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, value=bev_embed, reference_points=reference_points, spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), level_start_index=torch.tensor([0], device=query.device), query_pos=query_pos, reg_branches=reg_branches, ) return bev_embed, inter_states, reference_points, inter_references