Source code for src.training.subclasses.multilabel_distilbert_for_sequence_classification
from torch import nn
from transformers import DistilBertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from ..losses import AsymmetricLossOptimized
[docs]
class DistilBertForMultiLabelClassification(DistilBertForSequenceClassification):
"""
Custom implementation for multilabel distilbert classification
"""
mlb_losses = {
"asl": AsymmetricLossOptimized,
"bce": nn.BCEWithLogitsLoss,
}
def __init__(self, config, loss='bce', loss_args=None):
super().__init__(config)
if loss_args is None:
loss_args = {}
self.loss_fct = self.mlb_losses[loss](**loss_args)
[docs]
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
loss = None
if labels is not None:
if self.num_labels == 1:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
if not return_dict:
output = (logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
)
forward.__doc__ = DistilBertForSequenceClassification.forward.__doc__