vis4d.engine.loss_module

Loss module maps loss function input keys and controls loss weight.

Classes

LossDefinition

Loss definition.

LossModule(losses[, exclude_attributes])

Loss module maps input keys and combines losses with weights.

class LossDefinition[source]

Loss definition.

loss

Loss function to use.

Type:

Loss | nn.Module

connector

Connector to use for the loss.

Type:

LossConnector

weight

Weight to use for the loss.

Type:

float | dict[str, float], optional

name

Name to use for the loss.

Type:

str, optional

class LossModule(losses, exclude_attributes=None)[source]

Loss module maps input keys and combines losses with weights.

This loss combines multiple losses with weights. The loss values are weighted by the corresponding weight and returned as a dictionary.

Creates an instance of the class.

Each loss will be called with arguments matching the kwargs of the loss function through its connector. By default, the weight is set to 1.0.

Parameters:
  • losses (list[LossDefinition]) – List of loss definitions.

  • exclude_attributes (list[str] | None) – List of attributes returned by the losses that should be excluded from the total loss computation. Use it to log metrics that should not be optimised. Defaults to None.

Example

>>> loss = LossModule(
>>>     [
>>>         {
>>>             "loss": nn.MSELoss(),
>>>             "weight": 0.7,
>>>             "connector": LossConnector(
>>>                 {
>>>                     "input": pred_key("input"),
>>>                     "target": data_key("target"),
>>>                 }
>>>             ),
>>>         },
>>>         {
>>>             "loss": nn.L1Loss(),
>>>             "weight": 0.3
>>>             "connector": LossConnector(
>>>                 {
>>>                     "input": pred_key("input"),
>>>                     "target": data_key("target"),
>>>                 }
>>>             ),
>>>         },
>>>     ]
>>> )
forward(output, batch)[source]

Forward of loss module.

This function will call all loss functions and return a dictionary containing the loss values. The loss values are weighted by the corresponding weight.

If two losses have the same name, the name will be appended with two underscores.

Parameters:
  • output (DictData) – Output of the model.

  • batch (DictData) – Batch data.

Returns:

The total loss value. metrics: The metrics disctionary.

Return type:

total_loss