from __future__ import annotations
# typing imports
from typing import TYPE_CHECKING
import mlflow
if TYPE_CHECKING:
from ..config import Config
from logging import Logger
from ..sparql import RequestHandler
from .regular import BenchmarkWrapper
[docs]
class EmbeddingSimilarityBenchmark(BenchmarkWrapper):
"""
This is the wrapper class for zeroshot models benchmarking, for more information check out the baseclass
>>> benchmark = EmbeddingSimilarityBenchmark(
config=Config(),
logger=logging.logger,
request_handler=RequestHandler(),
model_ids=["...", ...],
taxonomoy_reference="..."
)
>>> benchmark()
"""
def __init__(
self,
config: Config,
logger: Logger,
request_handler: RequestHandler,
model_ids: list[str] | str,
taxonomy_reference: str = "http://stad.gent/id/concepts/gent_words",
nested_mlflow_run: bool = False,
checkpoint_dir: str = "data"
):
super().__init__(
config=config,
logger=logger,
request_handler=request_handler,
model_ids=model_ids,
taxonomy_reference=taxonomy_reference,
nested_mlflow_run=nested_mlflow_run,
checkpoint_dir=checkpoint_dir
)
self._default_mlflow_tags = {"model_type": "Embedding"}
self._default_description = "Running evaluation over all specified embedding models"
@property
def default_mlflow_tags(self):
"""
This property provides a getter for the default mlflow tags that are provided by the selection of the class
Example usage:
>>> benchmark = EmbeddingSimilarityBenchmark(...)
>>> mlflow_tags = benchmark.default_mlflow_tags
:return: tags for mlflow
"""
return self._default_mlflow_tags
@property
def default_description(self):
"""
This property provides a getter for the default description that should be provided for mlflow logging
Example usage:
>>> benchmark = EmbeddingSimilarityBenchmark(...)
>>> description = EmbeddingSimilarityBenchmark.default_description
:return: string description for mlflow run
"""
return self._default_description