import collections import os import random from typing import Dict, List, Union import numpy as np import torch import tqdm from torch.utils.data import Dataset from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.utils.audio import AudioProcessor # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 torch.multiprocessing.set_sharing_strategy("file_system") def _parse_sample(item): language_name = None attn_file = None if len(item) == 5: text, wav_file, speaker_name, language_name, attn_file = item elif len(item) == 4: text, wav_file, speaker_name, language_name = item elif len(item) == 3: text, wav_file, speaker_name = item else: raise ValueError(" [!] Dataset cannot parse the sample.") return text, wav_file, speaker_name, language_name, attn_file def noise_augment_audio(wav): return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) class TTSDataset(Dataset): def __init__( self, outputs_per_step: int = 1, compute_linear_spec: bool = False, ap: AudioProcessor = None, samples: List[Dict] = None, tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, f0_cache_path: str = None, return_wav: bool = False, batch_group_size: int = 0, min_text_len: int = 0, max_text_len: int = float("inf"), min_audio_len: int = 0, max_audio_len: int = float("inf"), phoneme_cache_path: str = None, precompute_num_workers: int = 0, speaker_id_mapping: Dict = None, d_vector_mapping: Dict = None, language_id_mapping: Dict = None, use_noise_augment: bool = False, start_by_longest: bool = False, verbose: bool = False, ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. If you need something different, you can subclass and override. Args: outputs_per_step (int): Number of time frames predicted per step. compute_linear_spec (bool): compute linear spectrogram if True. ap (TTS.tts.utils.AudioProcessor): Audio processor object. samples (list): List of dataset samples. tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else use the given. Defaults to None. compute_f0 (bool): compute f0 if True. Defaults to False. f0_cache_path (str): Path to store f0 cache. Defaults to None. return_wav (bool): Return the waveform of the sample. Defaults to False. batch_group_size (int): Range of batch randomization after sorting sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a batch. Set 0 to disable. Defaults to 0. min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an OOM error in training. Defaults to float("inf"). phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a separate file. Defaults to None. precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the embedding layer. Defaults to None. d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None. use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() self.batch_group_size = batch_group_size self._samples = samples self.outputs_per_step = outputs_per_step self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 self.f0_cache_path = f0_cache_path self.min_audio_len = min_audio_len self.max_audio_len = max_audio_len self.min_text_len = min_text_len self.max_text_len = max_text_len self.ap = ap self.phoneme_cache_path = phoneme_cache_path self.speaker_id_mapping = speaker_id_mapping self.d_vector_mapping = d_vector_mapping self.language_id_mapping = language_id_mapping self.use_noise_augment = use_noise_augment self.start_by_longest = start_by_longest self.verbose = verbose self.rescue_item_idx = 1 self.pitch_computed = False self.tokenizer = tokenizer if self.tokenizer.use_phonemes: self.phoneme_dataset = PhonemeDataset( self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers ) if compute_f0: self.f0_dataset = F0Dataset( self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers ) if self.verbose: self.print_logs() @property def lengths(self): lens = [] for item in self.samples: _, wav_file, *_ = _parse_sample(item) audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio lens.append(audio_len) return lens @property def samples(self): return self._samples @samples.setter def samples(self, new_samples): self._samples = new_samples if hasattr(self, "f0_dataset"): self.f0_dataset.samples = new_samples if hasattr(self, "phoneme_dataset"): self.phoneme_dataset.samples = new_samples def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.load_data(idx) def print_logs(self, level: int = 0) -> None: indent = "\t" * level print("\n") print(f"{indent}> DataLoader initialization") print(f"{indent}| > Tokenizer:") self.tokenizer.print_logs(level + 1) print(f"{indent}| > Number of instances : {len(self.samples)}") def load_wav(self, filename): waveform = self.ap.load_wav(filename) assert waveform.size > 0 return waveform def get_phonemes(self, idx, text): out_dict = self.phoneme_dataset[idx] assert text == out_dict["text"], f"{text} != {out_dict['text']}" assert len(out_dict["token_ids"]) > 0 return out_dict def get_f0(self, idx): out_dict = self.f0_dataset[idx] item = self.samples[idx] assert item["audio_file"] == out_dict["audio_file"] return out_dict @staticmethod def get_attn_mask(attn_file): return np.load(attn_file) def get_token_ids(self, idx, text): if self.tokenizer.use_phonemes: token_ids = self.get_phonemes(idx, text)["token_ids"] else: token_ids = self.tokenizer.text_to_ids(text) return np.array(token_ids, dtype=np.int32) def load_data(self, idx): item = self.samples[idx] raw_text = item["text"] wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) # apply noise for augmentation if self.use_noise_augment: wav = noise_augment_audio(wav) # get token ids token_ids = self.get_token_ids(idx, item["text"]) # get pre-computed attention maps attn = None if "alignment_file" in item: attn = self.get_attn_mask(item["alignment_file"]) # after phonemization the text length may change # this is a shareful 🤭 hack to prevent longer phonemes # TODO: find a better fix if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: self.rescue_item_idx += 1 return self.load_data(self.rescue_item_idx) # get f0 values f0 = None if self.compute_f0: f0 = self.get_f0(idx)["f0"] sample = { "raw_text": raw_text, "token_ids": token_ids, "wav": wav, "pitch": f0, "attn": attn, "item_idx": item["audio_file"], "speaker_name": item["speaker_name"], "language_name": item["language"], "wav_file_name": os.path.basename(item["audio_file"]), } return sample @staticmethod def _compute_lengths(samples): new_samples = [] for item in samples: audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio text_lenght = len(item["text"]) item["audio_length"] = audio_length item["text_length"] = text_lenght new_samples += [item] return new_samples @staticmethod def filter_by_length(lengths: List[int], min_len: int, max_len: int): idxs = np.argsort(lengths) # ascending order ignore_idx = [] keep_idx = [] for idx in idxs: length = lengths[idx] if length < min_len or length > max_len: ignore_idx.append(idx) else: keep_idx.append(idx) return ignore_idx, keep_idx @staticmethod def sort_by_length(samples: List[List]): audio_lengths = [s["audio_length"] for s in samples] idxs = np.argsort(audio_lengths) # ascending order return idxs @staticmethod def create_buckets(samples, batch_group_size: int): assert batch_group_size > 0 for i in range(len(samples) // batch_group_size): offset = i * batch_group_size end_offset = offset + batch_group_size temp_items = samples[offset:end_offset] random.shuffle(temp_items) samples[offset:end_offset] = temp_items return samples @staticmethod def _select_samples_by_idx(idxs, samples): samples_new = [] for idx in idxs: samples_new.append(samples[idx]) return samples_new def preprocess_samples(self): r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length range. """ samples = self._compute_lengths(self.samples) # sort items based on the sequence length in ascending order text_lengths = [i["text_length"] for i in samples] audio_lengths = [i["audio_length"] for i in samples] text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) samples = self._select_samples_by_idx(keep_idx, samples) sorted_idxs = self.sort_by_length(samples) if self.start_by_longest: longest_idxs = sorted_idxs[-1] sorted_idxs[-1] = sorted_idxs[0] sorted_idxs[0] = longest_idxs samples = self._select_samples_by_idx(sorted_idxs, samples) if len(samples) == 0: raise RuntimeError(" [!] No samples left") # shuffle batch groups # create batches with similar length items # the larger the `batch_group_size`, the higher the length variety in a batch. if self.batch_group_size > 0: samples = self.create_buckets(samples, self.batch_group_size) # update items to the new sorted items audio_lengths = [s["audio_length"] for s in samples] text_lengths = [s["text_length"] for s in samples] self.samples = samples if self.verbose: print(" | > Preprocessing samples") print(" | > Max text length: {}".format(np.max(text_lengths))) print(" | > Min text length: {}".format(np.min(text_lengths))) print(" | > Avg text length: {}".format(np.mean(text_lengths))) print(" | ") print(" | > Max audio length: {}".format(np.max(audio_lengths))) print(" | > Min audio length: {}".format(np.min(audio_lengths))) print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) print(f" | > Num. instances discarded samples: {len(ignore_idx)}") print(" | > Batch group size: {}.".format(self.batch_group_size)) @staticmethod def _sort_batch(batch, text_lengths): """Sort the batch by the input text length for RNN efficiency. Args: batch (Dict): Batch returned by `__getitem__`. text_lengths (List[int]): Lengths of the input character sequences. """ text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) batch = [batch[idx] for idx in ids_sorted_decreasing] return batch, text_lengths, ids_sorted_decreasing def collate_fn(self, batch): r""" Perform preprocessing and create a final data batch: 1. Sort batch instances by text-length 2. Convert Audio signal to features. 3. PAD sequences wrt r. 4. Load to Torch. """ # Puts each data field into a tensor with outer dimension batch size if isinstance(batch[0], collections.abc.Mapping): token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) # sort items with text input length for RNN efficiency batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) # convert list of dicts to dict of lists batch = {k: [dic[k] for dic in batch] for k in batch[0]} # get language ids from language names if self.language_id_mapping is not None: language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] else: language_ids = None # get pre-computed d-vectors if self.d_vector_mapping is not None: wav_files_names = list(batch["wav_file_name"]) d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] else: d_vectors = None # get numerical speaker ids from speaker names if self.speaker_id_mapping: speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] else: speaker_ids = None # compute features mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] mel_lengths = [m.shape[1] for m in mel] # lengths adjusted by the reduction factor mel_lengths_adjusted = [ m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) if m.shape[1] % self.outputs_per_step else m.shape[1] for m in mel ] # compute 'stop token' targets stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] # PAD stop targets stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch token_ids = prepare_data(batch["token_ids"]).astype(np.int32) # PAD features with longest instance mel = prepare_tensor(mel, self.outputs_per_step) # B x D x T --> B x T x D mel = mel.transpose(0, 2, 1) # convert things to pytorch token_ids_lengths = torch.LongTensor(token_ids_lengths) token_ids = torch.LongTensor(token_ids) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) # speaker vectors if d_vectors is not None: d_vectors = torch.FloatTensor(d_vectors) if speaker_ids is not None: speaker_ids = torch.LongTensor(speaker_ids) if language_ids is not None: language_ids = torch.LongTensor(language_ids) # compute linear spectrogram linear = None if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] linear = torch.FloatTensor(linear).contiguous() # format waveforms wav_padded = None if self.return_wav: wav_lengths = [w.shape[0] for w in batch["wav"]] max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length wav_lengths = torch.LongTensor(wav_lengths) wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) for i, w in enumerate(batch["wav"]): mel_length = mel_lengths_adjusted[i] w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") w = w[: mel_length * self.ap.hop_length] wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) wav_padded.transpose_(1, 2) # format F0 if self.compute_f0: pitch = prepare_data(batch["pitch"]) assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT else: pitch = None # format attention masks attns = None if batch["attn"][0] is not None: attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] pad1 = token_ids.shape[1] - attn.shape[0] assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" attn = np.pad(attn, [[0, pad1], [0, pad2]]) attns[idx] = attn attns = prepare_tensor(attns, self.outputs_per_step) attns = torch.FloatTensor(attns).unsqueeze(1) return { "token_id": token_ids, "token_id_lengths": token_ids_lengths, "speaker_names": batch["speaker_name"], "linear": linear, "mel": mel, "mel_lengths": mel_lengths, "stop_targets": stop_targets, "item_idxs": batch["item_idx"], "d_vectors": d_vectors, "speaker_ids": speaker_ids, "attns": attns, "waveform": wav_padded, "raw_text": batch["raw_text"], "pitch": pitch, "language_ids": language_ids, } raise TypeError( ( "batch must contain tensors, numbers, dicts or lists;\ found {}".format( type(batch[0]) ) ) ) class PhonemeDataset(Dataset): """Phoneme Dataset for converting input text to phonemes and then token IDs At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data loading latency. If `cache_path` is already present, it skips the pre-computation. Args: samples (Union[List[List], List[Dict]]): List of samples. Each sample is a list or a dict. tokenizer (TTSTokenizer): Tokenizer to convert input text to phonemes. cache_path (str): Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. precompute_num_workers (int): Number of workers used for pre-computing the phonemes. Defaults to 0. """ def __init__( self, samples: Union[List[Dict], List[List]], tokenizer: "TTSTokenizer", cache_path: str, precompute_num_workers=0, ): self.samples = samples self.tokenizer = tokenizer self.cache_path = cache_path if cache_path is not None and not os.path.exists(cache_path): os.makedirs(cache_path) self.precompute(precompute_num_workers) def __getitem__(self, index): item = self.samples[index] ids = self.compute_or_load(item["audio_file"], item["text"]) ph_hat = self.tokenizer.ids_to_text(ids) return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} def __len__(self): return len(self.samples) def compute_or_load(self, wav_file, text): """Compute phonemes for the given text. If the phonemes are already cached, load them from cache. """ file_name = os.path.splitext(os.path.basename(wav_file))[0] file_ext = "_phoneme.npy" cache_path = os.path.join(self.cache_path, file_name + file_ext) try: ids = np.load(cache_path) except FileNotFoundError: ids = self.tokenizer.text_to_ids(text) np.save(cache_path, ids) return ids def get_pad_id(self): """Get pad token ID for sequence padding""" return self.tokenizer.pad_id def precompute(self, num_workers=1): """Precompute phonemes for all samples. We use pytorch dataloader because we are lazy. """ print("[*] Pre-computing phonemes...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 dataloder = torch.utils.data.DataLoader( batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn ) for _ in dataloder: pbar.update(batch_size) def collate_fn(self, batch): ids = [item["token_ids"] for item in batch] ids_lens = [item["token_ids_len"] for item in batch] texts = [item["text"] for item in batch] texts_hat = [item["ph_hat"] for item in batch] ids_lens_max = max(ids_lens) ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) for i, ids_len in enumerate(ids_lens): ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} def print_logs(self, level: int = 0) -> None: indent = "\t" * level print("\n") print(f"{indent}> PhonemeDataset ") print(f"{indent}| > Tokenizer:") self.tokenizer.print_logs(level + 1) print(f"{indent}| > Number of instances : {len(self.samples)}") class F0Dataset: """F0 Dataset for computing F0 from wav files in CPU Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It also computes the mean and std of F0 values if `normalize_f0` is True. Args: samples (Union[List[List], List[Dict]]): List of samples. Each sample is a list or a dict. ap (AudioProcessor): AudioProcessor to compute F0 from wav files. cache_path (str): Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. Defaults to None. precompute_num_workers (int): Number of workers used for pre-computing the F0 values. Defaults to 0. normalize_f0 (bool): Whether to normalize F0 values by mean and std. Defaults to True. """ def __init__( self, samples: Union[List[List], List[Dict]], ap: "AudioProcessor", verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, ): self.samples = samples self.ap = ap self.verbose = verbose self.cache_path = cache_path self.normalize_f0 = normalize_f0 self.pad_id = 0.0 self.mean = None self.std = None if cache_path is not None and not os.path.exists(cache_path): os.makedirs(cache_path) self.precompute(precompute_num_workers) if normalize_f0: self.load_stats(cache_path) def __getitem__(self, idx): item = self.samples[idx] f0 = self.compute_or_load(item["audio_file"]) if self.normalize_f0: assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" f0 = self.normalize(f0) return {"audio_file": item["audio_file"], "f0": f0} def __len__(self): return len(self.samples) def precompute(self, num_workers=0): print("[*] Pre-computing F0s...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 # we do not normalize at preproessing normalize_f0 = self.normalize_f0 self.normalize_f0 = False dataloder = torch.utils.data.DataLoader( batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn ) computed_data = [] for batch in dataloder: f0 = batch["f0"] computed_data.append(f for f in f0) pbar.update(batch_size) self.normalize_f0 = normalize_f0 if self.normalize_f0: computed_data = [tensor for batch in computed_data for tensor in batch] # flatten pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) pitch_stats = {"mean": pitch_mean, "std": pitch_std} np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) def get_pad_id(self): return self.pad_id @staticmethod def create_pitch_file_path(wav_file, cache_path): file_name = os.path.splitext(os.path.basename(wav_file))[0] pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") return pitch_file @staticmethod def _compute_and_save_pitch(ap, wav_file, pitch_file=None): wav = ap.load_wav(wav_file) pitch = ap.compute_f0(wav) if pitch_file: np.save(pitch_file, pitch) return pitch @staticmethod def compute_pitch_stats(pitch_vecs): nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) mean, std = np.mean(nonzeros), np.std(nonzeros) return mean, std def load_stats(self, cache_path): stats_path = os.path.join(cache_path, "pitch_stats.npy") stats = np.load(stats_path, allow_pickle=True).item() self.mean = stats["mean"].astype(np.float32) self.std = stats["std"].astype(np.float32) def normalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch = pitch - self.mean pitch = pitch / self.std pitch[zero_idxs] = 0.0 return pitch def denormalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch *= self.std pitch += self.mean pitch[zero_idxs] = 0.0 return pitch def compute_or_load(self, wav_file): """ compute pitch and return a numpy array of pitch values """ pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) if not os.path.exists(pitch_file): pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) else: pitch = np.load(pitch_file) return pitch.astype(np.float32) def collate_fn(self, batch): audio_file = [item["audio_file"] for item in batch] f0s = [item["f0"] for item in batch] f0_lens = [len(item["f0"]) for item in batch] f0_lens_max = max(f0_lens) f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) for i, f0_len in enumerate(f0_lens): f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens} def print_logs(self, level: int = 0) -> None: indent = "\t" * level print("\n") print(f"{indent}> F0Dataset ") print(f"{indent}| > Number of instances : {len(self.samples)}")