"""ViT for classification tasks."""from__future__importannotationsimporttimm.models.vision_transformeras_vision_transformerimporttorchfromtorchimportnnfromvis4d.commonimportArgsTypefromvis4d.common.ckptimportload_model_checkpointfromvis4d.op.base.vitimportVisionTransformer,ViT_PRESETfrom.commonimportClsOut
[docs]classViTClassifer(nn.Module):"""ViT for classification tasks."""def__init__(self,variant:str="",num_classes:int=1000,use_global_pooling:bool=False,weights:str|None=None,num_prefix_tokens:int=1,**kwargs:ArgsType,)->None:"""Initialize the classification ViT. Args: variant (str): Name of the ViT variant. Defaults to "". If the name starts with "timm://", the variant will be loaded from timm's model zoo. Otherwise, the variant will be loaded from the ViT_PRESET dict. If the variant is empty, the default ViT variant will be used. In all cases, the additional keyword arguments will override the default arguments. num_classes (int, optional): Number of classes. Defaults to 1000. use_global_pooling (bool, optional): If to use global pooling. Defaults to False. If set to True, the output of the ViT will be averaged over the spatial dimensions. Otherwise, the first token will be used for classification. weights (str, optional): If to load pretrained weights. If set to "timm", the weights will be loaded from timm's model zoo that matches the variant. If a URL is provided, the weights will be downloaded from the URL. Defaults to None, which means no weights will be loaded. num_prefix_tokens (int, optional): Number of prefix tokens. Defaults to 1. **kwargs: Keyword arguments passed to the ViT model. """super().__init__()self.num_classes=num_classesself.use_global_pooling=use_global_poolingself.num_prefix_tokens=num_prefix_tokensifvariant!="":assertvariantinViT_PRESET,(f"Unknown ViT variant: {variant}. "f"Available ViT variants are: {list(ViT_PRESET.keys())}")preset_kwargs=ViT_PRESET[variant]preset_kwargs["num_classes"]=num_classespreset_kwargs.update(kwargs)self.vit=VisionTransformer(**preset_kwargs)# type: ignoreelse:# Build ViT from scratch using kwargspreset_kwargs={}self.vit=VisionTransformer(num_classes=num_classes,**kwargs)# Classification headembed_dim=kwargs.get("embed_dim",preset_kwargs.get("embed_dim",768))self.norm=(nn.LayerNorm(embed_dim)ifuse_global_poolingelsenn.Identity())self.head=(nn.Linear(embed_dim,num_classes)ifnum_classes>0elsenn.Identity())# Load pretrain weightsifweightsisnotNone:ifweights.startswith("timm://"):weights=weights.removeprefix("timm://")if"."inweights:model_name,pretrain_tag=weights.split(".")else:model_name=weightspretrain_tag=Noneassertmodel_namein_vision_transformer.__dict__,(f"Unknown Timm ViT weights: {model_name}. "f"Available Timm ViT weights are: "f"{list(_vision_transformer.__dict__.keys())}")_model=_vision_transformer.__dict__[model_name](pretrained=True,pretrained_cfg=pretrain_tag,**kwargs)self.vit.load_state_dict(_model.state_dict(),strict=False)self.norm.load_state_dict(_model.norm.state_dict(),strict=False)self.head.load_state_dict(_model.head.state_dict(),strict=False)else:load_model_checkpoint(self,weights)