Source code for vis4d.zoo.base.datasets.shift.common

"""SHIFT data loading config for data augmentation."""

from __future__ import annotations

from collections.abc import Sequence

from ml_collections.config_dict import ConfigDict

from vis4d.config import class_config
from vis4d.data.const import CommonKeys as K
from vis4d.data.data_pipe import DataPipe
from vis4d.data.datasets.shift import SHIFT
from vis4d.data.loader import default_collate, multi_sensor_collate
from vis4d.data.transforms.base import RandomApply, compose
from vis4d.data.transforms.crop import (
    CropBoxes2D,
    CropDepthMaps,
    CropImages,
    CropOpticalFlows,
    CropSegMasks,
    GenCropParameters,
)
from vis4d.data.transforms.flip import (
    FlipBoxes2D,
    FlipDepthMaps,
    FlipImages,
    FlipInstanceMasks,
    FlipOpticalFlows,
    FlipSegMasks,
)
from vis4d.data.transforms.normalize import NormalizeImages
from vis4d.data.transforms.photometric import ColorJitter
from vis4d.data.transforms.resize import (
    GenResizeParameters,
    ResizeBoxes2D,
    ResizeDepthMaps,
    ResizeImages,
    ResizeInstanceMasks,
    ResizeOpticalFlows,
    ResizeSegMasks,
)
from vis4d.data.transforms.select_sensor import SelectSensor
from vis4d.data.transforms.to_tensor import ToTensor
from vis4d.zoo.base import (
    get_inference_dataloaders_cfg,
    get_train_dataloader_cfg,
)

IMAGE_MEAN = [122.884, 117.266, 110.287]
IMAGE_STD = [59.925, 59.466, 60.69]


