Source code for vis4d.data.transforms.pad

"""Pad transformation."""

from __future__ import annotations

import numpy as np
import torch
import torch.nn.functional as F

from vis4d.common.typing import NDArrayF32, NDArrayUI8
from vis4d.data.const import CommonKeys as K

from .base import Transform


[docs] @Transform(K.images, K.images) class PadImages: """Pad batch of images at the bottom right.""" def __init__( self, stride: int = 32, mode: str = "constant", value: float = 0.0, shape: tuple[int, int] | None = None, pad2square: bool = False, ) -> None: """Creates an instance of PadImage. Args: stride (int, optional): Chooses padding size so that the input will be divisible by stride. Defaults to 32. mode (str, optional): Padding mode. One of constant, reflect, replicate or circular. Defaults to "constant". value (float, optional): Value for constant padding. Defaults to 0.0. shape (tuple[int, int], optional): Shape of the padded image (H, W). Defaults to None. pad2square (bool, optional): Pad to square. Defaults to False. """ if pad2square: assert ( shape is None ), "Cannot specify shape when pad2square is True." self.stride = stride self.mode = mode self.value = value self.shape = shape self.pad2square = pad2square
[docs] def __call__(self, images: list[NDArrayF32]) -> list[NDArrayF32]: """Pad images to consistent size.""" heights = [im.shape[1] for im in images] widths = [im.shape[2] for im in images] max_hw = _get_max_shape( heights, widths, self.stride, self.shape, self.pad2square ) # generate params for torch pad for i, (image, h, w) in enumerate(zip(images, heights, widths)): pad_param = (0, max_hw[1] - w, 0, max_hw[0] - h) image_ = torch.from_numpy(image).permute(0, 3, 1, 2) image_ = F.pad( # pylint: disable=not-callable image_, pad_param, self.mode, self.value ) images[i] = image_.permute(0, 2, 3, 1).numpy() return images
[docs] @Transform(K.seg_masks, K.seg_masks) class PadSegMasks: """Pad batch of segmentation masks at the bottom right.""" def __init__( self, stride: int = 32, mode: str = "constant", value: int = 255, shape: tuple[int, int] | None = None, pad2square: bool = False, ) -> None: """Creates an instance of PadSegMasks. Args: stride (int, optional): Chooses padding size so that the input will be divisible by stride. Defaults to 32. mode (str, optional): Padding mode. One of constant, reflect, replicate or circular. Defaults to "constant". value (float, optional): Value for constant padding. Defaults to 0.0. shape (tuple[int, int], optional): Shape of the padded image (H, W). Defaults to None. pad2square (bool, optional): Pad to square. Defaults to False. """ if pad2square: assert ( shape is None ), "Cannot specify shape when pad2square is True." self.stride = stride self.mode = mode self.value = value self.shape = shape self.pad2square = pad2square
[docs] def __call__(self, masks: list[NDArrayUI8]) -> list[NDArrayUI8]: """Pad images to consistent size.""" heights = [mask.shape[0] for mask in masks] widths = [mask.shape[1] for mask in masks] max_hw = _get_max_shape( heights, widths, self.stride, self.shape, self.pad2square ) # generate params for torch pad for i, (mask, h, w) in enumerate(zip(masks, heights, widths)): pad_param = ((0, max_hw[0] - h), (0, max_hw[1] - w)) masks[i] = np.pad( # type: ignore mask, pad_param, mode=self.mode, constant_values=self.value ) return masks
def _get_max_shape( heights: list[int], widths: list[int], stride: int, shape: tuple[int, int] | None, pad2square: bool, ) -> tuple[int, int]: """Get max shape for padding. Args: stride (int): Chooses padding size so that the input will be divisible by stride. shape (tuple[int, int], optional): Shape of the padded image (H, W). Defaults to None. heights (list[int]): List of heights of input. widths (list[int]): List of widths of input. pad2square (bool): Pad to square. Returns: tuple[int, int]: Max shape for padding. """ if pad2square: max_size = max(heights + widths) max_hw = (max_size, max_size) elif shape is not None: max_hw = shape else: max_hw = max(heights), max(widths) max_hw = tuple(_make_divisible(x, stride) for x in max_hw) # type: ignore # pylint: disable=line-too-long return max_hw def _make_divisible(x: int, stride: int) -> int: """Ensure divisibility by stride.""" return (x + (stride - 1)) // stride * stride