"""FCN Head for semantic segmentation."""from__future__importannotationsfromtypingimportNamedTupleimporttorchimporttorch.nn.functionalasFfromtorchimportnn
[docs]classFCNOut(NamedTuple):"""Output of the FCN prediction."""pred:torch.Tensor# logits for final prediction, (N, C, H, W)outputs:list[torch.Tensor]# transformed feature maps
[docs]classFCNHead(nn.Module):"""FCN Head made with ResNet base model. This is based on the implementation in `torchvision <https://github.com/pytorch/vision/blob/torchvision/models/segmentation/ fcn.py>`_. """def__init__(self,in_channels:list[int],out_channels:int,dropout_prob:float=0.1,resize:tuple[int,int]|None=None,)->None:"""Creates an instance of the class. Args: in_channels (list[int]): Number of channels in multi-level image feature. out_channels (int): Number of output channels. Usually the number of classes. dropout_prob (float, optional): Dropout probability. Defaults to 0.1. resize (tuple(int,int), optional): Target shape to resize output. Defaults to None. """super().__init__()self.in_channels=in_channelsself.out_channels=out_channelsself.resize=resizeself.heads=nn.ModuleList()forin_channelinself.in_channels:self.heads.append(self._make_head(in_channel,self.out_channels,dropout_prob))def_make_head(self,in_channels:int,channels:int,dropout_prob:float)->nn.Module:"""Generate FCN segmentation head. Args: in_channels (int): Input feature channels. channels (int): Output segmentation channels. dropout_prob (float): Dropout probability. Returns: nn.Module: FCN segmentation head. """inter_channels=in_channels//4layers=[nn.Conv2d(in_channels,inter_channels,kernel_size=3,padding=1,bias=False,),nn.BatchNorm2d(inter_channels),nn.ReLU(),nn.Dropout(dropout_prob),nn.Conv2d(inter_channels,channels,kernel_size=1),]returnnn.Sequential(*layers)
[docs]defforward(self,feats:list[torch.Tensor])->FCNOut:"""Transforms feature maps and returns segmentation prediction. Args: feats (list[torch.Tensor]): List of multi-level image features. Returns: output (list[torch.Tensor]): Each tensor has shape (batch_size, self.channels, H, W) which is prediction for each FCN stages. E.g., outputs[-1] ==> main output map outputs[-2] ==> aux output map (e.g., used for training) outputs[:-2] ==> x[:-2] """outputs=feats.copy()num_features=len(feats)foriinrange(len(self.in_channels)):idx=num_features-len(self.in_channels)+ifeat=feats[idx]output=self.heads[i](feat)ifself.resize:output=F.interpolate(output,size=self.resize,mode="bilinear",align_corners=False,)outputs[idx]=F.log_softmax(output,dim=1)returnFCNOut(pred=outputs[-1],outputs=outputs)
[docs]def__call__(self,feats:list[torch.Tensor])->FCNOut:"""Type definition for function call."""returnsuper()._call_impl(feats)