Source code for src.enums.supervised_flavours

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ..config import Config

import enum


[docs] class TrainingFlavours(str, enum.Enum): """ This enum is used to identify what type of training flavour you want to use. """ BERT: str = "bert" DISTIL_BERT: str = "distil_bert" SETFIT: str = "setfit"
[docs] @staticmethod def get_default_model_for_type(config: Config, model_flavour: TrainingFlavours) -> str: """ this function checks what the chosen flavour is and returns the defaulted value from the config for the model_id :param config: the global configuration object :param model_flavour: the enum value that is used. :return: format-able query in string format """ match model_flavour: case TrainingFlavours.BERT | TrainingFlavours.BERT.value: return config.run.training.default.bert_model_id case TrainingFlavours.DISTIL_BERT | TrainingFlavours.DISTIL_BERT.value: return config.run.training.default.distil_bert_model_id case TrainingFlavours.SETFIT | TrainingFlavours.SETFIT.value: return config.run.training.default.setfit_model_id