vis4d.model.detect.retinanet¶
RetinaNet model implementation and runtime.
Classes
|
RetinaNet wrapper class for checkpointing etc. |
|
RetinaNet Loss. |
- class RetinaNet(num_classes, weights=None)[source]¶
RetinaNet wrapper class for checkpointing etc.
Creates an instance of the class.
- Parameters:
num_classes (int) – Number of classes.
weights (None | str, optional) – Weights to load for model. If set to “mmdet”, will load MMDetection pre-trained weights. Defaults to None.
- forward(images, input_hw=None, original_hw=None)[source]¶
Forward pass.
- Parameters:
images (Tensor) – Input images.
input_hw (None | list[tuple[int, int]], optional) – Input image resolutions. Defaults to None.
original_hw (None | list[tuple[int, int]], optional) – Original image resolutions (before padding and resizing). Required for testing. Defaults to None.
- Returns:
- Either raw model outputs (for training) or
predicted outputs (for testing).
- Return type:
- forward_test(images, images_hw, original_hw)[source]¶
Forward testing stage.
- Parameters:
images (Tensor) – Input images.
images_hw (list[tuple[int, int]]) – Input image resolutions.
original_hw (list[tuple[int, int]]) – Original image resolutions (before padding and resizing).
- Returns:
Predicted outputs.
- Return type:
- class RetinaNetLoss(anchor_generator, box_encoder, box_matcher, box_sampler)[source]¶
RetinaNet Loss.
Creates an instance of the class.
- Parameters:
anchor_generator (AnchorGenerator) – Anchor generator for RPN.
box_encoder (BoxEncoder2D) – Bounding box encoder.
box_matcher (BaseMatcher) – Bounding box matcher.
box_sampler (BaseSampler) – Bounding box sampler.
- forward(outputs, images_hw, target_boxes, target_classes)[source]¶
Forward of loss function.
- Parameters:
outputs (RetinaNetOut) – Raw model outputs.
images_hw (list[tuple[int, int]]) – Input image resolutions.
target_boxes (list[Tensor]) – Bounding box labels.
target_classes (list[Tensor]) – Class labels.
- Returns:
Dictionary of model losses.
- Return type:
LossesType