Source code for src.models.embedding.base

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...config import Config
    from ...data_models import Taxonomy

    from logging import Logger

from ..base import Model
from ...data_models import Taxonomy

from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity


[docs] class EmbeddingModel(Model): """ Custom class implementation for model. """ def __init__( self, config: Config, logger: Logger, model_id: str, taxonomy: Taxonomy ) -> None: super().__init__( config=config, logger=logger, model_id=model_id ) self.embedding_matrix = None self._load_model(model_id) self._prep_labels(taxonomy)
[docs] def _load_model(self, model_uri: str) -> None: self.model = SentenceTransformer(self.model_id)
_load_model.__doc__ = Model._load_model.__doc__
[docs] def _prep_labels(self, taxonomy: Taxonomy | list[str]) -> None: if not isinstance(taxonomy, list): self.labels = taxonomy.get_labels(max_depth=1) self.embedding_matrix = self._embed(self.labels) else: self.labels= taxonomy self.embedding_matrix = self._embed(taxonomy)
_prep_labels.__doc__ = Model._prep_labels.__doc__
[docs] def add_labels(self, labels: list[str]) -> None: self.embedding_matrix = self._embed(labels)
add_labels.__doc__ = Model.add_labels.__doc__
[docs] def _embed(self, text: str | list[str]) -> np.array: """ This internal function is used to create embeddings. :param text: the text to embed :return: """ if isinstance(text, list): return np.asarray([np.asarray(self.model.encode(t)) for t in text]) else: return np.asarray(self.model.encode(text))
[docs] def classify(self, text: str, multi_label, **kwargs) -> dict[str, float]: """ [Adaptation] customized for embedding similarity predictions """ text_embedding = np.asarray(self._embed(text)) if labels := kwargs.get("labels", None): self._prep_labels(labels) similarity = cosine_similarity(np.asarray([text_embedding]), self.embedding_matrix) return {k: v for k, v in zip(self.labels, similarity.tolist()[0])}
classify.__doc__ += Model.classify.__doc__