vis4d.model.cls

Common classes and functions for classification models.

class ClsOut(logits: Tensor, probs: Tensor)[source]

Output of the classification results.

Create new instance of ClsOut(logits, probs)

logits: Tensor

Alias for field number 0

probs: Tensor

Alias for field number 1

class ViTClassifer(variant='', num_classes=1000, use_global_pooling=False, weights=None, num_prefix_tokens=1, **kwargs)[source]

ViT for classification tasks.

Initialize the classification ViT.

Parameters:
  • variant (str) – Name of the ViT variant. Defaults to “”. If the name starts with “timm://”, the variant will be loaded from timm’s model zoo. Otherwise, the variant will be loaded from the ViT_PRESET dict. If the variant is empty, the default ViT variant will be used. In all cases, the additional keyword arguments will override the default arguments.

  • num_classes (int, optional) – Number of classes. Defaults to 1000.

  • use_global_pooling (bool, optional) – If to use global pooling. Defaults to False. If set to True, the output of the ViT will be averaged over the spatial dimensions. Otherwise, the first token will be used for classification.

  • weights (str, optional) – If to load pretrained weights. If set to “timm”, the weights will be loaded from timm’s model zoo that matches the variant. If a URL is provided, the weights will be downloaded from the URL. Defaults to None, which means no weights will be loaded.

  • num_prefix_tokens (int, optional) – Number of prefix tokens. Defaults to 1.

  • **kwargs (Any) – Keyword arguments passed to the ViT model.

forward(images)[source]

Forward pass.

Return type:

ClsOut

Modules

vis4d.model.cls.common

Common types for classification models.

vis4d.model.cls.vit

ViT for classification tasks.