Source code for src.models.hybrid.semi_supervised_selective

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...config import Config
    from logging import Logger
    from ...data_models import Taxonomy

from ...enums import TaxonomyFindTypes
import torch

from ..base import Model


[docs] class SelectiveHybridModel(Model): """ Selective hybrid model is similar to the regular hybrid model, the only difference is that it only processes outputs of sub-nodes in the tree if they exceed a certain threshold """ def __init__( self, config: Config, logger: Logger, taxonomy: Taxonomy, supervised_model: Model, unsupervised_model: Model, ) -> None: super().__init__( config=config, logger=logger, model_id="" ) self.taxonomy = taxonomy self.supervised_model = supervised_model self.unsupervised_model = unsupervised_model # print(self.supervised_model, self.unsupervised_model)
[docs] def _prep_labels(self, taxonomy: Taxonomy) -> None: self.labels = self.taxonomy.get_level_specific_labels(level=2)
_prep_labels.__doc__ = Model._prep_labels.__doc__
[docs] @torch.inference_mode() def classify(self, text: str, multi_label: bool, **kwargs) -> dict[str, float]: prediction = self.supervised_model.classify( text=text, multi_label=multi_label ) children = self.taxonomy.children for label, score in prediction.items(): child_taxo: Taxonomy = [taxo for taxo in children if taxo.label == label][0] child_labels: list[str] = [taxo.label for taxo in child_taxo.children] self.unsupervised_model.add_labels(child_labels) result = { k: 0 if score <= self.config.run.benchmark.hybrid.minimum_threshold else v for k, v in self.unsupervised_model.classify( text=text, multi_label=multi_label, labels=child_labels ).items() } prediction = { **prediction, **result } # print(len(prediction.items())) return prediction
classify.__doc__ = Model.classify.__doc__