Source code for vis4d.op.loss.orthogonal_transform_loss
"""Orthogonal Transform Loss."""
from __future__ import annotations
import torch
from .base import Loss
[docs]
class OrthogonalTransformRegularizationLoss(Loss):
"""Loss that punishes linear transformations that are not orthogonal.
Calculates difference of X'*X and identity matrix using norm( X'*X - I)
"""
[docs]
def __call___(self, transforms: list[torch.Tensor]) -> torch.Tensor:
"""Calculates the loss.
Calculates difference of X'*X and the identity matrix using
norm(X'*X - I) for each transformation
Args:
transforms: (list(torch.tensor)) list with transformation matrices
batched ([N, 3, 3], [N, x, x], ....)
Returns:
torch.Tensor containing the mean loss value (mean(norm(X'*X - I)))
"""
return self._call_impl(transforms)
[docs]
def forward(self, transforms: list[torch.Tensor]) -> torch.Tensor:
"""Calculates the loss.
Calculates difference of X'*X and the identity matrix using
norm(X'*X - I) for each transformation
Args:
transforms: (list(torch.tensor)) list with transformation matrices
batched ([N, 3, 3], [N, x, x], ....)
Returns:
torch.Tensor containing the mean loss value (mean(norm(X'*X - I)))
"""
loss = torch.tensor(0.0)
for trans in transforms:
d = trans.size()[1]
try:
identity = self.get_buffer(f"identity_{d}")
except AttributeError as _:
# Create identity buffers if not yet allocated
identity = torch.eye(d, device=trans.device)
self.register_buffer(f"identity_{d}", identity)
loss += torch.mean(
torch.norm(
torch.bmm(trans, trans.transpose(2, 1)) - identity,
dim=(1, 2),
)
)
return loss