"""Combined Sampler."""
from __future__ import annotations
import torch
from torch import Tensor
from vis4d.common import ArgsType
from ..box2d import non_intersection, random_choice
from ..matchers.base import MatchResult
from .base import Sampler, SamplingResult
[docs]
class CombinedSampler(Sampler):
"""Combined sampler. Can have different strategies for pos/neg samples."""
def __init__(
self,
*args: ArgsType,
pos_strategy: str,
neg_strategy: str,
neg_pos_ub: float = 3.0,
floor_thr: float = -1.0,
floor_fraction: float = 0.0,
num_bins: int = 3,
bg_label: int = 0,
**kwargs: ArgsType,
):
"""Creates an instance of the class."""
super().__init__(*args, **kwargs)
self.neg_pos_ub = neg_pos_ub
self.floor_thr = floor_thr
self.floor_fraction = floor_fraction
self.num_bins = num_bins
self.bg_label = bg_label
if not pos_strategy in {
"instance_balanced",
"iou_balanced",
} or not neg_strategy in {"instance_balanced", "iou_balanced"}:
raise ValueError(
"strategies must be in [instance_balanced, iou_balanced]"
)
self.pos_strategy = getattr(self, pos_strategy + "_sampling")
self.neg_strategy = getattr(self, neg_strategy + "_sampling")
[docs]
@staticmethod
def instance_balanced_sampling(
idx_tensor: Tensor,
assigned_gts: Tensor,
assigned_gt_ious: Tensor, # pylint: disable=unused-argument
sample_size: int,
) -> Tensor:
"""Sample indices with balancing according to matched GT instance."""
if idx_tensor.numel() <= sample_size:
return idx_tensor
unique_gt_inds = assigned_gts.unique()
num_gts = len(unique_gt_inds)
num_per_gt = int(sample_size / float(num_gts))
sampled_inds_list = []
# sample specific amount per gt instance
for i in unique_gt_inds:
inds = torch.nonzero(assigned_gts == i, as_tuple=False)
inds = inds.squeeze(1)
if len(inds) > num_per_gt:
inds = random_choice(inds, num_per_gt)
sampled_inds_list.append(inds)
sampled_inds = torch.cat(sampled_inds_list)
# deal with edge cases
if len(sampled_inds) < sample_size:
num_extra = sample_size - len(sampled_inds)
extra_inds = non_intersection(idx_tensor, sampled_inds)
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
sampled_inds = torch.cat([sampled_inds, extra_inds])
return sampled_inds
[docs]
def iou_balanced_sampling(
self,
idx_tensor: Tensor,
assigned_gts: Tensor, # pylint: disable=unused-argument
assigned_gt_ious: Tensor,
sample_size: int,
) -> Tensor:
"""Sample indices with balancing according to IoU with matched GT."""
if idx_tensor.numel() <= sample_size:
return idx_tensor
# define 'floor' set - set with low iou samples
if self.floor_thr >= 0:
floor_set = idx_tensor[assigned_gt_ious <= self.floor_thr]
iou_sampling_set = idx_tensor[assigned_gt_ious > self.floor_thr]
else:
floor_set = None
iou_sampling_set = idx_tensor[assigned_gt_ious > self.floor_thr]
num_iou_set_samples = int(sample_size * (1 - self.floor_fraction))
if len(iou_sampling_set) > num_iou_set_samples:
if self.num_bins >= 2:
iou_sampled_inds = self.sample_within_intervals(
idx_tensor, assigned_gt_ious, num_iou_set_samples
)
else:
iou_sampled_inds = random_choice(
iou_sampling_set, num_iou_set_samples
)
else:
iou_sampled_inds = iou_sampling_set # pragma: no cover
if floor_set is not None:
num_floor_set_samples = sample_size - len(iou_sampled_inds)
if len(floor_set) > num_floor_set_samples:
sampled_floor_inds = random_choice(
floor_set, num_floor_set_samples
)
else:
sampled_floor_inds = floor_set # pragma: no cover
sampled_inds = torch.cat([sampled_floor_inds, iou_sampled_inds])
else:
sampled_inds = iou_sampled_inds
if len(sampled_inds) < sample_size: # pragma: no cover
num_extra = sample_size - len(sampled_inds)
extra_inds = non_intersection(idx_tensor, sampled_inds)
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
sampled_inds = torch.cat([sampled_inds, extra_inds])
return sampled_inds
[docs]
def forward(self, matching: MatchResult) -> SamplingResult:
"""Sample boxes according to strategies defined in cfg."""
pos_sample_size = int(self.batch_size * self.positive_fraction)
positive_mask: Tensor = (matching.assigned_labels != -1) & (
matching.assigned_labels != self.bg_label
)
negative_mask = torch.eq(matching.assigned_labels, self.bg_label)
positive = positive_mask.nonzero()[:, 0]
negative = negative_mask.nonzero()[:, 0]
num_pos = min(positive.numel(), pos_sample_size)
num_neg = self.batch_size - num_pos
if self.neg_pos_ub >= 0:
neg_upper_bound = int(self.neg_pos_ub * num_pos)
num_neg = min(num_neg, neg_upper_bound)
pos_idx = self.pos_strategy(
idx_tensor=positive,
assigned_gts=matching.assigned_gt_indices[positive_mask],
assigned_gt_ious=matching.assigned_gt_iou[positive_mask],
sample_size=num_pos,
)
neg_idx = self.neg_strategy(
idx_tensor=negative,
assigned_gts=matching.assigned_gt_indices[negative_mask],
assigned_gt_ious=matching.assigned_gt_iou[negative_mask],
sample_size=num_neg,
)
sampled_idcs = torch.cat([pos_idx, neg_idx], dim=0)
return SamplingResult(
sampled_box_indices=sampled_idcs,
sampled_target_indices=matching.assigned_gt_indices[sampled_idcs],
sampled_labels=matching.assigned_labels[sampled_idcs],
)
[docs]
def sample_within_intervals(
self,
idx_tensor: Tensor,
assigned_gt_ious: Tensor,
sample_size: int,
) -> Tensor:
"""Sample according to N iou intervals where N = num bins."""
floor_thr = max(self.floor_thr, 0.0)
max_iou = assigned_gt_ious.max()
iou_interval = (max_iou - floor_thr) / self.num_bins
per_bin_samples = int(sample_size / self.num_bins)
sampled_inds_list = []
for i in range(self.num_bins):
start_iou = floor_thr + i * iou_interval
end_iou = floor_thr + (i + 1) * iou_interval
tmp_set = (
(start_iou <= assigned_gt_ious) & (assigned_gt_ious < end_iou)
).nonzero()[:, 0]
if len(tmp_set) > per_bin_samples:
tmp_sampled_set = random_choice(
idx_tensor[tmp_set], per_bin_samples
)
else:
tmp_sampled_set = idx_tensor[tmp_set] # pragma: no cover
sampled_inds_list.append(tmp_sampled_set)
sampled_inds = torch.cat(sampled_inds_list)
if len(sampled_inds) < sample_size:
num_extra = sample_size - len(sampled_inds)
extra_inds = non_intersection(idx_tensor, sampled_inds)
if len(extra_inds) > num_extra:
extra_inds = random_choice(extra_inds, num_extra)
sampled_inds = torch.cat([sampled_inds, extra_inds])
return sampled_inds