Source code for src.models.zeroshot.base

from __future__ import annotations

from abc import ABC
# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...config import Config
    from ...data_models import Taxonomy
    from logging import Logger

from ..base import Model

from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch


[docs] class ZeroshotModel(Model): """ Base model for zeroshot approaches """ def __init__( self, config: Config, logger: Logger, model_id: str, taxonomy: Taxonomy ): super().__init__( config=config, logger=logger, model_id=model_id ) self._load_model(model_id) self._prep_labels(taxonomy)
[docs] def _load_model(self, model_id: str) -> None: """ [Adaptation] Logic for loading zeroshot models """ self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=self.config.run.model.pull_token) # force tokenizer to max length some models don't have this for some reason?? # should we try to do this only for the models that don't possess the max length information? self.tokenizer.model_max_length = 512 self.model = AutoModelForSequenceClassification.from_pretrained( model_id, token=self.config.run.model.pull_token ).to(self.device) self.pipe = pipeline( task="zero-shot-classification", model=self.model, tokenizer=self.tokenizer, # https://github.com/huggingface/transformers/issues/4501 device=self.device, )
_load_model.__doc__ += Model._load_model.__doc__
[docs] @torch.inference_mode() def nli_infer(self, premise: str, hypothesis: str): """ Low level implementation on how zeroshot models can be used aswel :param premise: input text :param hypothesis: parsed label text :return: score for prediction to be one of 3 hardcoded laabels """ input = self.tokenizer( premise, hypothesis, truncation=True, return_tensors="pt" ).to(self.device) output = self.model(input["input_ids"]) prediction = torch.softmax(output["logits"][0], -1).tolist() prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, ["entailment", "neutral", "contradiction"])} return prediction
[docs] @torch.inference_mode() def classify(self, text: str, multi_label: bool, **kwargs) -> dict[str, float]: labels = kwargs.get("labels", self.labels) result = self.pipe( text, labels, multi_label=multi_label ) return {k: v for k, v in zip(result.get("labels"), result.get("scores"))}
classify.__doc__ = Model.classify.__doc__