Source code for vis4d.op.layer.ms_deform_attn

# pylint: disable=no-name-in-module, abstract-method, arguments-differ
"""Multi-Scale Deformable Attention Module.

Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py) # pylint: disable=line-too-long
"""
from __future__ import annotations

import math

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.init import constant_, xavier_uniform_

from vis4d.common.imports import VIS4D_CUDA_OPS_AVAILABLE
from vis4d.common.logging import rank_zero_warn

if VIS4D_CUDA_OPS_AVAILABLE:
    from vis4d_cuda_ops import ms_deform_attn_backward, ms_deform_attn_forward
else:
    raise ImportError("vis4d_cuda_ops is not installed.")


[docs] class MSDeformAttentionFunction(Function): # pragma: no cover """Multi-Scale Deformable Attention Function module."""
[docs] @staticmethod def forward( # type: ignore ctx, value: Tensor, value_spatial_shapes: Tensor, value_level_start_index: Tensor, sampling_locations: Tensor, attention_weights: Tensor, im2col_step: int, ) -> Tensor: """Forward pass.""" if not VIS4D_CUDA_OPS_AVAILABLE: raise RuntimeError( "MSDeformAttentionFunction requires vis4d cuda ops to run." ) ctx.im2col_step = im2col_step output = ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step, ) ctx.save_for_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ) return output
[docs] @staticmethod @once_differentiable # type: ignore def backward( ctx, grad_output: Tensor ) -> tuple[Tensor, None, None, Tensor, Tensor, None]: """Backward pass.""" if not VIS4D_CUDA_OPS_AVAILABLE: raise RuntimeError( "MSDeformAttentionFunction requires vis4d cuda ops to run." ) ( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ) = ctx.saved_tensors ( grad_value, grad_sampling_loc, grad_attn_weight, ) = ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step, ) return ( grad_value, None, None, grad_sampling_loc, grad_attn_weight, None, )
[docs] def ms_deformable_attention_cpu( value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor, ) -> Tensor: """CPU version of multi-scale deformable attention. Args: value (Tensor): The value has shape (bs, num_keys, mum_heads, embed_dims // num_heads) value_spatial_shapes (Tensor): Spatial shape of each feature map, has shape (num_levels, 2), last dimension 2 represent (h, w). sampling_locations (Tensor): The location of sampling points, has shape (bs ,num_queries, num_heads, num_levels, num_points, 2), the last dimension 2 represent (x, y). attention_weights (Tensor): The weight of sampling points used when calculate the attention, has shape (bs ,num_queries, num_heads, num_levels, num_points), Returns: Tensor: has shape (bs, num_queries, embed_dims). """ bs, _, num_heads, embed_dims = value.shape ( _, num_queries, num_heads, num_levels, num_points, _, ) = sampling_locations.shape value_list = value.split([h * w for h, w in value_spatial_shapes], dim=1) sampling_grids: Tensor = 2 * sampling_locations - 1 sampling_value_list = [] for level, (h, w) in enumerate(value_spatial_shapes): # bs, h*w, num_heads, embed_dims -> # bs, h*w, num_heads*embed_dims -> # bs, num_heads*embed_dims, h*w -> # bs*num_heads, embed_dims, h, w value_l_ = ( value_list[level] .flatten(2) .transpose(1, 2) .reshape(bs * num_heads, embed_dims, h, w) ) # bs, num_queries, num_heads, num_points, 2 -> # bs, num_heads, num_queries, num_points, 2 -> # bs*num_heads, num_queries, num_points, 2 sampling_grid_l_ = ( sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) ) # bs*num_heads, embed_dims, num_queries, num_points sampling_value_l_ = F.grid_sample( value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False, ) sampling_value_list.append(sampling_value_l_) # (bs, num_queries, num_heads, num_levels, num_points) -> # (bs, num_heads, num_queries, num_levels, num_points) -> # (bs, num_heads, 1, num_queries, num_levels*num_points) attention_weights = attention_weights.transpose(1, 2).reshape( bs * num_heads, 1, num_queries, num_levels * num_points ) output = ( ( torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights ) .sum(-1) .view(bs, num_heads * embed_dims, num_queries) ) return output.transpose(1, 2).contiguous()
[docs] def is_power_of_2(number: int) -> None: """Check if a number is a power of 2.""" if (not isinstance(number, int)) or (number < 0): raise ValueError( f"invalid input for is_power_of_2: {number} (type: {type(number)})" ) if not ((number & (number - 1) == 0) and number != 0): rank_zero_warn( "You'd better set hidden dimensions in MultiScaleDeformAttention" "to make the dimension of each attention head a power of 2, " "which is more efficient in our CUDA implementation." )
[docs] class MSDeformAttention(nn.Module): """Multi-Scale Deformable Attention Module.""" def __init__( self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4, im2col_step: int = 64, ) -> None: """Creates an instance of the class. Args: d_model (int): Hidden dimensions. n_levels (int): Number of feature levels. n_heads (int): Number of attention heads. n_points (int): Number of sampling points per attention head per feature level. im2col_step (int): The step used in image_to_column. Default: 64. """ super().__init__() if d_model % n_heads != 0: raise ValueError( "d_model must be divisible by n_heads, but got " + f"{d_model} and {n_heads}." ) is_power_of_2(d_model // n_heads) self.d_model = d_model self.embed_dims = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.im2col_step = im2col_step self.sampling_offsets = nn.Linear( d_model, n_heads * n_levels * n_points * 2 ) self.attention_weights = nn.Linear( d_model, n_heads * n_levels * n_points ) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self) -> None: """Reset parameters.""" constant_(self.sampling_offsets.weight.data, 0.0) thetas = torch.mul( torch.arange(self.n_heads, dtype=torch.float32), (2.0 * math.pi / self.n_heads), ) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) .view(self.n_heads, 1, 1, 2) .repeat(1, self.n_levels, self.n_points, 1) ) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.0) constant_(self.attention_weights.bias.data, 0.0) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.0) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.0)
[docs] def forward( self, query: Tensor, reference_points: Tensor, input_flatten: Tensor, input_spatial_shapes: Tensor, input_level_start_index: Tensor, input_padding_mask: Tensor | None = None, ) -> Tensor: r"""Forward function. Args: query (Tensor): (n, length_{query}, C). reference_points (Tensor): (n, length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (n, length_{query}, n_levels, 4), add additional (w, h) to form reference boxes. input_flatten (Tensor): (n, \sum_{l=0}^{L-1} H_l \cdot W_l, C). input_spatial_shapes (Tensor): (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] input_level_start_index (Tensor): (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] input_padding_mask (Tensor): (n, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements. Retrun output (Tensor): (n, length_{query}, C). """ n, len_q, _ = query.shape n, len_in, _ = input_flatten.shape assert ( input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] ).sum() == len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view( n, len_in, self.n_heads, self.d_model // self.n_heads ) sampling_offsets = self.sampling_offsets(query).view( n, len_q, self.n_heads, self.n_levels, self.n_points, 2 ) attention_weights = self.attention_weights(query).view( n, len_q, self.n_heads, self.n_levels * self.n_points ) attention_weights = F.softmax(attention_weights, -1).view( n, len_q, self.n_heads, self.n_levels, self.n_points ) # n, len_q, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack( [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1, ) sampling_locations = ( reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] ) elif reference_points.shape[-1] == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) else: raise ValueError( "Last dim of reference_points must be 2 or 4, " + f"but get {reference_points.shape[-1]} instead." ) if torch.cuda.is_available() and value.is_cuda: output = MSDeformAttentionFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step, ) else: output = ms_deformable_attention_cpu( value, input_spatial_shapes, sampling_locations, attention_weights, ) output = self.output_proj(output) return output