Source code for src.models.zeroshot.sentence

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...config import Config
    from logging import Logger

from ..base import Model

from .base import ZeroshotModel
from nltk.tokenize import sent_tokenize
import numpy as np
import torch


[docs] class SentenceZeroshotModel(ZeroshotModel): """ Zeroshot model that has the sentence based approach (predict for each sentence separately) """
[docs] @torch.inference_mode() def classify(self, text: str, multi_label: bool, **kwargs) -> dict[str, float]: zeroshot_scores = [] result = None labels = kwargs.get("labels", self.labels) self.logger.debug(f"Input text: {text}") self.logger.debug(f"predicting for: {labels}") for sentence in sent_tokenize(text, language="dutch"): self.logger.debug(f"predicting for: '{sentence}'") result = self.pipe(sentence, labels, multi_label=multi_label) zeroshot_scores.append(result.get("scores")) scores = np.asarray(zeroshot_scores).mean(axis=0) return {k: v for k, v in zip(result.get("labels"), scores.tolist())}
classify.__doc__ = ZeroshotModel.classify.__doc__