Source code for src.training.losses.asymetric

import torch
from torch import nn


[docs] class AsymmetricLossOptimized(nn.Module): """ This class is an optimized implementation for the asymmetric loss calculation, it is more memory efficient and allocates better on gpu memory. It also favors inplace operations. Asymmetric loss is a type of loss function that is used in machine learning to penalize different types of errors differently. This can be useful in tasks where some types of errors are more costly than others. In multi-label classification, asymmetric loss can be used to penalize false positives more heavily than false negatives. This is because false positives can have a greater impact on the real world. For example, in a medical diagnosis task, a false positive could result in a patient receiving unnecessary treatment. """ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): super(AsymmetricLossOptimized, self).__init__() self.gamma_neg = gamma_neg self.gamma_pos = gamma_pos self.clip = clip self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss self.eps = eps # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
[docs] def forward(self, x, y): """" Loss forward step functoin :param x: input logits :param y: targets (multi-label binarized vector) :return: the loss value (float) """ self.targets = y self.anti_targets = 1 - y # Calculating Probabilities self.xs_pos = torch.sigmoid(x) self.xs_neg = 1.0 - self.xs_pos # Asymmetric Clipping if self.clip is not None and self.clip > 0: self.xs_neg.add_(self.clip).clamp_(max=1) # Basic CE calculation self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) # Asymmetric Focusing if self.gamma_neg > 0 or self.gamma_pos > 0: if self.disable_torch_grad_focal_loss: torch._C.set_grad_enabled(False) self.xs_pos = self.xs_pos * self.targets self.xs_neg = self.xs_neg * self.anti_targets self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) if self.disable_torch_grad_focal_loss: torch._C.set_grad_enabled(True) self.loss *= self.asymmetric_w return -self.loss.sum()