vis4d.model.cls.vit

ViT for classification tasks.

Classes

ViTClassifer([variant, num_classes, ...])

ViT for classification tasks.

class ViTClassifer(variant='', num_classes=1000, use_global_pooling=False, weights=None, num_prefix_tokens=1, **kwargs)[source]

ViT for classification tasks.

Initialize the classification ViT.

Parameters:
  • variant (str) – Name of the ViT variant. Defaults to “”. If the name starts with “timm://”, the variant will be loaded from timm’s model zoo. Otherwise, the variant will be loaded from the ViT_PRESET dict. If the variant is empty, the default ViT variant will be used. In all cases, the additional keyword arguments will override the default arguments.

  • num_classes (int, optional) – Number of classes. Defaults to 1000.

  • use_global_pooling (bool, optional) – If to use global pooling. Defaults to False. If set to True, the output of the ViT will be averaged over the spatial dimensions. Otherwise, the first token will be used for classification.

  • weights (str, optional) – If to load pretrained weights. If set to “timm”, the weights will be loaded from timm’s model zoo that matches the variant. If a URL is provided, the weights will be downloaded from the URL. Defaults to None, which means no weights will be loaded.

  • num_prefix_tokens (int, optional) – Number of prefix tokens. Defaults to 1.

  • **kwargs (Any) – Keyword arguments passed to the ViT model.

forward(images)[source]

Forward pass.

Return type:

ClsOut