vis4d.model.detect.retinanet

RetinaNet model implementation and runtime.

Classes

RetinaNet(num_classes[, weights])

RetinaNet wrapper class for checkpointing etc.

RetinaNetLoss(anchor_generator, box_encoder, ...)

RetinaNet Loss.

class RetinaNet(num_classes, weights=None)[source]

RetinaNet wrapper class for checkpointing etc.

Creates an instance of the class.

Parameters:
  • num_classes (int) – Number of classes.

  • weights (None | str, optional) – Weights to load for model. If set to “mmdet”, will load MMDetection pre-trained weights. Defaults to None.

forward(images, input_hw=None, original_hw=None)[source]

Forward pass.

Parameters:
  • images (Tensor) – Input images.

  • input_hw (None | list[tuple[int, int]], optional) – Input image resolutions. Defaults to None.

  • original_hw (None | list[tuple[int, int]], optional) – Original image resolutions (before padding and resizing). Required for testing. Defaults to None.

Returns:

Either raw model outputs (for training) or

predicted outputs (for testing).

Return type:

RetinaNetOut | DetOut

forward_test(images, images_hw, original_hw)[source]

Forward testing stage.

Parameters:
  • images (Tensor) – Input images.

  • images_hw (list[tuple[int, int]]) – Input image resolutions.

  • original_hw (list[tuple[int, int]]) – Original image resolutions (before padding and resizing).

Returns:

Predicted outputs.

Return type:

DetOut

forward_train(images)[source]

Forward training stage.

Parameters:

images (Tensor) – Input images.

Returns:

Raw model outputs.

Return type:

RetinaNetOut

class RetinaNetLoss(anchor_generator, box_encoder, box_matcher, box_sampler)[source]

RetinaNet Loss.

Creates an instance of the class.

Parameters:
  • anchor_generator (AnchorGenerator) – Anchor generator for RPN.

  • box_encoder (BoxEncoder2D) – Bounding box encoder.

  • box_matcher (BaseMatcher) – Bounding box matcher.

  • box_sampler (BaseSampler) – Bounding box sampler.

forward(outputs, images_hw, target_boxes, target_classes)[source]

Forward of loss function.

Parameters:
  • outputs (RetinaNetOut) – Raw model outputs.

  • images_hw (list[tuple[int, int]]) – Input image resolutions.

  • target_boxes (list[Tensor]) – Bounding box labels.

  • target_classes (list[Tensor]) – Class labels.

Returns:

Dictionary of model losses.

Return type:

LossesType