Source code for src.models.embedding.child_labels

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...config import Config
    from ...data_models import Taxonomy

    from logging import Logger

from .base import EmbeddingModel

from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity


[docs] class ChildLabelsEmbeddingModel(EmbeddingModel): """ Child label class implementation for embedding model """
[docs] def _text_formatting(self, taxonomy_node: Taxonomy) -> str: parent_label = taxonomy_node.label sub_labels = [label for label in taxonomy_node.all_linked_labels if label != parent_label] custom_text = f"Tekst over {parent_label} of meer specifiek {' of'.join(sub_labels)}" self.logger.debug(f"Created custom text: {custom_text}") return custom_text
[docs] def _prep_labels(self, taxonomy: Taxonomy | list[str]) -> None: if not isinstance(taxonomy, list): self.labels = taxonomy.get_labels(max_depth=1) label_string = [ self._text_formatting(taxonomy_node=parent_node) for parent_node in taxonomy.children ] self.embedding_matrix = self._embed(label_string) else: self.labels = taxonomy self.embedding_matrix = self._embed(taxonomy)