Training
TOC
Base
- class src.training.base.Training(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder)[source]
Bases:
objectBase class for training scripts
- abstract _create_dataset()[source]
This internal function is used to create datasets from the provided dataset builder :return:
Multilabel BERT
- class src.training.bert_multilabel.BertTraining(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder, sub_node: str = None, nested_mlflow_run: bool = False, trainer_flavour: TrainerTypes = TrainerTypes.CUSTOM)[source]
Bases:
Training,ABCTraining implementation for the bert class
- _abc_impl = <_abc._abc_data object>
- _create_dataset()[source]
This internal function is used to create datasets from the provided dataset builder :return:
Multilabel Distilbert
- class src.training.distilbert.DistilBertTraining(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder, sub_node: str = None, nested_mlflow_run: bool = False, trainer_flavour: TrainerTypes = TrainerTypes.CUSTOM)[source]
Bases:
Training,ABCTraining implementation for the distilbert class
- _abc_impl = <_abc._abc_data object>
- _create_dataset()[source]
This internal function is used to create datasets from the provided dataset builder :return:
Setfit
- class src.training.setfit.SetfitTraining(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder, setfit_head: SetfitClassifierHeads, sub_node: str = None, nested_mlflow_run: bool = False, trainer_flavour: TrainerTypes = TrainerTypes.SETFIT)[source]
Bases:
Training,ABCSetfit implementation for training
Compared to the distilbert or regular bert this had inferior performance (model not fully integrated into retraining stack)
- _abc_impl = <_abc._abc_data object>
- _create_dataset()[source]
This internal function is used to create datasets from the provided dataset builder :return:
Other
- src.training.get_training_module(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder, setfit_head=None, sub_node=None, nested_mlflow_run: bool = False, trainer_type: TrainerTypes = 'custom') Training[source]
This function returns the model that is selected using the config file.
- Parameters:
config – the global config to use
logger – the global logger to use
base_model_id – model id that is used as base model
dataset_builder – dataset builder object
setfit_head – setfit head to use (only relevant when pulling setfit models)
sub_node – specific node reference to train on
- Returns: