Source code for src.dataset

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ..config import Config
    from logging import Logger
    from typing import Any
    from ..data_models import Taxonomy

# single label datasets
from .singlelabel import TrainingDataset, SingleTopLevel, \
    BasicDataset

# multilabel datasets
from .multilabel import MultiLabelTopLevelFullText, MultiLabelSecondLevelFullText, \
    MultiLabelTopLevelArticleBased, MultiLabelTopLevelDescriptionBased, MultiLabelTopLevelMotivationBased, \
    MultiLabelTopLevelShortTitleBased, MultiLabelTopLevelArticleSplit, SummaryStatisticDataset,\
    DynamicMultilabelTrainingDataset

from .builder import DatasetBuilder
from .base import TrainDataset
from ..enums import DatasetType


[docs] def create_dataset( config: Config, logger: Logger, dataset: list[dict[str, Any]], taxonomy: Taxonomy, tokenizer: object = None, sub_node: str = None, ) -> TrainDataset | list[dict]: """ Function that creates the dataset based on the configuration that is provided :param sub_node: sub_node to reselect input data for :param tokenizer: tokenizer to use in the dataset if provided :param config: configuration object :param logger: logger object :param dataset: the created dataset (list of dict) :param taxonomy: the taxonomy that is used for :return: """ match config.run.dataset.type: case DatasetType.MULTI_TOP_LEVEL_ALL_BASED.value | DatasetType.MULTI_TOP_LEVEL_ALL_BASED: logger.info("Selected MultiLabelTopLevelFullText") return MultiLabelTopLevelFullText( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node = sub_node ) case DatasetType.MULTI_SECOND_LEVEL_ALL_BASED.value | DatasetType.MULTI_SECOND_LEVEL_ALL_BASED: logger.info("Selected MultiLabelSecondLevelFullText") return MultiLabelSecondLevelFullText( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node=sub_node ) case DatasetType.MULTI_TOP_LEVEL_ARTICLE_BASED.value | DatasetType.MULTI_TOP_LEVEL_ARTICLE_BASED: logger.info("Selected MultiLabelTopLevelArticleBased") return MultiLabelTopLevelArticleBased( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node=sub_node ) case DatasetType.MULTI_TOP_LEVEL_DESCRIPTION_BASED.value | DatasetType.MULTI_TOP_LEVEL_DESCRIPTION_BASED: logger.info("Selected MultiLabelTopLevelDescriptionBased") return MultiLabelTopLevelDescriptionBased( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node=sub_node ) case DatasetType.MULTI_TOP_LEVEL_MOTIVATION_BASED.value | DatasetType.MULTI_TOP_LEVEL_MOTIVATION_BASED: logger.info("Selected MultiLabelTopLevelMotivationBased") return MultiLabelTopLevelMotivationBased( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node=sub_node ) case DatasetType.MULTI_TOP_LEVEL_SHORT_TITLE_BASED.value | DatasetType.MULTI_TOP_LEVEL_SHORT_TITLE_BASED: logger.info("Selected MultiLabelTopLevelShortTitleBased") return MultiLabelTopLevelShortTitleBased( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node=sub_node ) case DatasetType.MULTI_TOP_LEVEL_ARTICLE_SPLIT.value | DatasetType.MULTI_TOP_LEVEL_ARTICLE_SPLIT: logger.info("Selected MultiLabelTopLevelArticleSplit") return MultiLabelTopLevelArticleSplit( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node=sub_node ) case DatasetType.SUMMARY_STATISTIC_DATASET | DatasetType.SUMMARY_STATISTIC_DATASET.value: logger.info("Selected SummaryStatisticDataset") return SummaryStatisticDataset( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node=sub_node ) case DatasetType.DYNAMIC | DatasetType.DYNAMIC.value: logger.info(f"DynamicMultilabelTrainingDataset") return DynamicMultilabelTrainingDataset( config=config, dataset=dataset, taxonomy=taxonomy, logger=logger, tokenizer=tokenizer, sub_node=sub_node ) case DatasetType.UNPROCESSED | DatasetType.UNPROCESSED.value: return dataset case _: raise NotImplementedError("No such dataset available")