Source code for vis4d.data.reference

"""Reference View Sampling.

These Classes sample reference views from a dataset that contains videos.
This is usually used when a model needs multiple samples of a video during
training.
"""

from __future__ import annotations

from abc import abstractmethod
from typing import Callable, List

import numpy as np
from torch.utils.data import Dataset

from .const import CommonKeys as K
from .datasets import VideoDataset
from .typing import DictData

SortingFunc = Callable[[DictData, list[DictData]], List[DictData]]


[docs] def sort_key_first( cur_sample: DictData, ref_data: list[DictData] ) -> list[DictData]: """Sort views as key first.""" return [cur_sample, *ref_data]
[docs] def sort_temporal( cur_sample: DictData, ref_data: list[DictData] ) -> list[DictData]: """Sort views temporally.""" return sorted([cur_sample, *ref_data], key=lambda x: x[K.frame_ids])
[docs] class ReferenceViewSampler: """Base reference view sampler.""" def __init__(self, num_ref_samples: int) -> None: """Creates an instance of the class. Args: num_ref_samples (int): Number of reference views to sample. """ self.num_ref_samples = num_ref_samples
[docs] @abstractmethod def __call__( self, key_dataset_index: int, indices_in_video: list[int], frame_ids: list[int], ) -> list[int]: """Sample num_ref_samples reference view indices. Args: key_index (int): Index of key view in the video. indices_in_video (list[int]): All dataset indices in the video. frame_ids (list[int]): Frame ids of all views in the video. Returns: list[int]: dataset indices of reference views. """ raise NotImplementedError
[docs] class SequentialViewSampler(ReferenceViewSampler): """Sequential View Sampler."""
[docs] def __call__( self, key_dataset_index: int, indices_in_video: list[int], frame_ids: list[int], ) -> list[int]: """Sample sequential reference views.""" assert len(frame_ids) >= self.num_ref_samples + 1 key_index = indices_in_video.index(key_dataset_index) right = key_index + 1 + self.num_ref_samples if right <= len(indices_in_video): ref_dataset_indices = indices_in_video[key_index + 1 : right] else: left = key_index - (right - len(indices_in_video)) ref_dataset_indices = ( indices_in_video[left:key_index] + indices_in_video[key_index + 1 :] ) return ref_dataset_indices
[docs] class UniformViewSampler(ReferenceViewSampler): """View Sampler that chooses reference views uniform at random.""" def __init__(self, scope: int, num_ref_samples: int) -> None: """Creates an instance of the class. Args: scope (int): Define scope of neighborhood to key view to sample from. num_ref_samples (int): Number of reference views to sample. """ super().__init__(num_ref_samples) if scope != 0 and scope < num_ref_samples // 2: raise ValueError("Scope must be higher than num_ref_imgs / 2.") self.scope = scope def _get_valid_indices( self, key_index: int, indices_in_video: list[int], frame_ids: list[int] ) -> list[int]: """Get valid indices in video.""" key_fid = frame_ids[key_index] min_fid = max(0, key_fid - self.scope) max_fid = min(key_fid + self.scope, frame_ids[-1]) return [ ind for i, ind in enumerate(indices_in_video) if min_fid <= frame_ids[i] <= max_fid and i != key_index ]
[docs] def __call__( self, key_dataset_index: int, indices_in_video: list[int], frame_ids: list[int], ) -> list[int]: """Uniformly sample reference views.""" if self.scope > 0: key_index = indices_in_video.index(key_dataset_index) valid_indices = self._get_valid_indices( key_index, indices_in_video, frame_ids ) if len(valid_indices) > 0: assert len(valid_indices) >= self.num_ref_samples return np.random.choice( valid_indices, self.num_ref_samples, replace=False ).tolist() return [key_dataset_index] * self.num_ref_samples
[docs] class MultiViewDataset(Dataset[list[DictData]]): """Dataset that samples reference views from a video dataset.""" def __init__( self, dataset: VideoDataset, sampler: ReferenceViewSampler, sort_fn: SortingFunc = sort_key_first, num_retry: int = 3, match_key: str = K.boxes2d_track_ids, skip_nomatch_samples: bool = False, ) -> None: """Creates an instance of the class. Args: dataset (Dataset): Video dataset to sample from. sampler (ReferenceViewSampler): Sampler that samples reference views. sort_fn (SortingFunc, optional): Function that sorts key and reference views. Defaults to sort_key_first. num_retry (int, optional): Number of retries if no match is found. Defaults to 3. match_key (str, optional): Key to match reference views with key view. Defaults to K.boxes2d_track_ids. skip_nomatch_samples (bool, optional): Whether to skip samples where no match is found. Defaults to False. """ self.dataset = dataset self.sampler = sampler self.sort_fn = sort_fn self.num_retry = num_retry self.match_key = match_key self.skip_nomatch_samples = skip_nomatch_samples
[docs] def has_matches( self, key_data: DictData, ref_data: list[DictData] ) -> bool: """Check if key / ref data have matches.""" key_target = key_data[self.match_key] for ref_view in ref_data: ref_target = ref_view[self.match_key] match = np.equal( np.expand_dims(key_target, axis=1), ref_target[None] ) if match.any(): return True return False # pragma: no cover
[docs] def __len__(self) -> int: """Get length of dataset.""" return len(self.dataset)
[docs] def get_ref_data(self, ref_indices: list[int]) -> list[DictData]: """Get reference data from dataset.""" ref_data = [] for ref_index in ref_indices: ref_sample = self.dataset[ref_index] ref_sample["keyframes"] = False ref_data.append(ref_sample) assert self.sampler.num_ref_samples == len(ref_data) return ref_data
[docs] def __getitem__(self, index: int) -> list[DictData]: """Get item from dataset.""" cur_sample = self.dataset[index] cur_sample["keyframes"] = True indices_in_video = self.dataset.video_mapping["video_to_indices"][ cur_sample[K.sequence_names] ] frame_ids = self.dataset.video_mapping["video_to_frame_ids"][ cur_sample[K.sequence_names] ] if self.sampler.num_ref_samples > 0: for _ in range(self.num_retry): ref_indices = self.sampler(index, indices_in_video, frame_ids) ref_data = self.get_ref_data(ref_indices) if self.skip_nomatch_samples and not ( self.has_matches(cur_sample, ref_data) ): continue return self.sort_fn(cur_sample, ref_data) ref_indices = [index] * self.sampler.num_ref_samples ref_data = self.get_ref_data(ref_indices) return [cur_sample, *ref_data] return [cur_sample]