Source code for src.utils.zeroshot

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ..data_models import Taxonomy

from transformers import pipeline


[docs] class ZeroshotMatching: """ Zeroshot wrapper model """ def __init__(self, model_id: str, taxonomy: Taxonomy, verbose: bool = False): self.verbose = verbose self.taxonomy = taxonomy self.model_id = model_id self._init_model() self._prep_taxonomy()
[docs] def _prep_taxonomy(self) -> None: """ Formatting input taxonomy :return:nothing """ self.candid_labels = [l.prefLabel for l in self.taxonomy.get_labels()]
[docs] def _init_model(self): """ Model initialization :return: """ self.model = pipeline("zero-shot-classification", model=self.model_id)
[docs] def match(self, text: str): """ predicting with text input :param text: input text to classify :return: the classification response """ return self.model(text, self.candid_labels)