Source code for vis4d.op.base.pointnet
"""Operations for PointNet.
Code taken from
https://github.com/timothylimyl/PointNet-Pytorch/blob/master/pointnet/model.py
and modified to allow for modular configuration.
"""
from __future__ import annotations
from collections.abc import Callable, Iterable
from typing import NamedTuple
import torch
from torch import nn
from vis4d.common.typing import ArgsType
[docs]
class PointNetEncoderOut(NamedTuple):
"""Output of the PointNetEncoder.
features: Global features shape [N, feature_dim]
pointwise Features: Pointwise features shape [N, last_mlp_dim, n_pts]
transformations: list with all transformation matrixes that were used.
Shape [N, d, d]
"""
features: torch.Tensor
pointwise_features: torch.Tensor #
transformations: list[ # list with all transformation matrices [[B, d, d]]
torch.Tensor
]
[docs]
class PointNetSemanticsLoss(NamedTuple):
"""Losses for the pointnet semantic segmentation network."""
semantic_loss: torch.Tensor
regularization_loss: torch.Tensor
[docs]
class PointNetSemanticsOut(NamedTuple):
"""Output of the PointNet Segmentation network."""
class_logits: torch.Tensor # B, n_classes, n_pts
transformations: list[ # list with all transformation matrices [[B, d, d]]
torch.Tensor
]
[docs]
class LinearTransform(nn.Module):
"""Module that learns a linear transformation for a input pointcloud.
Code taken from
https://github.com/timothylimyl/PointNet-Pytorch/blob/master/pointnet/model.py
and modified to allow for modular configuration.
See T-Net in Pointnet publication (https://arxiv.org/pdf/1612.00593.pdf)
for more information
"""
def __init__(
self,
in_dimension: int = 3,
upsampling_dims: Iterable[int] = (64, 128, 1024),
downsampling_dims: Iterable[int] = (1024, 512, 256),
norm_cls: str | None = "BatchNorm1d",
activation_cls: str = "ReLU",
) -> None:
"""Creates a new LinearTransform.
This learns a transformation matrix from data.
Args:
in_dimension (int): input dimension
upsampling_dims (Iterable[int]): list of intermediate feature
shapes for upsampling
downsampling_dims (Iterable[int]): list of intermediate feature
shapes for downsampling.
Make sure this matches with the
last upsampling_dims
norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None
activation_cls (str): class for activation (nn.'activation_cls')
"""
super().__init__()
self.upsampling_dims = list(upsampling_dims)
self.downsampling_dims = list(downsampling_dims)
assert (
len(self.upsampling_dims) != 0 and len(self.downsampling_dims) != 0
)
assert self.upsampling_dims[-1] == self.downsampling_dims[0]
self.in_dimension_ = in_dimension
self.identity: torch.Tensor
self.register_buffer(
"identity", torch.eye(in_dimension).reshape(1, in_dimension**2)
)
# Create activation
self.activation_ = getattr(nn, activation_cls)()
# Create norms
norm_fn: Callable[[int], nn.Module] = (
getattr(nn, norm_cls) if norm_cls is not None else None
)
if norm_fn is not None:
self.norms_ = nn.ModuleList(
norm_fn(feature_size)
for feature_size in (
*upsampling_dims,
*self.downsampling_dims[1:],
)
)
# Create upsampling layers
self.upsampling_layers = nn.ModuleList(
[nn.Conv1d(in_dimension, self.upsampling_dims[0], 1)]
)
for i in range(len(self.upsampling_dims) - 1):
self.upsampling_layers.append(
nn.Conv1d(
self.upsampling_dims[i], self.upsampling_dims[i + 1], 1
)
)
# Create downsampling layers
self.downsampling_layers = nn.ModuleList(
[
nn.Linear(
self.downsampling_dims[i], self.downsampling_dims[i + 1]
)
for i in range(len(self.downsampling_dims) - 1)
]
)
self.downsampling_layers.append(
nn.Linear(self.downsampling_dims[-1], in_dimension**2)
)
[docs]
def __call__(
self,
features: torch.Tensor,
) -> torch.Tensor:
"""Type definition for call implementation."""
return self._call_impl(features)
[docs]
def forward(
self,
features: torch.Tensor,
) -> torch.Tensor:
"""Linear Transform forward.
Args:
features (Tensor[B, C, N]): Input features (e.g. points)
Returns:
Learned Canonical Transfomation Matrix for this input.
See T-Net in Pointnet publication
(https://arxiv.org/pdf/1612.00593.pdf)
for further information
"""
batchsize = features.shape[0]
# Upsample features
for idx, layer in enumerate(self.upsampling_layers):
features = layer(features)
if self.norms_ is not None:
features = self.norms_[idx](features)
features = self.activation_(features)
features = torch.max(features, 2, keepdim=True)[0]
features = features.view(-1, self.upsampling_dims[-1])
# Downsample features
for idx, layer in enumerate(self.downsampling_layers):
features = layer(features)
# Do not apply norm and activation for
# final layer
if idx != len(self.downsampling_layers) - 1:
if self.norms_ is not None:
norm_idx = idx + len(self.upsampling_layers)
features = self.norms_[norm_idx](features)
features = self.activation_(features)
identity_batch = self.identity.repeat(batchsize, 1)
transformations = features + identity_batch
return transformations.view(
batchsize, self.in_dimension_, self.in_dimension_
)
[docs]
class PointNetEncoder(nn.Module):
"""PointNetEncoder.
Encodes a pointcloud and additional features into one feature description
See pointnet publication for more information
(https://arxiv.org/pdf/1612.00593.pdf)
"""
def __init__(
self,
in_dimensions: int = 3,
out_dimensions: int = 1024,
mlp_dimensions: Iterable[Iterable[int]] = ((64, 64), (64, 128)),
norm_cls: str | None = "BatchNorm1d",
activation_cls: str = "ReLU",
**kwargs: ArgsType,
):
"""Creates a new PointNetEncoder.
Args:
in_dimensions (int): input dimension (e.g. 3 for xzy, 6 for xzyrgb)
out_dimensions (int): output dimensions
mlp_dimensions (Iterable[Iterable[int]]):(Dimensions of MLP layers)
norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None
activation_cls (str): class for activation (nn.'activation_cls')
kwargs : See arguments of @LinearTransformStn
"""
super().__init__()
self.out_dimension = out_dimensions
# Extend dimensions to upscale from input dimension
mlp_dim_list: list[list[int]] = [list(d) for d in mlp_dimensions]
mlp_dim_list[0].insert(0, in_dimensions)
mlp_dim_list[-1].append(out_dimensions)
self.mlp_dimensions = mlp_dim_list
# Learnable transformation layers.
self.trans_layers_ = nn.ModuleList(
[
LinearTransform(
in_dimension=dims[0],
norm_cls=norm_cls,
activation_cls=activation_cls,
**kwargs,
)
for dims in mlp_dim_list
]
)
# MLP layers
self.mlp_layers_ = nn.ModuleList()
# Create activation
activation = getattr(nn, activation_cls)()
# Create norms
norm_fn: Callable[[int], nn.Module] = (
getattr(nn, norm_cls) if norm_cls is not None else None
)
for mlp_idx, mlp_dims in enumerate(mlp_dim_list):
layers: list[nn.Module] = []
for idx, (in_dim, out_dim) in enumerate(
zip(mlp_dims[:-1], mlp_dims[1:])
):
# Create MLP
layers.append(torch.nn.Conv1d(in_dim, out_dim, 1))
# Create BN if needed
if norm_fn is not None:
layers.append(norm_fn(out_dim))
# Only add activation if not last layer
if (
mlp_idx != len(mlp_dim_list) - 1
and idx != len(mlp_dims) - 2
):
layers.append(activation)
self.mlp_layers_.append(nn.Sequential(*layers))
[docs]
def __call__(self, features: torch.Tensor) -> PointNetEncoderOut:
"""Type definition for call implementation."""
return self._call_impl(features)
[docs]
def forward(self, features: torch.Tensor) -> PointNetEncoderOut:
"""Pointnet encoder forward.
Args:
features (Tensor[B, C, N]): Input features stacked in channels.
e.g. raw point inputs: [B, 3, N] , w color : [B, 3+3, N], ...
Returns:
Extracted feature representation for input and all
applied transformations.
"""
transforms: list[torch.Tensor] = []
for block_idx, trans_layer in enumerate(self.trans_layers_):
# Apply transformation
trans = trans_layer(features)
transforms.append(trans)
features = features.transpose(2, 1)
features = torch.bmm(features, trans)
features = features.transpose(2, 1)
if block_idx == len(self.trans_layers_) - 1:
pointwise_features = features.clone()
# Apply MLP
features = self.mlp_layers_[block_idx](features)
features = torch.max(features, 2, keepdim=True)[0]
features = features.view(-1, self.out_dimension)
return PointNetEncoderOut(
features=features,
transformations=transforms,
pointwise_features=pointwise_features, # pylint: disable=possibly-used-before-assignment, line-too-long
)
[docs]
class PointNetSegmentation(nn.Module):
"""Segmentation network using a simple pointnet as encoder."""
def __init__(
self,
n_classes: int,
in_dimensions: int = 3,
feature_dimension: int = 1024,
norm_cls: str = "BatchNorm1d",
activation_cls: str = "ReLU",
):
"""Creates a new Point Net segementation network.
Args:
n_classes (int): Number of semantic classes
in_dimensions (int): Input dimension (3 for xyz, 6 xyzrgb, ...)
feature_dimension (int): Size of feature from the encoder
norm_cls (Optional(str)): class for norm (nn.'norm_cls') or None
activation_cls (str): class for activation (nn.'activation_cls')
Raises:
ValueError: If dimensions are invalid
"""
super().__init__()
self.in_dimensions = in_dimensions
self.encoder = PointNetEncoder(
in_dimensions=in_dimensions,
out_dimensions=feature_dimension,
norm_cls=norm_cls,
activation_cls=activation_cls,
)
pc_feat_dim = self.encoder.mlp_dimensions[-1][0]
# Create activation
activation = getattr(nn, activation_cls)()
# Create norms
norm_fn: Callable[[int], nn.Module] = (
getattr(nn, norm_cls) if norm_cls is not None else None
)
self.classifier_dims = [feature_dimension + pc_feat_dim, 512, 256, 128]
# Build Model
self.classifier = nn.Sequential()
for in_dim, out_dim in zip(
self.classifier_dims[:-1], self.classifier_dims[1:]
):
self.classifier.append(nn.Conv1d(in_dim, out_dim, 1))
if norm_fn is not None:
self.classifier.append(norm_fn(out_dim))
self.classifier.append(activation)
self.classifier.append(
nn.Conv1d(
out_dim, # pylint: disable=undefined-loop-variable
n_classes,
1,
)
)
[docs]
def __call__(self, points: torch.Tensor) -> PointNetSemanticsOut:
"""Call function."""
return self._call_impl(points)
[docs]
def forward(self, points: torch.Tensor) -> PointNetSemanticsOut:
"""Pointnet Segmenter Forward.
Args:
points (tensor) : inputs points dimension [B, in_dim, n_pts]
Returns:
Returns a list of tensors where the first element is
the desired segmentation [B, n_classes, n_pts] and the other
elements are the linear transformation matrices which
have been used to transform the pointclouds
@see LinearTransform
"""
assert points.size(-2) == self.in_dimensions
n_pts = points.size(-1)
bs = points.size(0)
encoder_out = self.encoder(points)
global_features = encoder_out.features.view(bs, -1, 1).repeat(
1, 1, n_pts
)
x = torch.cat([global_features, encoder_out.pointwise_features], 1)
x = self.classifier(x)
return PointNetSemanticsOut(
class_logits=x, transformations=encoder_out.transformations
)