from __future__ import annotations
# typing imports
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..config import Config
from logging import Logger
from ..dataset import DatasetBuilder
from .base import Training
from .distilbert import DistilBertTraining
from .setfit import SetfitTraining
from .bert_multilabel import BertTraining
from ..enums import TrainingFlavours, TrainerTypes
[docs]
def get_training_module(
config: Config,
logger: Logger,
base_model_id: str,
dataset_builder: DatasetBuilder,
setfit_head=None,
sub_node=None,
nested_mlflow_run: bool = False,
trainer_type: TrainerTypes = "custom"
) -> Training:
"""
This function returns the model that is selected using the config file.
:param config: the global config to use
:param logger: the global logger to use
:param base_model_id: model id that is used as base model
:param dataset_builder: dataset builder object
:param setfit_head: setfit head to use (only relevant when pulling setfit models)
:param sub_node: specific node reference to train on
:return:
"""
match config.run.model.flavour:
case TrainingFlavours.SETFIT | TrainingFlavours.SETFIT.value:
logger.debug("Selected Setfit")
return SetfitTraining(
config=config,
logger=logger,
base_model_id=base_model_id,
dataset_builder=dataset_builder,
setfit_head=setfit_head,
sub_node=sub_node,
nested_mlflow_run=nested_mlflow_run,
trainer_flavour=trainer_type
)
case TrainingFlavours.DISTIL_BERT | TrainingFlavours.DISTIL_BERT.value:
logger.debug("Selected Distilbert")
return DistilBertTraining(
config=config,
logger=logger,
base_model_id=base_model_id,
dataset_builder=dataset_builder,
sub_node=sub_node,
nested_mlflow_run=nested_mlflow_run,
trainer_flavour=trainer_type
)
case TrainingFlavours.BERT | TrainingFlavours.BERT.value:
logger.debug("Selected Bert")
return BertTraining(
config=config,
logger=logger,
base_model_id=base_model_id,
dataset_builder=dataset_builder,
sub_node=sub_node,
nested_mlflow_run=nested_mlflow_run,
trainer_flavour=trainer_type
)
case _:
raise ValueError("Provided training module does not exists!")