Source code for vis4d.op.fpp.yolox_pafpn

"""YOLOX PAFPN.

Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
"""

from __future__ import annotations

import math

import torch
from torch import nn

from vis4d.op.layer import Conv2d, CSPLayer

from .base import FeaturePyramidProcessing


[docs] class YOLOXPAFPN(FeaturePyramidProcessing): """Path Aggregation Network used in YOLOX. Args: in_channels (list[int]): Number of input channels per scale. out_channels (int): Number of output channels (used at each scale). num_csp_blocks (int, optional): Number of bottlenecks in CSPLayer. Defaults to 3. start_index (int, optional): Index of the first input feature map. Defaults to 2. """ def __init__( self, in_channels: list[int], out_channels: int, num_csp_blocks: int = 3, start_index: int = 2, ): """Init.""" super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.start_index = start_index # build top-down blocks self.upsample = nn.Upsample(scale_factor=2, mode="nearest") self.reduce_layers = nn.ModuleList() self.top_down_blocks = nn.ModuleList() for idx in range(len(in_channels) - 1, 0, -1): self.reduce_layers.append( Conv2d( in_channels[idx], in_channels[idx - 1], 1, bias=False, norm=nn.BatchNorm2d( in_channels[idx - 1], eps=0.001, momentum=0.03 ), activation=nn.SiLU(inplace=True), ) ) self.top_down_blocks.append( CSPLayer( in_channels[idx - 1] * 2, in_channels[idx - 1], num_blocks=num_csp_blocks, add_identity=False, ) ) # build bottom-up blocks self.downsamples = nn.ModuleList() self.bottom_up_blocks = nn.ModuleList() for idx in range(len(in_channels) - 1): self.downsamples.append( Conv2d( in_channels[idx], in_channels[idx], 3, stride=2, padding=1, bias=False, norm=nn.BatchNorm2d( in_channels[idx], eps=0.001, momentum=0.03 ), activation=nn.SiLU(inplace=True), ) ) self.bottom_up_blocks.append( CSPLayer( in_channels[idx] * 2, in_channels[idx + 1], num_blocks=num_csp_blocks, add_identity=False, ) ) self.out_convs = nn.ModuleList() for _, inc in enumerate(in_channels): self.out_convs.append( Conv2d( inc, out_channels, 1, bias=False, norm=nn.BatchNorm2d( out_channels, eps=0.001, momentum=0.03 ), activation=nn.SiLU(inplace=True), ) ) self._init_weights() def _init_weights(self) -> None: """Initialize weights.""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_uniform_( m.weight, a=math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu", )
[docs] def forward(self, features: list[torch.Tensor]) -> list[torch.Tensor]: """Forward pass. Args: features (tuple[Tensor]): Input features. Returns: list[Tensor]: YOLOXPAFPN features. """ images, features = ( features[: self.start_index], features[self.start_index :], ) assert len(features) == len(self.in_channels) # top-down path inner_outs = [features[-1]] for idx in range(len(self.in_channels) - 1, 0, -1): feat_heigh = inner_outs[0] feat_low = features[idx - 1] feat_heigh = self.reduce_layers[len(self.in_channels) - 1 - idx]( feat_heigh ) inner_outs[0] = feat_heigh upsample_feat = self.upsample(feat_heigh) inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx]( torch.cat([upsample_feat, feat_low], 1) ) inner_outs.insert(0, inner_out) # bottom-up path outs = [inner_outs[0]] for idx in range(len(self.in_channels) - 1): feat_low = outs[-1] feat_height = inner_outs[idx + 1] downsample_feat = self.downsamples[idx](feat_low) out = self.bottom_up_blocks[idx]( torch.cat([downsample_feat, feat_height], 1) ) outs.append(out) # out convs for idx, conv in enumerate(self.out_convs): outs[idx] = conv(outs[idx]) return images + outs
[docs] def __call__(self, features: list[torch.Tensor]) -> list[torch.Tensor]: """Type definition for call implementation.""" return self._call_impl(features)