[docs]defcalc_bisoftmax_affinity(detection_embeddings:torch.Tensor,track_embeddings:torch.Tensor,detection_class_ids:torch.Tensor|None=None,track_class_ids:torch.Tensor|None=None,with_categories:bool=False,)->torch.Tensor:"""Calculate affinity matrix using bisoftmax metric."""feats=torch.mm(detection_embeddings,track_embeddings.t())d2t_scores=feats.softmax(dim=1)t2d_scores=feats.softmax(dim=0)similarity_scores=(d2t_scores+t2d_scores)/2ifwith_categories:assert(detection_class_idsisnotNoneandtrack_class_idsisnotNone),"Please provide class ids if with_categories=True!"cat_same=detection_class_ids.view(-1,1)==track_class_ids.view(1,-1)similarity_scores*=cat_same.float()returnsimilarity_scores
[docs]defcosine_similarity(key_embeds:torch.Tensor,ref_embeds:torch.Tensor,normalize:bool=True,temperature:float=-1,)->torch.Tensor:"""Calculate cosine similarity."""ifnormalize:key_embeds=F.normalize(key_embeds,p=2,dim=1)ref_embeds=F.normalize(ref_embeds,p=2,dim=1)dists=torch.mm(key_embeds,ref_embeds.t())iftemperature>0:dists/=temperature# pragma: no coverreturndists