Source code for vis4d.data.datasets.shift

"""SHIFT dataset."""

from __future__ import annotations

import json
import multiprocessing
import os
from collections.abc import Sequence
from functools import partial

import numpy as np
from tqdm import tqdm

from vis4d.common.imports import SCALABEL_AVAILABLE
from vis4d.common.logging import rank_zero_info
from vis4d.common.typing import NDArrayF32, NDArrayI64, NDArrayNumber
from vis4d.data.const import CommonKeys as K
from vis4d.data.datasets.base import VideoDataset
from vis4d.data.datasets.util import im_decode, npy_decode
from vis4d.data.io import DataBackend, FileBackend, HDF5Backend, ZipBackend
from vis4d.data.typing import DictData

from .base import VideoDataset, VideoMapping
from .scalabel import Scalabel

shift_det_map = {
    "pedestrian": 0,
    "car": 1,
    "truck": 2,
    "bus": 3,
    "motorcycle": 4,
    "bicycle": 5,
}
shfit_track_map = {
    "pedestrian": 0,
    "car": 1,
    "truck": 2,
    "bus": 3,
    "motorcycle": 4,
    "bicycle": 5,
}
shift_seg_map = {
    "unlabeled": 0,
    "building": 1,
    "fence": 2,
    "other": 3,
    "pedestrian": 4,
    "pole": 5,
    "road line": 6,
    "road": 7,
    "sidewalk": 8,
    "vegetation": 9,
    "vehicle": 10,
    "wall": 11,
    "traffic sign": 12,
    "sky": 13,
    "ground": 14,
    "bridge": 15,
    "rail track": 16,
    "guard rail": 17,
    "traffic light": 18,
    "static": 19,
    "dynamic": 20,
    "water": 21,
    "terrain": 22,
}
shift_seg_ignore = [
    "unlabeled",
    "other",
    "ground",
    "bridge",
    "rail track",
    "guard rail",
    "static",
    "dynamic",
    "water",
]

if SCALABEL_AVAILABLE:
    from scalabel.label.io import parse
    from scalabel.label.typing import Config
    from scalabel.label.typing import Dataset as ScalabelData
else:
    raise ImportError("scalabel is not installed.")


def _get_extension(backend: DataBackend) -> str:
    """Get the appropriate file extension for the given backend."""
    if isinstance(backend, HDF5Backend):
        return ".hdf5"
    if isinstance(backend, ZipBackend):
        return ".zip"
    if isinstance(backend, FileBackend):  # pragma: no cover
        return ""
    raise ValueError(f"Unsupported backend {backend}.")  # pragma: no cover


