Source code for vis4d.data.transforms.base
"""Basic data augmentation class."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import TypeVar, no_type_check
import torch
from vis4d.common.dict import get_dict_nested, set_dict_nested
from vis4d.data.typing import DictData
TFunctor = TypeVar("TFunctor", bound=object) # pylint: disable=invalid-name
TransformFunction = Callable[[list[DictData]], list[DictData]]
[docs]
class Transform:
"""Transforms Decorator.
This class stores which `in_keys` are input to a transformation function
and which `out_keys` are overwritten in the data dictionary by the output
of this transformation.
Nested keys in the data dictionary can be accessed via key.subkey1.subkey2
If any of `in_keys` is 'data', the full data dictionary will be forwarded
to the transformation.
If the only entry in `out_keys` is 'data', the full data dictionary will
be updated with the return value of the transformation.
For the case of multi-sensor data, the sensors that the transform should be
applied can be set via the 'sensors' attribute. By default, we assume
a transformation is applied to all sensors.
This class will add a 'apply_to_data' method to a given Functor which is
used to call it on a DictData object. NOTE: This is an issue for static
checking and is not recognized by pylint. It will usually be called in the
compose() function and will not be called directly.
Example:
>>> @Transform(in_keys="images", out_keys="images")
>>> class MyTransform:
>>> def __call__(images: list[np.array]) -> list[np.array]:
>>> images = do_something(images)
>>> return images
>>> my_transform = MyTransform()
>>> data = my_transform.apply_to_data(data)
"""
def __init__(
self,
in_keys: Sequence[str] | str,
out_keys: Sequence[str] | str,
sensors: Sequence[str] | str | None = None,
same_on_batch: bool = True,
) -> None:
"""Creates an instance of Transform.
Args:
in_keys (Sequence[str] | str): Specifies one or multiple (if any)
input keys of the data dictionary which should be remapeed to
another key. Defaults to None.
out_keys (Sequence[str] | str): Specifies one or multiple (if any)
output keys of the data dictionary which should be remaped to
another key. Defaults to None.
sensors (Sequence[str] | str | None, optional): Specifies the
sensors this transformation should be applied to. If None, it
will be applied to all available sensors. Defaults to None.
same_on_batch (bool, optional): Whether to use the same
transformation parameters to all sensors / view. Defaults to
True.
"""
if isinstance(in_keys, str):
in_keys = [in_keys]
self.in_keys = in_keys
if isinstance(out_keys, str):
out_keys = [out_keys]
self.out_keys = out_keys
if isinstance(sensors, str):
sensors = [sensors]
self.sensors = sensors
self.same_on_batch = same_on_batch
[docs]
@no_type_check
def __call__(self, transform: TFunctor) -> TFunctor:
"""Add in_keys / out_keys / sensors / apply_to_data attributes.
Args:
transform (TFunctor): A given Functor.
Returns:
TFunctor: The decorated Functor.
"""
original_init = transform.__init__
def apply_to_data(
self_, input_batch: list[DictData]
) -> list[DictData]:
"""Wrap function with a handler for input / output keys.
We use the specified in_keys in order to extract the positional
input arguments of a function from the data dictionary, and the
out_keys to replace the corresponding values in the output
dictionary.
"""
def _transform_fn(batch: list[DictData]) -> list[DictData]:
in_batch = []
for key in self_.in_keys:
key_data = []
for data in batch:
# Optionally allow the function to get the full data
# dict as aux input and set default value to None if
# key is not found
key_data += [
(
get_dict_nested(
data, key.split("."), allow_missing=True
)
if key != "data"
else data
)
]
if any(d is None for d in key_data):
# If any of the data in the batch is None, replace
# the input of the key with None.
in_batch.append(None)
else:
in_batch.append(key_data)
result = self_(*in_batch)
if len(self_.out_keys) == 1:
if self_.out_keys[0] == "data":
return result
result = [result]
for key, values in zip(self_.out_keys, result):
if values is None:
continue
for data, value in zip(batch, values):
if value is not None:
set_dict_nested(data, key.split("."), value)
return batch
if self_.sensors is not None:
if self_.same_on_batch:
for sensor in self_.sensors:
batch_sensor = _transform_fn(
[d[sensor] for d in input_batch]
)
for i, d in enumerate(batch_sensor):
input_batch[i][sensor] = d
else:
for i, data in enumerate(input_batch):
for sensor in self_.sensors:
input_batch[i][sensor] = _transform_fn(
[data[sensor]]
)
elif self_.same_on_batch:
input_batch = _transform_fn(input_batch)
else:
for i, data in enumerate(input_batch):
input_batch[i] = _transform_fn([data])[0]
return input_batch
def init(
*args,
in_keys: Sequence[str] = self.in_keys,
out_keys: Sequence[str] = self.out_keys,
sensors: Sequence[str] | None = self.sensors,
same_on_batch: bool = self.same_on_batch,
**kwargs,
):
self_ = args[0]
original_init(*args, **kwargs)
self_.in_keys = in_keys
self_.out_keys = out_keys
self_.sensors = sensors
self_.same_on_batch = same_on_batch
self_.apply_to_data = lambda *args, **kwargs: apply_to_data(
self_, *args, **kwargs
)
transform.__init__ = init
return transform
[docs]
def compose(transforms: list[TFunctor]) -> TransformFunction:
"""Compose transformations.
This function composes a given set of transformation functions, i.e. any
functor decorated with Transform, into a single transform.
"""
def _preprocess_func(batch: list[DictData]) -> list[DictData]:
for op in transforms:
batch = op.apply_to_data(batch) # type: ignore
return batch
return _preprocess_func
[docs]
@Transform("data", "data")
class RandomApply:
"""Randomize the application of a given set of transformations."""
def __init__(
self, transforms: list[TFunctor], probability: float = 0.5
) -> None:
"""Creates an instance of RandomApply.
Args:
transforms (list[TFunctor]): Transformations that are applied with
a given probability.
probability (float, optional): Probability to apply
transformations. Defaults to 0.5.
"""
self.transforms = transforms
self.probability = probability
[docs]
def __call__(self, batch: list[DictData]) -> list[DictData]:
"""Apply transforms with a given probability."""
if torch.rand(1) < self.probability:
for op in self.transforms:
batch = op.apply_to_data(batch) # type: ignore
return batch