[docs]classRandomSampler(Sampler):"""Random sampler class."""def__init__(self,*args:ArgsType,bg_label:int=0,**kwargs:ArgsType,):"""Creates an instance of the class."""super().__init__(*args,**kwargs)self.bg_label=bg_label
def_sample_labels(self,labels:torch.Tensor)->tuple[torch.Tensor,torch.Tensor]:"""Randomly sample indices from given labels."""positive=((labels!=-1)&(labels!=self.bg_label)).nonzero()[:,0]negative=torch.eq(labels,self.bg_label).nonzero()[:,0]num_pos=int(self.batch_size*self.positive_fraction)# protect against not enough positive examplesnum_pos=min(positive.numel(),num_pos)num_neg=self.batch_size-num_pos# protect against not enough negative examplesnum_neg=min(negative.numel(),num_neg)# randomly select positive and negative examplesperm1=torch.randperm(positive.numel(),device=positive.device)[:num_pos]perm2=torch.randperm(negative.numel(),device=negative.device)[:num_neg]pos_idx=positive[perm1]neg_idx=negative[perm2]returnpos_idx,neg_idx