import sys from collections import Counter from pathlib import Path from typing import Callable, Dict, List, Tuple, Union import numpy as np from TTS.tts.datasets.dataset import * from TTS.tts.datasets.formatters import * def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. Args: <<<<<<< HEAD items (List[List]): A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. eval_split_max_size (int): Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). eval_split_size (float): If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). ======= items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. >>>>>>> Fix docstring """ speakers = [item["speaker_name"] for item in items] is_multi_speaker = len(set(speakers)) > 1 if eval_split_size > 1: eval_split_size = int(eval_split_size) else: if eval_split_max_size: eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size)) else: eval_split_size = int(len(items) * eval_split_size) assert ( eval_split_size > 0 ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format( 1 / len(items) ) np.random.seed(0) np.random.shuffle(items) if is_multi_speaker: items_eval = [] speakers = [item["speaker_name"] for item in items] speaker_counter = Counter(speakers) while len(items_eval) < eval_split_size: item_idx = np.random.randint(0, len(items)) speaker_to_be_removed = items[item_idx]["speaker_name"] if speaker_counter[speaker_to_be_removed] > 1: items_eval.append(items[item_idx]) speaker_counter[speaker_to_be_removed] -= 1 del items[item_idx] return items_eval, items return items[:eval_split_size], items[eval_split_size:] def load_tts_samples( datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None, eval_split_max_size=None, eval_split_size=0.01, ) -> Tuple[List[List], List[List]]: """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based on the dataset name. Args: datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are in the list, they are all merged. eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate an eval split automatically. Defaults to True. formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It must take the root_path and the meta_file name and return a list of samples in the format of `[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as example. Defaults to None. eval_split_max_size (int): Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). eval_split_size (float): If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). Returns: Tuple[List[List], List[List]: training and evaluation splits of the dataset. """ meta_data_train_all = [] meta_data_eval_all = [] if eval_split else None if not isinstance(datasets, list): datasets = [datasets] for dataset in datasets: name = dataset["name"] root_path = dataset["path"] meta_file_train = dataset["meta_file_train"] meta_file_val = dataset["meta_file_val"] ignored_speakers = dataset["ignored_speakers"] language = dataset["language"] # setup the right data processor if formatter is None: formatter = _get_formatter_by_name(name) # load train set meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) meta_data_train = [{**item, **{"language": language}} for item in meta_data_train] print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") # load evaluation split if set if eval_split: if meta_file_val: meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval] else: meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train # load attention masks for the duration predictor training if dataset.meta_file_attn_mask: meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): attn_file = meta_data[ins["audio_file"]].strip() meta_data_train_all[idx].update({"alignment_file": attn_file}) if meta_data_eval_all: for idx, ins in enumerate(meta_data_eval_all): attn_file = meta_data[ins["audio_file"]].strip() meta_data_eval_all[idx].update({"alignment_file": attn_file}) # set none for the next iter formatter = None return meta_data_train_all, meta_data_eval_all def load_attention_mask_meta_data(metafile_path): """Load meta data file created by compute_attention_masks.py""" with open(metafile_path, "r", encoding="utf-8") as f: lines = f.readlines() meta_data = [] for line in lines: wav_file, attn_file = line.split("|") meta_data.append([wav_file, attn_file]) return meta_data def _get_formatter_by_name(name): """Returns the respective preprocessing function.""" thismodule = sys.modules[__name__] return getattr(thismodule, name.lower()) def find_unique_chars(data_samples, verbose=True): texts = "".join(item[0] for item in data_samples) chars = set(texts) lower_chars = filter(lambda c: c.islower(), chars) chars_force_lower = [c.lower() for c in chars] chars_force_lower = set(chars_force_lower) if verbose: print(f" > Number of unique characters: {len(chars)}") print(f" > Unique characters: {''.join(sorted(chars))}") print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") return chars_force_lower