"""Faster R-CNN RoI head."""
from __future__ import annotations
from math import prod
from typing import NamedTuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from vis4d.common.typing import TorchLossFunc
from vis4d.op.box.box2d import bbox_clip, multiclass_nms
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder, DeltaXYWHBBoxEncoder
from vis4d.op.box.poolers import MultiScaleRoIAlign
from vis4d.op.detect.common import DetOut
from vis4d.op.layer import add_conv_branch
from vis4d.op.layer.weight_init import kaiming_init, normal_init, xavier_init
from vis4d.op.loss.common import l1_loss
from vis4d.op.loss.reducer import SumWeightedLoss
[docs]
class RCNNOut(NamedTuple):
"""Faster R-CNN RoI head outputs."""
# Logits for box classication. The logit dimension is number of classes
# plus 1 for the background.
cls_score: torch.Tensor
# Each box has regression for all classes. So the tensor dimention is
# [batch_size, number of boxes, number of classes x 4]
bbox_pred: torch.Tensor
[docs]
def get_default_rcnn_box_codec(
target_means: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0),
target_stds: tuple[float, float, float, float] = (0.1, 0.1, 0.2, 0.2),
) -> tuple[DeltaXYWHBBoxEncoder, DeltaXYWHBBoxDecoder]:
"""Get the default bounding box encoder and decoder for RCNN."""
return (
DeltaXYWHBBoxEncoder(target_means, target_stds),
DeltaXYWHBBoxDecoder(target_means, target_stds),
)
[docs]
class RCNNHead(nn.Module):
"""Faster R-CNN RoI head.
This head pools the RoIs from a set of feature maps and processes them
into classification / regression outputs.
Args:
num_shared_convs (int, optional): number of shared conv layers.
Defaults to 0.
num_shared_fcs (int, optional): number of shared fc layers. Defaults
to 2.
conv_out_channels (int, optional): number of output channels for
shared conv layers. Defaults to 256.
in_channels (int, optional): Number of channels in input feature maps.
Defaults to 256.
fc_out_channels (int, optional): Output channels of shared linear
layers. Defaults to 1024.
num_classes (int, optional): number of categories. Defaults to 80.
roi_size (tuple[int, int], optional): size of pooled RoIs. Defaults
to (7, 7).
"""
def __init__(
self,
num_shared_convs: int = 0,
num_shared_fcs: int = 2,
conv_out_channels: int = 256,
in_channels: int = 256,
fc_out_channels: int = 1024,
num_classes: int = 80,
roi_size: tuple[int, int] = (7, 7),
start_level: int = 2,
) -> None:
"""Creates an instance of the class."""
super().__init__()
self.roi_pooler = MultiScaleRoIAlign(
sampling_ratio=0, resolution=roi_size, strides=[4, 8, 16, 32]
)
# Used feature layers are [start_level, end_level)
self.start_level = start_level
self.end_level = start_level + len(self.roi_pooler.scales)
self.num_shared_convs = num_shared_convs
self.num_shared_fcs = num_shared_fcs
self.conv_out_channels = conv_out_channels
self.fc_out_channels = fc_out_channels
# add shared convs and fcs
(
self.shared_convs,
self.shared_fcs,
last_layer_dim,
) = self._add_conv_fc_branch(
self.num_shared_convs, self.num_shared_fcs, in_channels, True
)
self.shared_out_channels = last_layer_dim
in_channels *= prod(roi_size)
self.fc_cls = nn.Linear(
in_features=fc_out_channels, out_features=num_classes + 1
)
self.fc_reg = nn.Linear(
in_features=fc_out_channels, out_features=4 * num_classes
)
self.relu = nn.ReLU(inplace=True)
self._init_weights()
def _add_conv_fc_branch(
self,
num_branch_convs: int = 0,
num_branch_fcs: int = 0,
in_channels: int = 0,
is_shared: bool = False,
) -> tuple[nn.ModuleList, nn.ModuleList, int]:
"""Add shared or separable branch."""
convs, last_layer_dim = add_conv_branch(
num_branch_convs,
in_channels,
self.conv_out_channels,
True,
None,
None,
)
fcs = nn.ModuleList()
if num_branch_fcs > 0:
if is_shared or num_branch_fcs == 0:
last_layer_dim *= int(np.prod(self.roi_pooler.resolution))
for i in range(num_branch_fcs):
fc_in_dim = last_layer_dim if i == 0 else self.fc_out_channels
fcs.append(nn.Linear(fc_in_dim, self.fc_out_channels))
return convs, fcs, last_layer_dim
def _init_weights(self) -> None:
"""Init weights."""
for m in self.shared_convs.modules():
kaiming_init(m)
for m in self.shared_fcs.modules():
xavier_init(m, distribution="uniform")
normal_init(self.fc_cls, std=0.01)
normal_init(self.fc_reg, std=0.001)
[docs]
def forward(
self, features: list[torch.Tensor], boxes: list[torch.Tensor]
) -> RCNNOut:
"""Forward pass during training stage."""
bbox_feats = self.roi_pooler(
features[self.start_level : self.end_level], boxes
)
if self.num_shared_convs > 0:
for conv in self.shared_convs:
bbox_feats = conv(bbox_feats)
bbox_feats = bbox_feats.flatten(start_dim=1)
for fc in self.shared_fcs:
bbox_feats = self.relu(fc(bbox_feats))
cls_score = self.fc_cls(bbox_feats)
bbox_pred = self.fc_reg(bbox_feats)
return RCNNOut(cls_score, bbox_pred)
[docs]
def __call__(
self, features: list[torch.Tensor], boxes: list[torch.Tensor]
) -> RCNNOut:
"""Type definition for function call."""
return self._call_impl(features, boxes)
[docs]
class RoI2Det(nn.Module):
"""Post processing of RCNN results and detection generation.
It does the following:
1. Take the classification and regression outputs from the RCNN heads.
2. Take the proposal boxes that are RCNN inputs.
3. Determine the final box classes and take the according box regression
parameters.
4. Adjust the box sizes and offsets according the regression parameters.
5. Return the final boxes.
"""
def __init__(
self,
box_decoder: None | DeltaXYWHBBoxDecoder = None,
score_threshold: float = 0.05,
iou_threshold: float = 0.5,
max_per_img: int = 100,
class_agnostic_nms: bool = False,
) -> None:
"""Creates an instance of the class.
Args:
box_decoder (DeltaXYWHBBoxDecoder, optional): Decodes regression
parameters to detected boxes. Defaults to None. If None, it
will use the default decoder.
score_threshold (float, optional): Minimum score of a detection.
Defaults to 0.05.
iou_threshold (float, optional): IoU threshold of NMS
post-processing step. Defaults to 0.5.
max_per_img (int, optional): Maximum number of detections per
image. Defaults to 100.
class_agnostic_nms (bool, optional): Whether to use class agnostic
NMS. Defaults to False.
"""
super().__init__()
if box_decoder is None:
_, self.box_decoder = get_default_rcnn_box_codec()
else:
self.box_decoder = box_decoder
self.score_threshold = score_threshold
self.max_per_img = max_per_img
self.iou_threshold = iou_threshold
self.class_agnostic_nms = class_agnostic_nms
[docs]
def forward(
self,
class_outs: torch.Tensor,
regression_outs: torch.Tensor,
boxes: list[torch.Tensor],
images_hw: list[tuple[int, int]],
) -> DetOut:
"""Convert RCNN network outputs to detections.
Args:
class_outs (torch.Tensor): [B, num_classes] batched tensor of
classifiation scores.
regression_outs (torch.Tensor): [B, num_classes * 4] predicted
box offsets.
boxes (list[torch.Tensor]): Initial boxes (RoIs).
images_hw (list[tuple[int, int]]): Image sizes.
Returns:
DetOut: boxes, scores and class ids of detections per image.
"""
num_proposals_per_img = tuple(len(p) for p in boxes)
regression_outs = regression_outs.split(num_proposals_per_img, 0)
class_outs = class_outs.split(num_proposals_per_img, 0)
all_det_boxes = []
all_det_scores = []
all_det_class_ids = []
for cls_out, reg_out, boxs, image_hw in zip(
class_outs, regression_outs, boxes, images_hw
):
scores = F.softmax(cls_out, dim=-1)
bboxes = bbox_clip(
self.box_decoder(boxs[:, :4], reg_out).view(-1, 4),
image_hw,
).view(reg_out.shape)
det_bbox, det_scores, det_label, _ = multiclass_nms(
bboxes,
scores,
self.score_threshold,
self.iou_threshold,
self.max_per_img,
self.class_agnostic_nms,
)
all_det_boxes.append(det_bbox)
all_det_scores.append(det_scores)
all_det_class_ids.append(det_label)
return DetOut(
boxes=all_det_boxes,
scores=all_det_scores,
class_ids=all_det_class_ids,
)
[docs]
def __call__(
self,
class_outs: torch.Tensor,
regression_outs: torch.Tensor,
boxes: list[torch.Tensor],
images_hw: list[tuple[int, int]],
) -> DetOut:
"""Type definition for function call."""
return self._call_impl(class_outs, regression_outs, boxes, images_hw)
[docs]
class RCNNTargets(NamedTuple):
"""Target container."""
labels: Tensor
label_weights: Tensor
bbox_targets: Tensor
bbox_weights: Tensor
[docs]
class RCNNLosses(NamedTuple):
"""RCNN loss container."""
rcnn_loss_cls: torch.Tensor
rcnn_loss_bbox: torch.Tensor
[docs]
class RCNNLoss(nn.Module):
"""RCNN loss in Faster R-CNN.
This class computes the loss of RCNN given proposal boxes and their
corresponding target boxes with the given box encoder.
"""
def __init__(
self,
box_encoder: DeltaXYWHBBoxEncoder,
num_classes: int = 80,
loss_cls: TorchLossFunc = F.cross_entropy,
loss_bbox: TorchLossFunc = l1_loss,
) -> None:
"""Creates an instance of the class.
Args:
box_encoder (DeltaXYWHBBoxEncoder): Decodes box regression
parameters into detected boxes.
num_classes (int, optional): number of object categories. Defaults
to 80.
loss_cls (TorchLossFunc, optional): Classification loss function.
Defaults to F.cross_entropy.
loss_bbox (TorchLossFunc, optional): Regression loss function.
Defaults to l1_loss.
"""
super().__init__()
self.num_classes = num_classes
self.box_encoder = box_encoder
self.loss_cls = loss_cls
self.loss_bbox = loss_bbox
def _get_targets_per_image(
self,
boxes: Tensor,
labels: Tensor,
target_boxes: Tensor,
target_classes: Tensor,
) -> RCNNTargets:
"""Generate targets per image.
Args:
boxes (Tensor): [N, 4] tensor of proposal boxes
labels (Tensor): [N,] tensor of positive / negative / ignore labels
target_boxes (Tensor): [N, 4] Assigned target boxes.
target_classes (Tensor): [N,] Assigned target class labels.
Returns:
RCNNTargets: Box / class label tensors and weights.
"""
pos_mask, neg_mask = torch.eq(labels, 1), torch.eq(labels, 0)
num_pos, num_neg = int(pos_mask.sum()), int(neg_mask.sum())
num_samples = num_pos + num_neg
# original implementation uses new_zeros since BG are set to be 0
# now use empty & fill because BG cat_id = num_classes,
# FG cat_id = [0, num_classes-1]
labels = boxes.new_full(
(num_samples,), self.num_classes, dtype=torch.long
)
label_weights = boxes.new_zeros(num_samples)
box_targets = boxes.new_zeros(num_samples, 4)
box_weights = boxes.new_zeros(num_samples, 4)
if num_pos > 0:
pos_target_boxes = target_boxes[pos_mask]
pos_target_classes = target_classes[pos_mask]
labels[:num_pos] = pos_target_classes
label_weights[:num_pos] = 1.0
pos_box_targets = self.box_encoder(
boxes[pos_mask], pos_target_boxes
)
box_targets[:num_pos, :] = pos_box_targets
box_weights[:num_pos, :] = 1
if num_neg > 0:
label_weights[-num_neg:] = 1.0
return RCNNTargets(labels, label_weights, box_targets, box_weights)
[docs]
def forward(
self,
class_outs: torch.Tensor,
regression_outs: torch.Tensor,
boxes: list[torch.Tensor],
boxes_mask: list[torch.Tensor],
target_boxes: list[torch.Tensor],
target_classes: list[torch.Tensor],
) -> RCNNLosses:
"""Calculate losses of RCNN head.
Args:
class_outs (torch.Tensor): [M*B, num_classes] classification
outputs.
regression_outs (torch.Tensor): Tensor[M*B, regression_params]
regression outputs.
boxes (list[torch.Tensor]): [M, 4] proposal boxes per batch
element.
boxes_mask (list[torch.Tensor]): positive (1), ignore (-1),
negative (0).
target_boxes (list[torch.Tensor]): list of [M, 4] assigned target
boxes for each proposal.
target_classes (list[torch.Tensor]): list of [M,] assigned target
classes for each proposal.
Returns:
RCNNLosses: classification and regression losses.
"""
# get targets
targets = []
for boxs, boxs_mask, tgt_boxs, tgt_cls in zip(
boxes, boxes_mask, target_boxes, target_classes
):
targets.append(
self._get_targets_per_image(boxs, boxs_mask, tgt_boxs, tgt_cls)
)
labels = torch.cat([tgt.labels for tgt in targets], 0)
label_weights = torch.cat([tgt.label_weights for tgt in targets], 0)
bbox_targets = torch.cat([tgt.bbox_targets for tgt in targets], 0)
bbox_weights = torch.cat([tgt.bbox_weights for tgt in targets], 0)
# compute losses
avg_factor = torch.sum(torch.greater(label_weights, 0)).clamp(1.0)
if class_outs.numel() > 0:
loss_cls = SumWeightedLoss(label_weights, avg_factor)(
self.loss_cls(class_outs, labels, reduction="none")
)
else:
loss_cls = class_outs.sum()
bg_class_ind = self.num_classes
# 0~self.num_classes-1 are FG, self.num_classes is BG
pos_inds = torch.logical_and(
torch.greater_equal(labels, 0), torch.less(labels, bg_class_ind)
)
# do not perform bounding box regression for BG anymore.
if pos_inds.any():
pos_reg_outs = regression_outs.view(
regression_outs.size(0), -1, 4
)[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]]
loss_bbox = self.loss_bbox(
pred=pos_reg_outs,
target=bbox_targets[pos_inds.type(torch.bool)],
reducer=SumWeightedLoss(
bbox_weights[pos_inds.type(torch.bool)],
bbox_targets.size(0),
),
)
else:
loss_bbox = regression_outs[pos_inds].sum()
return RCNNLosses(rcnn_loss_cls=loss_cls, rcnn_loss_bbox=loss_bbox)