Source code for src.training.trainers.multilabel_trainer

from transformers import Trainer

from ..losses import AsymmetricLossOptimized


[docs] class MultilabelTrainer(Trainer): __doc__ = Trainer.__doc__ def __init__(self, **kwargs): self.label_distribution = kwargs.pop("label_dist") super().__init__(**kwargs) self.loss_class = AsymmetricLossOptimized()
[docs] def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.logits loss = self.loss_class.forward( x=logits.view(-1, self.model.config.num_labels), y=labels.float().view(-1, self.model.config.num_labels) ) return (loss, outputs) if return_outputs else loss