Source code for vis4d.op.detect.dense_anchor

"""Dense anchor-based head."""

from __future__ import annotations

from typing import NamedTuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from vis4d.common import TorchLossFunc
from vis4d.op.box.anchor import AnchorGenerator, anchor_inside_image
from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder
from vis4d.op.box.matchers import Matcher
from vis4d.op.box.samplers import Sampler
from vis4d.op.loss.reducer import SumWeightedLoss
from vis4d.op.util import unmap


[docs] class DetectorTargets(NamedTuple): """Targets for first-stage detection.""" labels: Tensor label_weights: Tensor bbox_targets: Tensor bbox_weights: Tensor
[docs] def images_to_levels( targets: list[ tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]] ] ) -> list[list[Tensor]]: """Convert targets by image to targets by feature level.""" targets_per_level = [] for lvl_id in range(len(targets[0][0])): targets_single_level = [] for tgt_id in range(len(targets[0])): targets_single_level.append( torch.stack([tgt[tgt_id][lvl_id] for tgt in targets], 0) ) targets_per_level.append(targets_single_level) return targets_per_level
[docs] def get_targets_per_image( target_boxes: Tensor, anchors: Tensor, matcher: Matcher, sampler: Sampler, box_encoder: DeltaXYWHBBoxEncoder, image_hw: tuple[int, int], target_class: Tensor | float = 1.0, allowed_border: int = 0, ) -> tuple[DetectorTargets, int, int]: """Get targets per batch element, all scales. Args: target_boxes (Tensor): (N, 4) Tensor of target boxes for a single image. anchors (Tensor): (M, 4) box priors matcher (Matcher): box matcher matching anchors to targets. sampler (Sampler): box sampler sub-sampling matches. box_encoder (DeltaXYWHBBoxEncoder): Encodes boxes into target regression parameters. image_hw (tuple[int, int]): input image height and width. target_class (Tensor | float, optional): class label(s) of target boxes. Defaults to 1.0. allowed_border (int, optional): Allowed border for sub-sampling anchors that lie inside the input image. Defaults to 0. Returns: tuple[DetectorTargets, Tensor, Tensor]: Targets, sum of positives, sum of negatives. """ inside_flags = anchor_inside_image( anchors, image_hw, allowed_border=allowed_border ) # assign gt and sample anchors anchors = anchors[inside_flags, :] matching = matcher(anchors, target_boxes) sampling_result = sampler(matching) num_valid_anchors = anchors.size(0) bbox_targets = torch.zeros_like(anchors) bbox_weights = torch.zeros_like(anchors) labels = anchors.new_zeros((num_valid_anchors,)) label_weights = anchors.new_zeros(num_valid_anchors) positives = torch.eq(sampling_result.sampled_labels, 1) negatives = torch.eq(sampling_result.sampled_labels, 0) pos_inds = sampling_result.sampled_box_indices[positives] pos_target_inds = sampling_result.sampled_target_indices[positives] neg_inds = sampling_result.sampled_box_indices[negatives] if len(pos_inds) > 0: pos_bbox_targets = box_encoder( anchors[pos_inds], target_boxes[pos_target_inds] ) bbox_targets[pos_inds] = pos_bbox_targets bbox_weights[pos_inds] = 1.0 if isinstance(target_class, float): labels[pos_inds] = target_class else: labels[pos_inds] = target_class[pos_target_inds].float() label_weights[pos_inds] = 1.0 if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 # map up to original set of anchors num_total_anchors = inside_flags.size(0) labels = unmap(labels, num_total_anchors, inside_flags) label_weights = unmap(label_weights, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) return ( DetectorTargets(labels, label_weights, bbox_targets, bbox_weights), int(positives.sum()), int(negatives.sum()), )
[docs] def get_targets_per_batch( featmap_sizes: list[tuple[int, int]], target_boxes: list[Tensor], target_class_ids: list[Tensor | float], images_hw: list[tuple[int, int]], anchor_generator: AnchorGenerator, box_encoder: DeltaXYWHBBoxEncoder, box_matcher: Matcher, box_sampler: Sampler, allowed_border: int = 0, ) -> tuple[list[list[Tensor]], int]: """Get targets for all batch elements, all scales.""" device = target_boxes[0].device anchor_grids = anchor_generator.grid_priors(featmap_sizes, device=device) num_level_anchors = [anchors.size(0) for anchors in anchor_grids] anchors_all_levels = torch.cat(anchor_grids) targets: list[ tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]] ] = [] num_total_pos, num_total_neg = 0, 0 for tgt_box, tgt_cls, image_hw in zip( target_boxes, target_class_ids, images_hw ): target, num_pos, num_neg = get_targets_per_image( tgt_box, anchors_all_levels, box_matcher, box_sampler, box_encoder, image_hw, tgt_cls, allowed_border, ) num_total_pos += num_pos num_total_neg += num_neg bbox_targets_per_level = target.bbox_targets.split(num_level_anchors) bbox_weights_per_level = target.bbox_weights.split(num_level_anchors) labels_per_level = target.labels.split(num_level_anchors) label_weights_per_level = target.label_weights.split(num_level_anchors) targets.append( ( bbox_targets_per_level, bbox_weights_per_level, labels_per_level, label_weights_per_level, ) ) targets_per_level = images_to_levels(targets) num_samples = num_total_pos + num_total_neg return targets_per_level, num_samples
[docs] class DenseAnchorHeadLosses(NamedTuple): """Dense anchor head loss container.""" loss_cls: Tensor loss_bbox: Tensor
[docs] class DenseAnchorHeadLoss(nn.Module): """Loss of dense anchor heads. For a given set of multi-scale dense outputs, compute the desired target outputs and apply classification and regression losses. The targets are computed with the given target bounding boxes, the anchor grid defined by the anchor generator and the given box encoder. """ def __init__( self, anchor_generator: AnchorGenerator, box_encoder: DeltaXYWHBBoxEncoder, box_matcher: Matcher, box_sampler: Sampler, loss_cls: TorchLossFunc, loss_bbox: TorchLossFunc, allowed_border: int = 0, ) -> None: """Creates an instance of the class. Args: anchor_generator (AnchorGenerator): Generates anchor grid priors. box_encoder (DeltaXYWHBBoxEncoder): Encodes bounding boxes to the desired network output. box_matcher (Matcher): Box matcher. box_sampler (Sampler): Box sampler. loss_cls (TorchLossFunc): Classification loss. loss_bbox (TorchLossFunc): Bounding box regression loss. allowed_border (int): The border to allow the valid anchor. Defaults to 0. """ super().__init__() self.anchor_generator = anchor_generator self.box_encoder = box_encoder self.allowed_border = allowed_border self.matcher = box_matcher self.sampler = box_sampler self.loss_cls = loss_cls self.loss_bbox = loss_bbox def _loss_single_scale( self, cls_out: Tensor, reg_out: Tensor, bbox_targets: Tensor, bbox_weights: Tensor, labels: Tensor, label_weights: Tensor, num_total_samples: int, ) -> tuple[Tensor, Tensor]: """Compute losses per scale, all batch elements. Args: cls_out (Tensor): [N, C, H, W] tensor of class logits. reg_out (Tensor): [N, C, H, W] tensor of regression params. bbox_targets (Tensor): [H * W, 4] bounding box targets bbox_weights (Tensor): [H * W] per-sample weighting for loss. labels (Tensor): [H * W] classification targets. label_weights (Tensor): [H * W] per-sample weighting for loss. num_total_samples (int): average factor of loss. Returns: tuple[Tensor, Tensor]: classification and regression losses. """ # classification loss labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) cls_score = cls_out.permute(0, 2, 3, 1).reshape(labels.size(0), -1) if cls_score.size(1) > 1: labels = F.one_hot( # pylint: disable=not-callable labels.long(), num_classes=cls_score.size(1) + 1 )[:, : cls_score.size(1)].float() label_weights = label_weights.repeat(cls_score.size(1)).reshape( -1, cls_score.size(1) ) else: cls_score = cls_score.squeeze(1) loss_cls = self.loss_cls(cls_score, labels, reduction="none") loss_cls = SumWeightedLoss(label_weights, num_total_samples)(loss_cls) # regression loss bbox_targets = bbox_targets.reshape(-1, 4) bbox_weights = bbox_weights.reshape(-1, 4) bbox_pred = reg_out.permute(0, 2, 3, 1).reshape(-1, 4) loss_bbox = self.loss_bbox( pred=bbox_pred, target=bbox_targets, reducer=SumWeightedLoss(bbox_weights, num_total_samples), ) return loss_cls, loss_bbox
[docs] def forward( self, cls_outs: list[Tensor], reg_outs: list[Tensor], target_boxes: list[Tensor], images_hw: list[tuple[int, int]], target_class_ids: list[Tensor | float] | None = None, ) -> DenseAnchorHeadLosses: """Compute RetinaNet classification and regression losses. Args: cls_outs (list[Tensor]): Network classification outputs at all scales. reg_outs (list[Tensor]): Network regression outputs at all scales. target_boxes (list[Tensor]): Target bounding boxes. images_hw (list[tuple[int, int]]): Image dimensions without padding. target_class_ids (list[Tensor] | None, optional): Target class labels. Returns: DenseAnchorHeadLosses: Classification and regression losses. """ featmap_sizes = [ (featmap.size()[-2], featmap.size()[-1]) for featmap in cls_outs ] assert len(featmap_sizes) == self.anchor_generator.num_levels if target_class_ids is None: target_class_ids = [1.0 for _ in range(len(target_boxes))] targets_per_level, num_samples = get_targets_per_batch( featmap_sizes, target_boxes, target_class_ids, images_hw, self.anchor_generator, self.box_encoder, self.matcher, self.sampler, self.allowed_border, ) device = cls_outs[0].device loss_cls_all = torch.tensor(0.0, device=device) loss_bbox_all = torch.tensor(0.0, device=device) for level_id, (cls_out, reg_out) in enumerate(zip(cls_outs, reg_outs)): box_tgt, box_wgt, lbl, lbl_wgt = targets_per_level[level_id] loss_cls, loss_bbox = self._loss_single_scale( cls_out, reg_out, box_tgt, box_wgt, lbl, lbl_wgt, num_samples ) loss_cls_all += loss_cls loss_bbox_all += loss_bbox return DenseAnchorHeadLosses( loss_cls=loss_cls_all, loss_bbox=loss_bbox_all )
[docs] def __call__( self, cls_outs: list[Tensor], reg_outs: list[Tensor], target_boxes: list[Tensor], images_hw: list[tuple[int, int]], target_class_ids: list[Tensor] | None = None, ) -> DenseAnchorHeadLosses: """Type definition.""" return self._call_impl( cls_outs, reg_outs, target_boxes, images_hw, target_class_ids )