"""Residual networks base model.
Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
"""
from __future__ import annotations
from collections.abc import Sequence
import torchvision.models.resnet as _resnet
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.checkpoint import checkpoint
from vis4d.common.ckpt import load_model_checkpoint
from vis4d.common.typing import ArgsType
from vis4d.op.layer.util import build_conv_layer, build_norm_layer
from vis4d.op.layer.weight_init import constant_init, kaiming_init
from .base import BaseModel
[docs]
class BasicBlock(nn.Module):
"""BasicBlock."""
expansion = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
dilation: int = 1,
downsample: nn.Module | None = None,
style: str = "pytorch",
use_checkpoint: bool = False,
with_dcn: bool = False,
norm: str = "BatchNorm2d",
) -> None:
"""Creates an instance of the class."""
super().__init__()
assert style in {"pytorch", "caffe"} # No effect for BasicBlock
assert not with_dcn, "DCN is not supported for BasicBlock."
self.conv1 = build_conv_layer(
inplanes,
planes,
3,
stride=stride,
dilation=dilation,
padding=dilation,
bias=False,
)
self.bn1 = build_norm_layer(norm, planes)
self.conv2 = build_conv_layer(planes, planes, 3, padding=1, bias=False)
self.bn2 = build_norm_layer(norm, planes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.use_checkpoint = use_checkpoint
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Forward function."""
def _inner_forward(x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.use_checkpoint and x.requires_grad:
out = checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
[docs]
class Bottleneck(nn.Module):
"""Bottleneck."""
expansion = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
dilation: int = 1,
downsample: nn.Module | None = None,
style: str = "pytorch",
use_checkpoint: bool = False,
with_dcn: bool = False,
norm: str = "BatchNorm2d",
) -> None:
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super().__init__()
self.inplanes = inplanes
self.planes = planes
self.stride = stride
self.dilation = dilation
self.use_checkpoint = use_checkpoint
assert style in {"pytorch", "caffe"}
if style == "pytorch":
self.conv1_stride = 1
self.conv2_stride = stride
else:
self.conv1_stride = stride
self.conv2_stride = 1
self.conv1 = build_conv_layer(
inplanes,
planes,
kernel_size=1,
stride=self.conv1_stride,
bias=False,
)
self.bn1 = build_norm_layer(norm, planes)
self.conv2 = build_conv_layer(
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False,
use_dcn=with_dcn,
)
self.bn2 = build_norm_layer(norm, planes)
self.conv3 = build_conv_layer(
planes,
planes * self.expansion,
kernel_size=1,
bias=False,
)
self.bn3 = build_norm_layer(norm, planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Forward function."""
def _inner_forward(x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.use_checkpoint and x.requires_grad:
out = checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
[docs]
class ResNet(BaseModel):
"""ResNet BaseModel."""
arch_settings = {
"resnet18": (18, BasicBlock, (2, 2, 2, 2)),
"resnet34": (34, BasicBlock, (3, 4, 6, 3)),
"resnet50": (50, Bottleneck, (3, 4, 6, 3)),
"resnet101": (101, Bottleneck, (3, 4, 23, 3)),
"resnet152": (152, Bottleneck, (3, 8, 36, 3)),
}
def __init__(
self,
resnet_name: str,
in_channels: int = 3,
stem_channels: int | None = None,
base_channels: int = 64,
num_stages: int = 4,
strides: Sequence[int] = (1, 2, 2, 2),
dilations: Sequence[int] = (1, 1, 1, 1),
style: str = "pytorch",
deep_stem: bool = False,
avg_down: bool = False,
trainable_layers: int = 5,
norm: str = "BatchNorm2d",
norm_frozen: bool = True,
stages_with_dcn: Sequence[bool] = (False, False, False, False),
replace_stride_with_dilation: Sequence[bool] = (False, False, False),
use_checkpoint: bool = False,
zero_init_residual: bool = True,
pretrained: bool = False,
weights: None | str = None,
) -> None:
"""Create ResNet.
Args:
resnet_name (str): Name of the ResNet variant.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int | None): Number of stem channels. If not
specified, it will be the same as `base_channels`. Default:
None.
base_channels (int): Number of base channels of res layer. Default:
64.
num_stages (int): Resnet stages. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: (1, 2, 2, 2).
dilations (Sequence[int]): Dilation of each stage. Default: (1, 1,
1, 1)
style (str): `pytorch` or `caffe`. If set to "pytorch", the
stride-two layer is the 3x3 conv layer, otherwise the
stride-two layer is the first 1x1 conv layer. Default: pytorch.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
trainable_layers (int, optional): Number layers for training or
fine-tuning. 5 means all the layers can be fine-tuned. Defaults
to 5.
norm (str): Normalization layer str. Default: BatchNorm2d, which
means using `nn.BatchNorm2d`.
norm_frozen (bool): Whether to set norm layers to eval mode. It
freezes running stats (mean and var). Note: Effect on
Batch Norm and its variants only.
stages_with_dcn (Sequence[bool]): Indices of stages with deformable
convolutions. Default: (False, False, False, False).
replace_stride_with_dilation (Sequence[bool]): Whether to replace
stride with dilation. Default: (False, False, False).
use_checkpoint (bool): Use checkpoint or not. Using checkpoint will
save some memory while slowing down the training speed.
Default: False.
zero_init_residual (bool): Whether to use zero init for last norm
layer in resblocks to let them behave as identity.
Default: True.
pretrained (bool): Whether to load pretrained weights. Default:
False.
weights (str, optional): model pretrained path. Default: None
"""
super().__init__()
self._norm = norm
self.zero_init_residual = zero_init_residual
if resnet_name not in self.arch_settings:
raise KeyError(f"invalid architecture {resnet_name} for ResNet")
self.name = resnet_name
self.deep_stem = deep_stem
self.trainable_layers = trainable_layers
self.use_checkpoint = use_checkpoint
self.norm_frozen = norm_frozen
depth, self.block, stage_blocks = self.arch_settings[resnet_name]
assert isinstance(depth, int)
self.depth = depth
stem_channels = stem_channels or base_channels
assert 4 >= num_stages >= 1
assert len(strides) == len(dilations) == num_stages
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = stem_channels
self._make_stem_layer(in_channels, stem_channels)
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
if i > 0 and replace_stride_with_dilation[i - 1]:
dilation = strides[i]
stride = 1
else:
stride = strides[i]
dilation = dilations[i]
planes = base_channels * 2**i
res_layer = self._make_res_layer(
block=self.block, # type: ignore
inplanes=self.inplanes,
planes=planes,
num_blocks=num_blocks,
stride=stride,
dilation=dilation,
style=style,
avg_down=avg_down,
use_checkpoint=use_checkpoint,
with_dcn=stages_with_dcn[i],
)
self.inplanes = planes * self.block.expansion # type: ignore
layer_name = f"layer{i + 1}"
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
if pretrained:
if weights is None:
# default loading the imagenet-1k v1 pre-trained model weights
weights = _resnet.__dict__[
f"ResNet{depth}_Weights"
].IMAGENET1K_V1.url
load_model_checkpoint(self, weights)
else:
self._init_weights()
self._freeze_stages()
def _init_weights(self) -> None:
"""Initialize the weights of module."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck) and isinstance(
m.bn3.weight, nn.Parameter
):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock) and isinstance(
m.bn2.weight, nn.Parameter
):
nn.init.constant_(m.bn2.weight, 0)
def _make_stem_layer(self, in_channels: int, stem_channels: int) -> None:
"""Make stem layer for ResNet."""
if self.deep_stem:
self.stem = nn.Sequential(
build_conv_layer(
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False,
),
build_norm_layer(self._norm, stem_channels // 2),
nn.ReLU(inplace=True),
build_conv_layer(
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
build_norm_layer(self._norm, stem_channels // 2),
nn.ReLU(inplace=True),
build_conv_layer(
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
build_norm_layer(self._norm, stem_channels),
nn.ReLU(inplace=True),
)
else:
self.conv1 = build_conv_layer(
in_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False,
)
self.bn1 = build_norm_layer(self._norm, stem_channels)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _make_res_layer(
self,
block: BasicBlock | Bottleneck,
inplanes: int,
planes: int,
num_blocks: int,
stride: int,
dilation: int,
style: str,
avg_down: bool,
use_checkpoint: bool,
with_dcn: bool,
) -> nn.Sequential:
"""Pack all blocks in a stage into a ``ResLayer``."""
layers: list[BasicBlock | Bottleneck] = []
downsample: nn.Module | None = None
if stride != 1 or inplanes != planes * block.expansion:
downsample_list: list[nn.AvgPool2d | nn.Module] = []
conv_stride = stride
if avg_down:
conv_stride = 1
downsample_list.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False,
)
)
downsample_list.extend(
[
build_conv_layer(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=conv_stride,
bias=False,
),
build_norm_layer(self._norm, planes * block.expansion),
]
)
downsample = nn.Sequential(*downsample_list)
layers = []
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride,
dilation=dilation,
downsample=downsample,
style=style,
use_checkpoint=use_checkpoint,
with_dcn=with_dcn,
norm=self._norm,
)
)
inplanes = planes * block.expansion
for _ in range(1, num_blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=1,
dilation=dilation,
style=style,
use_checkpoint=use_checkpoint,
with_dcn=with_dcn,
norm=self._norm,
)
)
return nn.Sequential(*layers)
def _freeze_stages(self) -> None:
"""Freeze stages param and norm stats."""
if self.trainable_layers < 5:
if self.deep_stem:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
else:
self.bn1.eval()
for m in (self.conv1, self.bn1):
for param in m.parameters():
param.requires_grad = False
for i in range(1, 5 - self.trainable_layers):
m = getattr(self, f"layer{i}")
m.eval()
for param in m.parameters():
param.requires_grad = False
[docs]
def train(self, mode: bool = True) -> ResNet:
"""Override the train mode for the model."""
super().train(mode)
self._freeze_stages()
if mode and self.norm_frozen:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
return self
@property
def out_channels(self) -> list[int]:
"""Get the number of channels for each level of feature pyramid.
Returns:
list[int]: number of channels
"""
if self.name in {"resnet18", "resnet34"}:
# channels = [3, 3] + [64 * 2**i for i in range(4)]
channels = [3, 3, 64, 128, 256, 512]
else:
# channels = [3, 3] + [256 * 2**i for i in range(4)]
channels = [3, 3, 256, 512, 1024, 2048]
return channels
[docs]
def forward(self, images: Tensor) -> list[Tensor]:
"""Forward function.
Args:
images (Tensor[N, C, H, W]): Image input to process. Expected to
type float32 with values ranging 0..255.
Returns:
fp (list[torch.Tensor]): The output feature pyramid. The list index
represents the level, which has a downsampling raio of 2^index.
fp[0] and fp[1] is a reference to the input images and
torchvision resnet downsamples the feature maps by 4 directly.
The last feature map downsamples the input image by 64 with a
pooling layer on the second last map.
"""
if self.deep_stem:
x = self.stem(images)
else:
x = self.conv1(images)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = [images, images]
for _, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
outs.append(x)
return outs
[docs]
class ResNetV1c(ResNet):
"""ResNetV1c variant with a deeper stem.
Compared with default ResNet, ResNetV1c replaces the 7x7 conv in the input
stem with three 3x3 convs. For more details please refer to `Bag of Tricks
for Image Classification with Convolutional Neural Networks
<https://arxiv.org/abs/1812.01187>`.
"""
model_urls = {
"resnet50_v1c": (
"https://download.openmmlab.com/pretrain/third_party/"
"resnet50_v1c-2cccc1ad.pth"
),
"resnet101_v1c": (
"https://download.openmmlab.com/pretrain/third_party/"
"resnet101_v1c-e67eebb6.pth"
),
}
def __init__(
self,
resnet_name: str,
pretrained: bool = False,
weights: str | None = None,
**kwargs: ArgsType,
):
"""Initialize ResNetV1c.
Args:
resnet_name (str): Name of the resnet model.
pretrained (bool, optional): Whether to load ImageNet pre-trained
weights. Defaults to False.
weights (str, optional): Path to custom pretrained weights.
**kwargs: Arguments for ResNet.
"""
assert resnet_name in {
"resnet18_v1c",
"resnet34_v1c",
"resnet50_v1c",
"resnet101_v1c",
}
if pretrained and weights is None:
assert resnet_name in {
"resnet50_v1c",
"resnet101_v1c",
}, "Only resnet50_v1c and resnet101_v1c have pretrained weights."
weights = self.model_urls[resnet_name]
super().__init__(
resnet_name[:-4],
deep_stem=True,
pretrained=pretrained,
weights=weights,
**kwargs,
)