[docs]classFCNResNet(nn.Module):"""FCN with ResNet basemodel for semantic segmentation."""def__init__(self,base_model:str="resnet50",num_classes:int=21,resize:None|tuple[int,int]=(520,520),)->None:"""FCN with ResNet basemodel, following torchvision implementation. <https://github.com/pytorch/vision/blob/main/torchvision/models/ segmentation/fcn.py>_. model: FCNResNet(base_model="resnet50") - dataset: Coco2017 - recipe: vis4d/model/segment/FCNResNet_coco_training.py - metrics: - mIoU: 62.52 - Acc: 90.50 """super().__init__()ifbase_model.startswith("resnet"):self.basemodel=ResNet(base_model,pretrained=True,replace_stride_with_dilation=[False,True,True],)else:raiseValueError("base model not supported!")self.fcn=FCNHead(self.basemodel.out_channels[4:],num_classes,resize=resize)
[docs]defforward_train(self,images:torch.Tensor)->FCNOut:"""Forward pass for training. Args: images (torch.Tensor): Input images. Returns: FCNOut: Raw model predictions. """returnself.forward(images)
[docs]defforward_test(self,images:torch.Tensor)->FCNOut:"""Forward pass for testing. Args: images (torch.Tensor): Input images. Returns: FCNOut: Raw model predictions. """returnself.forward(images)
[docs]defforward(self,images:torch.Tensor)->FCNOut:"""Forward pass. Args: images (torch.Tensor): Input images. Returns: FCNOut: Raw model predictions. """features=self.basemodel(images)out=self.fcn(features)returnout