import torch
from torch import nn
[docs]
class AsymmetricLossOptimized(nn.Module):
"""
This class is an optimized implementation for the asymmetric loss calculation, it is more memory efficient and allocates
better on gpu memory. It also favors inplace operations.
Asymmetric loss is a type of loss function that is used in machine learning to penalize different types of errors differently.
This can be useful in tasks where some types of errors are more costly than others.
In multi-label classification, asymmetric loss can be used to penalize false positives more heavily than false negatives.
This is because false positives can have a greater impact on the real world. For example, in a medical diagnosis task,
a false positive could result in a patient receiving unnecessary treatment.
"""
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
super(AsymmetricLossOptimized, self).__init__()
self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps
# prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
[docs]
def forward(self, x, y):
""""
Loss forward step functoin
:param x: input logits
:param y: targets (multi-label binarized vector)
:return: the loss value (float)
"""
self.targets = y
self.anti_targets = 1 - y
# Calculating Probabilities
self.xs_pos = torch.sigmoid(x)
self.xs_neg = 1.0 - self.xs_pos
# Asymmetric Clipping
if self.clip is not None and self.clip > 0:
self.xs_neg.add_(self.clip).clamp_(max=1)
# Basic CE calculation
self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(False)
self.xs_pos = self.xs_pos * self.targets
self.xs_neg = self.xs_neg * self.anti_targets
self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(True)
self.loss *= self.asymmetric_w
return -self.loss.sum()