Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for vis4d.op.box.samplers.pseudo
"""Pseudo Sampler."""
from __future__ import annotations
import torch
from ..matchers.base import MatchResult
from .base import Sampler , SamplingResult
[docs]
class PseudoSampler ( Sampler ):
"""Pseudo sampler class (does nothing)."""
def __init__ ( self ) -> None :
"""Init."""
super ( Sampler , self ) . __init__ ()
[docs]
def forward ( self , matching : MatchResult ) -> SamplingResult :
"""Sample boxes randomly."""
pos_idx , neg_idx = self . _sample_labels ( matching . assigned_labels )
sampled_idcs = torch . cat ([ pos_idx , neg_idx ], dim = 0 )
return SamplingResult (
sampled_box_indices = sampled_idcs ,
sampled_target_indices = matching . assigned_gt_indices [ sampled_idcs ],
sampled_labels = matching . assigned_labels [ sampled_idcs ],
)
@staticmethod
def _sample_labels (
labels : torch . Tensor ,
) -> tuple [ torch . Tensor , torch . Tensor ]:
"""Randomly sample indices from given labels."""
positive = (( labels != - 1 ) & ( labels != 0 )) . nonzero ()[:, 0 ]
negative = torch . eq ( labels , 0 ) . nonzero ()[:, 0 ]
return positive , negative