class _SHIFTScalabelLabels(Scalabel):
    """Helper class for labels in SHIFT that are stored in Scalabel format."""

    VIEWS = [
        "front",
        "center",
        "left_45",
        "left_90",
        "right_45",
        "right_90",
        "left_stereo",
    ]

    def __init__(
        self,
        data_root: str,
        split: str,
        data_file: str = "",
        keys_to_load: Sequence[str] = (K.images, K.boxes2d),
        attributes_to_load: Sequence[dict[str, str | float]] | None = None,
        annotation_file: str = "",
        view: str = "front",
        framerate: str = "images",
        shift_type: str = "discrete",
        skip_empty_frames: bool = False,
        backend: DataBackend = HDF5Backend(),
        verbose: bool = False,
        num_workers: int = 1,
    ) -> None:
        """Initialize SHIFT dataset for one view.

        Args:
            data_root (str): Path to the root directory of the dataset.
            split (str): Which data split to load.
            data_file (str): Path to the data archive file. Default: "".
            keys_to_load (Sequence[str]): List of keys to load.
                Default: (K.images, K.boxes2d).
            attributes_to_load (Sequence[dict[str, str | float]] | None):
                List of attributes to load. Default: None.
            annotation_file (str): Path to the annotation file. Default: "".
            view (str): Which view to load. Default: "front". Options: "front",
                "center", "left_45", "left_90", "right_45", "right_90", and
                "left_stereo".
            framerate (str): Which framerate to load. Default: "images".
            shift_type (str): Which shift type to load. Default: "discrete".
                Options: "discrete", "continuous/1x", "continuous/10x", and
                "continuous/100x".
            skip_empty_frames (bool): Whether to skip frames with no
                instance annotations. Default: False.
            backend (DataBackend): Backend to use for loading data. Default:
                HDF5Backend().
            verbose (bool): Whether to print verbose logs. Default: False.
            num_workers (int): Number of workers to use for loading data.
                Default: 1.
        """
        self.verbose = verbose
        self.num_workers = num_workers

        # Validate input
        assert split in {"train", "val", "test"}, f"Invalid split '{split}'"
        assert view in _SHIFTScalabelLabels.VIEWS, f"Invalid view '{view}'"

        # Set attributes
        ext = _get_extension(backend)
        if shift_type.startswith("continuous"):
            shift_speed = shift_type.split("/")[-1]
            annotation_path = os.path.join(
                data_root,
                "continuous",
                framerate,
                shift_speed,
                split,
                view,
                annotation_file,
            )
            data_path = os.path.join(
                data_root,
                "continuous",
                framerate,
                shift_speed,
                split,
                view,
                f"{data_file}{ext}",
            )
        else:
            annotation_path = os.path.join(
                data_root, "discrete", framerate, split, view, annotation_file
            )
            data_path = os.path.join(
                data_root,
                "discrete",
                framerate,
                split,
                view,
                f"{data_file}{ext}",
            )
        super().__init__(
            data_path,
            annotation_path,
            data_backend=backend,
            keys_to_load=keys_to_load,
            attributes_to_load=attributes_to_load,
            skip_empty_samples=skip_empty_frames,
        )

    def _generate_mapping(self) -> ScalabelData:
        """Generate data mapping."""
        # Skipping validation for much faster loading
        if self.verbose:
            rank_zero_info(
                "Loading annotation from '%s' ...", self.annotation_path
            )
        return self._load(self.annotation_path)

    def _load(self, filepath: str) -> ScalabelData:
        """Load labels from a json file or a folder of json files."""
        raw_frames: list[DictData] = []
        raw_groups: list[DictData] = []
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"{filepath} does not exist.")

        def process_file(filepath: str) -> DictData | None:
            raw_cfg = None
            with open(filepath, mode="r", encoding="utf-8") as fp:
                content = json.load(fp)
            if isinstance(content, dict):
                raw_frames.extend(content["frames"])
                if "groups" in content and content["groups"] is not None:
                    raw_groups.extend(content["groups"])
                if "config" in content and content["config"] is not None:
                    raw_cfg = content["config"]
            elif isinstance(content, list):
                raw_frames.extend(content)
            else:
                raise TypeError(
                    "The input file contains neither dict nor list."
                )

            rank_zero_info(
                "Loading SHIFT annotation from '%s' Done.", filepath
            )
            return raw_cfg

        cfg = None
        if os.path.isfile(filepath) and filepath.endswith("json"):
            ret_cfg = process_file(filepath)
            if ret_cfg is not None:
                cfg = ret_cfg
        else:
            raise TypeError("Inputs must be a folder or a JSON file.")

        config = None
        if cfg is not None:
            config = Config(**cfg)

        parse_func = partial(parse, validate_frames=False)
        if self.num_workers > 1:
            with multiprocessing.Pool(self.num_workers) as pool:
                frames = []
                with tqdm(total=len(raw_frames)) as pbar:
                    for result in pool.imap_unordered(
                        parse_func, raw_frames, chunksize=1000
                    ):
                        frames.append(result)
                        pbar.update()
        else:
            frames = [parse_func(frame) for frame in raw_frames]
        return ScalabelData(frames=frames, config=config, groups=None)