[docs] def get_train_preprocessing( image_size: tuple[int, int] = (800, 1280), crop_size: tuple[int, int] | None = None, horizontal_flip_prob: float = 0.5, color_jitter_prob: float = 0.0, keys_to_load: Sequence[str] = (K.images, K.seg_masks), views_to_load: Sequence[str] = ("front",), ) -> ConfigDict: """Get the default data preprocessing for SHIFT dataset. Args: image_size: The image size to resize to. Defaults to (800, 1280). crop_size: The crop size to crop to randomly, if not None. Defaults to None. This step is applied after the resize step. horizontal_flip_prob: The probability of horizontal flipping. Defaults to 0.5. color_jitter_prob: The probability of color jittering. Defaults to 0.5. keys_to_load: The keys to load from the dataset. Defaults to (K.images, K.seg_masks). views_to_load: The views to load from the dataset. Defaults to ("front",). Returns: The data preprocessing config. """ preprocess_transforms = [] for key_to_load in keys_to_load: assert key_to_load in SHIFT.KEYS, f"Invalid key: {key_to_load}" views_arg = {} if len(views_to_load) == 1: preprocess_transforms.append( class_config( SelectSensor, selected_sensor=views_to_load[0], sensors=views_to_load, ) ) elif len(views_to_load) > 1: views_arg["sensors"] = views_to_load # Resize if image_size != (800, 1280): preprocess_transforms.append( class_config( GenResizeParameters, shape=image_size, keep_ratio=True, **views_arg, ) ) preprocess_transforms.append(class_config(ResizeImages, **views_arg)) if K.seg_masks in keys_to_load: preprocess_transforms.append( class_config(ResizeSegMasks, **views_arg) ) if K.boxes2d in keys_to_load: preprocess_transforms.append( class_config(ResizeBoxes2D, **views_arg) ) if K.instance_masks in keys_to_load: preprocess_transforms.append( class_config(ResizeInstanceMasks, **views_arg) ) if K.depth_maps in keys_to_load: preprocess_transforms.append( class_config(ResizeDepthMaps, **views_arg) ) if K.optical_flows in keys_to_load: preprocess_transforms.append( class_config( ResizeOpticalFlows, normalized_flow=False, **views_arg ) ) # Crop if crop_size is not None: preprocess_transforms.append( class_config( GenCropParameters, shape=crop_size, cat_max_ratio=0.75, **views_arg, ), ) preprocess_transforms.append(class_config(CropImages, **views_arg)) if K.seg_masks in keys_to_load: preprocess_transforms.append( class_config(CropSegMasks, **views_arg) ) if K.boxes2d in keys_to_load: preprocess_transforms.append( class_config(CropBoxes2D, **views_arg) ) if K.depth_maps in keys_to_load: preprocess_transforms.append( class_config(CropDepthMaps, **views_arg) ) if K.optical_flows in keys_to_load: preprocess_transforms.append( class_config(CropOpticalFlows, **views_arg) ) # Random flip if horizontal_flip_prob > 0: flip_transforms = [] flip_transforms.append(class_config(FlipImages)) if K.seg_masks in keys_to_load: flip_transforms.append(class_config(FlipSegMasks)) if K.boxes2d in keys_to_load: flip_transforms.append(class_config(FlipBoxes2D)) if K.instance_masks in keys_to_load: flip_transforms.append(class_config(FlipInstanceMasks)) if K.depth_maps in keys_to_load: flip_transforms.append(class_config(FlipDepthMaps)) if K.optical_flows in keys_to_load: flip_transforms.append(class_config(FlipOpticalFlows)) preprocess_transforms.append( class_config( RandomApply, transforms=flip_transforms, probability=horizontal_flip_prob, **views_arg, ) ) if color_jitter_prob > 0: preprocess_transforms.append( class_config( RandomApply, transforms=[class_config(ColorJitter, **views_arg)], probability=color_jitter_prob, ) ) preprocess_transforms.append( class_config( NormalizeImages, mean=IMAGE_MEAN, std=IMAGE_STD, **views_arg ) ) train_preprocess_cfg = class_config( compose, transforms=preprocess_transforms ) batchprocess_transforms = [class_config(ToTensor, **views_arg)] train_batchprocess_cfg = class_config( compose, transforms=batchprocess_transforms ) return train_preprocess_cfg, train_batchprocess_cfg
[docs] def get_test_preprocessing( image_size: tuple[int, int] = (800, 1280), keys_to_load: Sequence[str] = (K.images, K.seg_masks), views_to_load: Sequence[str] = ("front",), ) -> ConfigDict: """Get the default data preprocessing for SHIFT dataset. Args: image_size: The image size to resize to. Defaults to (800, 1280). keys_to_load: The keys to load from the dataset. Defaults to (K.images, K.seg_masks). views_to_load: The views to load from the dataset. Defaults to ("front",). Returns: The data preprocessing config. """ preprocess_transforms = [] for key_to_load in keys_to_load: assert key_to_load in SHIFT.KEYS, f"Invalid key: {key_to_load}" views_arg = {} if len(views_to_load) == 1: preprocess_transforms.append( class_config( SelectSensor, selected_sensor=views_to_load[0], sensors=views_to_load, ) ) elif len(views_to_load) > 1: views_arg["sensors"] = views_to_load # Resize if image_size != (800, 1280): preprocess_transforms.append( class_config( GenResizeParameters, shape=image_size, keep_ratio=True, **views_arg, ) ) preprocess_transforms.append(class_config(ResizeImages, **views_arg)) if K.seg_masks in keys_to_load: preprocess_transforms.append( class_config(ResizeSegMasks, **views_arg) ) if K.boxes2d in keys_to_load: preprocess_transforms.append( class_config(ResizeBoxes2D, **views_arg) ) if K.depth_maps in keys_to_load: preprocess_transforms.append( class_config(ResizeDepthMaps, **views_arg) ) if K.optical_flows in keys_to_load: preprocess_transforms.append( class_config(ResizeOpticalFlows, **views_arg) ) preprocess_transforms.append( class_config( NormalizeImages, mean=IMAGE_MEAN, std=IMAGE_STD, **views_arg ) ) test_preprocess_cfg = class_config( compose, transforms=preprocess_transforms ) batchprocess_transforms = [class_config(ToTensor, **views_arg)] test_batchprocess_cfg = class_config( compose, transforms=batchprocess_transforms ) return test_preprocess_cfg, test_batchprocess_cfg
[docs] def get_shift_dataloader_config( train_dataset_cfg: ConfigDict, test_dataset_cfg: ConfigDict, keys_to_load: Sequence[str] = (K.images, K.seg_masks), image_size: tuple[int, int] = (800, 1280), crop_size: tuple[int, int] | None = None, horizontal_flip_prob: float = 0.5, color_jitter_prob: float = 0.5, samples_per_gpu: int = 2, workers_per_gpu: int = 2, train_views_to_load: Sequence[str] = ("front",), test_views_to_load: Sequence[str] = ("front",), ) -> ConfigDict: """Get the default config for BDD100K segmentation.""" data = ConfigDict() train_preprocess_cfg, train_batchprocess_cfg = get_train_preprocessing( keys_to_load=keys_to_load, image_size=image_size, crop_size=crop_size, horizontal_flip_prob=horizontal_flip_prob, color_jitter_prob=color_jitter_prob, views_to_load=train_views_to_load, ) test_preprocess_cfg, test_batchprocess_cfg = get_test_preprocessing( keys_to_load=keys_to_load, image_size=image_size, views_to_load=test_views_to_load, ) data.train_dataloader = get_train_dataloader_cfg( preprocess_cfg=train_preprocess_cfg, dataset_cfg=train_dataset_cfg, batchprocess_cfg=train_batchprocess_cfg, samples_per_gpu=samples_per_gpu, workers_per_gpu=workers_per_gpu, shuffle=True, collate_fn=( multi_sensor_collate if len(train_views_to_load) > 1 else default_collate ), ) # Test Dataset Config test_dataset_cfg = class_config( DataPipe, datasets=test_dataset_cfg, preprocess_fn=test_preprocess_cfg ) data.test_dataloader = get_inference_dataloaders_cfg( datasets_cfg=test_dataset_cfg, batchprocess_cfg=test_batchprocess_cfg, samples_per_gpu=samples_per_gpu, workers_per_gpu=workers_per_gpu, collate_fn=( multi_sensor_collate if len(test_views_to_load) > 1 else default_collate ), ) return data
[docs] def get_shift_config( # pylint: disable=too-many-arguments data_root: str = "data/shift/images", train_split: str = "train", train_framerate: str = "images", train_shift_type: str = "discrete", train_views_to_load: Sequence[str] = ("front",), train_keys_to_load: Sequence[str] = (K.images, K.seg_masks), train_attributes_to_load: Sequence[dict[str, str | float]] | None = None, train_skip_empty_frames: bool = False, test_split: str = "val", test_framerate: str = "images", test_shift_type: str = "discrete", test_views_to_load: Sequence[str] = ("front",), test_keys_to_load: Sequence[str] = (K.images, K.seg_masks), test_attributes_to_load: Sequence[dict[str, str | float]] | None = None, test_skip_empty_frames: bool = False, data_backend: None | ConfigDict = None, image_size: tuple[int, int] = (800, 1280), crop_size: tuple[int, int] | None = None, horizontal_flip_prob: float = 0.5, color_jitter_prob: float = 0.0, samples_per_gpu: int = 2, workers_per_gpu: int = 2, ) -> ConfigDict: """Get the default config for BDD100K segmentation.""" train_dataset_cfg = class_config( SHIFT, data_root=data_root, split=train_split, framerate=train_framerate, shift_type=train_shift_type, views_to_load=train_views_to_load, keys_to_load=train_keys_to_load, attributes_to_load=train_attributes_to_load, skip_empty_frames=train_skip_empty_frames, backend=data_backend, ) test_dataset_cfg = class_config( SHIFT, data_root=data_root, split=test_split, framerate=test_framerate, shift_type=test_shift_type, views_to_load=test_views_to_load, keys_to_load=test_keys_to_load, attributes_to_load=test_attributes_to_load, skip_empty_frames=test_skip_empty_frames, backend=data_backend, ) return get_shift_dataloader_config( train_dataset_cfg=train_dataset_cfg, test_dataset_cfg=test_dataset_cfg, keys_to_load=train_keys_to_load, image_size=image_size, crop_size=crop_size, horizontal_flip_prob=horizontal_flip_prob, color_jitter_prob=color_jitter_prob, samples_per_gpu=samples_per_gpu, workers_per_gpu=workers_per_gpu, train_views_to_load=train_views_to_load, test_views_to_load=test_views_to_load, )