"""RetinaNet model implementation and runtime."""
from __future__ import annotations
from torch import Tensor, nn
from vis4d.common import LossesType
from vis4d.common.ckpt import load_model_checkpoint
from vis4d.op.base.resnet import ResNet
from vis4d.op.box.anchor import AnchorGenerator
from vis4d.op.box.box2d import scale_and_clip_boxes
from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder
from vis4d.op.box.matchers import Matcher
from vis4d.op.box.samplers import Sampler
from vis4d.op.detect.common import DetOut
from vis4d.op.detect.retinanet import (
Dense2Det,
RetinaNetHead,
RetinaNetHeadLoss,
RetinaNetOut,
)
from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock
REV_KEYS = [
(r"^backbone\.", "basemodel."),
(r"^bbox_head\.", "retinanet_head."),
(r"^neck.lateral_convs\.", "fpn.inner_blocks."),
(r"^neck.fpn_convs\.", "fpn.layer_blocks."),
(r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."),
(r"^fpn.layer_blocks.4\.", "fpn.extra_blocks.convs.1."),
(r"\.conv.weight", ".weight"),
(r"\.conv.bias", ".bias"),
]
[docs]
class RetinaNet(nn.Module):
"""RetinaNet wrapper class for checkpointing etc."""
def __init__(self, num_classes: int, weights: None | str = None) -> None:
"""Creates an instance of the class.
Args:
num_classes (int): Number of classes.
weights (None | str, optional): Weights to load for model. If
set to "mmdet", will load MMDetection pre-trained weights.
Defaults to None.
"""
super().__init__()
self.basemodel = ResNet(
"resnet50", pretrained=True, trainable_layers=3
)
self.fpn = FPN(
self.basemodel.out_channels[3:],
256,
ExtraFPNBlock(2, 2048, 256, add_extra_convs="on_input"),
start_index=3,
)
self.retinanet_head = RetinaNetHead(
num_classes=num_classes, in_channels=256
)
self.transform_outs = Dense2Det(
self.retinanet_head.anchor_generator,
self.retinanet_head.box_decoder,
num_pre_nms=1000,
max_per_img=100,
nms_threshold=0.5,
score_thr=0.05,
)
if weights == "mmdet":
weights = (
"mmdet://retinanet/retinanet_r50_fpn_2x_coco/"
"retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth"
)
load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
elif weights is not None:
load_model_checkpoint(self, weights)
[docs]
def forward(
self,
images: Tensor,
input_hw: None | list[tuple[int, int]] = None,
original_hw: None | list[tuple[int, int]] = None,
) -> RetinaNetOut | DetOut:
"""Forward pass.
Args:
images (Tensor): Input images.
input_hw (None | list[tuple[int, int]], optional): Input image
resolutions. Defaults to None.
original_hw (None | list[tuple[int, int]], optional): Original
image resolutions (before padding and resizing). Required for
testing. Defaults to None.
Returns:
RetinaNetOut | DetOut: Either raw model outputs (for training) or
predicted outputs (for testing).
"""
if self.training:
return self.forward_train(images)
assert input_hw is not None and original_hw is not None
return self.forward_test(images, input_hw, original_hw)
[docs]
def forward_train(self, images: Tensor) -> RetinaNetOut:
"""Forward training stage.
Args:
images (Tensor): Input images.
Returns:
RetinaNetOut: Raw model outputs.
"""
features = self.fpn(self.basemodel(images))
return self.retinanet_head(features[-5:])
[docs]
def forward_test(
self,
images: Tensor,
images_hw: list[tuple[int, int]],
original_hw: list[tuple[int, int]],
) -> DetOut:
"""Forward testing stage.
Args:
images (Tensor): Input images.
images_hw (list[tuple[int, int]]): Input image resolutions.
original_hw (list[tuple[int, int]]): Original image resolutions
(before padding and resizing).
Returns:
DetOut: Predicted outputs.
"""
features = self.fpn(self.basemodel(images))
outs = self.retinanet_head(features[-5:])
boxes, scores, class_ids = self.transform_outs(
cls_outs=outs.cls_score,
reg_outs=outs.bbox_pred,
images_hw=images_hw,
)
for i, boxs in enumerate(boxes):
boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i])
return DetOut(boxes, scores, class_ids)
[docs]
class RetinaNetLoss(nn.Module):
"""RetinaNet Loss."""
def __init__(
self,
anchor_generator: AnchorGenerator,
box_encoder: DeltaXYWHBBoxEncoder,
box_matcher: Matcher,
box_sampler: Sampler,
) -> None:
"""Creates an instance of the class.
Args:
anchor_generator (AnchorGenerator): Anchor generator for RPN.
box_encoder (BoxEncoder2D): Bounding box encoder.
box_matcher (BaseMatcher): Bounding box matcher.
box_sampler (BaseSampler): Bounding box sampler.
"""
super().__init__()
self.retinanet_loss = RetinaNetHeadLoss(
anchor_generator, box_encoder, box_matcher, box_sampler
)
[docs]
def forward(
self,
outputs: RetinaNetOut,
images_hw: list[tuple[int, int]],
target_boxes: list[Tensor],
target_classes: list[Tensor],
) -> LossesType:
"""Forward of loss function.
Args:
outputs (RetinaNetOut): Raw model outputs.
images_hw (list[tuple[int, int]]): Input image resolutions.
target_boxes (list[Tensor]): Bounding box labels.
target_classes (list[Tensor]): Class labels.
Returns:
LossesType: Dictionary of model losses.
"""
losses = self.retinanet_loss(
outputs.cls_score,
outputs.bbox_pred,
target_boxes,
images_hw,
target_classes,
)
return losses._asdict()