Source code for vis4d.op.detect.retinanet

"""RetinaNet."""

from __future__ import annotations

from math import prod
from typing import NamedTuple

import torch
from torch import nn
from torchvision.ops import batched_nms, sigmoid_focal_loss

from vis4d.common.typing import TorchLossFunc
from vis4d.op.box.anchor import AnchorGenerator
from vis4d.op.box.box2d import bbox_clip, filter_boxes_by_area
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder
from vis4d.op.box.matchers import Matcher, MaxIoUMatcher
from vis4d.op.box.samplers import PseudoSampler, Sampler
from vis4d.op.loss.common import l1_loss

from .common import DetOut
from .dense_anchor import DenseAnchorHeadLoss


[docs] class RetinaNetOut(NamedTuple): """RetinaNet head outputs.""" # Logits for box classification for each feature level. The logit # dimention is [batch_size, number of anchors * number of classes, height, # width]. cls_score: list[torch.Tensor] # Each box has regression for all classes for each feature level. So the # tensor dimension is [batch_size, number of anchors * 4, height, width]. bbox_pred: list[torch.Tensor]
[docs] def get_default_anchor_generator() -> AnchorGenerator: """Get default anchor generator.""" return AnchorGenerator( octave_base_scale=4, scales_per_octave=3, ratios=[0.5, 1.0, 2.0], strides=[8, 16, 32, 64, 128], )
[docs] def get_default_box_codec() -> ( tuple[DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder] ): """Get the default bounding box encoder.""" return ( DeltaXYWHBBoxEncoder( target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(1.0, 1.0, 1.0, 1.0) ), DeltaXYWHBBoxDecoder( target_means=(0.0, 0.0, 0.0, 0.0), target_stds=(1.0, 1.0, 1.0, 1.0) ), )
[docs] def get_default_box_matcher() -> MaxIoUMatcher: """Get default bounding box matcher.""" return MaxIoUMatcher( thresholds=[0.4, 0.5], labels=[0, -1, 1], allow_low_quality_matches=True, )
[docs] def get_default_box_sampler() -> PseudoSampler: """Get default bounding box sampler.""" return PseudoSampler()
[docs] class RetinaNetHead(nn.Module): # TODO: Refactor to use the new API """RetinaNet Head.""" def __init__( self, num_classes: int, in_channels: int, feat_channels: int = 256, stacked_convs: int = 4, use_sigmoid_cls: bool = True, anchor_generator: AnchorGenerator | None = None, box_decoder: DeltaXYWHBBoxDecoder | None = None, box_matcher: Matcher | None = None, box_sampler: Sampler | None = None, ): """Creates an instance of the class.""" super().__init__() self.anchor_generator = ( anchor_generator if anchor_generator is not None else get_default_anchor_generator() ) if box_decoder is None: _, self.box_decoder = get_default_box_codec() else: self.box_decoder = box_decoder self.box_matcher = ( box_matcher if box_matcher is not None else get_default_box_matcher() ) self.box_sampler = ( box_sampler if box_sampler is not None else get_default_box_sampler() ) num_base_priors = self.anchor_generator.num_base_priors[0] if use_sigmoid_cls: cls_out_channels = num_classes else: cls_out_channels = num_classes + 1 self.relu = nn.ReLU(inplace=True) self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() for i in range(stacked_convs): chn = in_channels if i == 0 else feat_channels self.cls_convs.append( nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1), ) self.reg_convs.append( nn.Conv2d(chn, feat_channels, 3, stride=1, padding=1), ) self.retina_cls = nn.Conv2d( feat_channels, num_base_priors * cls_out_channels, 3, padding=1 ) self.retina_reg = nn.Conv2d( feat_channels, num_base_priors * 4, 3, padding=1 )
[docs] def forward(self, features: list[torch.Tensor]) -> RetinaNetOut: """Forward pass of RetinaNet. Args: features (list[torch.Tensor]): Feature pyramid Returns: RetinaNetOut: classification score and box prediction. """ cls_scores, bbox_preds = [], [] for feat in features: cls_feat = feat reg_feat = feat for cls_conv in self.cls_convs: cls_feat = self.relu(cls_conv(cls_feat)) for reg_conv in self.reg_convs: reg_feat = self.relu(reg_conv(reg_feat)) cls_scores.append(self.retina_cls(cls_feat)) bbox_preds.append(self.retina_reg(reg_feat)) return RetinaNetOut(cls_score=cls_scores, bbox_pred=bbox_preds)
[docs] def __call__(self, features: list[torch.Tensor]) -> RetinaNetOut: """Type definition for call implementation.""" return self._call_impl(features)
[docs] def get_params_per_level( cls_out: torch.Tensor, reg_out: torch.Tensor, anchors: torch.Tensor, num_pre_nms: int = 2000, score_thr: float = 0.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Get topk params from feature output per level per image before nms. Params include flattened classification scores, box energies, and corresponding anchors. Args: cls_out (torch.Tensor): [C, H, W] classification scores at a particular scale. reg_out (torch.Tensor): [C, H, W] regression parameters at a particular scale. anchors (torch.Tensor): [H * W, 4] anchor boxes per cell. num_pre_nms (int): number of predictions before nms. score_thr (float): score threshold for filtering predictions. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: topk flattened classification, regression outputs, and corresponding anchors. """ assert cls_out.size()[-2:] == reg_out.size()[-2:], ( f"Shape mismatch: cls_out({cls_out.size()[-2:]}), reg_out(" f"{reg_out.size()[-2:]})." ) reg_out = reg_out.permute(1, 2, 0).reshape(-1, 4) cls_out = cls_out.permute(1, 2, 0).reshape(reg_out.size(0), -1).sigmoid() valid_mask = torch.greater(cls_out, score_thr) valid_idxs = torch.nonzero(valid_mask) num_topk = min(num_pre_nms, valid_idxs.size(0)) cls_out_filt = cls_out[valid_mask] cls_out_ranked, rank_inds = cls_out_filt.sort(descending=True) topk_inds = valid_idxs[rank_inds[:num_topk]] keep_inds, labels = topk_inds.unbind(dim=1) cls_out = cls_out_ranked[:num_topk] reg_out = reg_out[keep_inds, :] anchors = anchors[keep_inds, :] return cls_out, labels, reg_out, anchors
[docs] def decode_multi_level_outputs( cls_out_all: list[torch.Tensor], lbl_out_all: list[torch.Tensor], reg_out_all: list[torch.Tensor], anchors_all: list[torch.Tensor], image_hw: tuple[int, int], box_decoder: DeltaXYWHBBoxDecoder, max_per_img: int = 1000, nms_threshold: float = 0.7, min_box_size: tuple[int, int] = (0, 0), ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Decode box energies into detections for a single image. Detections are post-processed via NMS. NMS is performed per level. Afterwards, select topk detections. Args: cls_out_all (list[torch.Tensor]): topk class scores per level. lbl_out_all (list[torch.Tensor]): topk class labels per level. reg_out_all (list[torch.Tensor]): topk regression params per level. anchors_all (list[torch.Tensor]): topk anchor boxes per level. image_hw (tuple[int, int]): image size. box_decoder (DeltaXYWHBBoxDecoder): bounding box encoder. max_per_img (int, optional): maximum predictions per image. Defaults to 1000. nms_threshold (float, optional): iou threshold for NMS. Defaults to 0.7. min_box_size (tuple[int, int], optional): minimum box size. Defaults to (0, 0). Returns: tuple[torch.Tensor, torch.Tensor]: decoded proposal boxes & scores. """ scores, labels = torch.cat(cls_out_all), torch.cat(lbl_out_all) boxes = bbox_clip( box_decoder(torch.cat(anchors_all), torch.cat(reg_out_all)), image_hw, ) boxes, mask = filter_boxes_by_area(boxes, min_area=prod(min_box_size)) scores, labels = scores[mask], labels[mask] if boxes.numel() > 0: keep = batched_nms(boxes, scores, labels, iou_threshold=nms_threshold)[ :max_per_img ] return boxes[keep], scores[keep], labels[keep] return (boxes.new_zeros(0, 4), scores.new_zeros(0), labels.new_zeros(0))
[docs] class Dense2Det(nn.Module): """Compute detections from dense network outputs. This class acts as a stateless functor that does the following: 1. Create anchor grid for feature grids (classification and regression outputs) at all scales. For each image For each level 2. Get a topk pre-selection of flattened classification scores and box energies from feature output before NMS. 3. Decode class scores and box energies into detection boxes, apply NMS. Return detection boxes for all images. """ def __init__( self, anchor_generator: AnchorGenerator, box_decoder: DeltaXYWHBBoxDecoder, num_pre_nms: int = 2000, max_per_img: int = 1000, nms_threshold: float = 0.7, min_box_size: tuple[int, int] = (0, 0), score_thr: float = 0.0, ) -> None: """Creates an instance of the class.""" super().__init__() self.anchor_generator = anchor_generator self.box_decoder = box_decoder self.num_pre_nms = num_pre_nms self.max_per_img = max_per_img self.nms_threshold = nms_threshold self.min_box_size = min_box_size self.score_thr = score_thr
[docs] def forward( self, cls_outs: list[torch.Tensor], reg_outs: list[torch.Tensor], images_hw: list[tuple[int, int]], ) -> DetOut: """Compute detections from dense network outputs. Generate anchor grid for all scales. For each batch element: Compute classification, regression, and anchor pairs for all scales. Decode those pairs into proposals, post-process with NMS. Args: cls_outs (list[torch.Tensor]): [N, C * A, H, W] per scale. reg_outs (list[torch.Tensor]): [N, 4 * A, H, W] per scale. images_hw (list[tuple[int, int]]): list of image sizes. Returns: DetOut: Detection outputs. """ # since feature map sizes of all images are the same, we only compute # anchors for one time device = cls_outs[0].device featmap_sizes: list[tuple[int, int]] = [ featmap.size()[-2:] for featmap in cls_outs # type: ignore ] assert len(featmap_sizes) == self.anchor_generator.num_levels anchor_grids = self.anchor_generator.grid_priors( featmap_sizes, device=device ) proposals, scores, labels = [], [], [] for img_id, image_hw in enumerate(images_hw): cls_out_all, lbl_out_all, reg_out_all, anchors_all = [], [], [], [] for cls_out, reg_out, anchor_grid in zip( cls_outs, reg_outs, anchor_grids ): cls_out_, lbl_out, reg_out_, anchors = get_params_per_level( cls_out[img_id], reg_out[img_id], anchor_grid, self.num_pre_nms, self.score_thr, ) cls_out_all += [cls_out_] lbl_out_all += [lbl_out] reg_out_all += [reg_out_] anchors_all += [anchors] box, score, label = decode_multi_level_outputs( cls_out_all, lbl_out_all, reg_out_all, anchors_all, image_hw, self.box_decoder, self.max_per_img, self.nms_threshold, self.min_box_size, ) proposals.append(box) scores.append(score) labels.append(label) return DetOut(proposals, scores, labels)
[docs] def __call__( self, cls_outs: list[torch.Tensor], reg_outs: list[torch.Tensor], images_hw: list[tuple[int, int]], ) -> DetOut: """Type definition for function call.""" return self._call_impl(cls_outs, reg_outs, images_hw)
[docs] class RetinaNetHeadLoss(DenseAnchorHeadLoss): """Loss of RetinaNet head.""" def __init__( self, anchor_generator: AnchorGenerator, box_encoder: DeltaXYWHBBoxEncoder, box_matcher: None | Matcher = None, box_sampler: None | Sampler = None, loss_cls: TorchLossFunc = sigmoid_focal_loss, loss_bbox: TorchLossFunc = l1_loss, ) -> None: """Creates an instance of the class. Args: anchor_generator (AnchorGenerator): Generates anchor grid priors. box_encoder (DeltaXYWHBBoxEncoder): Encodes bounding boxes to the desired network output. box_matcher (None | Matcher, optional): Box matcher. Defaults to None. box_sampler (None | Sampler, optional): Box sampler. Defaults to None. loss_cls (TorchLossFunc, optional): Classification loss function. Defaults to sigmoid_focal_loss. loss_bbox (TorchLossFunc, optional): Regression loss function. Defaults to l1_loss. """ matcher = ( box_matcher if box_matcher is not None else get_default_box_matcher() ) sampler = ( box_sampler if box_sampler is not None else get_default_box_sampler() ) super().__init__( anchor_generator, box_encoder, matcher, sampler, loss_cls, loss_bbox, )