Source code for

"""A wrap for timm transforms."""

from typing import Union

import numpy as np
from PIL import Image

from vis4d.common.imports import TIMM_AVAILABLE
from vis4d.common.typing import NDArrayUI8
from import CommonKeys as K

from .base import Transform

    from import (
    raise ImportError("timm is not installed.")

AugOp = Union[AutoAugment, RandAugment, AugMixAugment]

def _apply_aug(images: NDArrayUI8, aug_op: AugOp) -> NDArrayUI8:
    """Apply augmentation to a batch of images with shape [N, H, W, C]."""
    assert images.shape[-1] == 3, "Images must be in RGB format."
    imgs: list[Image.Image] = []
    for img in images:
        # convert to uint8 if necessary
        if img.dtype != np.uint8:
            img = img.astype(np.uint8)
    return np.stack([np.array(img).astype(np.float32) for img in imgs])

@Transform(K.images, K.images)
class _AutoAug:
    """Apply Timm's AutoAugment to a image array."""

    def __init__(self) -> None:
        self.aug_op = None

    def _create(self, policy: str, hparams: dict[str, float]) -> AugOp:
        """Create augmentation op."""
        aa_policy = auto_augment_policy(policy, hparams=hparams)
        return AutoAugment(aa_policy)

    def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]:
        """Execute the transform."""
        assert self.aug_op is not None, "Augmentation op is not created."
        for i, img in enumerate(images):
            images[i] = _apply_aug(img, self.aug_op)
        return images

[docs] class AutoAugV0(_AutoAug): """Apply Timm's AutoAugment (policy=v0) to a image array.""" def __init__(self, magnitude_std: float = 0.5): """Create an instance of AutoAug. Args: magnitude_std (float, optional): Standard deviation of the magnitude for random autoaugment. Defaults to 0.5. """ super().__init__() self.aug_op = self._create("v0", {"magnitude_std": magnitude_std})
[docs] class AutoAugOriginal(_AutoAug): """Apply Timm's AutoAugment (policy=original) to a image array.""" def __init__(self, magnitude_std: float = 0.5): """Create an instance of AutoAug. Args: magnitude_std (float, optional): Standard deviation of the magnitude for random autoaugment. Defaults to 0.5. """ super().__init__() self.aug_op = self._create( "original", {"magnitude_std": magnitude_std} )
[docs] @Transform(K.images, K.images) class RandAug: """Apply Timm's RandomAugment to a image tensor.""" def __init__( self, magnitude: int = 10, num_layers: int = 2, use_increasing: bool = False, magnitude_std: float = 0.5, ): """Create an instance of RandAug. Args: magnitude (int): Level of magnitude for augments, ranging from 1 to 9. num_layers (int, optional): Number of layers for rand augment. Defaults to 2. use_increasing (bool, optional): Whether to use increasing setting for transforms. Defaults to False. magnitude_std (float, optional): Standard deviation of the magnitude for random autoaugment. Defaults to 0.5. Returns: Callable: A function that takes a tensor of shape [N, C, H, W] and returns a tensor of the same shape. Example: Rand augment with magnitude 9. (``) >>> rand_augment(magnitude=9) """ super().__init__() assert TIMM_AVAILABLE, "timm is not installed." self.magnitude = magnitude self.num_layers = num_layers self.use_increasing = use_increasing self.magnitude_std = magnitude_std hparams = {"magnitude_std": self.magnitude_std} if self.use_increasing: transforms = _RAND_INCREASING_TRANSFORMS else: transforms = _RAND_TRANSFORMS ra_ops = rand_augment_ops( magnitude=self.magnitude, hparams=hparams, transforms=transforms ) self.aug_op = RandAugment(ra_ops, self.num_layers)
[docs] def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]: """Execute the transform.""" for i, img in enumerate(images): images[i] = _apply_aug(img, self.aug_op) return images
[docs] @Transform(K.images, K.images) class AugMix: """Apply Timm's AugMix to a image tensor.""" def __init__( self, magnitude: int = 10, width: int = 3, alpha: float = 1.0, depth: int = -1, blended: bool = True, magnitude_std: float = 0.5, ): """Create an instance of AugMix. Args: magnitude (int): Level of magnitude, ranging from 1 to 9. width (int, optional): Width of the augmentation chain. Defaults to 3. alpha (float, optional): Alpha for beta distribution. Defaults to 1.0. depth (int, optional): Depth of the augmentation chain. Defaults to -1. blended (bool, optional): Whether to blend the original image with the augmented image. Defaults to True. magnitude_std (float, optional): Standard deviation of the magnitude for random autoaugment. Defaults to 0.5. Returns: Callable: A function that takes a tensor of shape [N, C, H, W] and returns a tensor of the same shape. Example: Augmix with magnitude 9. (``) >>> augmix(magnitude=9) """ super().__init__() assert TIMM_AVAILABLE, "timm is not installed." self.magnitude = magnitude self.width = width self.alpha = alpha self.depth = depth self.blended = blended self.magnitude_std = magnitude_std hparams = {"magnitude_std": self.magnitude_std} am_ops = augmix_ops(magnitude=self.magnitude, hparams=hparams) self.aug_op = AugMixAugment( am_ops, alpha=self.alpha, width=self.width, depth=self.depth, blended=self.blended, )
[docs] def __call__(self, images: list[NDArrayUI8]) -> list[NDArrayUI8]: """Execute the transform.""" for i, img in enumerate(images): images[i] = _apply_aug(img, self.aug_op) return images