|
import collections |
|
import os |
|
import random |
|
from typing import Dict, List, Union |
|
|
|
import numpy as np |
|
import torch |
|
import tqdm |
|
from torch.utils.data import Dataset |
|
|
|
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor |
|
from TTS.utils.audio import AudioProcessor |
|
|
|
|
|
|
|
torch.multiprocessing.set_sharing_strategy("file_system") |
|
|
|
|
|
def _parse_sample(item): |
|
language_name = None |
|
attn_file = None |
|
if len(item) == 5: |
|
text, wav_file, speaker_name, language_name, attn_file = item |
|
elif len(item) == 4: |
|
text, wav_file, speaker_name, language_name = item |
|
elif len(item) == 3: |
|
text, wav_file, speaker_name = item |
|
else: |
|
raise ValueError(" [!] Dataset cannot parse the sample.") |
|
return text, wav_file, speaker_name, language_name, attn_file |
|
|
|
|
|
def noise_augment_audio(wav): |
|
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) |
|
|
|
|
|
class TTSDataset(Dataset): |
|
def __init__( |
|
self, |
|
outputs_per_step: int = 1, |
|
compute_linear_spec: bool = False, |
|
ap: AudioProcessor = None, |
|
samples: List[Dict] = None, |
|
tokenizer: "TTSTokenizer" = None, |
|
compute_f0: bool = False, |
|
f0_cache_path: str = None, |
|
return_wav: bool = False, |
|
batch_group_size: int = 0, |
|
min_text_len: int = 0, |
|
max_text_len: int = float("inf"), |
|
min_audio_len: int = 0, |
|
max_audio_len: int = float("inf"), |
|
phoneme_cache_path: str = None, |
|
precompute_num_workers: int = 0, |
|
speaker_id_mapping: Dict = None, |
|
d_vector_mapping: Dict = None, |
|
language_id_mapping: Dict = None, |
|
use_noise_augment: bool = False, |
|
start_by_longest: bool = False, |
|
verbose: bool = False, |
|
): |
|
"""Generic π data loader for `tts` models. It is configurable for different outputs and needs. |
|
|
|
If you need something different, you can subclass and override. |
|
|
|
Args: |
|
outputs_per_step (int): Number of time frames predicted per step. |
|
|
|
compute_linear_spec (bool): compute linear spectrogram if True. |
|
|
|
ap (TTS.tts.utils.AudioProcessor): Audio processor object. |
|
|
|
samples (list): List of dataset samples. |
|
|
|
tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else |
|
use the given. Defaults to None. |
|
|
|
compute_f0 (bool): compute f0 if True. Defaults to False. |
|
|
|
f0_cache_path (str): Path to store f0 cache. Defaults to None. |
|
|
|
return_wav (bool): Return the waveform of the sample. Defaults to False. |
|
|
|
batch_group_size (int): Range of batch randomization after sorting |
|
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a |
|
batch. Set 0 to disable. Defaults to 0. |
|
|
|
min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. |
|
Defaults to 0. |
|
|
|
max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. |
|
Defaults to float("inf"). |
|
|
|
min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. |
|
Defaults to 0. |
|
|
|
max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. |
|
The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to |
|
this value if you encounter an OOM error in training. Defaults to float("inf"). |
|
|
|
phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a |
|
separate file. Defaults to None. |
|
|
|
precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. |
|
|
|
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the |
|
embedding layer. Defaults to None. |
|
|
|
d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None. |
|
|
|
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. |
|
|
|
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. |
|
|
|
verbose (bool): Print diagnostic information. Defaults to false. |
|
""" |
|
super().__init__() |
|
self.batch_group_size = batch_group_size |
|
self._samples = samples |
|
self.outputs_per_step = outputs_per_step |
|
self.compute_linear_spec = compute_linear_spec |
|
self.return_wav = return_wav |
|
self.compute_f0 = compute_f0 |
|
self.f0_cache_path = f0_cache_path |
|
self.min_audio_len = min_audio_len |
|
self.max_audio_len = max_audio_len |
|
self.min_text_len = min_text_len |
|
self.max_text_len = max_text_len |
|
self.ap = ap |
|
self.phoneme_cache_path = phoneme_cache_path |
|
self.speaker_id_mapping = speaker_id_mapping |
|
self.d_vector_mapping = d_vector_mapping |
|
self.language_id_mapping = language_id_mapping |
|
self.use_noise_augment = use_noise_augment |
|
self.start_by_longest = start_by_longest |
|
|
|
self.verbose = verbose |
|
self.rescue_item_idx = 1 |
|
self.pitch_computed = False |
|
self.tokenizer = tokenizer |
|
|
|
if self.tokenizer.use_phonemes: |
|
self.phoneme_dataset = PhonemeDataset( |
|
self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers |
|
) |
|
|
|
if compute_f0: |
|
self.f0_dataset = F0Dataset( |
|
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers |
|
) |
|
|
|
if self.verbose: |
|
self.print_logs() |
|
|
|
@property |
|
def lengths(self): |
|
lens = [] |
|
for item in self.samples: |
|
_, wav_file, *_ = _parse_sample(item) |
|
audio_len = os.path.getsize(wav_file) / 16 * 8 |
|
lens.append(audio_len) |
|
return lens |
|
|
|
@property |
|
def samples(self): |
|
return self._samples |
|
|
|
@samples.setter |
|
def samples(self, new_samples): |
|
self._samples = new_samples |
|
if hasattr(self, "f0_dataset"): |
|
self.f0_dataset.samples = new_samples |
|
if hasattr(self, "phoneme_dataset"): |
|
self.phoneme_dataset.samples = new_samples |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
return self.load_data(idx) |
|
|
|
def print_logs(self, level: int = 0) -> None: |
|
indent = "\t" * level |
|
print("\n") |
|
print(f"{indent}> DataLoader initialization") |
|
print(f"{indent}| > Tokenizer:") |
|
self.tokenizer.print_logs(level + 1) |
|
print(f"{indent}| > Number of instances : {len(self.samples)}") |
|
|
|
def load_wav(self, filename): |
|
waveform = self.ap.load_wav(filename) |
|
assert waveform.size > 0 |
|
return waveform |
|
|
|
def get_phonemes(self, idx, text): |
|
out_dict = self.phoneme_dataset[idx] |
|
assert text == out_dict["text"], f"{text} != {out_dict['text']}" |
|
assert len(out_dict["token_ids"]) > 0 |
|
return out_dict |
|
|
|
def get_f0(self, idx): |
|
out_dict = self.f0_dataset[idx] |
|
item = self.samples[idx] |
|
assert item["audio_file"] == out_dict["audio_file"] |
|
return out_dict |
|
|
|
@staticmethod |
|
def get_attn_mask(attn_file): |
|
return np.load(attn_file) |
|
|
|
def get_token_ids(self, idx, text): |
|
if self.tokenizer.use_phonemes: |
|
token_ids = self.get_phonemes(idx, text)["token_ids"] |
|
else: |
|
token_ids = self.tokenizer.text_to_ids(text) |
|
return np.array(token_ids, dtype=np.int32) |
|
|
|
def load_data(self, idx): |
|
item = self.samples[idx] |
|
|
|
raw_text = item["text"] |
|
|
|
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) |
|
|
|
|
|
if self.use_noise_augment: |
|
wav = noise_augment_audio(wav) |
|
|
|
|
|
token_ids = self.get_token_ids(idx, item["text"]) |
|
|
|
|
|
attn = None |
|
if "alignment_file" in item: |
|
attn = self.get_attn_mask(item["alignment_file"]) |
|
|
|
|
|
|
|
|
|
if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: |
|
self.rescue_item_idx += 1 |
|
return self.load_data(self.rescue_item_idx) |
|
|
|
|
|
f0 = None |
|
if self.compute_f0: |
|
f0 = self.get_f0(idx)["f0"] |
|
|
|
sample = { |
|
"raw_text": raw_text, |
|
"token_ids": token_ids, |
|
"wav": wav, |
|
"pitch": f0, |
|
"attn": attn, |
|
"item_idx": item["audio_file"], |
|
"speaker_name": item["speaker_name"], |
|
"language_name": item["language"], |
|
"wav_file_name": os.path.basename(item["audio_file"]), |
|
} |
|
return sample |
|
|
|
@staticmethod |
|
def _compute_lengths(samples): |
|
new_samples = [] |
|
for item in samples: |
|
audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 |
|
text_lenght = len(item["text"]) |
|
item["audio_length"] = audio_length |
|
item["text_length"] = text_lenght |
|
new_samples += [item] |
|
return new_samples |
|
|
|
@staticmethod |
|
def filter_by_length(lengths: List[int], min_len: int, max_len: int): |
|
idxs = np.argsort(lengths) |
|
ignore_idx = [] |
|
keep_idx = [] |
|
for idx in idxs: |
|
length = lengths[idx] |
|
if length < min_len or length > max_len: |
|
ignore_idx.append(idx) |
|
else: |
|
keep_idx.append(idx) |
|
return ignore_idx, keep_idx |
|
|
|
@staticmethod |
|
def sort_by_length(samples: List[List]): |
|
audio_lengths = [s["audio_length"] for s in samples] |
|
idxs = np.argsort(audio_lengths) |
|
return idxs |
|
|
|
@staticmethod |
|
def create_buckets(samples, batch_group_size: int): |
|
assert batch_group_size > 0 |
|
for i in range(len(samples) // batch_group_size): |
|
offset = i * batch_group_size |
|
end_offset = offset + batch_group_size |
|
temp_items = samples[offset:end_offset] |
|
random.shuffle(temp_items) |
|
samples[offset:end_offset] = temp_items |
|
return samples |
|
|
|
@staticmethod |
|
def _select_samples_by_idx(idxs, samples): |
|
samples_new = [] |
|
for idx in idxs: |
|
samples_new.append(samples[idx]) |
|
return samples_new |
|
|
|
def preprocess_samples(self): |
|
r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length |
|
range. |
|
""" |
|
samples = self._compute_lengths(self.samples) |
|
|
|
|
|
text_lengths = [i["text_length"] for i in samples] |
|
audio_lengths = [i["audio_length"] for i in samples] |
|
text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) |
|
audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) |
|
keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) |
|
ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) |
|
|
|
samples = self._select_samples_by_idx(keep_idx, samples) |
|
|
|
sorted_idxs = self.sort_by_length(samples) |
|
|
|
if self.start_by_longest: |
|
longest_idxs = sorted_idxs[-1] |
|
sorted_idxs[-1] = sorted_idxs[0] |
|
sorted_idxs[0] = longest_idxs |
|
|
|
samples = self._select_samples_by_idx(sorted_idxs, samples) |
|
|
|
if len(samples) == 0: |
|
raise RuntimeError(" [!] No samples left") |
|
|
|
|
|
|
|
|
|
if self.batch_group_size > 0: |
|
samples = self.create_buckets(samples, self.batch_group_size) |
|
|
|
|
|
audio_lengths = [s["audio_length"] for s in samples] |
|
text_lengths = [s["text_length"] for s in samples] |
|
self.samples = samples |
|
|
|
if self.verbose: |
|
print(" | > Preprocessing samples") |
|
print(" | > Max text length: {}".format(np.max(text_lengths))) |
|
print(" | > Min text length: {}".format(np.min(text_lengths))) |
|
print(" | > Avg text length: {}".format(np.mean(text_lengths))) |
|
print(" | ") |
|
print(" | > Max audio length: {}".format(np.max(audio_lengths))) |
|
print(" | > Min audio length: {}".format(np.min(audio_lengths))) |
|
print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) |
|
print(f" | > Num. instances discarded samples: {len(ignore_idx)}") |
|
print(" | > Batch group size: {}.".format(self.batch_group_size)) |
|
|
|
@staticmethod |
|
def _sort_batch(batch, text_lengths): |
|
"""Sort the batch by the input text length for RNN efficiency. |
|
|
|
Args: |
|
batch (Dict): Batch returned by `__getitem__`. |
|
text_lengths (List[int]): Lengths of the input character sequences. |
|
""" |
|
text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) |
|
batch = [batch[idx] for idx in ids_sorted_decreasing] |
|
return batch, text_lengths, ids_sorted_decreasing |
|
|
|
def collate_fn(self, batch): |
|
r""" |
|
Perform preprocessing and create a final data batch: |
|
1. Sort batch instances by text-length |
|
2. Convert Audio signal to features. |
|
3. PAD sequences wrt r. |
|
4. Load to Torch. |
|
""" |
|
|
|
|
|
if isinstance(batch[0], collections.abc.Mapping): |
|
|
|
token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) |
|
|
|
|
|
batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) |
|
|
|
|
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]} |
|
|
|
|
|
if self.language_id_mapping is not None: |
|
language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] |
|
else: |
|
language_ids = None |
|
|
|
if self.d_vector_mapping is not None: |
|
wav_files_names = list(batch["wav_file_name"]) |
|
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] |
|
else: |
|
d_vectors = None |
|
|
|
|
|
if self.speaker_id_mapping: |
|
speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] |
|
else: |
|
speaker_ids = None |
|
|
|
mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] |
|
|
|
mel_lengths = [m.shape[1] for m in mel] |
|
|
|
|
|
mel_lengths_adjusted = [ |
|
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) |
|
if m.shape[1] % self.outputs_per_step |
|
else m.shape[1] |
|
for m in mel |
|
] |
|
|
|
|
|
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] |
|
|
|
|
|
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) |
|
|
|
|
|
token_ids = prepare_data(batch["token_ids"]).astype(np.int32) |
|
|
|
|
|
mel = prepare_tensor(mel, self.outputs_per_step) |
|
|
|
|
|
mel = mel.transpose(0, 2, 1) |
|
|
|
|
|
token_ids_lengths = torch.LongTensor(token_ids_lengths) |
|
token_ids = torch.LongTensor(token_ids) |
|
mel = torch.FloatTensor(mel).contiguous() |
|
mel_lengths = torch.LongTensor(mel_lengths) |
|
stop_targets = torch.FloatTensor(stop_targets) |
|
|
|
|
|
if d_vectors is not None: |
|
d_vectors = torch.FloatTensor(d_vectors) |
|
|
|
if speaker_ids is not None: |
|
speaker_ids = torch.LongTensor(speaker_ids) |
|
|
|
if language_ids is not None: |
|
language_ids = torch.LongTensor(language_ids) |
|
|
|
|
|
linear = None |
|
if self.compute_linear_spec: |
|
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] |
|
linear = prepare_tensor(linear, self.outputs_per_step) |
|
linear = linear.transpose(0, 2, 1) |
|
assert mel.shape[1] == linear.shape[1] |
|
linear = torch.FloatTensor(linear).contiguous() |
|
|
|
|
|
wav_padded = None |
|
if self.return_wav: |
|
wav_lengths = [w.shape[0] for w in batch["wav"]] |
|
max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length |
|
wav_lengths = torch.LongTensor(wav_lengths) |
|
wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) |
|
for i, w in enumerate(batch["wav"]): |
|
mel_length = mel_lengths_adjusted[i] |
|
w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") |
|
w = w[: mel_length * self.ap.hop_length] |
|
wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) |
|
wav_padded.transpose_(1, 2) |
|
|
|
|
|
if self.compute_f0: |
|
pitch = prepare_data(batch["pitch"]) |
|
assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" |
|
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() |
|
else: |
|
pitch = None |
|
|
|
|
|
attns = None |
|
if batch["attn"][0] is not None: |
|
attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] |
|
for idx, attn in enumerate(attns): |
|
pad2 = mel.shape[1] - attn.shape[1] |
|
pad1 = token_ids.shape[1] - attn.shape[0] |
|
assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" |
|
attn = np.pad(attn, [[0, pad1], [0, pad2]]) |
|
attns[idx] = attn |
|
attns = prepare_tensor(attns, self.outputs_per_step) |
|
attns = torch.FloatTensor(attns).unsqueeze(1) |
|
|
|
return { |
|
"token_id": token_ids, |
|
"token_id_lengths": token_ids_lengths, |
|
"speaker_names": batch["speaker_name"], |
|
"linear": linear, |
|
"mel": mel, |
|
"mel_lengths": mel_lengths, |
|
"stop_targets": stop_targets, |
|
"item_idxs": batch["item_idx"], |
|
"d_vectors": d_vectors, |
|
"speaker_ids": speaker_ids, |
|
"attns": attns, |
|
"waveform": wav_padded, |
|
"raw_text": batch["raw_text"], |
|
"pitch": pitch, |
|
"language_ids": language_ids, |
|
} |
|
|
|
raise TypeError( |
|
( |
|
"batch must contain tensors, numbers, dicts or lists;\ |
|
found {}".format( |
|
type(batch[0]) |
|
) |
|
) |
|
) |
|
|
|
|
|
class PhonemeDataset(Dataset): |
|
"""Phoneme Dataset for converting input text to phonemes and then token IDs |
|
|
|
At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data |
|
loading latency. If `cache_path` is already present, it skips the pre-computation. |
|
|
|
Args: |
|
samples (Union[List[List], List[Dict]]): |
|
List of samples. Each sample is a list or a dict. |
|
|
|
tokenizer (TTSTokenizer): |
|
Tokenizer to convert input text to phonemes. |
|
|
|
cache_path (str): |
|
Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. |
|
|
|
precompute_num_workers (int): |
|
Number of workers used for pre-computing the phonemes. Defaults to 0. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
samples: Union[List[Dict], List[List]], |
|
tokenizer: "TTSTokenizer", |
|
cache_path: str, |
|
precompute_num_workers=0, |
|
): |
|
self.samples = samples |
|
self.tokenizer = tokenizer |
|
self.cache_path = cache_path |
|
if cache_path is not None and not os.path.exists(cache_path): |
|
os.makedirs(cache_path) |
|
self.precompute(precompute_num_workers) |
|
|
|
def __getitem__(self, index): |
|
item = self.samples[index] |
|
ids = self.compute_or_load(item["audio_file"], item["text"]) |
|
ph_hat = self.tokenizer.ids_to_text(ids) |
|
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def compute_or_load(self, wav_file, text): |
|
"""Compute phonemes for the given text. |
|
|
|
If the phonemes are already cached, load them from cache. |
|
""" |
|
file_name = os.path.splitext(os.path.basename(wav_file))[0] |
|
file_ext = "_phoneme.npy" |
|
cache_path = os.path.join(self.cache_path, file_name + file_ext) |
|
try: |
|
ids = np.load(cache_path) |
|
except FileNotFoundError: |
|
ids = self.tokenizer.text_to_ids(text) |
|
np.save(cache_path, ids) |
|
return ids |
|
|
|
def get_pad_id(self): |
|
"""Get pad token ID for sequence padding""" |
|
return self.tokenizer.pad_id |
|
|
|
def precompute(self, num_workers=1): |
|
"""Precompute phonemes for all samples. |
|
|
|
We use pytorch dataloader because we are lazy. |
|
""" |
|
print("[*] Pre-computing phonemes...") |
|
with tqdm.tqdm(total=len(self)) as pbar: |
|
batch_size = num_workers if num_workers > 0 else 1 |
|
dataloder = torch.utils.data.DataLoader( |
|
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn |
|
) |
|
for _ in dataloder: |
|
pbar.update(batch_size) |
|
|
|
def collate_fn(self, batch): |
|
ids = [item["token_ids"] for item in batch] |
|
ids_lens = [item["token_ids_len"] for item in batch] |
|
texts = [item["text"] for item in batch] |
|
texts_hat = [item["ph_hat"] for item in batch] |
|
ids_lens_max = max(ids_lens) |
|
ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) |
|
for i, ids_len in enumerate(ids_lens): |
|
ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) |
|
return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} |
|
|
|
def print_logs(self, level: int = 0) -> None: |
|
indent = "\t" * level |
|
print("\n") |
|
print(f"{indent}> PhonemeDataset ") |
|
print(f"{indent}| > Tokenizer:") |
|
self.tokenizer.print_logs(level + 1) |
|
print(f"{indent}| > Number of instances : {len(self.samples)}") |
|
|
|
|
|
class F0Dataset: |
|
"""F0 Dataset for computing F0 from wav files in CPU |
|
|
|
Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It |
|
also computes the mean and std of F0 values if `normalize_f0` is True. |
|
|
|
Args: |
|
samples (Union[List[List], List[Dict]]): |
|
List of samples. Each sample is a list or a dict. |
|
|
|
ap (AudioProcessor): |
|
AudioProcessor to compute F0 from wav files. |
|
|
|
cache_path (str): |
|
Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. |
|
Defaults to None. |
|
|
|
precompute_num_workers (int): |
|
Number of workers used for pre-computing the F0 values. Defaults to 0. |
|
|
|
normalize_f0 (bool): |
|
Whether to normalize F0 values by mean and std. Defaults to True. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
samples: Union[List[List], List[Dict]], |
|
ap: "AudioProcessor", |
|
verbose=False, |
|
cache_path: str = None, |
|
precompute_num_workers=0, |
|
normalize_f0=True, |
|
): |
|
self.samples = samples |
|
self.ap = ap |
|
self.verbose = verbose |
|
self.cache_path = cache_path |
|
self.normalize_f0 = normalize_f0 |
|
self.pad_id = 0.0 |
|
self.mean = None |
|
self.std = None |
|
if cache_path is not None and not os.path.exists(cache_path): |
|
os.makedirs(cache_path) |
|
self.precompute(precompute_num_workers) |
|
if normalize_f0: |
|
self.load_stats(cache_path) |
|
|
|
def __getitem__(self, idx): |
|
item = self.samples[idx] |
|
f0 = self.compute_or_load(item["audio_file"]) |
|
if self.normalize_f0: |
|
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" |
|
f0 = self.normalize(f0) |
|
return {"audio_file": item["audio_file"], "f0": f0} |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def precompute(self, num_workers=0): |
|
print("[*] Pre-computing F0s...") |
|
with tqdm.tqdm(total=len(self)) as pbar: |
|
batch_size = num_workers if num_workers > 0 else 1 |
|
|
|
normalize_f0 = self.normalize_f0 |
|
self.normalize_f0 = False |
|
dataloder = torch.utils.data.DataLoader( |
|
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn |
|
) |
|
computed_data = [] |
|
for batch in dataloder: |
|
f0 = batch["f0"] |
|
computed_data.append(f for f in f0) |
|
pbar.update(batch_size) |
|
self.normalize_f0 = normalize_f0 |
|
|
|
if self.normalize_f0: |
|
computed_data = [tensor for batch in computed_data for tensor in batch] |
|
pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) |
|
pitch_stats = {"mean": pitch_mean, "std": pitch_std} |
|
np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) |
|
|
|
def get_pad_id(self): |
|
return self.pad_id |
|
|
|
@staticmethod |
|
def create_pitch_file_path(wav_file, cache_path): |
|
file_name = os.path.splitext(os.path.basename(wav_file))[0] |
|
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") |
|
return pitch_file |
|
|
|
@staticmethod |
|
def _compute_and_save_pitch(ap, wav_file, pitch_file=None): |
|
wav = ap.load_wav(wav_file) |
|
pitch = ap.compute_f0(wav) |
|
if pitch_file: |
|
np.save(pitch_file, pitch) |
|
return pitch |
|
|
|
@staticmethod |
|
def compute_pitch_stats(pitch_vecs): |
|
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) |
|
mean, std = np.mean(nonzeros), np.std(nonzeros) |
|
return mean, std |
|
|
|
def load_stats(self, cache_path): |
|
stats_path = os.path.join(cache_path, "pitch_stats.npy") |
|
stats = np.load(stats_path, allow_pickle=True).item() |
|
self.mean = stats["mean"].astype(np.float32) |
|
self.std = stats["std"].astype(np.float32) |
|
|
|
def normalize(self, pitch): |
|
zero_idxs = np.where(pitch == 0.0)[0] |
|
pitch = pitch - self.mean |
|
pitch = pitch / self.std |
|
pitch[zero_idxs] = 0.0 |
|
return pitch |
|
|
|
def denormalize(self, pitch): |
|
zero_idxs = np.where(pitch == 0.0)[0] |
|
pitch *= self.std |
|
pitch += self.mean |
|
pitch[zero_idxs] = 0.0 |
|
return pitch |
|
|
|
def compute_or_load(self, wav_file): |
|
""" |
|
compute pitch and return a numpy array of pitch values |
|
""" |
|
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) |
|
if not os.path.exists(pitch_file): |
|
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) |
|
else: |
|
pitch = np.load(pitch_file) |
|
return pitch.astype(np.float32) |
|
|
|
def collate_fn(self, batch): |
|
audio_file = [item["audio_file"] for item in batch] |
|
f0s = [item["f0"] for item in batch] |
|
f0_lens = [len(item["f0"]) for item in batch] |
|
f0_lens_max = max(f0_lens) |
|
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) |
|
for i, f0_len in enumerate(f0_lens): |
|
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) |
|
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens} |
|
|
|
def print_logs(self, level: int = 0) -> None: |
|
indent = "\t" * level |
|
print("\n") |
|
print(f"{indent}> F0Dataset ") |
|
print(f"{indent}| > Number of instances : {len(self.samples)}") |
|
|