Source code for vis4d.op.layer.deform_conv

"""Wrapper for deformable convolution."""

from __future__ import annotations

import torch
from torch import Tensor, nn
from torchvision.ops import DeformConv2d

from .weight_init import constant_init


[docs] class DeformConv(DeformConv2d): # type: ignore """Wrapper around Deformable Convolution operator with norm/activation. If norm is specified, it is initialized with 1.0 and bias with 0.0. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, norm: nn.Module | None = None, activation: nn.Module | None = None, ) -> None: """Creates an instance of the class. Args: in_channels (int): Input channels. out_channels (int): Output channels. kernel_size (int): Size of convolutional kernel. stride (int, optional): Stride of convolutional layer. Defaults to 1. padding (int, optional): Padding of convolutional layer. Defaults to 0. dilation (int, optional): Dilation of convolutional layer. Defaults to 1. groups (int, optional): Number of deformable groups. Defaults to 1. bias (bool, optional): Whether to use bias in convolutional layer. Defaults to True. norm (nn.Module, optional): Normalization layer. Defaults to None. activation (nn.Module, optional): Activation layer. Defaults to None. """ super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) self.conv_offset = nn.Conv2d( self.in_channels, self.groups * 3 * self.kernel_size[0] * self.kernel_size[1], kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=True, ) self.norm = norm self.activation = activation self.init_weights()
[docs] def init_weights(self) -> None: """Initialize weights of offset conv layer.""" self.conv_offset.weight.data.zero_() self.conv_offset.bias.data.zero_() # type: ignore if self.norm is not None: constant_init(self.norm, 1.0, bias=0.0)
[docs] def forward( # pylint: disable=arguments-differ self, input_x: Tensor ) -> Tensor: """Forward.""" out = self.conv_offset(input_x) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) input_x = super().forward(input_x, offset, mask) if self.norm is not None: input_x = self.norm(input_x) if self.activation is not None: input_x = self.activation(input_x) return input_x