Source code for vis4d.op.geometry.transform
"""Vis4D geometric transformation functions."""
import torch
from torch import Tensor
[docs]
def transform_points(points: Tensor, transform: Tensor) -> Tensor:
"""Applies transform to points.
Args:
points (Tensor): points of shape (N, D) or (B, N, D).
transform (Tensor): transforms of shape (D+1, D+1) or (B, D+1, D+1).
Returns:
Tensor: (N, D) / (B, N, D) transformed points.
Raises:
ValueError: Either points or transform have incorrect shape
"""
hom_coords = torch.cat([points, torch.ones_like(points[..., 0:1])], -1)
if len(points.shape) == 2:
if len(transform.shape) == 3:
assert (
transform.shape[0] == 1
), "Got multiple transforms for single point set!"
transform = transform.squeeze(0)
transform = transform.T
elif len(points.shape) == 3:
if len(transform.shape) == 2:
transform = transform.T.unsqueeze(0)
elif len(transform.shape) == 3:
transform = transform.permute(0, 2, 1)
else:
raise ValueError(f"Shape of transform invalid: {transform.shape}")
else:
raise ValueError(f"Shape of input points invalid: {points.shape}")
points_transformed = hom_coords @ transform
return points_transformed[..., : points.shape[-1]]
[docs]
def inverse_pinhole(intrinsic_matrix: Tensor) -> Tensor:
"""Calculate inverse of pinhole projection matrix.
Args:
intrinsic_matrix (Tensor): [..., 3, 3] intrinsics or single [3, 3]
intrinsics.
Returns:
Tensor: Inverse of input intrinisics.
"""
squeeze = False
inv = intrinsic_matrix.clone()
if len(intrinsic_matrix.shape) == 2:
inv = inv.unsqueeze(0)
squeeze = True
inv[..., 0, 0] = 1.0 / inv[..., 0, 0]
inv[..., 1, 1] = 1.0 / inv[..., 1, 1]
inv[..., 0, 2] = -inv[..., 0, 2] * inv[..., 0, 0]
inv[..., 1, 2] = -inv[..., 1, 2] * inv[..., 1, 1]
if squeeze:
inv = inv.squeeze(0)
return inv
[docs]
def inverse_rigid_transform(transformation: Tensor) -> Tensor:
"""Calculate inverse of rigid body transformation(s).
Args:
transformation (Tensor): [N, 4, 4] transformations or single [4, 4]
transformation.
Returns:
Tensor: Inverse of input transformation(s).
"""
squeeze = False
if len(transformation.shape) == 2:
transformation = transformation.unsqueeze(0)
squeeze = True
rotation, translation = transformation[:, :3, :3], transformation[:, :3, 3]
rot = rotation.permute(0, 2, 1)
t = -rot @ translation[:, :, None]
inv = torch.cat([torch.cat([rot, t], -1), transformation[:, 3:4]], 1)
if squeeze:
inv = inv.squeeze(0)
return inv
[docs]
def get_transform_matrix(rotation: Tensor, translation: Tensor) -> Tensor:
"""Assembles 4x4 transformation from rotation / translation pair(s).
Args:
rotation (Tensor): [N, 3, 3] or [3, 3] rotation(s).
translation (Tensor): [N, 3] or [3,] translation(s).
Returns:
Tensor: [N, 4, 4] or [4, 4] transformation.
"""
squeeze = False
if len(rotation.shape) == 2:
assert len(translation.shape) == 1
rotation = rotation.unsqueeze(0)
translation = translation.unsqueeze(0)
squeeze = True
batch_size = 1
else:
assert len(rotation.shape) == 3 and len(translation.shape) == 2
assert rotation.shape[0] == translation.shape[0]
batch_size = rotation.shape[0]
assert (
rotation.shape[-2] == rotation.shape[-1] == translation.shape[-1] == 3
)
transforms = rotation.new_zeros((batch_size, 4, 4))
transforms[:, :3, :3] = rotation
transforms[:, :3, 3] = translation
transforms[:, 3, 3] = 1.0
if squeeze:
transforms = transforms.squeeze(0)
return transforms