Source code for vis4d.op.fpp.fpn
"""Feature Pyramid Network.
This is based on `"Feature Pyramid Network for Object Detection"
<https://arxiv.org/abs/1612.03144>`_.
"""
from __future__ import annotations
from collections import OrderedDict
import torch.nn.functional as F
from torch import Tensor, nn
from torchvision.ops import FeaturePyramidNetwork as _FPN
from torchvision.ops.feature_pyramid_network import (
ExtraFPNBlock as _ExtraFPNBlock,
)
from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
from .base import FeaturePyramidProcessing
[docs]
class FPN(_FPN, FeaturePyramidProcessing): # type: ignore
"""Feature Pyramid Network.
This is a wrapper of the torchvision implementation.
"""
def __init__(
self,
in_channels_list: list[int],
out_channels: int,
extra_blocks: _ExtraFPNBlock | None = LastLevelMaxPool(),
start_index: int = 2,
) -> None:
"""Init without additional components.
Args:
in_channels_list (list[int]): List of input channels.
out_channels (int): Output channels.
extra_blocks (_ExtraFPNBlock, optional): Extra block. Defaults to
LastLevelMaxPool().
start_index (int, optional): Start index of base model feature
maps. Defaults to 2.
"""
super().__init__(
in_channels_list, out_channels, extra_blocks=extra_blocks
)
self.start_index = start_index
[docs]
def forward(self, x: list[Tensor]) -> list[Tensor]:
"""Process the input features with FPN.
Because by default, FPN doesn't upsample the first two feature maps in
the pyramid, we keep the first two feature maps intact.
Args:
x (list[Tensor]): Feature pyramid as outputs of the
base model.
Returns:
list[Tensor]: Feature pyramid after FPN processing.
"""
feat_dict = OrderedDict(
(k, v)
for k, v in zip(
[str(i) for i in range(len(x) - self.start_index)],
x[self.start_index :],
)
)
outs = super().forward(feat_dict) # type: ignore
return [*x[: self.start_index], *outs.values()] # type: ignore
[docs]
def __call__(self, x: list[Tensor]) -> list[Tensor]:
"""Type definition for call implementation."""
return self._call_impl(x)
[docs]
class ExtraFPNBlock(_ExtraFPNBlock): # type: ignore
"""Extra block in the FPN.
This is a wrapper of the torchvision implementation.
"""
def __init__(
self,
extra_levels: int,
in_channels: int,
out_channels: int,
add_extra_convs: str = "on_output",
extra_relu: bool = False,
) -> None:
"""Create an instance of the class."""
super().__init__()
self.extra_levels = extra_levels
self.add_extra_convs = add_extra_convs
self.extra_relu = extra_relu
self.convs = nn.ModuleList()
if extra_levels >= 1:
for i in range(extra_levels):
if i == 0 and self.add_extra_convs == "on_input":
_in_channels = in_channels
else:
_in_channels = out_channels
extra_fpn_conv = nn.Conv2d(
_in_channels,
out_channels,
3,
stride=2,
padding=1,
)
self.convs.append(extra_fpn_conv)
[docs]
def forward(
self, results: list[Tensor], x: list[Tensor], names: list[str]
) -> tuple[list[Tensor], list[str]]:
"""Forward."""
if self.add_extra_convs == "on_input":
extra_source = x[-1]
elif self.add_extra_convs == "on_output":
extra_source = results[-1]
else:
raise NotImplementedError
results.append(self.convs[0](extra_source))
names.append(str(int(names[-1]) + 1))
for i in range(1, self.extra_levels):
if self.extra_relu:
results.append(self.convs[i](F.relu(results[-1])))
else:
results.append(self.convs[i](results[-1]))
names.append(str(int(names[-1]) + 1))
return results, names