"""Mask RCNN model implementation and runtime."""
from __future__ import annotations
from typing import NamedTuple
import torch
from torch import nn
from vis4d.common.ckpt import load_model_checkpoint
from vis4d.op.base import BaseModel, ResNet
from vis4d.op.box.box2d import apply_mask, scale_and_clip_boxes
from vis4d.op.box.encoder import DeltaXYWHBBoxDecoder
from vis4d.op.detect.common import DetOut
from vis4d.op.detect.faster_rcnn import FasterRCNNHead, FRCNNOut
from vis4d.op.detect.mask_rcnn import (
Det2Mask,
MaskOut,
MaskRCNNHead,
MaskRCNNHeadOut,
)
from vis4d.op.detect.rcnn import RoI2Det
from vis4d.op.fpp.fpn import FPN
[docs]
class MaskDetectionOut(NamedTuple):
"""Mask detection output."""
boxes: DetOut
masks: MaskOut
[docs]
class MaskRCNNOut(NamedTuple):
"""Mask RCNN output."""
boxes: FRCNNOut
masks: MaskRCNNHeadOut
REV_KEYS = [
(r"^backbone\.", "basemodel."),
(r"^rpn_head.rpn_reg\.", "rpn_head.rpn_box."),
(r"^roi_head.bbox_head\.", "roi_head."),
(r"^roi_head.mask_head\.", "mask_head."),
(r"^convs\.", "mask_head.convs."),
(r"^upsample\.", "mask_head.upsample."),
(r"^conv_logits\.", "mask_head.conv_logits."),
(r"^roi_head\.", "faster_rcnn_head.roi_head."),
(r"^rpn_head\.", "faster_rcnn_head.rpn_head."),
(r"^neck.lateral_convs\.", "fpn.inner_blocks."),
(r"^neck.fpn_convs\.", "fpn.layer_blocks."),
(r"\.conv.weight", ".weight"),
(r"\.conv.bias", ".bias"),
]
[docs]
class MaskRCNN(nn.Module):
"""Mask RCNN model.
Args:
num_classes (int): Number of classes.
basemodel (BaseModel, optional): Base model network. Defaults to
None. If None, will use ResNet50.
faster_rcnn_head (FasterRCNNHead, optional): Faster RCNN head.
Defaults to None. if None, will use default FasterRCNNHead.
mask_head (MaskRCNNHead, optional): Mask RCNN head. Defaults to
None. if None, will use default MaskRCNNHead.
rcnn_box_decoder (DeltaXYWHBBoxDecoder, optional): Decoder for RCNN
bounding boxes. Defaults to None.
no_overlap (bool, optional): Whether to remove overlapping pixels
between masks. Defaults to False.
weights (None | str, optional): Weights to load for model. If set
to "mmdet", will load MMDetection pre-trained weights.
Defaults to None.
"""
def __init__(
self,
num_classes: int,
basemodel: BaseModel | None = None,
faster_rcnn_head: FasterRCNNHead | None = None,
mask_head: MaskRCNNHead | None = None,
rcnn_box_decoder: DeltaXYWHBBoxDecoder | None = None,
no_overlap: bool = False,
weights: None | str = None,
) -> None:
"""Creates an instance of the class."""
super().__init__()
self.basemodel = (
ResNet(resnet_name="resnet50", pretrained=True, trainable_layers=3)
if basemodel is None
else basemodel
)
self.fpn = FPN(self.basemodel.out_channels[2:], 256)
if faster_rcnn_head is None:
self.faster_rcnn_head = FasterRCNNHead(num_classes=num_classes)
else:
self.faster_rcnn_head = faster_rcnn_head
if mask_head is None:
self.mask_head = MaskRCNNHead(num_classes=num_classes)
else:
self.mask_head = mask_head
self.transform_outs = RoI2Det(rcnn_box_decoder)
self.det2mask = Det2Mask(no_overlap=no_overlap)
if weights is not None:
if weights == "mmdet":
weights = (
"mmdet://mask_rcnn/mask_rcnn_r50_fpn_2x_coco/"
"mask_rcnn_r50_fpn_2x_coco_bbox_mAP-0.392__segm_mAP-0.354_"
"20200505_003907-3e542a40.pth"
)
if weights.startswith("mmdet://") or weights.startswith(
"bdd100k://"
):
load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
else:
load_model_checkpoint(self, weights)
[docs]
def forward(
self,
images: torch.Tensor,
input_hw: list[tuple[int, int]],
boxes2d: None | list[torch.Tensor] = None,
boxes2d_classes: None | list[torch.Tensor] = None,
original_hw: None | list[tuple[int, int]] = None,
) -> MaskRCNNOut | MaskDetectionOut:
"""Forward pass.
Args:
images (torch.Tensor): Input images.
input_hw (list[tuple[int, int]]): Input image resolutions.
boxes2d (None | list[torch.Tensor], optional): Bounding box
labels. Required for training. Defaults to None.
boxes2d_classes (None | list[torch.Tensor], optional): Class
labels. Required for training. Defaults to None.
original_hw (None | list[tuple[int, int]], optional): Original
image resolutions (before padding and resizing). Required for
testing. Defaults to None.
Returns:
MaskRCNNOut | MaskDetectionOut: Either raw model
outputs (for training) or predicted outputs (for testing).
"""
if self.training:
assert boxes2d is not None and boxes2d_classes is not None
return self.forward_train(
images, input_hw, boxes2d, boxes2d_classes
)
assert original_hw is not None
return self.forward_test(images, input_hw, original_hw)
[docs]
def forward_train(
self,
images: torch.Tensor,
images_hw: list[tuple[int, int]],
target_boxes: list[torch.Tensor],
target_classes: list[torch.Tensor],
) -> MaskRCNNOut:
"""Forward training stage.
Args:
images (torch.Tensor): Input images.
images_hw (list[tuple[int, int]]): Input image resolutions.
target_boxes (list[torch.Tensor]): Bounding box labels. Required
for training. Defaults to None.
target_classes (list[torch.Tensor]): Class labels. Required for
training. Defaults to None.
Returns:
MaskRCNNOut: Raw model outputs.
"""
features = self.fpn(self.basemodel(images))
outputs = self.faster_rcnn_head(
features, images_hw, target_boxes, target_classes
)
assert outputs.sampled_proposals is not None
assert outputs.sampled_targets is not None
pos_proposals = apply_mask(
[torch.eq(label, 1) for label in outputs.sampled_targets.labels],
outputs.sampled_proposals.boxes,
)[0]
mask_outs = self.mask_head(features, pos_proposals)
return MaskRCNNOut(outputs, mask_outs)
[docs]
def forward_test(
self,
images: torch.Tensor,
images_hw: list[tuple[int, int]],
original_hw: list[tuple[int, int]],
) -> MaskDetectionOut:
"""Forward testing stage.
Args:
images (torch.Tensor): Input images.
images_hw (list[tuple[int, int]]): Input image resolutions.
original_hw (list[tuple[int, int]]): Original image resolutions
(before padding and resizing).
Returns:
MaskDetectionOut: Predicted outputs.
"""
features = self.fpn(self.basemodel(images))
outs = self.faster_rcnn_head(features, images_hw)
boxes, scores, class_ids = self.transform_outs(
*outs.roi, outs.proposals.boxes, images_hw
)
mask_outs = self.mask_head(features, boxes)
for i, boxs in enumerate(boxes):
boxes[i] = scale_and_clip_boxes(boxs, original_hw[i], images_hw[i])
mask_preds = [m.sigmoid() for m in mask_outs.mask_pred]
masks = self.det2mask(
mask_preds, boxes, scores, class_ids, original_hw
)
return MaskDetectionOut(DetOut(boxes, scores, class_ids), masks)