Training

TOC

Base

class src.training.base.Training(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder)[source]

Bases: object

Base class for training scripts

abstract _create_dataset()[source]

This internal function is used to create datasets from the provided dataset builder :return:

abstract _create_model()[source]

Function that instantiates the model for training useage :return:

abstract compute_metrics(pred)[source]

This function computes the custom metrics during the training process

Parameters:

pred – input values from training

Returns:

abstract train()[source]

This function executes the training code.

Returns:

Nothing

Multilabel BERT

class src.training.bert_multilabel.BertTraining(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder, sub_node: str = None, nested_mlflow_run: bool = False, trainer_flavour: TrainerTypes = TrainerTypes.CUSTOM)[source]

Bases: Training, ABC

Training implementation for the bert class

_abc_impl = <_abc._abc_data object>
_create_dataset()[source]

This internal function is used to create datasets from the provided dataset builder :return:

_create_model()[source]

Function that instantiates the model for training useage :return:

compute_metrics(pred)[source]

This function computes the custom metrics during the training process

Parameters:

pred – input values from training

Returns:

train()[source]

This function executes the training code.

Returns:

Nothing

Multilabel Distilbert

class src.training.distilbert.DistilBertTraining(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder, sub_node: str = None, nested_mlflow_run: bool = False, trainer_flavour: TrainerTypes = TrainerTypes.CUSTOM)[source]

Bases: Training, ABC

Training implementation for the distilbert class

_abc_impl = <_abc._abc_data object>
_create_dataset()[source]

This internal function is used to create datasets from the provided dataset builder :return:

_create_model()[source]

Function that instantiates the model for training useage :return:

compute_metrics(pred)[source]

This function computes the custom metrics during the training process

Parameters:

pred – input values from training

Returns:

train()[source]

This function executes the training code.

Returns:

Nothing

Setfit

class src.training.setfit.SetfitTraining(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder, setfit_head: SetfitClassifierHeads, sub_node: str = None, nested_mlflow_run: bool = False, trainer_flavour: TrainerTypes = TrainerTypes.SETFIT)[source]

Bases: Training, ABC

Setfit implementation for training

Compared to the distilbert or regular bert this had inferior performance (model not fully integrated into retraining stack)

_abc_impl = <_abc._abc_data object>
_create_dataset()[source]

This internal function is used to create datasets from the provided dataset builder :return:

_create_model()[source]

Function that instantiates the model for training useage :return:

compute_metrics(pred, labels)[source]

This function computes the custom metrics during the training process

Parameters:

pred – input values from training

Returns:

train()[source]

This function executes the training code.

Returns:

Nothing

Other

src.training.get_training_module(config: Config, logger: Logger, base_model_id: str, dataset_builder: DatasetBuilder, setfit_head=None, sub_node=None, nested_mlflow_run: bool = False, trainer_type: TrainerTypes = 'custom') Training[source]

This function returns the model that is selected using the config file.

Parameters:
  • config – the global config to use

  • logger – the global logger to use

  • base_model_id – model id that is used as base model

  • dataset_builder – dataset builder object

  • setfit_head – setfit head to use (only relevant when pulling setfit models)

  • sub_node – specific node reference to train on

Returns: