Source code for vis4d.op.detect3d.bevformer.temporal_self_attention

"""An attention module used in BEVFormer based on Deformable-Detr."""

from __future__ import annotations

import math

import torch
from torch import Tensor, nn

from vis4d.op.layer.ms_deform_attn import (
    MSDeformAttentionFunction,
    is_power_of_2,
    ms_deformable_attention_cpu,
)
from vis4d.op.layer.weight_init import constant_init, xavier_init


[docs] class TemporalSelfAttention(nn.Module): """Temperal Self Attention.""" def __init__( self, embed_dims: int = 256, num_heads: int = 8, num_levels: int = 4, num_points: int = 4, num_bev_queue: int = 2, im2col_step: int = 64, dropout: float = 0.1, batch_first: bool = True, ) -> None: """Init. Args: embed_dims (int): The embedding dimension of Attention. Default: 256. num_heads (int): Parallel attention heads. Default: 64. num_levels (int): The number of feature map used in Attention. Default: 4. num_points (int): The number of sampling points for each query in each head. Default: 4. num_bev_queue (int): In this version, we only use one history BEV and one currenct BEV. The length of BEV queue is 2. im2col_step (int): The step used in image_to_column. Default: 64. dropout (float): A Dropout layer on `inp_identity`. Default: 0.1. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default to True. """ super().__init__() if embed_dims % num_heads != 0: raise ValueError( f"embed_dims must be divisible by num_heads, " f"but got {embed_dims} and {num_heads}" ) is_power_of_2(embed_dims // num_heads) self.dropout = nn.Dropout(dropout) self.batch_first = batch_first self.im2col_step = im2col_step self.embed_dims = embed_dims self.num_levels = num_levels self.num_heads = num_heads self.num_points = num_points self.num_bev_queue = num_bev_queue self.sampling_offsets = nn.Linear( embed_dims * self.num_bev_queue, num_bev_queue * num_heads * num_levels * num_points * 2, ) self.attention_weights = nn.Linear( embed_dims * self.num_bev_queue, num_bev_queue * num_heads * num_levels * num_points, ) self.value_proj = nn.Linear(embed_dims, embed_dims) self.output_proj = nn.Linear(embed_dims, embed_dims) self.init_weights()
[docs] def init_weights(self) -> None: """Default initialization for Parameters of Module.""" constant_init(self.sampling_offsets, 0.0) thetas = torch.mul( torch.arange(self.num_heads, dtype=torch.float32), (2.0 * math.pi / self.num_heads), ) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) .view(self.num_heads, 1, 1, 2) .repeat( 1, self.num_levels * self.num_bev_queue, self.num_points, 1 ) ) for i in range(self.num_points): grid_init[:, :, i, :] *= i + 1 self.sampling_offsets.bias.data = grid_init.view(-1) constant_init(self.attention_weights, val=0.0, bias=0.0) xavier_init(self.value_proj, distribution="uniform", bias=0.0) xavier_init(self.output_proj, distribution="uniform", bias=0.0)
[docs] def forward( self, query: Tensor, reference_points: Tensor, value: Tensor | None, spatial_shapes: Tensor, level_start_index: Tensor, key_padding_mask: Tensor | None = None, identity: Tensor | None = None, query_pos: Tensor | None = None, ) -> Tensor: """Forward Function of MultiScaleDeformAttention. Args: query (Tensor): Query of Transformer with shape (num_query, bs, embed_dims). reference_points (Tensor): The normalized reference points with shape (bs, num_query, num_levels, 2), all elements is range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area. or (N, Length_{query}, num_levels, 4), add additional two dimensions is (w, h) to form reference boxes. value (Tensor): The value tensor with shape (num_key, bs, embed_dims). spatial_shapes (Tensor): Spatial shape of features in different levels. With shape (num_levels, 2), last dimension represents (h, w). level_start_index (Tensor): The start index of each level. A tensor has shape ``(num_levels, )`` and can be represented as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. key_padding_mask (Tensor): ByteTensor for value, with shape [bs, num_key]. identity (Tensor): The tensor used for addition, with the same shape as query. Default None. If None, query will be used. query_pos (Tensor, optional): The positional encoding for query. Default: None. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ if value is None: assert self.batch_first bs, len_bev, c = query.shape value = torch.stack([query, query], 1).reshape(bs * 2, len_bev, c) if identity is None: identity = query if query_pos is not None: query = query + query_pos if not self.batch_first: # change to (bs, num_query ,embed_dims) query = query.permute(1, 0, 2) value = value.permute(1, 0, 2) bs, num_query, embed_dims = query.shape _, num_value, _ = value.shape assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value assert self.num_bev_queue == 2 query = torch.cat([value[:bs], query], -1) value = self.value_proj(value) assert isinstance(value, Tensor) if key_padding_mask is not None: value = value.masked_fill(key_padding_mask[..., None], 0.0) value = value.reshape( bs * self.num_bev_queue, num_value, self.num_heads, -1 ) sampling_offsets = self.sampling_offsets(query) sampling_offsets = sampling_offsets.view( bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, 2, ) attention_weights = self.attention_weights(query).view( bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels * self.num_points, ) attention_weights = attention_weights.softmax(-1) attention_weights = attention_weights.view( bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, ) attention_weights = ( attention_weights.permute(0, 3, 1, 2, 4, 5) .reshape( bs * self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, ) .contiguous() ) sampling_offsets = sampling_offsets.permute( 0, 3, 1, 2, 4, 5, 6 ).reshape( bs * self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2, ) if reference_points.shape[-1] == 2: offset_normalizer = torch.stack( [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1 ) sampling_locations = ( reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] ) elif reference_points.shape[-1] == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) else: raise ValueError( f"Last dim of reference_points must be" f" 2 or 4, but get {reference_points.shape[-1]} instead." ) if torch.cuda.is_available() and value.is_cuda: output = MSDeformAttentionFunction.apply( value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step, ) else: output = ms_deformable_attention_cpu( value, spatial_shapes, sampling_locations, attention_weights, ) # output shape (bs*num_bev_queue, num_query, embed_dims) # (bs*num_bev_queue, num_query, embed_dims) # -> (num_query, embed_dims, bs*num_bev_queue) output = output.permute(1, 2, 0) # fuse history value and current value # (num_query, embed_dims, bs*num_bev_queue) # -> (num_query, embed_dims, bs, num_bev_queue) output = output.view(num_query, embed_dims, bs, self.num_bev_queue) output = output.mean(-1) # (num_query, embed_dims, bs)-> (bs, num_query, embed_dims) output = output.permute(2, 0, 1) output = self.output_proj(output) if not self.batch_first: output = output.permute(1, 0, 2) return self.dropout(output) + identity