"""Crop transformation."""
from __future__ import annotations
import math
from collections.abc import Callable
from typing import List, Tuple, TypedDict, Union
import numpy as np
import torch
from vis4d.common.logging import rank_zero_warn
from vis4d.common.typing import (
NDArrayBool,
NDArrayF32,
NDArrayI32,
NDArrayI64,
NDArrayUI8,
)
from vis4d.data.const import CommonKeys as K
from vis4d.op.box.box2d import bbox_intersection
from .base import Transform
CropShape = Union[
Tuple[float, float],
Tuple[int, int],
List[Tuple[float, float]],
List[Tuple[int, int]],
]
CropFunc = Callable[[int, int, CropShape], Tuple[int, int]]
[docs]
class CropParam(TypedDict):
"""Parameters for Crop."""
crop_box: NDArrayI32
keep_mask: NDArrayBool
[docs]
def absolute_crop(im_h: int, im_w: int, shape: CropShape) -> tuple[int, int]:
"""Absolute crop."""
assert isinstance(shape, tuple)
assert shape[0] > 0 and shape[1] > 0
return (min(int(shape[0]), im_h), min(int(shape[1]), im_w))
[docs]
def absolute_range_crop(
im_h: int, im_w: int, shape: CropShape
) -> tuple[int, int]:
"""Absolute range crop."""
assert isinstance(shape, list)
assert len(shape) == 2
assert shape[1][0] >= shape[0][0]
assert shape[1][1] >= shape[0][1]
for crop in shape:
assert crop[0] > 0 and crop[1] > 0
shape_min: tuple[int, int] = (int(shape[0][0]), int(shape[0][1]))
shape_max: tuple[int, int] = (int(shape[1][0]), int(shape[1][1]))
crop_h = np.random.randint(
min(im_h, shape_min[0]), min(im_h, shape_max[0]) + 1
)
crop_w = np.random.randint(
min(im_w, shape_min[1]), min(im_w, shape_max[1]) + 1
)
return int(crop_h), int(crop_w)
[docs]
def relative_crop(im_h: int, im_w: int, shape: CropShape) -> tuple[int, int]:
"""Relative crop."""
assert isinstance(shape, tuple)
assert 0 < shape[0] <= 1 and 0 < shape[1] <= 1
crop_h, crop_w = shape
return int(im_h * crop_h + 0.5), int(im_w * crop_w + 0.5)
[docs]
def relative_range_crop(
im_h: int, im_w: int, shape: CropShape
) -> tuple[int, int]:
"""Relative range crop."""
assert isinstance(shape, list)
assert len(shape) == 2
assert shape[1][0] >= shape[0][0]
assert shape[1][1] >= shape[0][1]
for crop in shape:
assert 0 < crop[0] <= 1 and 0 < crop[1] <= 1
scale_min: tuple[float, float] = shape[0]
scale_max: tuple[float, float] = shape[1]
crop_h = np.random.rand() * (scale_max[0] - scale_min[0]) + scale_min[0]
crop_w = np.random.rand() * (scale_max[1] - scale_min[1]) + scale_min[1]
return int(im_h * crop_h + 0.5), int(im_w * crop_w + 0.5)
[docs]
@Transform(
in_keys=[K.input_hw, K.boxes2d, K.seg_masks],
out_keys="transforms.crop",
)
class GenCropParameters:
"""Generate the parameters for a crop operation."""
def __init__(
self,
shape: CropShape,
crop_func: CropFunc = absolute_crop,
allow_empty_crops: bool = True,
cat_max_ratio: float = 1.0,
ignore_index: int = 255,
) -> None:
"""Creates an instance of the class.
Args:
shape (CropShape): Image shape to be cropped to in [H, W].
crop_func (CropFunc, optional): Function used to generate the size
of the crop. Defaults to absolute_crop.
allow_empty_crops (bool, optional): Allow crops which result in
empty labels. Defaults to True.
cat_max_ratio (float, optional): Maximum ratio of a particular
class in segmentation masks after cropping. Defaults to 1.0.
ignore_index (int, optional): The index to ignore. Defaults to 255.
"""
self.shape = shape
self.crop_func = crop_func
self.allow_empty_crops = allow_empty_crops
self.cat_max_ratio = cat_max_ratio
self.ignore_index = ignore_index
def _get_crop(
self, im_h: int, im_w: int, boxes: NDArrayF32 | None = None
) -> tuple[NDArrayI32, NDArrayBool]:
"""Get the crop parameters."""
crop_size = self.crop_func(im_h, im_w, self.shape)
crop_box = _sample_crop(im_h, im_w, crop_size)
keep_mask = _get_keep_mask(boxes, crop_box)
return crop_box, keep_mask
[docs]
def __call__(
self,
input_hw_list: list[tuple[int, int]],
boxes_list: list[NDArrayF32] | None,
masks_list: list[NDArrayUI8] | None,
) -> list[CropParam]:
"""Compute the parameters and put them in the data dict."""
im_h, im_w = input_hw_list[0]
boxes = boxes_list[0] if boxes_list is not None else None
masks = masks_list[0] if masks_list is not None else None
crop_box, keep_mask = self._get_crop(im_h, im_w, boxes)
if (boxes is not None and len(boxes) > 0) or masks is not None:
# resample crop if conditions not satisfied
found_crop = False
for _ in range(10):
# try resampling 10 times, otherwise use last crop
if (self.allow_empty_crops or keep_mask.sum() != 0) and (
_check_seg_max_cat(
masks, crop_box, self.cat_max_ratio, self.ignore_index
)
):
found_crop = True
break
crop_box, keep_mask = self._get_crop(im_h, im_w, boxes)
if not found_crop:
rank_zero_warn("Random crop not found within 10 resamples.")
crop_params = [
CropParam(crop_box=crop_box, keep_mask=keep_mask)
] * len(input_hw_list)
return crop_params
[docs]
@Transform([K.input_hw, K.boxes2d], "transforms.crop")
class GenCentralCropParameters:
"""Generate the parameters for a central crop operation."""
def __init__(
self,
shape: CropShape,
crop_func: CropFunc = absolute_crop,
) -> None:
"""Creates an instance of the class.
Args:
shape (CropShape): Image shape to be cropped to.
crop_func (CropFunc, optional): Function used to generate the size
of the crop. Defaults to absolute_crop.
"""
self.shape = shape
self.crop_func = crop_func
[docs]
def __call__(
self,
input_hw_list: list[tuple[int, int]],
boxes_list: list[NDArrayF32] | None,
) -> list[CropParam]:
"""Compute the parameters and put them in the data dict."""
im_h, im_w = input_hw_list[0]
boxes = boxes_list[0] if boxes_list is not None else None
crop_size = self.crop_func(im_h, im_w, self.shape)
crop_box = _get_central_crop(im_h, im_w, crop_size)
keep_mask = _get_keep_mask(boxes, crop_box)
crop_params = [
CropParam(crop_box=crop_box, keep_mask=keep_mask)
] * len(input_hw_list)
return crop_params
[docs]
@Transform([K.input_hw, K.boxes2d], "transforms.crop")
class GenRandomSizeCropParameters:
"""Generate the parameters for a random size crop operation.
A crop of the original image is made: the crop has a random area (H * W)
and a random aspect ratio. Code adapted from torchvision.
"""
def __init__(
self,
scale: tuple[float, float] = (0.08, 1.0),
ratio: tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
):
"""Creates an instance of the class.
Args:
scale (tuple[float, float], optional): Scale range of the cropped
area. Defaults to (0.08, 1.0).
ratio (tuple[float, float], optional): Aspect ratio range of the
cropped area. Defaults to (3.0 / 4.0, 4.0 / 3.0).
"""
self.scale = scale
self.ratio = np.array(ratio)
self.log_ratio = np.log(self.ratio)
[docs]
def get_params(self, height: int, width: int) -> NDArrayI32:
"""Get parameters for the random size crop."""
area = height * width
for _ in range(10):
target_area = area * np.random.uniform(
self.scale[0], self.scale[1]
)
aspect_ratio = np.exp(
np.random.uniform(self.log_ratio[0], self.log_ratio[1])
)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = np.random.randint(0, height - h + 1)
j = np.random.randint(0, width - w + 1)
crop_x1, crop_y1, crop_x2, crop_y2 = i, j, i + h, j + w
return np.array([crop_x1, crop_y1, crop_x2, crop_y2])
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.ratio):
w = width
h = int(round(w / min(self.ratio)))
elif in_ratio > max(self.ratio):
h = height
w = int(round(h * max(self.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
crop_x1, crop_y1, crop_x2, crop_y2 = i, j, i + h, j + w
return np.array([crop_x1, crop_y1, crop_x2, crop_y2])
[docs]
def __call__(
self,
input_hw_list: list[tuple[int, int]],
boxes_list: list[NDArrayF32] | None,
) -> list[CropParam]:
"""Compute the parameters and put them in the data dict."""
im_h, im_w = input_hw_list[0]
boxes = boxes_list[0] if boxes_list is not None else None
crop_box = self.get_params(im_h, im_w)
keep_mask = _get_keep_mask(boxes, crop_box)
crop_params = [
CropParam(crop_box=crop_box, keep_mask=keep_mask)
] * len(input_hw_list)
return crop_params
[docs]
@Transform([K.images, "transforms.crop.crop_box"], [K.images, K.input_hw])
class CropImages:
"""Crop Images."""
[docs]
def __call__(
self, images: list[NDArrayF32], crop_box_list: list[NDArrayI32]
) -> tuple[list[NDArrayF32], list[tuple[int, int]]]:
"""Crop a list of image of dimensions [N, H, W, C].
Args:
images (list[NDArrayF32]): The list of image.
crop_box (list[NDArrayI32]): The list of box to crop.
Returns:
list[NDArrayF32]: List of cropped image according to parameters.
"""
input_hw_list = []
for i, (image, crop_box) in enumerate(zip(images, crop_box_list)):
h, w = image.shape[1], image.shape[2]
x1, y1, x2, y2 = crop_box
crop_w, crop_h = x2 - x1, y2 - y1
image = image[:, y1:y2, x1:x2, :]
input_hw = (min(crop_h, h), min(crop_w, w))
images[i] = image
input_hw_list.append(input_hw)
return images, input_hw_list
[docs]
@Transform(
in_keys=[
K.boxes2d,
K.boxes2d_classes,
K.boxes2d_track_ids,
"transforms.crop.crop_box",
"transforms.crop.keep_mask",
],
out_keys=[K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids],
)
class CropBoxes2D:
"""Crop 2D bounding boxes."""
[docs]
def __call__(
self,
boxes_list: list[NDArrayF32],
classes_list: list[NDArrayI64],
track_ids_list: list[NDArrayI64] | None,
crop_box_list: list[NDArrayI32],
keep_mask_list: list[NDArrayBool],
) -> tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64] | None]:
"""Crop 2D bounding boxes.
Args:
boxes_list (list[NDArrayF32]): The list of bounding boxes to be
cropped.
classes_list (list[NDArrayI64]): The list of the corresponding
classes.
track_ids_list (list[NDArrayI64] | None, optional): The list of
corresponding tracking IDs. Defaults to None.
crop_box_list (list[NDArrayI32]): The list of box to crop.
keep_mask_list (list[NDArrayBool]): Which boxes to keep.
Returns:
tuple[list[NDArrayF32], list[NDArrayI64], list[NDArrayI64]] | None:
List of cropped bounding boxes according to parameters.
"""
for i, (boxes, classes, crop_box, keep_mask) in enumerate(
zip(
boxes_list,
classes_list,
crop_box_list,
keep_mask_list,
)
):
x1, y1 = crop_box[:2]
boxes -= np.array([x1, y1, x1, y1])
boxes_list[i] = boxes[keep_mask]
classes_list[i] = classes[keep_mask]
if track_ids_list is not None:
track_ids_list[i] = track_ids_list[i][keep_mask]
return boxes_list, classes_list, track_ids_list
[docs]
@Transform([K.seg_masks, "transforms.crop.crop_box"], K.seg_masks)
class CropSegMasks:
"""Crop segmentation masks."""
[docs]
def __call__(
self, masks_list: list[NDArrayUI8], crop_box_list: list[NDArrayI32]
) -> list[NDArrayUI8]:
"""Crop masks."""
for i, (masks, crop_box) in enumerate(zip(masks_list, crop_box_list)):
x1, y1, x2, y2 = crop_box
masks_list[i] = masks[y1:y2, x1:x2]
return masks_list
[docs]
@Transform(
in_keys=[
K.instance_masks,
"transforms.crop.crop_box",
"transforms.crop.keep_mask",
],
out_keys=[K.instance_masks],
)
class CropInstanceMasks:
"""Crop instance segmentation masks."""
[docs]
def __call__(
self,
masks_list: list[NDArrayUI8],
crop_box_list: list[NDArrayI32],
keep_mask_list: list[NDArrayBool],
) -> list[NDArrayUI8]:
"""Crop masks."""
for i, (masks, crop_box) in enumerate(zip(masks_list, crop_box_list)):
x1, y1, x2, y2 = crop_box
masks = masks[:, y1:y2, x1:x2]
masks_list[i] = masks[keep_mask_list[i]]
return masks_list
[docs]
@Transform([K.depth_maps, "transforms.crop.crop_box"], K.depth_maps)
class CropDepthMaps:
"""Crop depth maps."""
[docs]
def __call__(
self, depth_maps: list[NDArrayF32], crop_box_list: list[NDArrayI32]
) -> list[NDArrayF32]:
"""Crop depth maps."""
for i, (depth_map, crop_box) in enumerate(
zip(depth_maps, crop_box_list)
):
x1, y1, x2, y2 = crop_box
depth_maps[i] = depth_map[y1:y2, x1:x2]
return depth_maps
[docs]
@Transform([K.optical_flows, "transforms.crop.crop_box"], K.optical_flows)
class CropOpticalFlows:
"""Crop optical flows."""
[docs]
def __call__(
self, optical_flows: list[NDArrayF32], crop_box_list: NDArrayI32
) -> list[NDArrayF32]:
"""Crop optical flows."""
for i, (optical_flow, crop_box) in enumerate(
zip(optical_flows, crop_box_list)
):
x1, y1, x2, y2 = crop_box
optical_flows[i] = optical_flow[y1:y2, x1:x2]
return optical_flows
[docs]
@Transform([K.intrinsics, "transforms.crop.crop_box"], K.intrinsics)
class CropIntrinsics:
"""Crop Intrinsics."""
[docs]
def __call__(
self,
intrinsics_list: list[NDArrayF32],
crop_box_list: list[NDArrayI32],
) -> list[NDArrayF32]:
"""Crop camera intrinsics."""
for i, crop_box in enumerate(crop_box_list):
x1, y1 = crop_box[:2]
intrinsics_list[i][0, 2] -= x1
intrinsics_list[i][1, 2] -= y1
return intrinsics_list
def _sample_crop(
im_h: int, im_w: int, crop_size: tuple[int, int]
) -> NDArrayI32:
"""Sample crop parameters according to config."""
margin_h = max(im_h - crop_size[0], 0)
margin_w = max(im_w - crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
return np.array([crop_x1, crop_y1, crop_x2, crop_y2])
def _get_central_crop(
im_h: int, im_w: int, crop_size: tuple[int, int]
) -> NDArrayI32:
"""Get central crop parameters."""
margin_h = max(im_h - crop_size[0], 0)
margin_w = max(im_w - crop_size[1], 0)
offset_h = margin_h // 2
offset_w = margin_w // 2
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
return np.array([crop_x1, crop_y1, crop_x2, crop_y2])
def _get_keep_mask(
boxes: NDArrayF32 | None, crop_box: NDArrayI32
) -> NDArrayBool:
"""Get mask for 2D annotations to keep."""
if boxes is None or len(boxes) == 0:
return np.array([], dtype=bool)
# will be better to compute mask intersection (if exists) instead
overlap = bbox_intersection(
torch.tensor(boxes), torch.tensor(crop_box).unsqueeze(0)
).numpy()
return overlap.squeeze(-1) > 0
def _check_seg_max_cat(
masks: NDArrayUI8 | None,
crop_box: NDArrayI32,
cat_max_ratio: float,
ignore_index: int = 255,
) -> bool:
"""Check if any category occupies more than cat_max_ratio.
Args:
masks (NDArrayUI8 | None): Segmentation masks.
crop_box (NDArrayI32): The box to crop.
cat_max_ratio (float): Maximum category ratio.
ignore_index (int, optional): The index to ignore. Defaults to 255.
Returns:
bool: True if no category occupies more than cat_max_ratio.
"""
if cat_max_ratio >= 1.0 or masks is None:
return True
x1, y1, x2, y2 = crop_box
crop_masks = masks[y1:y2, x1:x2]
cls_ids, cnts = np.unique(crop_masks, return_counts=True)
cnts = cnts[cls_ids != ignore_index]
return (cnts.max() / cnts.sum()) < cat_max_ratio