"""Attention layer."""
from __future__ import annotations
from torch import Tensor, nn
from vis4d.common.logging import rank_zero_warn
from vis4d.common.typing import ArgsType
[docs]
class Attention(nn.Module):
"""ViT Attention Layer.
Modified from timm (https://github.com/huggingface/pytorch-image-models).
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
"""Init attention layer.
Args:
dim (int): Input tensor's dimension.
num_heads (int, optional): Number of attention heads. Defaults to
8.
qkv_bias (bool, optional): If to add bias to qkv. Defaults to
False.
attn_drop (float, optional): Dropout rate for attention. Defaults
to 0.0.
proj_drop (float, optional): Dropout rate for projection. Defaults
to 0.0.
"""
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
[docs]
def __call__(self, data: Tensor) -> Tensor:
"""Applies the layer.
Args:
data (Tensor): Input tensor of shape (B, N, dim).
Returns:
Tensor: Output tensor of the same shape as input.
"""
return self._call_impl(data)
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
batch_size, num_samples, dim = x.shape
qkv = (
self.qkv(x)
.reshape(
batch_size,
num_samples,
3,
self.num_heads,
dim // self.num_heads,
)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(
0
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(batch_size, num_samples, dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]
class MultiheadAttention(nn.Module):
"""A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with identity connection,
and positional encoding is also passed as input.
"""
def __init__(
self,
embed_dims: int,
num_heads: int,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
dropout_layer: nn.Module | None = None,
batch_first: bool = False,
**kwargs: ArgsType,
) -> None:
"""Init MultiheadAttention.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (nn.Module | None, optional): The dropout_layer used
when adding the shortcut. Defaults to None.
batch_first (bool): When it is True, Key, Query and Value are
shape of (batch, n, embed_dim), otherwise (n, batch,
embed_dim). Default to False.
"""
super().__init__()
self.batch_first = batch_first
self.embed_dims = embed_dims
self.attn = nn.MultiheadAttention(
embed_dims, num_heads, dropout=attn_drop, **kwargs
)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = dropout_layer or nn.Identity()
[docs]
def forward(
self,
query: Tensor,
key: Tensor | None = None,
value: Tensor | None = None,
identity: Tensor | None = None,
query_pos: Tensor | None = None,
key_pos: Tensor | None = None,
attn_mask: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor:
"""Forward function for `MultiheadAttention`.
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (Tensor): The input query with shape [num_queries, bs,
embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
identity (Tensor): This tensor, with the same shape as x,
will be used for the identity link.
If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will
be added to `x` before forward function. Defaults to None.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Defaults to None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`. Defaults to None.
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Defaults to None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims]
if self.batch_first is False, else [bs, num_queries,
embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None and query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
rank_zero_warn(
"position encoding of key is"
+ f"missing in {self.__class__.__name__}."
)
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query, batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
out = self.attn(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
)[0]
if self.batch_first:
out = out.transpose(0, 1)
return identity + self.dropout_layer(self.proj_drop(out))