Source code for src.dataset.multilabel.summary_statistic_dataset

from __future__ import annotations

from abc import ABC
# typing imports
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from ...config import Config
    from ...data_models import Taxonomy
    from logging import Logger

from transformers import AutoTokenizer
import torch
from torch import device
import abc

from .base import MultilabelTrainingDataset


[docs] class SummaryStatisticDataset(MultilabelTrainingDataset, ABC): """ [Adapated from baseclase] This implementation uses lvl2 labels PARENT DOCS: --- """ __doc__ += MultilabelTrainingDataset.__doc__
[docs] def _get_label(self, idx: int, label_level: int) -> list[int]: """ [Adapted implementation] overwritten from baseclase, label responds with string value instead """ __doc__ = MultilabelTrainingDataset._get_text.__doc__ labels = [] selected_record = self.dataset[idx] record_labels = selected_record.get("labels") if record_labels is None: record_labels = [] for label in record_labels: label_in_tree = self.taxonomy.find(label) self.logger.debug(f"Label ({label}) found in tree {label_in_tree}") if selected_label := label_in_tree.get(label_level, None): labels.append(selected_label.get("label")) return list(set(labels))
_get_label.__doc__ += MultilabelTrainingDataset._get_label.__doc__
[docs] def get_specific_record(self, idx: int, label_level: int) -> dict[str, str | list]: """ This function implements the functionality to retrieve what label is available at what level. :param idx: the index to take as integer value :param label_level: the label level as integer value :return: """ labels = self._get_label( idx=idx, label_level=label_level ) return dict(labels=labels)