File size: 5,022 Bytes
6127b48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import random
import torch
from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV
class EncoderDataset(Dataset):
def __init__(
self,
config,
ap,
meta_data,
voice_len=1.6,
num_classes_in_batch=64,
num_utter_per_class=10,
verbose=False,
augmentation_config=None,
use_torch_spec=None,
):
"""
Args:
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
super().__init__()
self.config = config
self.items = meta_data
self.sample_rate = ap.sample_rate
self.seq_len = int(voice_len * self.sample_rate)
self.num_utter_per_class = num_utter_per_class
self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec
self.classes, self.items = self.__parse_items()
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
# Data Augmentation
self.augmentator = None
self.gaussian_augmentation_config = None
if augmentation_config:
self.data_augmentation_p = augmentation_config["p"]
if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
self.augmentator = AugmentWAV(ap, augmentation_config)
if "gaussian" in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Classes per Batch: {num_classes_in_batch}")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num Classes: {len(self.classes)}")
print(f" | > Classes: {self.classes}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio
def __parse_items(self):
class_to_utters = {}
for item in self.items:
path_ = item["audio_file"]
class_name = item[self.config.class_name_key]
if class_name in class_to_utters.keys():
class_to_utters[class_name].append(path_)
else:
class_to_utters[class_name] = [
path_,
]
# skip classes with number of samples >= self.num_utter_per_class
class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class}
classes = list(class_to_utters.keys())
classes.sort()
new_items = []
for item in self.items:
path_ = item["audio_file"]
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
# ignore filtered classes
if class_name not in classes:
continue
# ignore small audios
if self.load_wav(path_).shape[0] - self.seq_len <= 0:
continue
new_items.append({"wav_file_path": path_, "class_name": class_name})
return classes, new_items
def __len__(self):
return len(self.items)
def get_num_classes(self):
return len(self.classes)
def get_class_list(self):
return self.classes
def set_classes(self, classes):
self.classes = classes
self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
def get_map_classid_to_classname(self):
return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
def __getitem__(self, idx):
return self.items[idx]
def collate_fn(self, batch):
# get the batch class_ids
labels = []
feats = []
for item in batch:
utter_path = item["wav_file_path"]
class_name = item["class_name"]
# get classid
class_id = self.classname_to_classid[class_name]
# load wav file
wav = self.load_wav(utter_path)
offset = random.randint(0, wav.shape[0] - self.seq_len)
wav = wav[offset : offset + self.seq_len]
if self.augmentator is not None and self.data_augmentation_p:
if random.random() < self.data_augmentation_p:
wav = self.augmentator.apply_one(wav)
if not self.use_torch_spec:
mel = self.ap.melspectrogram(wav)
feats.append(torch.FloatTensor(mel))
else:
feats.append(torch.FloatTensor(wav))
labels.append(class_id)
feats = torch.stack(feats)
labels = torch.LongTensor(labels)
return feats, labels
|