Source code for vis4d.op.seg.fcn

"""FCN Head for semantic segmentation."""

from __future__ import annotations

from typing import NamedTuple

import torch
import torch.nn.functional as F
from torch import nn


[docs] class FCNOut(NamedTuple): """Output of the FCN prediction.""" pred: torch.Tensor # logits for final prediction, (N, C, H, W) outputs: list[torch.Tensor] # transformed feature maps
[docs] class FCNHead(nn.Module): """FCN Head made with ResNet base model. This is based on the implementation in `torchvision <https://github.com/pytorch/vision/blob/torchvision/models/segmentation/ fcn.py>`_. """ def __init__( self, in_channels: list[int], out_channels: int, dropout_prob: float = 0.1, resize: tuple[int, int] | None = None, ) -> None: """Creates an instance of the class. Args: in_channels (list[int]): Number of channels in multi-level image feature. out_channels (int): Number of output channels. Usually the number of classes. dropout_prob (float, optional): Dropout probability. Defaults to 0.1. resize (tuple(int,int), optional): Target shape to resize output. Defaults to None. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.resize = resize self.heads = nn.ModuleList() for in_channel in self.in_channels: self.heads.append( self._make_head(in_channel, self.out_channels, dropout_prob) ) def _make_head( self, in_channels: int, channels: int, dropout_prob: float ) -> nn.Module: """Generate FCN segmentation head. Args: in_channels (int): Input feature channels. channels (int): Output segmentation channels. dropout_prob (float): Dropout probability. Returns: nn.Module: FCN segmentation head. """ inter_channels = in_channels // 4 layers = [ nn.Conv2d( in_channels, inter_channels, kernel_size=3, padding=1, bias=False, ), nn.BatchNorm2d(inter_channels), nn.ReLU(), nn.Dropout(dropout_prob), nn.Conv2d(inter_channels, channels, kernel_size=1), ] return nn.Sequential(*layers)
[docs] def forward(self, feats: list[torch.Tensor]) -> FCNOut: """Transforms feature maps and returns segmentation prediction. Args: feats (list[torch.Tensor]): List of multi-level image features. Returns: output (list[torch.Tensor]): Each tensor has shape (batch_size, self.channels, H, W) which is prediction for each FCN stages. E.g., outputs[-1] ==> main output map outputs[-2] ==> aux output map (e.g., used for training) outputs[:-2] ==> x[:-2] """ outputs = feats.copy() num_features = len(feats) for i in range(len(self.in_channels)): idx = num_features - len(self.in_channels) + i feat = feats[idx] output = self.heads[i](feat) if self.resize: output = F.interpolate( output, size=self.resize, mode="bilinear", align_corners=False, ) outputs[idx] = F.log_softmax(output, dim=1) return FCNOut(pred=outputs[-1], outputs=outputs)
[docs] def __call__(self, feats: list[torch.Tensor]) -> FCNOut: """Type definition for function call.""" return super()._call_impl(feats)