Models
TOC
Base class
- class src.models.base.Model(config: Config, logger: Logger, model_id: str)[source]
Bases:
objectBase model for all classes
- _load_model(model_id: str) None[source]
This function enables custom model preperations before executing the classification
- Parameters:
model_id – model_id to pull
- Returns:
- _prep_labels(taxonomy: Taxonomy | list[str]) None[source]
The function that prepares the labels, this converts them to the required format for further processing with a model. :param taxonomy: Taxonomy object where we will use the labels from :return:
- add_labels(labels: list[str]) None[source]
This function enables the adding of extra labels to the models setup
- Parameters:
labels – list of new labels to add/ set in place
- Returns:
nothing
- classify(text: str, multi_label: bool, **kwargs) dict[str, float][source]
Abstract function that executes the text classificatoin
- Parameters:
text – the text to classify
multi_label – boolean to identify if it is a multilabel problem
kwargs – potential extra vars
- Returns:
the results
- property device
This property returns the device that the model is running on.
- Returns:
torch device in use
Other
- src.models.get_model(config: Config, logger: Logger, model_id: str, taxonomy: Taxonomy, specific_model_type: ModelType = None, **kwargs) Model[source]
This function provides you with an instantiated model based on the provided arguments and config.
- Parameters:
config – the global config object
logger – the global logger object
model_id – model it to use as base model weights
taxonomy – the taxonomy to use for label predicitons
specific_model_type – model flavour explicitly defined (overrule config)
kwargs – extra kwargs used to initialize models
- Returns:
An instance of Model class
- src.models.get_topic_model(model_type: ModelType, config: Config, logger: Logger, dataset_builder: DatasetBuilder)[source]
Model provided specifically for the topic models.
- Parameters:
model_type – the specific model type requested
config – the global config object
logger – the global logger object
dataset_builder – the dataset builder object containing all the relevant information that could be used for the topic modeling
- Returns:
An instance of the requested topic modeling