from __future__ import annotations
import traceback
from abc import ABC
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..config import Config
from logging import Logger
from ..dataset import DatasetBuilder
from ..enums import SetfitClassifierHeads
from setfit import SetFitModel
from uuid import uuid4
from .trainers import CustomSetFitTrainer
from .base import Training
from ..enums import TrainerTypes
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, jaccard_score, hamming_loss, \
classification_report, multilabel_confusion_matrix
import mlflow
import os
from shutil import rmtree
[docs]
class SetfitTraining(Training, ABC):
"""
Setfit implementation for training
Compared to the distilbert or regular bert this had inferior performance (model not fully integrated into retraining stack)
"""
def __init__(
self,
config: Config,
logger: Logger,
base_model_id: str,
dataset_builder: DatasetBuilder,
setfit_head: SetfitClassifierHeads,
sub_node: str = None,
nested_mlflow_run: bool = False,
trainer_flavour: TrainerTypes = TrainerTypes.SETFIT
) -> None:
super().__init__(
config=config,
logger=logger,
base_model_id=base_model_id,
dataset_builder=dataset_builder
)
self.sub_node = sub_node
self.nested_mlflow_run = nested_mlflow_run
self.setfit_head = setfit_head or SetfitClassifierHeads.SKLEARN_MULTI_OUTPUT
mlflow.autolog()
self.train_folder = f"/tmp/training_{uuid4().hex}"
self.count_flag = 0
os.makedirs(self.train_folder, exist_ok=True)
self._create_dataset()
self._create_model()
self.trainer_flavour = trainer_flavour
[docs]
def compute_metrics(self, pred, labels):
self.count_flag += 1
# print(f"pred: ", pred)
# print(f"label: ", labels)
# labels = pred.label_ids.reshape(-1, len(self.target_names))
probs = torch.sigmoid(torch.tensor(pred)).cpu()
preds = torch.where(probs < 0.5, 0, 1).int().numpy() # this is somewhat of a duplicate line of code?
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='micro')
acc = accuracy_score(labels, preds)
jaccard = jaccard_score(labels, preds, average='micro')
hamming = hamming_loss(labels, preds)
clsf_report = classification_report(
labels,
preds,
target_names=self.target_names
)
multilabel_conf_matrix = multilabel_confusion_matrix(labels, preds)
classification_report_file = 'Classification_report.txt'
with open(os.path.join(self.train_folder, classification_report_file), 'w+') as f:
f.write(clsf_report)
confusion_matrix_file = 'Confusion_matrix.txt'
with open(os.path.join(self.train_folder, confusion_matrix_file), 'w+') as f:
f.write(str(multilabel_conf_matrix))
metrics = {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall,
'jaccard_score': jaccard,
'hamming_loss': hamming
}
mlflow.log_metrics({k: v for k, v in metrics.items()}, step=self.count_flag)
return metrics
[docs]
def _create_dataset(self):
from ..dataset import create_dataset
from datasets import Dataset
train_dataset = create_dataset(
config=self.config,
logger=self.logger,
dataset=self.dataset_builder.train_dataset,
taxonomy=self.dataset_builder.taxonomy,
sub_node=self.sub_node
)
self.target_names = list(train_dataset.binarized_label_dictionary.keys())
self.train_ds = Dataset.from_list(
train_dataset
)
eval_dataset = create_dataset(
config=self.config,
logger=self.logger,
dataset=self.dataset_builder.test_dataset,
taxonomy=self.dataset_builder.taxonomy,
sub_node=self.sub_node
)
self.eval_ds = Dataset.from_list(
eval_dataset
)
[docs]
def _create_model(self):
self.model = SetFitModel.from_pretrained(
self.base_model_id,
**SetfitClassifierHeads.match(self.config, self.setfit_head)
)
[docs]
def train(self):
try:
trainer = CustomSetFitTrainer(
model=self.model,
train_dataset=self.train_ds,
eval_dataset=self.eval_ds,
batch_size=self.config.run.training.arguments.per_device_train_batch_size,
num_iterations=self.config.run.training.arguments.num_train_epochs,
metric=self.compute_metrics,
column_mapping={"text": "text", "labels": "label"},
)
trainer.train()
model_checkpoint_dir = f"{self.train_folder}/ModelCheckpoint"
os.makedirs(model_checkpoint_dir, exist_ok=True)
trainer.model.save_pretrained(model_checkpoint_dir)
mlflow.log_artifacts(self.train_folder)
metrics = trainer.evaluate()
mlflow.log_metrics(metrics)
except Exception as ex:
traceback.print_exception(ex)
self.logger.error(f"The following error occurred during training: {ex}")
mlflow.set_tag("LOG_STATUS", "FAILED")
finally:
rmtree(self.train_folder)
def __call__(self):
self.train()