"""Mask RCNN detector."""

from __future__ import annotations

from typing import NamedTuple, Protocol

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torchvision.ops import roi_align

from import apply_mask
from import MultiScaleRoIAlign
from vis4d.op.mask.util import paste_masks_in_image, remove_overlap

from ..typing import Proposals, Targets

[docs] class MaskRCNNHeadOut(NamedTuple): """Mask R-CNN RoI head outputs.""" # logits for mask prediction. The dimension is number of masks x number of # classes x H_mask x W_mask mask_pred: list[torch.Tensor]
[docs] class MaskRCNNHead(nn.Module): """Mask R-CNN RoI head. Args: num_classes (int, optional): Number of classes. Defaults to 80. num_convs (int, optional): Number of convolution layers. Defaults to 4. roi_size (tuple[int, int], optional): Size of RoI after pooling. Defaults to (14, 14). in_channels (int, optional): Input feature channels. Defaults to 256. conv_kernel_size (int, optional): Kernel size of convolution. Defaults to 3. conv_out_channels (int, optional): Output channels of convolution. Defaults to 256. scale_factor (int, optional): Scaling factor of upsampling. Defaults to 2. class_agnostic (bool, optional): Whether to do class agnostic mask prediction. Defaults to False. """ def __init__( self, num_classes: int = 80, num_convs: int = 4, roi_size: tuple[int, int] = (14, 14), in_channels: int = 256, conv_kernel_size: int = 3, conv_out_channels: int = 256, scale_factor: int = 2, class_agnostic: bool = False, ) -> None: """Creates an instance of the class.""" super().__init__() self.roi_pooler = MultiScaleRoIAlign( sampling_ratio=0, resolution=roi_size, strides=[4, 8, 16, 32] ) self.convs = nn.ModuleList() for i in range(num_convs): in_channels = in_channels if i == 0 else conv_out_channels padding = (conv_kernel_size - 1) // 2 self.convs.append( nn.Conv2d( in_channels, conv_out_channels, conv_kernel_size, padding=padding, ) ) upsample_in_channels = ( conv_out_channels if num_convs > 0 else in_channels ) self.upsample = nn.ConvTranspose2d( upsample_in_channels, conv_out_channels, scale_factor, stride=scale_factor, ) out_channels = 1 if class_agnostic else num_classes self.conv_logits = nn.Conv2d(conv_out_channels, out_channels, 1) self.relu = nn.ReLU(inplace=True) self._init_weights(self.convs) self._init_weights(self.upsample, mode="fan_out") self._init_weights(self.conv_logits, mode="fan_out") @staticmethod def _init_weights(module: nn.Module, mode: str = "fan_in") -> None: """Initialize weights.""" if hasattr(module, "weight") and hasattr(module, "bias"): assert isinstance(module.weight, torch.Tensor) and isinstance( module.bias, torch.Tensor ) nn.init.kaiming_normal_( module.weight, mode=mode, nonlinearity="relu" ) nn.init.constant_(module.bias, 0)
[docs] def forward( self, features: list[torch.Tensor], boxes: list[torch.Tensor] ) -> MaskRCNNHeadOut: """Forward pass. Args: features (list[torch.Tensor]): Feature pyramid. boxes (list[torch.Tensor]): Proposal boxes. Returns: MaskRCNNHeadOut: Mask prediction outputs. """ # Take stride 4, 8, 16, 32 features mask_feats = self.roi_pooler(features[2:6], boxes) for conv in self.convs: mask_feats = self.relu(conv(mask_feats)) mask_feats = self.relu(self.upsample(mask_feats)) mask_pred = self.conv_logits(mask_feats) num_dets_per_img = tuple(len(d) for d in boxes) mask_preds = mask_pred.split(num_dets_per_img, 0) return MaskRCNNHeadOut(mask_pred=mask_preds)
[docs] class MaskOut(NamedTuple): """Output of the final detections from Mask RCNN.""" masks: list[torch.Tensor] # N, H, W scores: list[torch.Tensor] class_ids: list[torch.Tensor]
[docs] class Det2Mask(nn.Module): """Post processing of mask predictions. Args: mask_threshold (float, optional): Positive threshold. Defaults to 0.5. no_overlap (bool, optional): Whether to remove overlapping pixels between masks. Defaults to False. """ def __init__( self, mask_threshold: float = 0.5, no_overlap: bool = False ) -> None: """Creates an instance of the class.""" super().__init__() self.mask_threshold = mask_threshold self.no_overlap = no_overlap
[docs] def forward( self, mask_outs: list[torch.Tensor], det_boxes: list[torch.Tensor], det_scores: list[torch.Tensor], det_class_ids: list[torch.Tensor], original_hw: list[tuple[int, int]], ) -> MaskOut: """Paste mask predictions back into original image resolution. Args: mask_outs (list[torch.Tensor]): List of mask outputs for each batch element. det_boxes (list[torch.Tensor]): List of detection boxes for each batch element. det_scores (list[torch.Tensor]): List of detection scores for each batch element. det_class_ids (list[torch.Tensor]): List of detection classeds for each batch element. original_hw (list[tuple[int, int]]): Original image resolution. Returns: MaskOut: Post-processed mask predictions. """ all_masks = [] all_scores = [] all_class_ids = [] for mask_out, boxes, scores, class_ids, orig_hw in zip( mask_outs, det_boxes, det_scores, det_class_ids, original_hw ): pasted_masks = paste_masks_in_image( mask_out[torch.arange(len(mask_out)), class_ids], boxes, orig_hw[::-1], self.mask_threshold, ) if self.no_overlap: pasted_masks = remove_overlap(pasted_masks, scores) all_masks.append(pasted_masks) all_scores.append(scores) all_class_ids.append(class_ids) return MaskOut( masks=all_masks, scores=all_scores, class_ids=all_class_ids )
[docs] def __call__( self, mask_outs: list[torch.Tensor], det_boxes: list[torch.Tensor], det_scores: list[torch.Tensor], det_class_ids: list[torch.Tensor], original_hw: list[tuple[int, int]], ) -> MaskOut: """Type definition for function call.""" return self._call_impl( mask_outs, det_boxes, det_scores, det_class_ids, original_hw )
[docs] class MaskRCNNHeadLosses(NamedTuple): """Mask RoI head loss container.""" rcnn_loss_mask: torch.Tensor
[docs] class MaskRCNNHeadLoss(nn.Module): """Mask RoI head loss function. Args: num_classes (int): number of object categories. """ def __init__(self, num_classes: int) -> None: """Creates an instance of the class.""" super().__init__() self.num_classes = num_classes @staticmethod def _get_targets_per_image( boxes: Tensor, tgt_masks: Tensor, out_shape: tuple[int, int], binarize: bool = True, ) -> Tensor: """Get aligned mask targets for each proposal. Args: boxes (Tensor): proposal boxes. tgt_masks (Tensor): target masks. out_shape (tuple[int, int]): output shape. binarize (bool, optional): whether to convert target mask to binary. Defaults to True. Returns: Tensor: aligned mask targets. """ fake_inds = torch.arange(len(boxes), device=boxes.device)[:, None] rois =[fake_inds, boxes], dim=1) # Nx5 gt_masks_th = tgt_masks[:, None, :, :].type(rois.dtype) targets = roi_align( gt_masks_th, rois, out_shape, 1.0, 0, True ).squeeze(1) resized_masks = targets >= 0.5 if binarize else targets return resized_masks
[docs] def forward( self, mask_preds: list[torch.Tensor], proposal_boxes: list[torch.Tensor], target_classes: list[torch.Tensor], target_masks: list[torch.Tensor], ) -> MaskRCNNHeadLosses: """Calculate losses of Mask RCNN head. Args: mask_preds (list[torch.Tensor]): [M, C, H', W'] mask outputs per batch element. proposal_boxes (list[torch.Tensor]): [M, 4] proposal boxes per batch element. target_classes (list[torch.Tensor]): list of [M, 4] assigned target boxes for each proposal. target_masks (list[torch.Tensor]): list of [M, H, W] assigned target masks for each proposal. Returns: MaskRCNNHeadLosses: mask loss. """ mask_pred = mask_size = (mask_pred.shape[2], mask_pred.shape[3]) # get targets targets = [] for boxes, tgt_masks in zip(proposal_boxes, target_masks): if len(tgt_masks) == 0: targets.append( torch.empty((0, *mask_size), device=tgt_masks.device) ) else: targets.append( self._get_targets_per_image(boxes, tgt_masks, mask_size) ) mask_targets = mask_labels = if len(mask_targets) > 0: num_rois = mask_pred.shape[0] inds = torch.arange( 0, num_rois, dtype=torch.long, device=mask_pred.device ) pred_slice = mask_pred[inds, mask_labels[inds].long()].squeeze(1) loss_mask = F.binary_cross_entropy_with_logits( pred_slice, mask_targets.float(), reduction="mean" ) else: loss_mask = mask_targets.sum() return MaskRCNNHeadLosses(rcnn_loss_mask=loss_mask)
[docs] class MaskSampler(Protocol): """Type definition for mask sampler."""
[docs] def __call__( self, target_masks: list[Tensor], sampled_target_indices: list[Tensor], sampled_targets: Targets, sampled_proposals: Proposals, ) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: """Type definition for function call. Args: target_masks (list[Tensor]): list of [N, H, W] target masks per batch element. sampled_target_indices (list[Tensor]): list of [M] indices of sampled targets per batch element. sampled_targets (Targets): sampled targets. sampled_proposals (Proposals): sampled proposals. Returns: tuple[list[Tensor], list[Tensor], list[Tensor]]: sampled masks, sampled target indices, sampled targets. """
[docs] def positive_mask_sampler( target_masks: list[Tensor], sampled_target_indices: list[Tensor], sampled_targets: Targets, sampled_proposals: Proposals, ) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: """Sample only positive masks from target masks. Args: target_masks (list[Tensor]): list of [N, H, W] target masks per batch element. sampled_target_indices (list[Tensor]): list of [M] indices of sampled targets per batch element. sampled_targets (Targets): sampled targets. sampled_proposals (Proposals): sampled proposals. Returns: tuple[list[Tensor], list[Tensor], list[Tensor]]: sampled masks, sampled target indices, sampled targets. """ sampled_masks = apply_mask(sampled_target_indices, target_masks)[0] pos_proposals, pos_classes, pos_mask_targets = apply_mask( [torch.eq(label, 1) for label in sampled_targets.labels], sampled_proposals.boxes, sampled_targets.classes, sampled_masks, ) return pos_proposals, pos_classes, pos_mask_targets
[docs] class SampledMaskLoss(nn.Module): """Sampled Mask RCNN head loss function.""" def __init__( self, mask_sampler: MaskSampler, loss: MaskRCNNHeadLoss, ) -> None: """Initialize sampled mask loss. Args: mask_sampler (MaskSampler): mask sampler. loss (MaskRCNNHeadLoss): mask loss. """ super().__init__() self.loss = loss self.mask_sampler = mask_sampler
[docs] def forward( self, mask_preds: list[Tensor], target_masks: list[Tensor], sampled_target_indices: list[Tensor], sampled_targets: Targets, sampled_proposals: Proposals, ) -> MaskRCNNHeadLosses: """Calculate losses of Mask RCNN head. Args: mask_preds (list[torch.Tensor]): [M, C, H', W'] mask outputs per batch element. target_masks (list[torch.Tensor]): list of [M, H, W] assigned target masks for each proposal. sampled_target_indices (list[Tensor]): list of [M, 4] assigned target boxes for each proposal. sampled_targets (Targets): list of [M, 4] assigned target boxes for each proposal. sampled_proposals (Proposals): list of [M, 4] assigned target boxes for each proposal. Returns: MaskRCNNHeadLosses: mask loss. """ pos_proposals, pos_classes, pos_mask_targets = self.mask_sampler( target_masks, sampled_target_indices, sampled_targets, sampled_proposals, ) return self.loss( mask_preds, pos_proposals, pos_classes, pos_mask_targets )