Source code for vis4d.data.transforms.mosaic

"""Mosaic transformation.

Modified from mmdetection (https://github.com/open-mmlab/mmdetection).
"""

from __future__ import annotations

import random
from typing import TypedDict

import numpy as np

from vis4d.common.typing import NDArrayF32, NDArrayI64
from vis4d.data.const import CommonKeys as K

from .base import Transform
from .crop import _get_keep_mask
from .resize import resize_image


[docs] class MosaicParam(TypedDict): """Parameters for Mosaic.""" out_shape: tuple[int, int] paste_coords: list[tuple[int, int, int, int]] crop_coords: list[tuple[int, int, int, int]] im_shapes: list[tuple[int, int]] im_scales: list[tuple[float, float]]
[docs] def mosaic_combine( index: int, center: tuple[int, int], im_hw: tuple[int, int], out_shape: tuple[int, int], ) -> tuple[tuple[int, int, int, int], tuple[int, int, int, int]]: """Compute the mosaic parameters for the image at the current index. Index: 0 = top_left, 1 = top_right, 3 = bottom_left, 4 = bottom_right """ assert index in {0, 1, 2, 3} if index == 0: # index0 to top left part of image x1, y1, x2, y2 = ( max(center[1] - im_hw[1], 0), max(center[0] - im_hw[0], 0), center[1], center[0], ) crop_coord = ( im_hw[1] - (x2 - x1), im_hw[0] - (y2 - y1), im_hw[1], im_hw[0], ) elif index == 1: # index1 to top right part of image x1, y1, x2, y2 = ( center[1], max(center[0] - im_hw[0], 0), min(center[1] + im_hw[1], out_shape[1] * 2), center[0], ) crop_coord = ( 0, im_hw[0] - (y2 - y1), min(im_hw[1], x2 - x1), im_hw[0], ) elif index == 2: # index2 to bottom left part of image x1, y1, x2, y2 = ( max(center[1] - im_hw[1], 0), center[0], center[1], min(out_shape[0] * 2, center[0] + im_hw[0]), ) crop_coord = ( im_hw[1] - (x2 - x1), 0, im_hw[1], min(y2 - y1, im_hw[0]), ) else: # index3 to bottom right part of image x1, y1, x2, y2 = ( center[1], center[0], min(center[1] + im_hw[1], out_shape[1] * 2), min(out_shape[0] * 2, center[0] + im_hw[0]), ) crop_coord = 0, 0, min(im_hw[1], x2 - x1), min(y2 - y1, im_hw[0]) paste_coord = x1, y1, x2, y2 return paste_coord, crop_coord
[docs] @Transform(K.input_hw, ["transforms.mosaic"]) class GenMosaicParameters: """Generate the parameters for a mosaic operation. Given 4 images, mosaic transform combines them into one output image. The output image is composed of the parts from each sub- image. mosaic transform center_x +------------------------------+ | pad | pad | | +-----------+ | | | | | | | image1 |--------+ | | | | | | | | | image2 | | center_y |----+-------------+-----------| | | cropped | | |pad | image3 | image4 | | | | | +----|-------------+-----------+ | | +-------------+ The mosaic transform steps are as follows: 1. Choose the mosaic center as the intersections of 4 images. 2. Get the left top image according to the index, and randomly sample another 3 images from the dataset. 3. Sub image will be cropped if image is larger than mosaic patch. Args: out_shape (tuple[int, int]): The output shape of the mosaic transform. center_ratio_range (tuple[float, float]): The range of the ratio of the center of the mosaic patch to the output image size. """ NUM_SAMPLES = 4 def __init__( self, out_shape: tuple[int, int], center_ratio_range: tuple[float, float] = (0.5, 1.5), ) -> None: """Creates an instance of the class.""" self.out_shape = out_shape self.center_ratio_range = center_ratio_range
[docs] def __call__(self, input_hw: list[tuple[int, int]]) -> list[MosaicParam]: """Compute the parameters and put them in the data dict.""" assert ( len(input_hw) % self.NUM_SAMPLES == 0 ), "Input number of images must be a multiple of 4 for Mosaic." h, w = self.out_shape # mosaic center x, y center_y = int(random.uniform(*self.center_ratio_range) * h) center_x = int(random.uniform(*self.center_ratio_range) * w) center = (center_y, center_x) mosaic_params = [] for i in range(0, len(input_hw), self.NUM_SAMPLES): paste_coords, crop_coords, im_scales, im_shapes = [], [], [], [] for idx, ori_hw in enumerate(input_hw[i : i + self.NUM_SAMPLES]): # compute the resize shape scale_ratio_i = min(h / ori_hw[0], w / ori_hw[1]) h_i = int(ori_hw[0] * scale_ratio_i) w_i = int(ori_hw[1] * scale_ratio_i) # compute the combine parameters paste_coord, crop_coord = mosaic_combine( idx, center, (h_i, w_i), self.out_shape ) paste_coords.append(paste_coord) crop_coords.append(crop_coord) im_shapes.append((h_i, w_i)) im_scales.append((scale_ratio_i, scale_ratio_i)) mosaic_params += [ MosaicParam( out_shape=self.out_shape, paste_coords=paste_coords, crop_coords=crop_coords, im_shapes=im_shapes, im_scales=im_scales, ) for _ in range(self.NUM_SAMPLES) ] return mosaic_params
[docs] @Transform( in_keys=[ K.images, "transforms.mosaic.out_shape", "transforms.mosaic.paste_coords", "transforms.mosaic.crop_coords", "transforms.mosaic.im_shapes", ], out_keys=[K.images, K.input_hw], ) class MosaicImages: """Apply Mosaic to images.""" NUM_SAMPLES = 4 def __init__( self, pad_value: float = 114.0, interpolation: str = "bilinear", imresize_backend: str = "torch", ) -> None: """Creates an instance of the class. Args: pad_value (float): The value to pad the image with. Defaults to 114.0. interpolation (str): Interpolation mode for resizing image. Defaults to bilinear. imresize_backend (str): One of torch, cv2. Defaults to torch. """ self.pad_value = pad_value self.interpolation = interpolation self.imresize_backend = imresize_backend assert imresize_backend in { "torch", "cv2", }, f"Invalid imresize backend: {imresize_backend}"
[docs] def __call__( self, images: list[NDArrayF32], out_shape: list[tuple[int, int]], paste_coords: list[list[tuple[int, int, int, int]]], crop_coords: list[list[tuple[int, int, int, int]]], im_shapes: list[list[tuple[int, int]]], ) -> tuple[list[NDArrayF32], list[tuple[int, int]]]: """Resize an image of dimensions [N, H, W, C].""" h, w = out_shape[0] c = images[0].shape[-1] mosaic_imgs = [] for i in range(0, len(images), self.NUM_SAMPLES): mosaic_img = np.full( (1, h * 2, w * 2, c), self.pad_value, dtype=np.float32 ) for idx, img in enumerate(images[i : i + self.NUM_SAMPLES]): # resize current image h_i, w_i = im_shapes[i][idx] img_ = resize_image( img, (h_i, w_i), self.interpolation, backend=self.imresize_backend, ) x1_p, y1_p, x2_p, y2_p = paste_coords[i][idx] x1_c, y1_c, x2_c, y2_c = crop_coords[i][idx] # crop and paste image mosaic_img[:, y1_p:y2_p, x1_p:x2_p, :] = img_[ :, y1_c:y2_c, x1_c:x2_c, : ] mosaic_imgs += [mosaic_img for _ in range(self.NUM_SAMPLES)] return mosaic_imgs, [(m.shape[1], m.shape[2]) for m in mosaic_imgs]
[docs] @Transform( in_keys=[ K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids, "transforms.mosaic.paste_coords", "transforms.mosaic.crop_coords", "transforms.mosaic.im_scales", ], out_keys=[K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids], ) class MosaicBoxes2D: """Apply Mosaic to a list of 2D bounding boxes.""" NUM_SAMPLES = 4 def __init__( self, clip_inside_image: bool = True, max_track_ids: int = 1000 ) -> None: """Creates an instance of the class. Args: clip_inside_image (bool): Whether to clip the boxes to be inside the image. Defaults to True. max_track_ids (int): The maximum number of track ids. Defaults to 1000. """ self.clip_inside_image = clip_inside_image self.max_track_ids = max_track_ids
[docs] def __call__( self, boxes: list[NDArrayF32], classes: list[NDArrayI64], track_ids: list[NDArrayI64] | None, paste_coords: list[list[tuple[int, int, int, int]]], crop_coords: list[list[tuple[int, int, int, int]]], im_scales: list[list[tuple[float, float]]], ) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]: """Apply Mosaic to 2D bounding boxes.""" new_boxes, new_classes = [], [] new_track_ids: list[NDArrayI64] | None = ( [] if track_ids is not None else None ) for i in range(0, len(boxes), self.NUM_SAMPLES): for idx in range(self.NUM_SAMPLES): j = i + idx x1_p, y1_p, x2_p, y2_p = paste_coords[i][idx] x1_c, y1_c, _, _ = crop_coords[i][idx] pw = x1_p - x1_c ph = y1_p - y1_c boxes[j][:, [0, 2]] = ( im_scales[i][idx][1] * boxes[j][:, [0, 2]] + pw ) boxes[j][:, [1, 3]] = ( im_scales[i][idx][0] * boxes[j][:, [1, 3]] + ph ) keep_mask = _get_keep_mask( boxes[j], np.array([x1_p, y1_p, x2_p, y2_p]) ) boxes[j] = boxes[j][keep_mask] classes[j] = classes[j][keep_mask] if track_ids is not None: track_ids[j] = track_ids[j][keep_mask].copy() if len(track_ids[j]) > 0: if max(track_ids[j]) >= self.max_track_ids: raise ValueError( f"Track id exceeds maximum track id" f"{self.max_track_ids}!" ) track_ids[j] += self.max_track_ids * idx if self.clip_inside_image: boxes[j][:, [0, 2]] = boxes[j][:, [0, 2]].clip(x1_p, x2_p) boxes[j][:, [1, 3]] = boxes[j][:, [1, 3]].clip(y1_p, y2_p) new_boxes += [ np.concatenate(boxes[i : i + self.NUM_SAMPLES]) for _ in range(self.NUM_SAMPLES) ] new_classes += [ np.concatenate(classes[i : i + self.NUM_SAMPLES]) for _ in range(self.NUM_SAMPLES) ] if track_ids is not None: assert new_track_ids is not None new_track_ids += [ np.concatenate(track_ids[i : i + self.NUM_SAMPLES]) for _ in range(self.NUM_SAMPLES) ] return new_boxes, new_classes, new_track_ids