"""Reference View Sampling.
These Classes sample reference views from a dataset that contains videos.
This is usually used when a model needs multiple samples of a video during
training.
"""
from __future__ import annotations
from abc import abstractmethod
from typing import Callable, List
import numpy as np
from torch.utils.data import Dataset
from .const import CommonKeys as K
from .datasets import VideoDataset
from .typing import DictData
SortingFunc = Callable[[DictData, list[DictData]], List[DictData]]
[docs]
def sort_key_first(
cur_sample: DictData, ref_data: list[DictData]
) -> list[DictData]:
"""Sort views as key first."""
return [cur_sample, *ref_data]
[docs]
def sort_temporal(
cur_sample: DictData, ref_data: list[DictData]
) -> list[DictData]:
"""Sort views temporally."""
return sorted([cur_sample, *ref_data], key=lambda x: x[K.frame_ids])
[docs]
class ReferenceViewSampler:
"""Base reference view sampler."""
def __init__(self, num_ref_samples: int) -> None:
"""Creates an instance of the class.
Args:
num_ref_samples (int): Number of reference views to sample.
"""
self.num_ref_samples = num_ref_samples
[docs]
@abstractmethod
def __call__(
self,
key_dataset_index: int,
indices_in_video: list[int],
frame_ids: list[int],
) -> list[int]:
"""Sample num_ref_samples reference view indices.
Args:
key_index (int): Index of key view in the video.
indices_in_video (list[int]): All dataset indices in the video.
frame_ids (list[int]): Frame ids of all views in the video.
Returns:
list[int]: dataset indices of reference views.
"""
raise NotImplementedError
[docs]
class SequentialViewSampler(ReferenceViewSampler):
"""Sequential View Sampler."""
[docs]
def __call__(
self,
key_dataset_index: int,
indices_in_video: list[int],
frame_ids: list[int],
) -> list[int]:
"""Sample sequential reference views."""
assert len(frame_ids) >= self.num_ref_samples + 1
key_index = indices_in_video.index(key_dataset_index)
right = key_index + 1 + self.num_ref_samples
if right <= len(indices_in_video):
ref_dataset_indices = indices_in_video[key_index + 1 : right]
else:
left = key_index - (right - len(indices_in_video))
ref_dataset_indices = (
indices_in_video[left:key_index]
+ indices_in_video[key_index + 1 :]
)
return ref_dataset_indices
[docs]
class MultiViewDataset(Dataset[list[DictData]]):
"""Dataset that samples reference views from a video dataset."""
def __init__(
self,
dataset: VideoDataset,
sampler: ReferenceViewSampler,
sort_fn: SortingFunc = sort_key_first,
num_retry: int = 3,
match_key: str = K.boxes2d_track_ids,
skip_nomatch_samples: bool = False,
) -> None:
"""Creates an instance of the class.
Args:
dataset (Dataset): Video dataset to sample from.
sampler (ReferenceViewSampler): Sampler that samples reference
views.
sort_fn (SortingFunc, optional): Function that sorts key and
reference views. Defaults to sort_key_first.
num_retry (int, optional): Number of retries if no match is found.
Defaults to 3.
match_key (str, optional): Key to match reference views with key
view. Defaults to K.boxes2d_track_ids.
skip_nomatch_samples (bool, optional): Whether to skip samples
where no match is found. Defaults to False.
"""
self.dataset = dataset
self.sampler = sampler
self.sort_fn = sort_fn
self.num_retry = num_retry
self.match_key = match_key
self.skip_nomatch_samples = skip_nomatch_samples
[docs]
def has_matches(
self, key_data: DictData, ref_data: list[DictData]
) -> bool:
"""Check if key / ref data have matches."""
key_target = key_data[self.match_key]
for ref_view in ref_data:
ref_target = ref_view[self.match_key]
match = np.equal(
np.expand_dims(key_target, axis=1), ref_target[None]
)
if match.any():
return True
return False # pragma: no cover
[docs]
def __len__(self) -> int:
"""Get length of dataset."""
return len(self.dataset)
[docs]
def get_ref_data(self, ref_indices: list[int]) -> list[DictData]:
"""Get reference data from dataset."""
ref_data = []
for ref_index in ref_indices:
ref_sample = self.dataset[ref_index]
ref_sample["keyframes"] = False
ref_data.append(ref_sample)
assert self.sampler.num_ref_samples == len(ref_data)
return ref_data
[docs]
def __getitem__(self, index: int) -> list[DictData]:
"""Get item from dataset."""
cur_sample = self.dataset[index]
cur_sample["keyframes"] = True
indices_in_video = self.dataset.video_mapping["video_to_indices"][
cur_sample[K.sequence_names]
]
frame_ids = self.dataset.video_mapping["video_to_frame_ids"][
cur_sample[K.sequence_names]
]
if self.sampler.num_ref_samples > 0:
for _ in range(self.num_retry):
ref_indices = self.sampler(index, indices_in_video, frame_ids)
ref_data = self.get_ref_data(ref_indices)
if self.skip_nomatch_samples and not (
self.has_matches(cur_sample, ref_data)
):
continue
return self.sort_fn(cur_sample, ref_data)
ref_indices = [index] * self.sampler.num_ref_samples
ref_data = self.get_ref_data(ref_indices)
return [cur_sample, *ref_data]
return [cur_sample]