Source code for vis4d.model.detect.retinanet

"""RetinaNet model implementation and runtime."""

from __future__ import annotations

from torch import Tensor, nn

from vis4d.common import LossesType
from vis4d.common.ckpt import load_model_checkpoint
from vis4d.op.base.resnet import ResNet
from vis4d.op.box.anchor import AnchorGenerator
from vis4d.op.box.box2d import scale_and_clip_boxes
from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder
from vis4d.op.box.matchers import Matcher
from vis4d.op.box.samplers import Sampler
from vis4d.op.detect.common import DetOut
from vis4d.op.detect.retinanet import (
    Dense2Det,
    RetinaNetHead,
    RetinaNetHeadLoss,
    RetinaNetOut,
)
from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock

REV_KEYS = [
    (r"^backbone\.", "basemodel."),
    (r"^bbox_head\.", "retinanet_head."),
    (r"^neck.lateral_convs\.", "fpn.inner_blocks."),
    (r"^neck.fpn_convs\.", "fpn.layer_blocks."),
    (r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."),
    (r"^fpn.layer_blocks.4\.", "fpn.extra_blocks.convs.1."),
    (r"\.conv.weight", ".weight"),
    (r"\.conv.bias", ".bias"),
]


[docs] class RetinaNet(nn.Module): """RetinaNet wrapper class for checkpointing etc.""" def __init__(self, num_classes: int, weights: None | str = None) -> None: """Creates an instance of the class. Args: num_classes (int): Number of classes. weights (None | str, optional): Weights to load for model. If set to "mmdet", will load MMDetection pre-trained weights. Defaults to None. """ super().__init__() self.basemodel = ResNet( "resnet50", pretrained=True, trainable_layers=3 ) self.fpn = FPN( self.basemodel.out_channels[3:], 256, ExtraFPNBlock(2, 2048, 256, add_extra_convs="on_input"), start_index=3, ) self.retinanet_head = RetinaNetHead( num_classes=num_classes, in_channels=256 ) self.transform_outs = Dense2Det( self.retinanet_head.anchor_generator, self.retinanet_head.box_decoder, num_pre_nms=1000, max_per_img=100, nms_threshold=0.5, score_thr=0.05, ) if weights == "mmdet": weights = ( "mmdet://retinanet/retinanet_r50_fpn_2x_coco/" "retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth" ) load_model_checkpoint(self, weights, rev_keys=REV_KEYS) elif weights is not None: load_model_checkpoint(self, weights)
[docs] def forward( self, images: Tensor, input_hw: None | list[tuple[int, int]] = None, original_hw: None | list[tuple[int, int]] = None, ) -> RetinaNetOut | DetOut: """Forward pass. Args: images (Tensor): Input images. input_hw (None | list[tuple[int, int]], optional): Input image resolutions. Defaults to None. original_hw (None | list[tuple[int, int]], optional): Original image resolutions (before padding and resizing). Required for testing. Defaults to None. Returns: RetinaNetOut | DetOut: Either raw model outputs (for training) or predicted outputs (for testing). """ if self.training: return self.forward_train(images) assert input_hw is not None and original_hw is not None return self.forward_test(images, input_hw, original_hw)
[docs] def forward_train(self, images: Tensor) -> RetinaNetOut: """Forward training stage. Args: images (Tensor): Input images. Returns: RetinaNetOut: Raw model outputs. """ features = self.fpn(self.basemodel(images)) return self.retinanet_head(features[-5:])
[docs] def forward_test( self, images: Tensor, images_hw: list[tuple[int, int]], original_hw: list[tuple[int, int]], ) -> DetOut: """Forward testing stage. Args: images (Tensor): Input images. images_hw (list[tuple[int, int]]): Input image resolutions. original_hw (list[tuple[int, int]]): Original image resolutions (before padding and resizing). Returns: DetOut: Predicted outputs. """ features = self.fpn(self.basemodel(images)) outs = self.retinanet_head(features[-5:]) boxes, scores, class_ids = self.transform_outs( cls_outs=outs.cls_score, reg_outs=outs.bbox_pred, images_hw=images_hw, ) for i, boxs in enumerate(boxes): boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i]) return DetOut(boxes, scores, class_ids)
[docs] class RetinaNetLoss(nn.Module): """RetinaNet Loss.""" def __init__( self, anchor_generator: AnchorGenerator, box_encoder: DeltaXYWHBBoxEncoder, box_matcher: Matcher, box_sampler: Sampler, ) -> None: """Creates an instance of the class. Args: anchor_generator (AnchorGenerator): Anchor generator for RPN. box_encoder (BoxEncoder2D): Bounding box encoder. box_matcher (BaseMatcher): Bounding box matcher. box_sampler (BaseSampler): Bounding box sampler. """ super().__init__() self.retinanet_loss = RetinaNetHeadLoss( anchor_generator, box_encoder, box_matcher, box_sampler )
[docs] def forward( self, outputs: RetinaNetOut, images_hw: list[tuple[int, int]], target_boxes: list[Tensor], target_classes: list[Tensor], ) -> LossesType: """Forward of loss function. Args: outputs (RetinaNetOut): Raw model outputs. images_hw (list[tuple[int, int]]): Input image resolutions. target_boxes (list[Tensor]): Bounding box labels. target_classes (list[Tensor]): Class labels. Returns: LossesType: Dictionary of model losses. """ losses = self.retinanet_loss( outputs.cls_score, outputs.bbox_pred, target_boxes, images_hw, target_classes, ) return losses._asdict()