[docs]classSegCrossEntropyLoss(Loss):"""Segmentation cross entropy loss class. Wrapper for nn.CrossEntropyLoss that additionally clips the output to the target size and converts the target mask tensor to long. """def__init__(self,reducer:LossReducer=mean_loss)->None:"""Creates an instance of the class. Args: reducer (LossReducer): Reducer for the loss function. Defaults to mean_loss. """super().__init__(reducer)
[docs]defforward(self,output:Tensor,target:Tensor,ignore_index:int=255)->LossesType:"""Forward pass. Args: output (list[Tensor]): Model output. target (Tensor): Assigned segmentation target mask. ignore_index (int): Ignore class id. Default to 255. Returns: LossesType: Computed loss. """losses:LossesType={}tgt_h,tgt_w=target.shape[-2:]losses["loss_seg"]=self.reducer(cross_entropy(output[:,:,:tgt_h,:tgt_w],target,ignore_index=ignore_index))returnlosses