Source code for src.dataset.singlelabel.base

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...config import Config
    from ...data_models import Taxonomy
    from logging import Logger

from torch import Tensor
from transformers import AutoTokenizer

import abc
from typing import Any
from ..base import TrainDataset


[docs] class TrainingDataset(TrainDataset): """ This class is the basic implementation for the single label dataset concept. This class is based from 'TrainDataset', which on its own is derived from the torch.utils.Dataset. Typical usage: >>> from src.dataset.builder import DatasetBuilder >>> dataset_builder = DatasetBuilder(...) >>> ds = TrainingDataset( config=Config(), logger=logging.logger, taxonomy=Taxonomy(...), dataset=dataset_builder.train_dataset ) >>> # getting the first custom formatted data example >>> print(ds[0]) """ def __init__( self, config: Config, logger: Logger, taxonomy: Taxonomy, data: Any, tokenizer: AutoTokenizer = None, device: str = "cpu" ): self.logger = logger self.config = config self.tokenizer = tokenizer self.taxonomy = taxonomy self.dataset = data self.device = device @abc.abstractmethod def __get_label(self, idx: int) -> Tensor | str: pass @abc.abstractmethod def __get_text(self, idx: int) -> dict[str, Tensor] | str: pass def __getitem__(self, idx) -> dict[str, Tensor | str]: """ This function provides the functionality to retrieve a training/inference record based on the provided idx. The function also provides extra functionality like tokenization and moving tensors from one device to another. Example usage: >>> dataset = TrainingDataset(...) >>> record = dataset[10] >>> print(record) :param idx: the integer value of the input index :return: a dictionary containing the train/inference samples (adaptable with a config) """ label = self.__get_label(idx) text = self.__get_text(idx) tokenized_text = self.tokenizer.tokenize(text).to(self.device) return dict( input_ids=tokenized_text["input_ids"].to(self.device), attention_mask=tokenized_text["attention_mask"].to(self.device), label=label.to(self.device) ) def __len__(self) -> int: """ This method responds with the current length of the dataset :return: integer value representing the length of the internal dataset. """ return len(self.dataset)
[docs] def iter_records(self): """ Iterator function that iterates over all the items in the dataset :return: """ for i in range(len(self)): yield self[i]