Source code for vis4d.op.layer.attention

"""Attention layer."""

from __future__ import annotations

from torch import Tensor, nn

from vis4d.common.logging import rank_zero_warn
from vis4d.common.typing import ArgsType

[docs] class Attention(nn.Module): """ViT Attention Layer. Modified from timm ( """ def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: """Init attention layer. Args: dim (int): Input tensor's dimension. num_heads (int, optional): Number of attention heads. Defaults to 8. qkv_bias (bool, optional): If to add bias to qkv. Defaults to False. attn_drop (float, optional): Dropout rate for attention. Defaults to 0.0. proj_drop (float, optional): Dropout rate for projection. Defaults to 0.0. """ super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop)
[docs] def __call__(self, data: Tensor) -> Tensor: """Applies the layer. Args: data (Tensor): Input tensor of shape (B, N, dim). Returns: Tensor: Output tensor of the same shape as input. """ return self._call_impl(data)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass.""" batch_size, num_samples, dim = x.shape qkv = ( self.qkv(x) .reshape( batch_size, num_samples, 3, self.num_heads, dim // self.num_heads, ) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind( 0 ) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(batch_size, num_samples, dim) x = self.proj(x) x = self.proj_drop(x) return x
[docs] class MultiheadAttention(nn.Module): """A wrapper for ``torch.nn.MultiheadAttention``. This module implements MultiheadAttention with identity connection, and positional encoding is also passed as input. """ def __init__( self, embed_dims: int, num_heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0, dropout_layer: nn.Module | None = None, batch_first: bool = False, **kwargs: ArgsType, ) -> None: """Init MultiheadAttention. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. dropout_layer (nn.Module | None, optional): The dropout_layer used when adding the shortcut. Defaults to None. batch_first (bool): When it is True, Key, Query and Value are shape of (batch, n, embed_dim), otherwise (n, batch, embed_dim). Default to False. """ super().__init__() self.batch_first = batch_first self.embed_dims = embed_dims self.attn = nn.MultiheadAttention( embed_dims, num_heads, dropout=attn_drop, **kwargs ) self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = dropout_layer or nn.Identity()
[docs] def forward( self, query: Tensor, key: Tensor | None = None, value: Tensor | None = None, identity: Tensor | None = None, query_pos: Tensor | None = None, key_pos: Tensor | None = None, attn_mask: Tensor | None = None, key_padding_mask: Tensor | None = None, ) -> Tensor: """Forward function for `MultiheadAttention`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . If None, the ``query`` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries, embed_dims]. """ if key is None: key = query if value is None: value = key if identity is None: identity = query if key_pos is None and query_pos is not None: # use query_pos if key_pos is not available if query_pos.shape == key.shape: key_pos = query_pos else: rank_zero_warn( "position encoding of key is" + f"missing in {self.__class__.__name__}." ) if query_pos is not None: query = query + query_pos if key_pos is not None: key = key + key_pos # Because the dataflow('key', 'query', 'value') of # ``torch.nn.MultiheadAttention`` is (num_query, batch, # embed_dims), We should adjust the shape of dataflow from # batch_first (batch, num_query, embed_dims) to num_query_first # (num_query, batch, embed_dims), and recover ``attn_output`` # from num_query_first to batch_first. if self.batch_first: query = query.transpose(0, 1) key = key.transpose(0, 1) value = value.transpose(0, 1) out = self.attn( query=query, key=key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask, )[0] if self.batch_first: out = out.transpose(0, 1) return identity + self.dropout_layer(self.proj_drop(out))