SummerTime / dataset /st_dataset.py
aliabd
full demo working with old graido
7e3e85d
from abc import abstractmethod
from pprint import pformat
from time import sleep
from typing import List, Tuple, Optional, Union, Generator
from datasets import (
Dataset,
DatasetDict,
DatasetInfo,
concatenate_datasets,
load_dataset,
)
# Defualt values for retrying dataset download
DEFAULT_NUMBER_OF_RETRIES_ALLOWED = 5
DEFAULT_WAIT_SECONDS_BEFORE_RETRY = 5
# Default value for creating missing val/test splits
TEST_OR_VAL_SPLIT_RATIO = 0.1
class SummInstance:
"""
Basic instance for summarization tasks
"""
def __init__(
self, source: Union[List[str], str], summary: str, query: Optional[str] = None
):
"""
Create a summarization instance
:rtype: object
:param source: either `List[str]` or `str`, depending on the dataset itself, string joining may needed to fit
into specific models. For example, for the same document, it could be simply `str` or `List[str]` for
a list of sentences in the same document
:param summary: a string summary that serves as ground truth
:param query: Optional, applies when a string query is present
"""
self.source = source
self.summary = summary
self.query = query
def __repr__(self):
instance_dict = {"source": self.source, "summary": self.summary}
if self.query:
instance_dict["query"] = self.query
return str(instance_dict)
def __str__(self):
instance_dict = {"source": self.source, "summary": self.summary}
if self.query:
instance_dict["query"] = self.query
return pformat(instance_dict, indent=1)
class SummDataset:
"""
Dataset class for summarization, which takes into account of the following tasks:
* Single document summarization
* Multi-document/Dialogue summarization
* Query-based summarization
"""
def __init__(
self, dataset_args: Optional[Tuple[str]] = None, splitseed: Optional[int] = None
):
"""Create dataset information from the huggingface Dataset class
:rtype: object
:param dataset_args: a tuple containing arguments to passed on to the 'load_dataset_safe' method.
Only required for datasets loaded from the Huggingface library.
The arguments for each dataset are different and comprise of a string or multiple strings
:param splitseed: a number to instantiate the random generator used to generate val/test splits
for the datasets without them
"""
# Load dataset from huggingface, use default huggingface arguments
if self.huggingface_dataset:
dataset = self._load_dataset_safe(*dataset_args)
# Load non-huggingface dataset, use custom dataset builder
else:
dataset = self._load_dataset_safe(path=self.builder_script_path)
info_set = self._get_dataset_info(dataset)
# Ensure any dataset with a val or dev or validation split is standardised to validation split
if "val" in dataset:
dataset["validation"] = dataset["val"]
dataset.remove("val")
elif "dev" in dataset:
dataset["validation"] = dataset["dev"]
dataset.remove("dev")
# If no splits other other than training, generate them
assert (
"train" in dataset or "validation" in dataset or "test" in dataset
), "At least one of train/validation test needs to be not empty!"
if not ("validation" in dataset or "test" in dataset):
dataset = self._generate_missing_val_test_splits(dataset, splitseed)
self.description = info_set.description
self.citation = info_set.citation
self.homepage = info_set.homepage
# Extract the dataset entries from folders and load into dataset
self._train_set = self._process_data(dataset["train"])
self._validation_set = self._process_data(
dataset["validation"]
) # Some datasets have a validation split
self._test_set = self._process_data(dataset["test"])
@property
def train_set(self) -> Union[Generator[SummInstance, None, None], List]:
if self._train_set is not None:
return self._train_set
else:
print(
f"{self.dataset_name} does not contain a train set, empty list returned"
)
return list()
@property
def validation_set(self) -> Union[Generator[SummInstance, None, None], List]:
if self._validation_set is not None:
return self._validation_set
else:
print(
f"{self.dataset_name} does not contain a validation set, empty list returned"
)
return list()
@property
def test_set(self) -> Union[Generator[SummInstance, None, None], List]:
if self._test_set is not None:
return self._test_set
else:
print(
f"{self.dataset_name} does not contain a test set, empty list returned"
)
return list()
def _load_dataset_safe(self, *args, **kwargs) -> Dataset:
"""
This method creates a wrapper around the huggingface 'load_dataset()' function for a more robust download function,
the original 'load_dataset()' function occassionally fails when it cannot reach a server especially after multiple requests.
This method tackles this problem by attempting the download multiple times with a wait time before each retry
The wrapper method passes all arguments and keyword arguments to the 'load_dataset' function with no alteration.
:rtype: Dataset
:param args: non-keyword arguments to passed on to the 'load_dataset' function
:param kwargs: keyword arguments to passed on to the 'load_dataset' function
"""
tries = DEFAULT_NUMBER_OF_RETRIES_ALLOWED
wait_time = DEFAULT_WAIT_SECONDS_BEFORE_RETRY
for i in range(tries):
try:
dataset = load_dataset(*args, **kwargs)
except ConnectionError:
if i < tries - 1: # i is zero indexed
sleep(wait_time)
continue
else:
raise RuntimeError(
"Wait for a minute and attempt downloading the dataset again. \
The server hosting the dataset occassionally times out."
)
break
return dataset
def _get_dataset_info(self, data_dict: DatasetDict) -> DatasetInfo:
"""
Get the information set from the dataset
The information set contains: dataset name, description, version, citation and licence
:param data_dict: DatasetDict
:rtype: DatasetInfo
"""
return data_dict["train"].info
@abstractmethod
def _process_data(self, dataset: Dataset) -> Generator[SummInstance, None, None]:
"""
Abstract class method to process the data contained within each dataset.
Each dataset class processes it's own information differently due to the diversity in domains
This method processes the data contained in the dataset
and puts each data instance into a SummInstance object,
the SummInstance has the following properties [source, summary, query[optional]]
:param dataset: a train/validation/test dataset
:rtype: a generator yielding SummInstance objects
"""
return
def _generate_missing_val_test_splits(
self, dataset_dict: DatasetDict, seed: int
) -> DatasetDict:
"""
Creating the train, val and test splits from a dataset
the generated sets are 'train: ~.80', 'validation: ~.10', and 'test: ~10' in size
the splits are randomized for each object unless a seed is provided for the random generator
:param dataset: Arrow Dataset with containing, usually the train set
:param seed: seed for the random generator to shuffle the dataset
:rtype: Arrow DatasetDict containing the three splits
"""
# Return dataset if no train set available for splitting
if "train" not in dataset_dict:
if "validation" not in dataset_dict:
dataset_dict["validation"] = None
if "test" not in dataset_dict:
dataset_dict["test"] = None
return dataset_dict
# Create a 'test' split from 'train' if no 'test' set is available
if "test" not in dataset_dict:
dataset_traintest_split = dataset_dict["train"].train_test_split(
test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed
)
dataset_dict["train"] = dataset_traintest_split["train"]
dataset_dict["test"] = dataset_traintest_split["test"]
# Create a 'validation' split from the remaining 'train' set if no 'validation' set is available
if "validation" not in dataset_dict:
dataset_trainval_split = dataset_dict["train"].train_test_split(
test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed
)
dataset_dict["train"] = dataset_trainval_split["train"]
dataset_dict["validation"] = dataset_trainval_split["test"]
return dataset_dict
def _concatenate_dataset_dicts(
self, dataset_dicts: List[DatasetDict]
) -> DatasetDict:
"""
Concatenate two dataset dicts with similar splits and columns tinto one
:param dataset_dicts: A list of DatasetDicts
:rtype: DatasetDict containing the combined data
"""
# Ensure all dataset dicts have the same splits
setsofsplits = set(tuple(dataset_dict.keys()) for dataset_dict in dataset_dicts)
if len(setsofsplits) > 1:
raise ValueError("Splits must match for all datasets")
# Concatenate all datasets into one according to the splits
temp_dict = {}
for split in setsofsplits.pop():
split_set = [dataset_dict[split] for dataset_dict in dataset_dicts]
temp_dict[split] = concatenate_datasets(split_set)
return DatasetDict(temp_dict)
@classmethod
def generate_basic_description(cls) -> str:
"""
Automatically generate the basic description string based on the attributes
:rtype: string containing the description
:param cls: class object
"""
basic_description = (
f": {cls.dataset_name} is a "
f"{'query-based ' if cls.is_query_based else ''}"
f"{'dialogue ' if cls.is_dialogue_based else ''}"
f"{'multi-document' if cls.is_multi_document else 'single-document'} "
f"summarization dataset."
)
return basic_description
def show_description(self):
"""
Print the description of the dataset.
"""
print(self.dataset_name, ":\n", self.description)