Source code for vis4d.zoo.base.models.mask_rcnn

"""Mask RCNN base model config."""

from __future__ import annotations

from ml_collections import ConfigDict, FieldReference

from vis4d.config import class_config
from vis4d.data.const import CommonKeys as K
from vis4d.engine.connectors import (
    LossConnector,
    data_key,
    pred_key,
    remap_pred_keys,
)
from vis4d.engine.loss_module import LossModule
from vis4d.model.detect.mask_rcnn import MaskRCNN
from vis4d.op.box.anchor import AnchorGenerator
from vis4d.op.box.matchers import MaxIoUMatcher
from vis4d.op.box.samplers import RandomSampler
from vis4d.op.detect.faster_rcnn import FasterRCNNHead
from vis4d.op.detect.mask_rcnn import (
    MaskRCNNHead,
    MaskRCNNHeadLoss,
    SampledMaskLoss,
    positive_mask_sampler,
)
from vis4d.op.detect.rcnn import RCNNHead, RCNNLoss
from vis4d.op.detect.rpn import RPNLoss
from vis4d.zoo.base import get_callable_cfg
from vis4d.zoo.base.models.faster_rcnn import (
    CONN_ROI_LOSS_2D as _CONN_ROI_LOSS_2D,
)
from vis4d.zoo.base.models.faster_rcnn import (
    CONN_RPN_LOSS_2D as _CONN_RPN_LOSS_2D,
)
from vis4d.zoo.base.models.faster_rcnn import (
    get_default_rcnn_box_codec_cfg,
    get_default_rpn_box_codec_cfg,
)

# Data connectors
CONN_MASK_HEAD_LOSS_2D = {
    "mask_preds": pred_key("masks.mask_pred"),
    "target_masks": data_key(K.instance_masks),
    "sampled_target_indices": pred_key("boxes.sampled_target_indices"),
    "sampled_targets": pred_key("boxes.sampled_targets"),
    "sampled_proposals": pred_key("boxes.sampled_proposals"),
}

CONN_RPN_LOSS_2D = remap_pred_keys(_CONN_RPN_LOSS_2D, "boxes")

CONN_ROI_LOSS_2D = remap_pred_keys(_CONN_ROI_LOSS_2D, "boxes")


[docs] def get_mask_rcnn_cfg( num_classes: FieldReference | int, basemodel: ConfigDict, no_overlap: bool = False, weights: str | None = None, ) -> tuple[ConfigDict, ConfigDict]: """Return default config for mask_rcnn model and loss. This is an example for setting every component of the model and loss. Everything is the same as the default args. Args: num_classes (FieldReference | int): Number of classes. basemodel (ConfigDict): Base model config. no_overlap (bool, optional): Whether to remove overlapping pixels between masks. Defaults to False. weights (str | None, optional): Weights to load. Defaults to None. """ ###################################################### ## MODEL ## ###################################################### anchor_generator = class_config( AnchorGenerator, scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4, 8, 16, 32, 64], ) rpn_box_encoder, rpn_box_decoder = get_default_rpn_box_codec_cfg() rcnn_box_encoder, rcnn_box_decoder = get_default_rcnn_box_codec_cfg() box_matcher = class_config( MaxIoUMatcher, thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False, ) box_sampler = class_config( RandomSampler, batch_size=512, positive_fraction=0.25 ) roi_head = class_config(RCNNHead, num_classes=num_classes) mask_head = class_config(MaskRCNNHead, num_classes=num_classes) faster_rcnn_head = class_config( FasterRCNNHead, num_classes=num_classes, anchor_generator=anchor_generator, rpn_box_decoder=rpn_box_decoder, box_matcher=box_matcher, box_sampler=box_sampler, roi_head=roi_head, ) model = class_config( MaskRCNN, num_classes=num_classes, basemodel=basemodel, faster_rcnn_head=faster_rcnn_head, mask_head=mask_head, rcnn_box_decoder=rcnn_box_decoder, no_overlap=no_overlap, weights=weights, ) ###################################################### ## LOSS ## ###################################################### rpn_loss = class_config( RPNLoss, anchor_generator=anchor_generator, box_encoder=rpn_box_encoder, ) rcnn_loss = class_config( RCNNLoss, box_encoder=rcnn_box_encoder, num_classes=num_classes ) mask_loss = class_config( SampledMaskLoss, mask_sampler=get_callable_cfg(positive_mask_sampler), loss=class_config(MaskRCNNHeadLoss, num_classes=num_classes), ) loss = class_config( LossModule, losses=[ { "loss": rpn_loss, "connector": class_config( LossConnector, key_mapping=CONN_RPN_LOSS_2D ), }, { "loss": rcnn_loss, "connector": class_config( LossConnector, key_mapping=CONN_ROI_LOSS_2D ), }, { "loss": mask_loss, "connector": class_config( LossConnector, key_mapping=CONN_MASK_HEAD_LOSS_2D ), }, ], ) return model, loss