Source code for vis4d.op.base.vit

"""Residual networks for classification."""

from __future__ import annotations

import torch
from timm.models.helpers import named_apply
from torch import nn

from ..layer import PatchEmbed, TransformerBlock
from .base import BaseModel


def _init_weights_vit_timm(  # pylint: disable=unused-argument
    module: nn.Module, name: str
) -> None:
    """Weight initialization, original timm impl (for reproducibility)."""
    if isinstance(module, nn.Linear):
        nn.init.trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, "init_weights"):
        module.init_weights()


ViT_PRESET = {  # pylint: disable=consider-using-namedtuple-or-dataclass
    "vit_tiny_patch16_224": {
        "patch_size": 16,
        "embed_dim": 192,
        "depth": 12,
        "num_heads": 3,
    },
    "vit_small_patch16_224": {
        "patch_size": 16,
        "embed_dim": 384,
        "depth": 12,
        "num_heads": 6,
    },
    "vit_base_patch16_224": {
        "patch_size": 16,
        "embed_dim": 768,
        "depth": 12,
        "num_heads": 12,
    },
    "vit_large_patch16_224": {
        "patch_size": 16,
        "embed_dim": 1024,
        "depth": 24,
        "num_heads": 16,
    },
    "vit_huge_patch16_224": {
        "patch_size": 16,
        "embed_dim": 1280,
        "depth": 32,
        "num_heads": 16,
    },
    "vit_small_patch32_224": {
        "patch_size": 32,
        "embed_dim": 384,
        "depth": 12,
        "num_heads": 6,
    },
    "vit_base_patch32_224": {
        "patch_size": 32,
        "embed_dim": 768,
        "depth": 12,
        "num_heads": 12,
    },
    "vit_large_patch32_224": {
        "patch_size": 32,
        "embed_dim": 1024,
        "depth": 24,
        "num_heads": 16,
    },
    "vit_huge_patch32_224": {
        "patch_size": 32,
        "embed_dim": 1280,
        "depth": 32,
        "num_heads": 16,
    },
}


[docs] class VisionTransformer(BaseModel): """Vision Transformer (ViT) model without classification head. A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 Adapted from: - pytorch vision transformer impl - timm vision transformer impl """ def __init__( self, img_size: int = 224, patch_size: int = 16, in_channels: int = 3, num_classes: int = 1000, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, init_values: float | None = None, class_token: bool = True, no_embed_class: bool = False, pre_norm: bool = False, pos_drop_rate: float = 0.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: nn.Module | None = None, act_layer: nn.Module = nn.GELU(), ) -> None: """Init VisionTransformer. Args: img_size (int, optional): Input image size. Defaults to 224. patch_size (int, optional): Patch size. Defaults to 16. in_channels (int, optional): Number of input channels. Defaults to 3. num_classes (int, optional): Number of classes. Defaults to 1000. embed_dim (int, optional): Embedding dimension. Defaults to 768. depth (int, optional): Depth. Defaults to 12. num_heads (int, optional): Number of attention heads. Defaults to 12. mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim. Defaults to 4.0. qkv_bias (bool, optional): If to add bias to qkv. Defaults to True. init_values (float, optional): Initial values for layer scale. Defaults to None. class_token (bool, optional): If to add a class token. Defaults to True. no_embed_class (bool, optional): If to not embed class token. Defaults to False. pre_norm (bool, optional): If to use pre-norm. Defaults to False. pos_drop_rate (float, optional): Postional dropout rate. Defaults to 0.0. drop_rate (float, optional): Dropout rate. Defaults to 0.0. attn_drop_rate (float, optional): Attention dropout rate. Defaults to 0.0. drop_path_rate (float, optional): Drop path rate. Defaults to 0.0. embed_layer (nn.Module, optional): Embedding layer. Defaults to PatchEmbed. norm_layer (nn.Module, optional): Normalization layer. If None, nn.LayerNorm is used. Defaults to None. act_layer (nn.Module, optional): Activation layer. Defaults to nn.GELU(). """ super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = ( embed_dim # num_features for consistency with other models ) self.num_depth = depth self.num_prefix_tokens = 1 if class_token else 0 self.no_embed_class = no_embed_class self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) ) num_patches = self.patch_embed.num_patches self.cls_token = ( nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None ) embed_len = ( num_patches if no_embed_class else num_patches + self.num_prefix_tokens ) self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, embed_dim)) self.pos_drop = nn.Dropout(p=pos_drop_rate) self.norm_pre = ( nn.LayerNorm(embed_dim, eps=1e-6) if pre_norm else nn.Identity() ) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule blocks = [ TransformerBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ) for i in range(depth) ] self.blocks = nn.ModuleList(blocks) self.init_weights()
[docs] def init_weights(self) -> None: """Init weights using timm's implementation.""" nn.init.trunc_normal_(self.pos_embed, std=0.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(_init_weights_vit_timm, self)
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: """Add positional embeddings.""" if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then # concat x = x + self.pos_embed if self.cls_token is not None: x = torch.cat( (self.cls_token.expand(x.shape[0], -1, -1), x), dim=1 ) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if self.cls_token is not None: x = torch.cat( (self.cls_token.expand(x.shape[0], -1, -1), x), dim=1 ) x = x + self.pos_embed return self.pos_drop(x) @property def out_channels(self) -> list[int]: """Return the number of output channels per feature level.""" return [self.embed_dim] * (self.num_depth + 1)
[docs] def __call__(self, data: torch.Tensor) -> list[torch.Tensor]: """Applies the ViT encoder. Args: data (tensor): Input Images into the network shape [N, C, W, H] """ return self._call_impl(data)
[docs] def forward(self, images: torch.Tensor) -> list[torch.Tensor]: """Forward pass. Args: images (torch.Tensor): Input images tensor of shape (B, C, H, W). Returns: feats (list[torch.Tensor]): Features of the input images extracted by the ViT encoder. feats[0] is the input images, and feats[1] is the output of the patch embedding layer. The rest of the elements are the outputs of each transformer block, with the shape (B, N, dim), where N is the number of patches, and dim is the embedding dimension. The final element is the output of the ViT encoder. """ feats = [images] x = self.patch_embed(images) x = self.norm_pre(self._pos_embed(x)) feats.append(x) for blk in self.blocks: x = blk(x) feats.append(x) return feats