"""A wrap for timm transforms."""
from typing import Union
import numpy as np
from PIL import Image
from vis4d.common.imports import TIMM_AVAILABLE
from vis4d.common.typing import NDArrayUI8
from vis4d.data.const import CommonKeys as K
from .base import Transform
if TIMM_AVAILABLE:
from timm.data.auto_augment import (
_RAND_INCREASING_TRANSFORMS,
_RAND_TRANSFORMS,
AugMixAugment,
AutoAugment,
RandAugment,
augmix_ops,
auto_augment_policy,
rand_augment_ops,
)
else:
raise ImportError("timm is not installed.")
AugOp = Union[AutoAugment, RandAugment, AugMixAugment]
def _apply_aug(images: NDArrayUI8, aug_op: AugOp) -> NDArrayUI8:
"""Apply augmentation to a batch of images with shape [N, H, W, C]."""
assert images.shape[-1] == 3, "Images must be in RGB format."
imgs: list[Image.Image] = []
for img in images:
# convert to uint8 if necessary
if img.dtype != np.uint8:
img = img.astype(np.uint8)
imgs.append(aug_op(Image.fromarray(img)))
return np.stack([np.array(img).astype(np.float32) for img in imgs])
@Transform(K.images, K.images)
class _AutoAug:
"""Apply Timm's AutoAugment to a image array."""
def __init__(self) -> None:
self.aug_op = None
def _create(self, policy: str, hparams: dict[str, float]) -> AugOp:
"""Create augmentation op."""
aa_policy = auto_augment_policy(policy, hparams=hparams)
return AutoAugment(aa_policy)
def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]:
"""Execute the transform."""
assert self.aug_op is not None, "Augmentation op is not created."
for i, img in enumerate(images):
images[i] = _apply_aug(img, self.aug_op)
return images
[docs]
class AutoAugV0(_AutoAug):
"""Apply Timm's AutoAugment (policy=v0) to a image array."""
def __init__(self, magnitude_std: float = 0.5):
"""Create an instance of AutoAug.
Args:
magnitude_std (float, optional): Standard deviation of the
magnitude for random autoaugment. Defaults to 0.5.
"""
super().__init__()
self.aug_op = self._create("v0", {"magnitude_std": magnitude_std})
[docs]
class AutoAugOriginal(_AutoAug):
"""Apply Timm's AutoAugment (policy=original) to a image array."""
def __init__(self, magnitude_std: float = 0.5):
"""Create an instance of AutoAug.
Args:
magnitude_std (float, optional): Standard deviation of the
magnitude for random autoaugment. Defaults to 0.5.
"""
super().__init__()
self.aug_op = self._create(
"original", {"magnitude_std": magnitude_std}
)
[docs]
@Transform(K.images, K.images)
class RandAug:
"""Apply Timm's RandomAugment to a image tensor."""
def __init__(
self,
magnitude: int = 10,
num_layers: int = 2,
use_increasing: bool = False,
magnitude_std: float = 0.5,
):
"""Create an instance of RandAug.
Args:
magnitude (int): Level of magnitude for augments, ranging from 1 to
9.
num_layers (int, optional): Number of layers for rand augment.
Defaults to 2.
use_increasing (bool, optional): Whether to use increasing setting
for transforms. Defaults to False.
magnitude_std (float, optional): Standard deviation of the
magnitude for random autoaugment. Defaults to 0.5.
Returns:
Callable: A function that takes a tensor of shape [N, C, H, W] and
returns a tensor of the same shape.
Example:
Rand augment with magnitude 9. (`https://arxiv.org/abs/1909.13719`)
>>> rand_augment(magnitude=9)
"""
super().__init__()
assert TIMM_AVAILABLE, "timm is not installed."
self.magnitude = magnitude
self.num_layers = num_layers
self.use_increasing = use_increasing
self.magnitude_std = magnitude_std
hparams = {"magnitude_std": self.magnitude_std}
if self.use_increasing:
transforms = _RAND_INCREASING_TRANSFORMS
else:
transforms = _RAND_TRANSFORMS
ra_ops = rand_augment_ops(
magnitude=self.magnitude, hparams=hparams, transforms=transforms
)
self.aug_op = RandAugment(ra_ops, self.num_layers)
[docs]
def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]:
"""Execute the transform."""
for i, img in enumerate(images):
images[i] = _apply_aug(img, self.aug_op)
return images
[docs]
@Transform(K.images, K.images)
class AugMix:
"""Apply Timm's AugMix to a image tensor."""
def __init__(
self,
magnitude: int = 10,
width: int = 3,
alpha: float = 1.0,
depth: int = -1,
blended: bool = True,
magnitude_std: float = 0.5,
):
"""Create an instance of AugMix.
Args:
magnitude (int): Level of magnitude, ranging from 1 to 9.
width (int, optional): Width of the augmentation chain. Defaults to
3.
alpha (float, optional): Alpha for beta distribution. Defaults to
1.0.
depth (int, optional): Depth of the augmentation chain. Defaults to
-1.
blended (bool, optional): Whether to blend the original image with
the augmented image. Defaults to True.
magnitude_std (float, optional): Standard deviation of the
magnitude for random autoaugment. Defaults to 0.5.
Returns:
Callable: A function that takes a tensor of shape [N, C, H, W] and
returns a tensor of the same shape.
Example:
Augmix with magnitude 9. (`https://arxiv.org/abs/1912.02781`)
>>> augmix(magnitude=9)
"""
super().__init__()
assert TIMM_AVAILABLE, "timm is not installed."
self.magnitude = magnitude
self.width = width
self.alpha = alpha
self.depth = depth
self.blended = blended
self.magnitude_std = magnitude_std
hparams = {"magnitude_std": self.magnitude_std}
am_ops = augmix_ops(magnitude=self.magnitude, hparams=hparams)
self.aug_op = AugMixAugment(
am_ops,
alpha=self.alpha,
width=self.width,
depth=self.depth,
blended=self.blended,
)
[docs]
def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]:
"""Execute the transform."""
for i, img in enumerate(images):
images[i] = _apply_aug(img, self.aug_op)
return images