Source code for vis4d.zoo.base.datasets.shift.tasks

"""SHIFT data loading config for segmentation."""

from __future__ import annotations

from ml_collections.config_dict import ConfigDict

from vis4d.common.typing import ArgsType
from vis4d.data.const import CommonKeys as K
from vis4d.engine.connectors import data_key, pred_key

from .common import get_shift_config

CONN_SHIFT_DET_EVAL = {
    "frame_ids": data_key("frame_ids"),
    "sample_names": data_key("sample_names"),
    "sequence_names": data_key("sequence_names"),
    "pred_boxes": pred_key("boxes"),
    "pred_scores": pred_key("scores"),
    "pred_classes": pred_key("class_ids"),
}
CONN_SHIFT_INS_EVAL = {
    "frame_ids": data_key("frame_ids"),
    "sample_names": data_key("sample_names"),
    "sequence_names": data_key("sequence_names"),
    "pred_boxes": pred_key("boxes.boxes"),
    "pred_scores": pred_key("boxes.scores"),
    "pred_classes": pred_key("boxes.class_ids"),
    "pred_masks": pred_key("masks.masks"),
}


[docs] def get_shift_sem_seg_config(**kwargs: ArgsType) -> ConfigDict: """Get the config for the SHIFT segmentation task.""" keys_to_load = (K.images, K.input_hw, K.original_hw, K.seg_masks) cfg = get_shift_config( train_keys_to_load=keys_to_load, test_keys_to_load=keys_to_load, horizontal_flip_prob=0.5, color_jitter_prob=0.5, crop_size=kwargs.get("crop_size", (512, 1024)), **kwargs, ) return cfg
[docs] def get_shift_det_config(**kwargs: ArgsType) -> ConfigDict: """Get the config for the SHIFT detection task.""" keys_to_load = ( K.images, K.input_hw, K.original_hw, K.boxes2d, K.boxes2d_classes, ) cfg = get_shift_config( train_keys_to_load=keys_to_load, test_keys_to_load=keys_to_load, train_skip_empty_frames=True, test_skip_empty_frames=False, horizontal_flip_prob=0.5, color_jitter_prob=0.0, crop_size=kwargs.get("crop_size", None), **kwargs, ) return cfg
[docs] def get_shift_instance_seg_config(**kwargs: ArgsType) -> ConfigDict: """Get the config for the SHIFT instance segmentation task.""" keys_to_load = ( K.images, K.input_hw, K.original_hw, K.boxes2d, K.boxes2d_classes, K.instance_masks, ) cfg = get_shift_config( train_keys_to_load=keys_to_load, test_keys_to_load=keys_to_load, train_skip_empty_frames=True, test_skip_empty_frames=False, horizontal_flip_prob=0.5, color_jitter_prob=0.5, crop_size=kwargs.get("crop_size", None), **kwargs, ) return cfg
[docs] def get_shift_depth_est_config(**kwargs: ArgsType) -> ConfigDict: """Get the config for the SHIFT depth estimation task.""" keys_to_load = (K.images, K.input_hw, K.original_hw, K.depth_maps) cfg = get_shift_config( train_keys_to_load=keys_to_load, test_keys_to_load=keys_to_load, horizontal_flip_prob=0.5, crop_size=kwargs.get("crop_size", None), **kwargs, ) return cfg
[docs] def get_shift_optical_flow_config(**kwargs: ArgsType) -> ConfigDict: """Get the config for the SHIFT optical flow task.""" keys_to_load = (K.images, K.input_hw, K.original_hw, K.optical_flows) cfg = get_shift_config( train_keys_to_load=keys_to_load, test_keys_to_load=keys_to_load, horizontal_flip_prob=0.5, crop_size=kwargs.get("crop_size", None), **kwargs, ) return cfg
[docs] def get_shift_tracking_config(**kwargs: ArgsType) -> ConfigDict: """Get the config for the SHIFT tracking task.""" keys_to_load = ( K.images, K.input_hw, K.original_hw, K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids, ) cfg = get_shift_config( train_keys_to_load=keys_to_load, test_keys_to_load=keys_to_load, horizontal_flip_prob=0.5, crop_size=kwargs.get("crop_size", None), **kwargs, ) return cfg
[docs] def get_shift_multitask_2d_config(**kwargs: ArgsType) -> ConfigDict: """Get the config for the SHIFT multitask 2D task.""" keys_to_load = ( K.images, K.input_hw, K.original_hw, K.intrinsics, K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids, K.seg_masks, K.depth_maps, ) cfg = get_shift_config( train_keys_to_load=keys_to_load, test_keys_to_load=keys_to_load, horizontal_flip_prob=0.5, crop_size=kwargs.get("crop_size", None), **kwargs, ) return cfg
[docs] def get_shift_multitask_3d_config(**kwargs: ArgsType) -> ConfigDict: """Get the config for the SHIFT multitask 3D task.""" keys_to_load = ( K.images, K.input_hw, K.original_hw, K.intrinsics, K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids, K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids, K.seg_masks, K.depth_maps, ) cfg = get_shift_config( train_keys_to_load=keys_to_load, test_keys_to_load=keys_to_load, horizontal_flip_prob=0.5, crop_size=kwargs.get("crop_size", None), **kwargs, ) return cfg