Source code for vis4d.op.layer.conv2d

"""Wrapper for conv2d."""

from __future__ import annotations

from typing import NamedTuple

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from vis4d.common.typing import ArgsType

from .weight_init import constant_init


[docs] class Conv2d(nn.Conv2d): """Wrapper around Conv2d to support empty inputs and norm/activation.""" def __init__( self, *args: ArgsType, norm: nn.Module | None = None, activation: nn.Module | None = None, **kwargs: ArgsType, ) -> None: """Creates an instance of the class. If norm is specified, it is initialized with 1.0 and bias with 0.0. """ super().__init__(*args, **kwargs) self.norm = norm self.activation = activation if self.norm is not None: constant_init(self.norm, 1.0, bias=0.0)
[docs] def forward( # pylint: disable=arguments-renamed self, x: Tensor ) -> Tensor: """Forward pass.""" if not torch.jit.is_scripting(): # type: ignore # https://github.com/pytorch/pytorch/issues/12013 if ( x.numel() == 0 and self.training and isinstance(self.norm, nn.SyncBatchNorm) ): raise ValueError( "SyncBatchNorm does not support empty inputs!" ) x = F.conv2d( # pylint: disable=not-callable x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x
[docs] def add_conv_branch( num_branch_convs: int, last_layer_dim: int, conv_out_dim: int, conv_has_bias: bool, norm_cfg: str | None, num_groups: int | None, ) -> tuple[nn.ModuleList, int]: """Init conv branch for head.""" convs = nn.ModuleList() if norm_cfg is not None: norm = getattr(nn, norm_cfg) else: norm = None if norm == nn.GroupNorm: assert num_groups is not None, "num_groups must be specified" norm = lambda x: nn.GroupNorm( # pylint: disable=unnecessary-lambda-assignment num_groups, x ) if num_branch_convs > 0: for i in range(num_branch_convs): conv_in_dim = last_layer_dim if i == 0 else conv_out_dim convs.append( Conv2d( conv_in_dim, conv_out_dim, kernel_size=3, padding=1, bias=conv_has_bias, norm=norm(conv_out_dim) if norm is not None else norm, activation=nn.ReLU(inplace=True), ) ) last_layer_dim = conv_out_dim return convs, last_layer_dim
[docs] class UnetDownConvOut(NamedTuple): """Output of the UnetDownConv operator. features: Features before applying the pooling operator pooled_features: Features after applying the pooling operator """ features: Tensor pooled_features: Tensor
[docs] class UnetDownConv(nn.Module): """Downsamples a feature map by applying two convolutions and maxpool.""" def __init__( self, in_channels: int, out_channels: int, pooling: bool = True, activation: str = "ReLU", ): """Creates a new downsampling convolution operator. This operator consists of two convolutions followed by a maxpool operator. Args: in_channels (int): input channesl out_channels (int): output channesl pooling (bool): If pooling should be applied activation (str): Activation that should be applied """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.pooling = pooling activation = getattr(nn, activation)() self.conv1 = nn.Conv2d( self.in_channels, self.out_channels, kernel_size=3, padding=1, stride=1, bias=True, ) self.conv2 = nn.Conv2d( self.out_channels, self.out_channels, kernel_size=3, padding=1, stride=1, bias=True, ) if self.pooling: self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
[docs] def __call__(self, data: Tensor) -> UnetDownConvOut: """Applies the operator. Args: data (Tensor): Input data. Returns: UnetDownConvOut: Containing the features before the pooling operation (features) and after (pooled_features). """ return self._call_impl(data)
[docs] def forward(self, data: Tensor) -> UnetDownConvOut: """Applies the operator. Args: data (Tensor): Input data. Returns: UnetDownConvOut: containing the features before the pooling operation (features) and after (pooled_features). """ x = F.relu(self.conv1(data)) x = F.relu(self.conv2(x)) before_pool = x if self.pooling: x = self.pool(x) return UnetDownConvOut(features=before_pool, pooled_features=x)
[docs] class UnetUpConv(nn.Module): """An operator that performs 2 convolutions and 1 UpConvolution. A ReLU activation follows each convolution. """ def __init__( self, in_channels: int, out_channels: int, merge_mode: str = "concat", up_mode: str = "transpose", ): """Creates a new UpConv operator. This operator merges two inputs by upsampling one and combining it with the other. Args: in_channels: Number of input channels (low res) out_channels: Number of output channels (high res) merge_mode: How to merge both input channels up_mode: How to upsample the channel with lower resolution Raises: ValueError: If upsampling mode is unknown """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.merge_mode = merge_mode self.up_mode = up_mode # Upsampling if self.up_mode == "transpose": self.upconv: nn.Module = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=2, stride=2 ) elif self.up_mode == "upsample": self.upconv = nn.Sequential( nn.Upsample(mode="bilinear", scale_factor=2), nn.Conv2d(in_channels, out_channels, kernel_size=1), ) else: raise ValueError(f"Unknown upsampling mode: {up_mode}") if self.merge_mode == "concat": self.conv1 = nn.Conv2d( 2 * self.out_channels, self.out_channels, 3, padding=1 ) else: # num of input channels to conv2 is same self.conv1 = nn.Conv2d( self.out_channels, self.out_channels, 3, padding=1 ) self.conv2 = nn.Conv2d( self.out_channels, self.out_channels, 3, padding=1 )
[docs] def __call__(self, from_down: Tensor, from_up: Tensor) -> Tensor: """Forward pass. Arguments: from_down (Tensor): Tensor from the encoder pathway. Assumed to have dimension 'out_channels' from_up (Tensor): Upconv'd tensor from the decoder pathway. Assumed to have dimension 'in_channels' """ return self._call_impl(from_down, from_up)
[docs] def forward(self, from_down: Tensor, from_up: Tensor) -> Tensor: """Forward pass. Arguments: from_down (Tensor): Tensor from the encoder pathway. Assumed to have dimension 'out_channels' from_up (Tensor): Upconv'd tensor from the decoder pathway. Assumed to have dimension 'in_channels' """ from_up = self.upconv(from_up) if self.merge_mode == "concat": x = torch.cat((from_up, from_down), 1) else: x = from_up + from_down x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) return x