Source code for vis4d.model.detect.mask_rcnn

"""Mask RCNN model implementation and runtime."""

from __future__ import annotations

from typing import NamedTuple

import torch
from torch import nn

from vis4d.common.ckpt import load_model_checkpoint
from vis4d.op.base import BaseModel, ResNet
from vis4d.op.box.box2d import apply_mask, scale_and_clip_boxes
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
from vis4d.op.detect.common import DetOut
from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut
from vis4d.op.detect.mask_rcnn import (
    Det2Mask,
    MaskOut,
    MaskRCNNHead,
    MaskRCNNHeadOut,
)
from vis4d.op.detect.rcnn import RoI2Det
from vis4d.op.fpp.fpn import FPN


[docs] class MaskDetectionOut(NamedTuple): """Mask detection output.""" boxes: DetOut masks: MaskOut
[docs] class MaskRCNNOut(NamedTuple): """Mask RCNN output.""" boxes: FRCNNOut masks: MaskRCNNHeadOut
REV_KEYS = [ (r"^backbone\.", "basemodel."), (r"^rpn_head.rpn_reg\.", "rpn_head.rpn_box."), (r"^roi_head.bbox_head\.", "roi_head."), (r"^roi_head.mask_head\.", "mask_head."), (r"^convs\.", "mask_head.convs."), (r"^upsample\.", "mask_head.upsample."), (r"^conv_logits\.", "mask_head.conv_logits."), (r"^roi_head\.", "faster_rcnn_head.roi_head."), (r"^rpn_head\.", "faster_rcnn_head.rpn_head."), (r"^neck.lateral_convs\.", "fpn.inner_blocks."), (r"^neck.fpn_convs\.", "fpn.layer_blocks."), (r"\.conv.weight", ".weight"), (r"\.conv.bias", ".bias"), ]
[docs] class MaskRCNN(nn.Module): """Mask RCNN model. Args: num_classes (int): Number of classes. basemodel (BaseModel, optional): Base model network. Defaults to None. If None, will use ResNet50. faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head. Defaults to None. if None, will use default FasterRCNNHead. mask_head (MaskRCNNHead, optional): Mask RCNN head. Defaults to None. if None, will use default MaskRCNNHead. rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN bounding boxes. Defaults to None. no_overlap (bool, optional): Whether to remove overlapping pixels between masks. Defaults to False. weights (None | str, optional): Weights to load for model. If set to "mmdet", will load MMDetection pre-trained weights. Defaults to None. """ def __init__( self, num_classes: int, basemodel: BaseModel | None = None, faster_rcnn_head: FasterRCNNHead | None = None, mask_head: MaskRCNNHead | None = None, rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None, no_overlap: bool = False, weights: None | str = None, ) -> None: """Creates an instance of the class.""" super().__init__() self.basemodel = ( ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3) if basemodel is None else basemodel ) self.fpn = FPN(self.basemodel.out_channels[2:], 256) if faster_rcnn_head is None: self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes) else: self.faster_rcnn_head = faster_rcnn_head if mask_head is None: self.mask_head = MaskRCNNHead(num_classes=num_classes) else: self.mask_head = mask_head self.transform_outs = RoI2Det(rcnn_box_decoder) self.det2mask = Det2Mask(no_overlap=no_overlap) if weights is not None: if weights == "mmdet": weights = ( "mmdet://mask_rcnn/mask_rcnn_r50_fpn_2x_coco/" "mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_" "20200505_003907-3e542a40.pth" ) if weights.startswith("mmdet://") or weights.startswith( "bdd100k://" ): load_model_checkpoint(self, weights, rev_keys=REV_KEYS) else: load_model_checkpoint(self, weights)
[docs] def forward( self, images: torch.Tensor, input_hw: list[tuple[int, int]], boxes2d: None | list[torch.Tensor] = None, boxes2d_classes: None | list[torch.Tensor] = None, original_hw: None | list[tuple[int, int]] = None, ) -> MaskRCNNOut | MaskDetectionOut: """Forward pass. Args: images (torch.Tensor): Input images. input_hw (list[tuple[int, int]]): Input image resolutions. boxes2d (None | list[torch.Tensor], optional): Bounding box labels. Required for training. Defaults to None. boxes2d_classes (None | list[torch.Tensor], optional): Class labels. Required for training. Defaults to None. original_hw (None | list[tuple[int, int]], optional): Original image resolutions (before padding and resizing). Required for testing. Defaults to None. Returns: MaskRCNNOut | MaskDetectionOut: Either raw model outputs (for training) or predicted outputs (for testing). """ if self.training: assert boxes2d is not None and boxes2d_classes is not None return self.forward_train( images, input_hw, boxes2d, boxes2d_classes ) assert original_hw is not None return self.forward_test(images, input_hw, original_hw)
[docs] def forward_train( self, images: torch.Tensor, images_hw: list[tuple[int, int]], target_boxes: list[torch.Tensor], target_classes: list[torch.Tensor], ) -> MaskRCNNOut: """Forward training stage. Args: images (torch.Tensor): Input images. images_hw (list[tuple[int, int]]): Input image resolutions. target_boxes (list[torch.Tensor]): Bounding box labels. Required for training. Defaults to None. target_classes (list[torch.Tensor]): Class labels. Required for training. Defaults to None. Returns: MaskRCNNOut: Raw model outputs. """ features = self.fpn(self.basemodel(images)) outputs = self.faster_rcnn_head( features, images_hw, target_boxes, target_classes ) assert outputs.sampled_proposals is not None assert outputs.sampled_targets is not None pos_proposals = apply_mask( [torch.eq(label, 1) for label in outputs.sampled_targets.labels], outputs.sampled_proposals.boxes, )[0] mask_outs = self.mask_head(features, pos_proposals) return MaskRCNNOut(outputs, mask_outs)
[docs] def forward_test( self, images: torch.Tensor, images_hw: list[tuple[int, int]], original_hw: list[tuple[int, int]], ) -> MaskDetectionOut: """Forward testing stage. Args: images (torch.Tensor): Input images. images_hw (list[tuple[int, int]]): Input image resolutions. original_hw (list[tuple[int, int]]): Original image resolutions (before padding and resizing). Returns: MaskDetectionOut: Predicted outputs. """ features = self.fpn(self.basemodel(images)) outs = self.faster_rcnn_head(features, images_hw) boxes, scores, class_ids = self.transform_outs( *outs.roi, outs.proposals.boxes, images_hw ) mask_outs = self.mask_head(features, boxes) for i, boxs in enumerate(boxes): boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i]) mask_preds = [m.sigmoid() for m in mask_outs.mask_pred] masks = self.det2mask( mask_preds, boxes, scores, class_ids, original_hw ) return MaskDetectionOut(DetOut(boxes, scores, class_ids), masks)