Source code for vis4d.data.datasets.base

"""Base dataset classes.

We implement a typed version of the PyTorch dataset class here. In addition, we
provide a number of Mixin classes which a dataset can inherit from to implement
additional functionality.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import TypedDict

from torch.utils.data import Dataset as TorchDataset

from vis4d.common import ArgsType
from vis4d.data.io.base import DataBackend
from vis4d.data.io.file import FileBackend
from vis4d.data.typing import DictData


[docs] class Dataset(TorchDataset[DictData]): """Basic pytorch dataset with defined return type.""" # Dataset metadata. DESCRIPTION = "" HOMEPAGE = "" PAPER = "" LICENSE = "" # List of all keys supported by this dataset. KEYS: Sequence[str] = [] def __init__( self, image_channel_mode: str = "RGB", data_backend: None | DataBackend = None, ) -> None: """Initialize dataset. Args: image_channel_mode (str): Image channel mode to use. Default: RGB. data_backend (None | DataBackend): Data backend to use. Default: None. """ self.image_channel_mode = image_channel_mode self.data_backend = ( data_backend if data_backend is not None else FileBackend() )
[docs] def __len__(self) -> int: """Return length of dataset.""" raise NotImplementedError
[docs] def __getitem__(self, idx: int) -> DictData: """Convert single element at given index into Vis4D data format.""" raise NotImplementedError
[docs] def validate_keys(self, keys_to_load: Sequence[str]) -> None: """Validate that all keys to load are supported. Args: keys_to_load (list[str]): List of keys to load. Raises: ValueError: Raise if any key is not defined in AVAILABLE_KEYS. """ for k in keys_to_load: if k not in self.KEYS: raise ValueError(f"Key '{k}' is not supported!")
[docs] class VideoMapping(TypedDict): """Grouped dataset sample indices and frame indices.""" video_to_indices: dict[str, list[int]] video_to_frame_ids: dict[str, list[int]]
[docs] class VideoDataset(Dataset): """Video datasets. Provides video_mapping attribute for video based interface and reference view samplers. """ def __init__(self, *args: ArgsType, **kwargs: ArgsType) -> None: """Initialize dataset.""" super().__init__(*args, **kwargs) self.video_mapping: VideoMapping = { "video_to_indices": {}, "video_to_frame_ids": {}, } def _sort_video_mapping(self, video_mapping: VideoMapping) -> VideoMapping: """Sort video mapping by frame ids.""" video_to_indices = video_mapping["video_to_indices"] video_to_frame_ids = video_mapping["video_to_frame_ids"] for seq in video_to_indices: sorted_zipped = sorted( list(zip(video_to_indices[seq], video_to_frame_ids[seq])), key=lambda x: x[1], ) sorted_indices, sorted_frame_ids = zip(*sorted_zipped) video_mapping["video_to_indices"][seq] = list(sorted_indices) video_mapping["video_to_frame_ids"][seq] = list(sorted_frame_ids) return video_mapping def _generate_video_mapping(self) -> VideoMapping: """Group dataset sample by their associated video ID. The sample index is an integer while video IDs are string. Returns: VideoMapping: Mapping of video IDs to sample indices and frame IDs. """ raise NotImplementedError