Source code for src.models.zeroshot.child_labels

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...data_models import Taxonomy


from ...data_models import Taxonomy
from .base import ZeroshotModel
import torch


[docs] class ChildLabelsZeroshotModel(ZeroshotModel): """ Zeroshot approach that implements the label merging of all child labels """
[docs] def _text_formatting(self, taxonomy_node: Taxonomy) -> str: """ Custom text formatting logic :param taxonomy_node: taxonomy node to format text for :return: """ parent_label = taxonomy_node.label sub_labels = [label for label in taxonomy_node.all_linked_labels if label != parent_label] custom_text = f"Tekst over {parent_label} of meer specifiek {' of'.join(sub_labels)}" self.logger.debug(f"Created custom text: {custom_text}") return custom_text
[docs] def _prep_labels(self, taxonomy: Taxonomy) -> None: if not isinstance(taxonomy, list): self.labels = taxonomy.get_labels(max_depth=1) self.build_labels = [ self._text_formatting(taxonomy_node=parent_node) for parent_node in taxonomy.children ] else: self.labels = taxonomy
_prep_labels.__doc__ = ZeroshotModel._prep_labels.__doc__
[docs] @torch.inference_mode() def classify(self, text: str, multi_label: bool, **kwargs) -> dict[str, float]: # custom adaptation required here, longer label input returns max length errors in # huggingface pipeline implementation... scores = [] labels = kwargs.get("labels", self.build_labels) for label in labels: inputs = self.tokenizer.encode( text, label, return_tensors='pt', truncation=True, truncation_strategy='only_first' ) logits = self.model(inputs.to(self.device))[0] contradiction_logits = logits[:, [0, 2]] probs = contradiction_logits.softmax(dim=1) scores.append(probs[:, 1]) # # result = self.pipe( # text, # self.labels, # multi_label=multi_label # ) return {k: v for k, v in zip(kwargs.get("labels", self.labels), scores)}
classify.__doc__ = ZeroshotModel._prep_labels.__doc__