Source code for src.dataset.base

from __future__ import annotations

# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ..data_models import Taxonomy

import abc

from torch.utils.data import Dataset


[docs] class TrainDataset(Dataset, abc.ABC): """ Abstract training dataset class. """ @abc.abstractmethod def __len__(self): pass
[docs] @abc.abstractmethod def _get_label(self, idx: int) -> list[int]: pass
[docs] @abc.abstractmethod def _get_text(self, idx: int) -> str: pass
@abc.abstractmethod def __getitem__(self, idx) -> dict[str, str]: pass @property @abc.abstractmethod def candid_labels(self) -> list[str]: pass @candid_labels.setter @abc.abstractmethod def candid_labels(self, value: list[str]) -> None: pass @property @abc.abstractmethod def binarized_label_dictionary(self) -> dict[str, int]: pass @property @abc.abstractmethod def max_label_depth(self) -> int: pass @max_label_depth.setter @abc.abstractmethod def max_label_depth(self, value: int): pass @property @abc.abstractmethod def taxonomy(self) -> Taxonomy: pass @taxonomy.setter @abc.abstractmethod def taxonomy(self, value: Taxonomy): pass
[docs] @abc.abstractmethod def get_specific_record(self, idx: int, label_level: int) -> dict[str, str | list]: pass