[docs] class SHIFT(VideoDataset): """SHIFT dataset class, supporting multiple tasks and views.""" DESCRIPTION = """SHIFT Dataset, a synthetic driving dataset for continuous multi-task domain adaptation""" HOMEPAGE = "https://www.vis.xyz/shift/" PAPER = "https://arxiv.org/abs/2206.08367" LICENSE = "CC BY-NC-SA 4.0" KEYS = [ # Inputs K.images, K.original_hw, K.input_hw, K.points3d, # Scalabel formatted annotations K.intrinsics, K.extrinsics, K.timestamp, K.axis_mode, K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids, K.instance_masks, K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids, # Bit masks K.seg_masks, K.depth_maps, K.optical_flows, ] VIEWS = [ "front", "center", "left_45", "left_90", "right_45", "right_90", "left_stereo", ] DATA_GROUPS = { "img": [ K.images, K.original_hw, K.input_hw, K.intrinsics, ], "det_2d": [ K.timestamp, K.axis_mode, K.extrinsics, K.boxes2d, K.boxes2d_classes, K.boxes2d_track_ids, ], "det_3d": [ K.boxes3d, K.boxes3d_classes, K.boxes3d_track_ids, ], "det_insseg_2d": [ K.instance_masks, ], "semseg": [ K.seg_masks, ], "depth": [ K.depth_maps, ], "flow": [ K.optical_flows, ], "lidar": [ K.points3d, ], } GROUPS_IN_SCALABEL = ["det_2d", "det_3d", "det_insseg_2d"] def __init__( self, data_root: str, split: str, keys_to_load: Sequence[str] = (K.images, K.boxes2d), views_to_load: Sequence[str] = ("front",), attributes_to_load: Sequence[dict[str, str | float]] | None = None, framerate: str = "images", shift_type: str = "discrete", skip_empty_frames: bool = False, backend: DataBackend = HDF5Backend(), num_workers: int = 1, verbose: bool = False, ) -> None: """Initialize SHIFT dataset.""" super().__init__(data_backend=backend) # Validate input assert split in {"train", "val", "test"}, f"Invalid split '{split}'." assert framerate in { "images", "videos", }, f"Invalid framerate '{framerate}'. Must be 'images' or 'videos'." assert shift_type in { "discrete", "continuous/1x", "continuous/10x", "continuous/100x", }, ( f"Invalid shift_type '{shift_type}'. " "Must be one of 'discrete', 'continuous/1x', 'continuous/10x', " "or 'continuous/100x'." ) self.validate_keys(keys_to_load) # Set attributes self.data_root = data_root self.split = split self.keys_to_load = keys_to_load self.views_to_load = views_to_load self.attributes_to_load = attributes_to_load self.framerate = framerate self.shift_type = shift_type self.backend = backend self.verbose = verbose self.ext = _get_extension(backend) if self.shift_type.startswith("continuous"): shift_speed = self.shift_type.split("/")[-1] self.annotation_base = os.path.join( self.data_root, "continuous", self.framerate, shift_speed, self.split, ) else: self.annotation_base = os.path.join( self.data_root, self.shift_type, self.framerate, self.split ) if self.verbose: print(f"Base: {self.annotation_base}. Backend: {self.backend}") # Get the data groups' classes that need to be loaded self._data_groups_to_load = self._get_data_groups(keys_to_load) if "det_2d" not in self._data_groups_to_load: raise ValueError( "In current implementation, the 'det_2d' data group must be " "loaded to load any other data group." ) self.scalabel_datasets = {} for view in self.views_to_load: if view == "center": # Load lidar data, only available for center view self.scalabel_datasets["center/lidar"] = _SHIFTScalabelLabels( data_root=self.data_root, split=self.split, data_file="lidar", annotation_file="det_3d.json", view=view, framerate=self.framerate, shift_type=self.shift_type, keys_to_load=(K.points3d, *self.DATA_GROUPS["det_3d"]), attributes_to_load=self.attributes_to_load, skip_empty_frames=skip_empty_frames, backend=backend, num_workers=num_workers, verbose=verbose, ) else: # Skip the lidar data group, which is loaded separately image_loaded = False for group in self._data_groups_to_load: name = f"{view}/{group}" keys_to_load = list(self.DATA_GROUPS[group]) # Load the image data group only once if not image_loaded: keys_to_load.extend(self.DATA_GROUPS["img"]) image_loaded = True self.scalabel_datasets[name] = _SHIFTScalabelLabels( data_root=self.data_root, split=self.split, data_file="img", annotation_file=f"{group}.json", view=view, framerate=self.framerate, shift_type=self.shift_type, keys_to_load=keys_to_load, attributes_to_load=self.attributes_to_load, skip_empty_frames=skip_empty_frames, backend=backend, num_workers=num_workers, verbose=verbose, ) self.video_mapping = self._generate_video_mapping()
[docs] def validate_keys(self, keys_to_load: Sequence[str]) -> None: """Validate that all keys to load are supported.""" for k in keys_to_load: if k not in self.KEYS: raise ValueError(f"Key '{k}' is not supported!")
def _get_data_groups(self, keys_to_load: Sequence[str]) -> list[str]: """Get the data groups that need to be loaded from Scalabel.""" data_groups = ["det_2d"] for data_group, group_keys in self.DATA_GROUPS.items(): if data_group in self.GROUPS_IN_SCALABEL: # If the data group is loaded by Scalabel, add it to the list if any(key in group_keys for key in keys_to_load): data_groups.append(data_group) return list(set(data_groups)) def _load( self, view: str, data_group: str, file_ext: str, video: str, frame: str ) -> NDArrayNumber: """Load data from the given data group.""" frame_number = frame.split("_")[0] filepath = os.path.join( self.annotation_base, view, f"{data_group}{self.ext}", video, f"{frame_number}_{data_group}_{view}.{file_ext}", ) if data_group == "semseg": return self._load_semseg(filepath) if data_group == "depth": return self._load_depth(filepath) if data_group == "flow": return self._load_flow(filepath) raise ValueError( f"Invalid data group '{data_group}'" ) # pragma: no cover def _load_semseg(self, filepath: str) -> NDArrayI64: """Load semantic segmentation data.""" im_bytes = self.backend.get(filepath) image = im_decode(im_bytes)[..., 0] return image.astype(np.int64) def _load_depth( self, filepath: str, depth_factor: float = 16777.216 # 256 ^ 3 / 1000 ) -> NDArrayF32: """Load depth data.""" assert depth_factor > 0, "Max depth value must be greater than 0." im_bytes = self.backend.get(filepath) image = im_decode(im_bytes) if image.shape[2] > 3: # pragma: no cover image = image[:, :, :3] image = image.astype(np.float32) # Convert to depth depth = ( image[:, :, 2] * 256 * 256 + image[:, :, 1] * 256 + image[:, :, 0] ) return np.ascontiguousarray(depth / depth_factor, dtype=np.float32) def _load_flow(self, filepath: str) -> NDArrayF32: """Load optical flow data.""" npy_bytes = self.backend.get(filepath) flow = npy_decode(npy_bytes, key="flow") flow = flow[:, :, [1, 0]] # Convert to (u, v) format flow *= flow.shape[1] # Scale to image size (1280) if self.framerate == "images": flow *= 10.0 # NOTE: Scale to 1 fps approximately return flow.astype(np.float32) def _get_frame_key(self, idx: int) -> tuple[str, str]: """Get the frame identifier (video name, frame name) by index.""" if len(self.scalabel_datasets) > 0: frames = self.scalabel_datasets[ list(self.scalabel_datasets.keys())[0] ].frames return frames[idx].videoName, frames[idx].name raise ValueError("No Scalabel file has been loaded.")
[docs] def __len__(self) -> int: """Get the number of samples in the dataset.""" if len(self.scalabel_datasets) > 0: return len( self.scalabel_datasets[list(self.scalabel_datasets.keys())[0]] ) raise ValueError( "No Scalabel file has been loaded." ) # pragma: no cover
def _generate_video_mapping(self) -> VideoMapping: """Group all dataset sample indices (int) by their video ID (str). Returns: VideoMapping: Mapping of video IDs to sample indices and frame IDs. Raises: ValueError: If no Scalabel file has been loaded. """ if len(self.scalabel_datasets) > 0: return self.scalabel_datasets[ list(self.scalabel_datasets.keys())[0] ].video_mapping raise ValueError("No Scalabel file has been loaded.")
[docs] def __getitem__(self, idx: int) -> DictData: """Get single sample. Args: idx (int): Index of sample. Returns: DictData: sample at index in Vis4D input format. """ # load camera frames data_dict = {} # metadata video_name, frame_name = self._get_frame_key(idx) data_dict[K.sample_names] = frame_name data_dict[K.sequence_names] = video_name data_dict[K.frame_ids] = frame_name.split("_")[0] for view in self.views_to_load: data_dict_view = {} if view == "center": # Lidar is only available in the center view if K.points3d in self.keys_to_load: data_dict_view.update( self.scalabel_datasets["center/lidar"][idx] ) else: # Load data from Scalabel for group in self._data_groups_to_load: data_dict_view.update( self.scalabel_datasets[f"{view}/{group}"][idx] ) # Load data from bit masks if K.seg_masks in self.keys_to_load: data_dict_view[K.seg_masks] = self._load( view, "semseg", "png", video_name, frame_name ) if K.depth_maps in self.keys_to_load: data_dict_view[K.depth_maps] = self._load( view, "depth", "png", video_name, frame_name ) if K.optical_flows in self.keys_to_load: data_dict_view[K.optical_flows] = self._load( view, "flow", "npz", video_name, frame_name ) data_dict[view] = data_dict_view # type: ignore return data_dict