"""SHIFT data loading config for segmentation."""
from __future__ import annotations
from ml_collections.config_dict import ConfigDict
from vis4d.common.typing import ArgsType
from vis4d.data.const import CommonKeys as K
from vis4d.engine.connectors import data_key, pred_key
from .common import get_shift_config
CONN_SHIFT_DET_EVAL = {
"frame_ids": data_key("frame_ids"),
"sample_names": data_key("sample_names"),
"sequence_names": data_key("sequence_names"),
"pred_boxes": pred_key("boxes"),
"pred_scores": pred_key("scores"),
"pred_classes": pred_key("class_ids"),
}
CONN_SHIFT_INS_EVAL = {
"frame_ids": data_key("frame_ids"),
"sample_names": data_key("sample_names"),
"sequence_names": data_key("sequence_names"),
"pred_boxes": pred_key("boxes.boxes"),
"pred_scores": pred_key("boxes.scores"),
"pred_classes": pred_key("boxes.class_ids"),
"pred_masks": pred_key("masks.masks"),
}
[docs]
def get_shift_sem_seg_config(**kwargs: ArgsType) -> ConfigDict:
"""Get the config for the SHIFT segmentation task."""
keys_to_load = (K.images, K.input_hw, K.original_hw, K.seg_masks)
cfg = get_shift_config(
train_keys_to_load=keys_to_load,
test_keys_to_load=keys_to_load,
horizontal_flip_prob=0.5,
color_jitter_prob=0.5,
crop_size=kwargs.get("crop_size", (512, 1024)),
**kwargs,
)
return cfg
[docs]
def get_shift_det_config(**kwargs: ArgsType) -> ConfigDict:
"""Get the config for the SHIFT detection task."""
keys_to_load = (
K.images,
K.input_hw,
K.original_hw,
K.boxes2d,
K.boxes2d_classes,
)
cfg = get_shift_config(
train_keys_to_load=keys_to_load,
test_keys_to_load=keys_to_load,
train_skip_empty_frames=True,
test_skip_empty_frames=False,
horizontal_flip_prob=0.5,
color_jitter_prob=0.0,
crop_size=kwargs.get("crop_size", None),
**kwargs,
)
return cfg
[docs]
def get_shift_instance_seg_config(**kwargs: ArgsType) -> ConfigDict:
"""Get the config for the SHIFT instance segmentation task."""
keys_to_load = (
K.images,
K.input_hw,
K.original_hw,
K.boxes2d,
K.boxes2d_classes,
K.instance_masks,
)
cfg = get_shift_config(
train_keys_to_load=keys_to_load,
test_keys_to_load=keys_to_load,
train_skip_empty_frames=True,
test_skip_empty_frames=False,
horizontal_flip_prob=0.5,
color_jitter_prob=0.5,
crop_size=kwargs.get("crop_size", None),
**kwargs,
)
return cfg
[docs]
def get_shift_depth_est_config(**kwargs: ArgsType) -> ConfigDict:
"""Get the config for the SHIFT depth estimation task."""
keys_to_load = (K.images, K.input_hw, K.original_hw, K.depth_maps)
cfg = get_shift_config(
train_keys_to_load=keys_to_load,
test_keys_to_load=keys_to_load,
horizontal_flip_prob=0.5,
crop_size=kwargs.get("crop_size", None),
**kwargs,
)
return cfg
[docs]
def get_shift_optical_flow_config(**kwargs: ArgsType) -> ConfigDict:
"""Get the config for the SHIFT optical flow task."""
keys_to_load = (K.images, K.input_hw, K.original_hw, K.optical_flows)
cfg = get_shift_config(
train_keys_to_load=keys_to_load,
test_keys_to_load=keys_to_load,
horizontal_flip_prob=0.5,
crop_size=kwargs.get("crop_size", None),
**kwargs,
)
return cfg
[docs]
def get_shift_tracking_config(**kwargs: ArgsType) -> ConfigDict:
"""Get the config for the SHIFT tracking task."""
keys_to_load = (
K.images,
K.input_hw,
K.original_hw,
K.boxes2d,
K.boxes2d_classes,
K.boxes2d_track_ids,
)
cfg = get_shift_config(
train_keys_to_load=keys_to_load,
test_keys_to_load=keys_to_load,
horizontal_flip_prob=0.5,
crop_size=kwargs.get("crop_size", None),
**kwargs,
)
return cfg
[docs]
def get_shift_multitask_2d_config(**kwargs: ArgsType) -> ConfigDict:
"""Get the config for the SHIFT multitask 2D task."""
keys_to_load = (
K.images,
K.input_hw,
K.original_hw,
K.intrinsics,
K.boxes2d,
K.boxes2d_classes,
K.boxes2d_track_ids,
K.seg_masks,
K.depth_maps,
)
cfg = get_shift_config(
train_keys_to_load=keys_to_load,
test_keys_to_load=keys_to_load,
horizontal_flip_prob=0.5,
crop_size=kwargs.get("crop_size", None),
**kwargs,
)
return cfg
[docs]
def get_shift_multitask_3d_config(**kwargs: ArgsType) -> ConfigDict:
"""Get the config for the SHIFT multitask 3D task."""
keys_to_load = (
K.images,
K.input_hw,
K.original_hw,
K.intrinsics,
K.boxes2d,
K.boxes2d_classes,
K.boxes2d_track_ids,
K.boxes3d,
K.boxes3d_classes,
K.boxes3d_track_ids,
K.seg_masks,
K.depth_maps,
)
cfg = get_shift_config(
train_keys_to_load=keys_to_load,
test_keys_to_load=keys_to_load,
horizontal_flip_prob=0.5,
crop_size=kwargs.get("crop_size", None),
**kwargs,
)
return cfg