Spaces:
Runtime error
Runtime error
akhaliq3
commited on
Commit
•
24829a1
1
Parent(s):
f01b5b9
spaces demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- 8230_00000.mp3 +0 -0
- demo_cli.py +225 -0
- demo_toolbox.py +43 -0
- encoder/__init__.py +0 -0
- encoder/audio.py +117 -0
- encoder/config.py +45 -0
- encoder/data_objects/__init__.py +2 -0
- encoder/data_objects/random_cycler.py +37 -0
- encoder/data_objects/speaker.py +40 -0
- encoder/data_objects/speaker_batch.py +12 -0
- encoder/data_objects/speaker_verification_dataset.py +56 -0
- encoder/data_objects/utterance.py +26 -0
- encoder/inference.py +178 -0
- encoder/model.py +135 -0
- encoder/params_data.py +29 -0
- encoder/params_model.py +11 -0
- encoder/preprocess.py +175 -0
- encoder/train.py +123 -0
- encoder/visualizations.py +178 -0
- encoder_preprocess.py +70 -0
- encoder_train.py +47 -0
- requirements.txt +16 -0
- synthesizer/LICENSE.txt +24 -0
- synthesizer/__init__.py +1 -0
- synthesizer/audio.py +206 -0
- synthesizer/hparams.py +92 -0
- synthesizer/inference.py +171 -0
- synthesizer/models/tacotron.py +519 -0
- synthesizer/preprocess.py +259 -0
- synthesizer/synthesize.py +97 -0
- synthesizer/synthesizer_dataset.py +92 -0
- synthesizer/train.py +269 -0
- synthesizer/utils/__init__.py +45 -0
- synthesizer/utils/_cmudict.py +62 -0
- synthesizer/utils/cleaners.py +88 -0
- synthesizer/utils/numbers.py +68 -0
- synthesizer/utils/plot.py +76 -0
- synthesizer/utils/symbols.py +17 -0
- synthesizer/utils/text.py +74 -0
- synthesizer_preprocess_audio.py +59 -0
- synthesizer_preprocess_embeds.py +25 -0
- synthesizer_train.py +35 -0
- toolbox/__init__.py +357 -0
- toolbox/ui.py +611 -0
- toolbox/utterance.py +5 -0
- utils/__init__.py +0 -0
- utils/argutils.py +40 -0
- utils/logmmse.py +247 -0
- utils/modelutils.py +17 -0
- utils/profiler.py +45 -0
8230_00000.mp3
ADDED
Binary file (16.1 kB). View file
|
|
demo_cli.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.params_model import model_embedding_size as speaker_embedding_size
|
2 |
+
from utils.argutils import print_args
|
3 |
+
from utils.modelutils import check_model_paths
|
4 |
+
from synthesizer.inference import Synthesizer
|
5 |
+
from encoder import inference as encoder
|
6 |
+
from vocoder import inference as vocoder
|
7 |
+
from pathlib import Path
|
8 |
+
import numpy as np
|
9 |
+
import soundfile as sf
|
10 |
+
import librosa
|
11 |
+
import argparse
|
12 |
+
import torch
|
13 |
+
import sys
|
14 |
+
import os
|
15 |
+
from audioread.exceptions import NoBackendError
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
## Info & args
|
19 |
+
parser = argparse.ArgumentParser(
|
20 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
21 |
+
)
|
22 |
+
parser.add_argument("-e", "--enc_model_fpath", type=Path,
|
23 |
+
default="encoder/saved_models/pretrained.pt",
|
24 |
+
help="Path to a saved encoder")
|
25 |
+
parser.add_argument("-s", "--syn_model_fpath", type=Path,
|
26 |
+
default="synthesizer/saved_models/pretrained/pretrained.pt",
|
27 |
+
help="Path to a saved synthesizer")
|
28 |
+
parser.add_argument("-v", "--voc_model_fpath", type=Path,
|
29 |
+
default="vocoder/saved_models/pretrained/pretrained.pt",
|
30 |
+
help="Path to a saved vocoder")
|
31 |
+
parser.add_argument("--cpu", action="store_true", help=\
|
32 |
+
"If True, processing is done on CPU, even when a GPU is available.")
|
33 |
+
parser.add_argument("--no_sound", action="store_true", help=\
|
34 |
+
"If True, audio won't be played.")
|
35 |
+
parser.add_argument("--seed", type=int, default=None, help=\
|
36 |
+
"Optional random number seed value to make toolbox deterministic.")
|
37 |
+
parser.add_argument("--no_mp3_support", action="store_true", help=\
|
38 |
+
"If True, disallows loading mp3 files to prevent audioread errors when ffmpeg is not installed.")
|
39 |
+
args = parser.parse_args()
|
40 |
+
print_args(args, parser)
|
41 |
+
if not args.no_sound:
|
42 |
+
import sounddevice as sd
|
43 |
+
|
44 |
+
if args.cpu:
|
45 |
+
# Hide GPUs from Pytorch to force CPU processing
|
46 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
47 |
+
|
48 |
+
if not args.no_mp3_support:
|
49 |
+
try:
|
50 |
+
librosa.load("samples/1320_00000.mp3")
|
51 |
+
except NoBackendError:
|
52 |
+
print("Librosa will be unable to open mp3 files if additional software is not installed.\n"
|
53 |
+
"Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.")
|
54 |
+
exit(-1)
|
55 |
+
|
56 |
+
print("Running a test of your configuration...\n")
|
57 |
+
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
device_id = torch.cuda.current_device()
|
60 |
+
gpu_properties = torch.cuda.get_device_properties(device_id)
|
61 |
+
## Print some environment information (for debugging purposes)
|
62 |
+
print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
|
63 |
+
"%.1fGb total memory.\n" %
|
64 |
+
(torch.cuda.device_count(),
|
65 |
+
device_id,
|
66 |
+
gpu_properties.name,
|
67 |
+
gpu_properties.major,
|
68 |
+
gpu_properties.minor,
|
69 |
+
gpu_properties.total_memory / 1e9))
|
70 |
+
else:
|
71 |
+
print("Using CPU for inference.\n")
|
72 |
+
|
73 |
+
## Remind the user to download pretrained models if needed
|
74 |
+
check_model_paths(encoder_path=args.enc_model_fpath,
|
75 |
+
synthesizer_path=args.syn_model_fpath,
|
76 |
+
vocoder_path=args.voc_model_fpath)
|
77 |
+
|
78 |
+
## Load the models one by one.
|
79 |
+
print("Preparing the encoder, the synthesizer and the vocoder...")
|
80 |
+
encoder.load_model(args.enc_model_fpath)
|
81 |
+
synthesizer = Synthesizer(args.syn_model_fpath)
|
82 |
+
vocoder.load_model(args.voc_model_fpath)
|
83 |
+
|
84 |
+
|
85 |
+
## Run a test
|
86 |
+
print("Testing your configuration with small inputs.")
|
87 |
+
# Forward an audio waveform of zeroes that lasts 1 second. Notice how we can get the encoder's
|
88 |
+
# sampling rate, which may differ.
|
89 |
+
# If you're unfamiliar with digital audio, know that it is encoded as an array of floats
|
90 |
+
# (or sometimes integers, but mostly floats in this projects) ranging from -1 to 1.
|
91 |
+
# The sampling rate is the number of values (samples) recorded per second, it is set to
|
92 |
+
# 16000 for the encoder. Creating an array of length <sampling_rate> will always correspond
|
93 |
+
# to an audio of 1 second.
|
94 |
+
print("\tTesting the encoder...")
|
95 |
+
encoder.embed_utterance(np.zeros(encoder.sampling_rate))
|
96 |
+
|
97 |
+
# Create a dummy embedding. You would normally use the embedding that encoder.embed_utterance
|
98 |
+
# returns, but here we're going to make one ourselves just for the sake of showing that it's
|
99 |
+
# possible.
|
100 |
+
embed = np.random.rand(speaker_embedding_size)
|
101 |
+
# Embeddings are L2-normalized (this isn't important here, but if you want to make your own
|
102 |
+
# embeddings it will be).
|
103 |
+
embed /= np.linalg.norm(embed)
|
104 |
+
# The synthesizer can handle multiple inputs with batching. Let's create another embedding to
|
105 |
+
# illustrate that
|
106 |
+
embeds = [embed, np.zeros(speaker_embedding_size)]
|
107 |
+
texts = ["test 1", "test 2"]
|
108 |
+
print("\tTesting the synthesizer... (loading the model will output a lot of text)")
|
109 |
+
mels = synthesizer.synthesize_spectrograms(texts, embeds)
|
110 |
+
|
111 |
+
# The vocoder synthesizes one waveform at a time, but it's more efficient for long ones. We
|
112 |
+
# can concatenate the mel spectrograms to a single one.
|
113 |
+
mel = np.concatenate(mels, axis=1)
|
114 |
+
# The vocoder can take a callback function to display the generation. More on that later. For
|
115 |
+
# now we'll simply hide it like this:
|
116 |
+
no_action = lambda *args: None
|
117 |
+
print("\tTesting the vocoder...")
|
118 |
+
# For the sake of making this test short, we'll pass a short target length. The target length
|
119 |
+
# is the length of the wav segments that are processed in parallel. E.g. for audio sampled
|
120 |
+
# at 16000 Hertz, a target length of 8000 means that the target audio will be cut in chunks of
|
121 |
+
# 0.5 seconds which will all be generated together. The parameters here are absurdly short, and
|
122 |
+
# that has a detrimental effect on the quality of the audio. The default parameters are
|
123 |
+
# recommended in general.
|
124 |
+
vocoder.infer_waveform(mel, target=200, overlap=50, progress_callback=no_action)
|
125 |
+
|
126 |
+
print("All test passed! You can now synthesize speech.\n\n")
|
127 |
+
|
128 |
+
|
129 |
+
## Interactive speech generation
|
130 |
+
print("This is a GUI-less example of interface to SV2TTS. The purpose of this script is to "
|
131 |
+
"show how you can interface this project easily with your own. See the source code for "
|
132 |
+
"an explanation of what is happening.\n")
|
133 |
+
|
134 |
+
print("Interactive generation loop")
|
135 |
+
num_generated = 0
|
136 |
+
while True:
|
137 |
+
try:
|
138 |
+
# Get the reference audio filepath
|
139 |
+
message = "Reference voice: enter an audio filepath of a voice to be cloned (mp3, " \
|
140 |
+
"wav, m4a, flac, ...):\n"
|
141 |
+
in_fpath = Path(input(message).replace("\"", "").replace("\'", ""))
|
142 |
+
|
143 |
+
if in_fpath.suffix.lower() == ".mp3" and args.no_mp3_support:
|
144 |
+
print("Can't Use mp3 files please try again:")
|
145 |
+
continue
|
146 |
+
## Computing the embedding
|
147 |
+
# First, we load the wav using the function that the speaker encoder provides. This is
|
148 |
+
# important: there is preprocessing that must be applied.
|
149 |
+
|
150 |
+
# The following two methods are equivalent:
|
151 |
+
# - Directly load from the filepath:
|
152 |
+
preprocessed_wav = encoder.preprocess_wav(in_fpath)
|
153 |
+
# - If the wav is already loaded:
|
154 |
+
original_wav, sampling_rate = librosa.load(str(in_fpath))
|
155 |
+
preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate)
|
156 |
+
print("Loaded file succesfully")
|
157 |
+
|
158 |
+
# Then we derive the embedding. There are many functions and parameters that the
|
159 |
+
# speaker encoder interfaces. These are mostly for in-depth research. You will typically
|
160 |
+
# only use this function (with its default parameters):
|
161 |
+
embed = encoder.embed_utterance(preprocessed_wav)
|
162 |
+
print("Created the embedding")
|
163 |
+
|
164 |
+
|
165 |
+
## Generating the spectrogram
|
166 |
+
text = input("Write a sentence (+-20 words) to be synthesized:\n")
|
167 |
+
|
168 |
+
# If seed is specified, reset torch seed and force synthesizer reload
|
169 |
+
if args.seed is not None:
|
170 |
+
torch.manual_seed(args.seed)
|
171 |
+
synthesizer = Synthesizer(args.syn_model_fpath)
|
172 |
+
|
173 |
+
# The synthesizer works in batch, so you need to put your data in a list or numpy array
|
174 |
+
texts = [text]
|
175 |
+
embeds = [embed]
|
176 |
+
# If you know what the attention layer alignments are, you can retrieve them here by
|
177 |
+
# passing return_alignments=True
|
178 |
+
specs = synthesizer.synthesize_spectrograms(texts, embeds)
|
179 |
+
spec = specs[0]
|
180 |
+
print("Created the mel spectrogram")
|
181 |
+
|
182 |
+
|
183 |
+
## Generating the waveform
|
184 |
+
print("Synthesizing the waveform:")
|
185 |
+
|
186 |
+
# If seed is specified, reset torch seed and reload vocoder
|
187 |
+
if args.seed is not None:
|
188 |
+
torch.manual_seed(args.seed)
|
189 |
+
vocoder.load_model(args.voc_model_fpath)
|
190 |
+
|
191 |
+
# Synthesizing the waveform is fairly straightforward. Remember that the longer the
|
192 |
+
# spectrogram, the more time-efficient the vocoder.
|
193 |
+
generated_wav = vocoder.infer_waveform(spec)
|
194 |
+
|
195 |
+
|
196 |
+
## Post-generation
|
197 |
+
# There's a bug with sounddevice that makes the audio cut one second earlier, so we
|
198 |
+
# pad it.
|
199 |
+
generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant")
|
200 |
+
|
201 |
+
# Trim excess silences to compensate for gaps in spectrograms (issue #53)
|
202 |
+
generated_wav = encoder.preprocess_wav(generated_wav)
|
203 |
+
|
204 |
+
# Play the audio (non-blocking)
|
205 |
+
if not args.no_sound:
|
206 |
+
try:
|
207 |
+
sd.stop()
|
208 |
+
sd.play(generated_wav, synthesizer.sample_rate)
|
209 |
+
except sd.PortAudioError as e:
|
210 |
+
print("\nCaught exception: %s" % repr(e))
|
211 |
+
print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
|
212 |
+
except:
|
213 |
+
raise
|
214 |
+
|
215 |
+
# Save it on the disk
|
216 |
+
filename = "demo_output_%02d.wav" % num_generated
|
217 |
+
print(generated_wav.dtype)
|
218 |
+
sf.write(filename, generated_wav.astype(np.float32), synthesizer.sample_rate)
|
219 |
+
num_generated += 1
|
220 |
+
print("\nSaved output as %s\n\n" % filename)
|
221 |
+
|
222 |
+
|
223 |
+
except Exception as e:
|
224 |
+
print("Caught exception: %s" % repr(e))
|
225 |
+
print("Restarting\n")
|
demo_toolbox.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from toolbox import Toolbox
|
3 |
+
from utils.argutils import print_args
|
4 |
+
from utils.modelutils import check_model_paths
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
parser = argparse.ArgumentParser(
|
11 |
+
description="Runs the toolbox",
|
12 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
13 |
+
)
|
14 |
+
|
15 |
+
parser.add_argument("-d", "--datasets_root", type=Path, help= \
|
16 |
+
"Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
|
17 |
+
"supported datasets.", default=None)
|
18 |
+
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
|
19 |
+
help="Directory containing saved encoder models")
|
20 |
+
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
|
21 |
+
help="Directory containing saved synthesizer models")
|
22 |
+
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
|
23 |
+
help="Directory containing saved vocoder models")
|
24 |
+
parser.add_argument("--cpu", action="store_true", help=\
|
25 |
+
"If True, processing is done on CPU, even when a GPU is available.")
|
26 |
+
parser.add_argument("--seed", type=int, default=None, help=\
|
27 |
+
"Optional random number seed value to make toolbox deterministic.")
|
28 |
+
parser.add_argument("--no_mp3_support", action="store_true", help=\
|
29 |
+
"If True, no mp3 files are allowed.")
|
30 |
+
args = parser.parse_args()
|
31 |
+
print_args(args, parser)
|
32 |
+
|
33 |
+
if args.cpu:
|
34 |
+
# Hide GPUs from Pytorch to force CPU processing
|
35 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
36 |
+
del args.cpu
|
37 |
+
|
38 |
+
## Remind the user to download pretrained models if needed
|
39 |
+
check_model_paths(encoder_path=args.enc_models_dir, synthesizer_path=args.syn_models_dir,
|
40 |
+
vocoder_path=args.voc_models_dir)
|
41 |
+
|
42 |
+
# Launch the toolbox
|
43 |
+
Toolbox(**vars(args))
|
encoder/__init__.py
ADDED
File without changes
|
encoder/audio.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.ndimage.morphology import binary_dilation
|
2 |
+
from encoder.params_data import *
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
from warnings import warn
|
6 |
+
import numpy as np
|
7 |
+
import librosa
|
8 |
+
import struct
|
9 |
+
|
10 |
+
try:
|
11 |
+
import webrtcvad
|
12 |
+
except:
|
13 |
+
warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
|
14 |
+
webrtcvad=None
|
15 |
+
|
16 |
+
int16_max = (2 ** 15) - 1
|
17 |
+
|
18 |
+
|
19 |
+
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
|
20 |
+
source_sr: Optional[int] = None,
|
21 |
+
normalize: Optional[bool] = True,
|
22 |
+
trim_silence: Optional[bool] = True):
|
23 |
+
"""
|
24 |
+
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
25 |
+
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
26 |
+
|
27 |
+
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
28 |
+
just .wav), either the waveform as a numpy array of floats.
|
29 |
+
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
30 |
+
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
31 |
+
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
32 |
+
this argument will be ignored.
|
33 |
+
"""
|
34 |
+
# Load the wav from disk if needed
|
35 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
36 |
+
wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
|
37 |
+
else:
|
38 |
+
wav = fpath_or_wav
|
39 |
+
|
40 |
+
# Resample the wav if needed
|
41 |
+
if source_sr is not None and source_sr != sampling_rate:
|
42 |
+
wav = librosa.resample(wav, source_sr, sampling_rate)
|
43 |
+
|
44 |
+
# Apply the preprocessing: normalize volume and shorten long silences
|
45 |
+
if normalize:
|
46 |
+
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
47 |
+
if webrtcvad and trim_silence:
|
48 |
+
wav = trim_long_silences(wav)
|
49 |
+
|
50 |
+
return wav
|
51 |
+
|
52 |
+
|
53 |
+
def wav_to_mel_spectrogram(wav):
|
54 |
+
"""
|
55 |
+
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
56 |
+
Note: this not a log-mel spectrogram.
|
57 |
+
"""
|
58 |
+
frames = librosa.feature.melspectrogram(
|
59 |
+
wav,
|
60 |
+
sampling_rate,
|
61 |
+
n_fft=int(sampling_rate * mel_window_length / 1000),
|
62 |
+
hop_length=int(sampling_rate * mel_window_step / 1000),
|
63 |
+
n_mels=mel_n_channels
|
64 |
+
)
|
65 |
+
return frames.astype(np.float32).T
|
66 |
+
|
67 |
+
|
68 |
+
def trim_long_silences(wav):
|
69 |
+
"""
|
70 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
71 |
+
threshold determined by the VAD parameters in params.py.
|
72 |
+
|
73 |
+
:param wav: the raw waveform as a numpy array of floats
|
74 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
75 |
+
"""
|
76 |
+
# Compute the voice detection window size
|
77 |
+
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
78 |
+
|
79 |
+
# Trim the end of the audio to have a multiple of the window size
|
80 |
+
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
81 |
+
|
82 |
+
# Convert the float waveform to 16-bit mono PCM
|
83 |
+
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
84 |
+
|
85 |
+
# Perform voice activation detection
|
86 |
+
voice_flags = []
|
87 |
+
vad = webrtcvad.Vad(mode=3)
|
88 |
+
for window_start in range(0, len(wav), samples_per_window):
|
89 |
+
window_end = window_start + samples_per_window
|
90 |
+
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
91 |
+
sample_rate=sampling_rate))
|
92 |
+
voice_flags = np.array(voice_flags)
|
93 |
+
|
94 |
+
# Smooth the voice detection with a moving average
|
95 |
+
def moving_average(array, width):
|
96 |
+
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
97 |
+
ret = np.cumsum(array_padded, dtype=float)
|
98 |
+
ret[width:] = ret[width:] - ret[:-width]
|
99 |
+
return ret[width - 1:] / width
|
100 |
+
|
101 |
+
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
102 |
+
audio_mask = np.round(audio_mask).astype(np.bool)
|
103 |
+
|
104 |
+
# Dilate the voiced regions
|
105 |
+
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
106 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
107 |
+
|
108 |
+
return wav[audio_mask == True]
|
109 |
+
|
110 |
+
|
111 |
+
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
112 |
+
if increase_only and decrease_only:
|
113 |
+
raise ValueError("Both increase only and decrease only are set")
|
114 |
+
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
|
115 |
+
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
116 |
+
return wav
|
117 |
+
return wav * (10 ** (dBFS_change / 20))
|
encoder/config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librispeech_datasets = {
|
2 |
+
"train": {
|
3 |
+
"clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
|
4 |
+
"other": ["LibriSpeech/train-other-500"]
|
5 |
+
},
|
6 |
+
"test": {
|
7 |
+
"clean": ["LibriSpeech/test-clean"],
|
8 |
+
"other": ["LibriSpeech/test-other"]
|
9 |
+
},
|
10 |
+
"dev": {
|
11 |
+
"clean": ["LibriSpeech/dev-clean"],
|
12 |
+
"other": ["LibriSpeech/dev-other"]
|
13 |
+
},
|
14 |
+
}
|
15 |
+
libritts_datasets = {
|
16 |
+
"train": {
|
17 |
+
"clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
|
18 |
+
"other": ["LibriTTS/train-other-500"]
|
19 |
+
},
|
20 |
+
"test": {
|
21 |
+
"clean": ["LibriTTS/test-clean"],
|
22 |
+
"other": ["LibriTTS/test-other"]
|
23 |
+
},
|
24 |
+
"dev": {
|
25 |
+
"clean": ["LibriTTS/dev-clean"],
|
26 |
+
"other": ["LibriTTS/dev-other"]
|
27 |
+
},
|
28 |
+
}
|
29 |
+
voxceleb_datasets = {
|
30 |
+
"voxceleb1" : {
|
31 |
+
"train": ["VoxCeleb1/wav"],
|
32 |
+
"test": ["VoxCeleb1/test_wav"]
|
33 |
+
},
|
34 |
+
"voxceleb2" : {
|
35 |
+
"train": ["VoxCeleb2/dev/aac"],
|
36 |
+
"test": ["VoxCeleb2/test_wav"]
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
other_datasets = [
|
41 |
+
"LJSpeech-1.1",
|
42 |
+
"VCTK-Corpus/wav48",
|
43 |
+
]
|
44 |
+
|
45 |
+
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
|
encoder/data_objects/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
|
encoder/data_objects/random_cycler.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
class RandomCycler:
|
4 |
+
"""
|
5 |
+
Creates an internal copy of a sequence and allows access to its items in a constrained random
|
6 |
+
order. For a source sequence of n items and one or several consecutive queries of a total
|
7 |
+
of m items, the following guarantees hold (one implies the other):
|
8 |
+
- Each item will be returned between m // n and ((m - 1) // n) + 1 times.
|
9 |
+
- Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, source):
|
13 |
+
if len(source) == 0:
|
14 |
+
raise Exception("Can't create RandomCycler from an empty collection")
|
15 |
+
self.all_items = list(source)
|
16 |
+
self.next_items = []
|
17 |
+
|
18 |
+
def sample(self, count: int):
|
19 |
+
shuffle = lambda l: random.sample(l, len(l))
|
20 |
+
|
21 |
+
out = []
|
22 |
+
while count > 0:
|
23 |
+
if count >= len(self.all_items):
|
24 |
+
out.extend(shuffle(list(self.all_items)))
|
25 |
+
count -= len(self.all_items)
|
26 |
+
continue
|
27 |
+
n = min(count, len(self.next_items))
|
28 |
+
out.extend(self.next_items[:n])
|
29 |
+
count -= n
|
30 |
+
self.next_items = self.next_items[n:]
|
31 |
+
if len(self.next_items) == 0:
|
32 |
+
self.next_items = shuffle(list(self.all_items))
|
33 |
+
return out
|
34 |
+
|
35 |
+
def __next__(self):
|
36 |
+
return self.sample(1)[0]
|
37 |
+
|
encoder/data_objects/speaker.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from encoder.data_objects.utterance import Utterance
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# Contains the set of utterances of a single speaker
|
6 |
+
class Speaker:
|
7 |
+
def __init__(self, root: Path):
|
8 |
+
self.root = root
|
9 |
+
self.name = root.name
|
10 |
+
self.utterances = None
|
11 |
+
self.utterance_cycler = None
|
12 |
+
|
13 |
+
def _load_utterances(self):
|
14 |
+
with self.root.joinpath("_sources.txt").open("r") as sources_file:
|
15 |
+
sources = [l.split(",") for l in sources_file]
|
16 |
+
sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
|
17 |
+
self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
|
18 |
+
self.utterance_cycler = RandomCycler(self.utterances)
|
19 |
+
|
20 |
+
def random_partial(self, count, n_frames):
|
21 |
+
"""
|
22 |
+
Samples a batch of <count> unique partial utterances from the disk in a way that all
|
23 |
+
utterances come up at least once every two cycles and in a random order every time.
|
24 |
+
|
25 |
+
:param count: The number of partial utterances to sample from the set of utterances from
|
26 |
+
that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
|
27 |
+
the number of utterances available.
|
28 |
+
:param n_frames: The number of frames in the partial utterance.
|
29 |
+
:return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
|
30 |
+
frames are the frames of the partial utterances and range is the range of the partial
|
31 |
+
utterance with regard to the complete utterance.
|
32 |
+
"""
|
33 |
+
if self.utterances is None:
|
34 |
+
self._load_utterances()
|
35 |
+
|
36 |
+
utterances = self.utterance_cycler.sample(count)
|
37 |
+
|
38 |
+
a = [(u,) + u.random_partial(n_frames) for u in utterances]
|
39 |
+
|
40 |
+
return a
|
encoder/data_objects/speaker_batch.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List
|
3 |
+
from encoder.data_objects.speaker import Speaker
|
4 |
+
|
5 |
+
class SpeakerBatch:
|
6 |
+
def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
|
7 |
+
self.speakers = speakers
|
8 |
+
self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
|
9 |
+
|
10 |
+
# Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
|
11 |
+
# 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
|
12 |
+
self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
|
encoder/data_objects/speaker_verification_dataset.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from encoder.data_objects.speaker_batch import SpeakerBatch
|
3 |
+
from encoder.data_objects.speaker import Speaker
|
4 |
+
from encoder.params_data import partials_n_frames
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# TODO: improve with a pool of speakers for data efficiency
|
9 |
+
|
10 |
+
class SpeakerVerificationDataset(Dataset):
|
11 |
+
def __init__(self, datasets_root: Path):
|
12 |
+
self.root = datasets_root
|
13 |
+
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
14 |
+
if len(speaker_dirs) == 0:
|
15 |
+
raise Exception("No speakers found. Make sure you are pointing to the directory "
|
16 |
+
"containing all preprocessed speaker directories.")
|
17 |
+
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
|
18 |
+
self.speaker_cycler = RandomCycler(self.speakers)
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return int(1e10)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
return next(self.speaker_cycler)
|
25 |
+
|
26 |
+
def get_logs(self):
|
27 |
+
log_string = ""
|
28 |
+
for log_fpath in self.root.glob("*.txt"):
|
29 |
+
with log_fpath.open("r") as log_file:
|
30 |
+
log_string += "".join(log_file.readlines())
|
31 |
+
return log_string
|
32 |
+
|
33 |
+
|
34 |
+
class SpeakerVerificationDataLoader(DataLoader):
|
35 |
+
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
|
36 |
+
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
|
37 |
+
worker_init_fn=None):
|
38 |
+
self.utterances_per_speaker = utterances_per_speaker
|
39 |
+
|
40 |
+
super().__init__(
|
41 |
+
dataset=dataset,
|
42 |
+
batch_size=speakers_per_batch,
|
43 |
+
shuffle=False,
|
44 |
+
sampler=sampler,
|
45 |
+
batch_sampler=batch_sampler,
|
46 |
+
num_workers=num_workers,
|
47 |
+
collate_fn=self.collate,
|
48 |
+
pin_memory=pin_memory,
|
49 |
+
drop_last=False,
|
50 |
+
timeout=timeout,
|
51 |
+
worker_init_fn=worker_init_fn
|
52 |
+
)
|
53 |
+
|
54 |
+
def collate(self, speakers):
|
55 |
+
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
|
56 |
+
|
encoder/data_objects/utterance.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class Utterance:
|
5 |
+
def __init__(self, frames_fpath, wave_fpath):
|
6 |
+
self.frames_fpath = frames_fpath
|
7 |
+
self.wave_fpath = wave_fpath
|
8 |
+
|
9 |
+
def get_frames(self):
|
10 |
+
return np.load(self.frames_fpath)
|
11 |
+
|
12 |
+
def random_partial(self, n_frames):
|
13 |
+
"""
|
14 |
+
Crops the frames into a partial utterance of n_frames
|
15 |
+
|
16 |
+
:param n_frames: The number of frames of the partial utterance
|
17 |
+
:return: the partial utterance frames and a tuple indicating the start and end of the
|
18 |
+
partial utterance in the complete utterance.
|
19 |
+
"""
|
20 |
+
frames = self.get_frames()
|
21 |
+
if frames.shape[0] == n_frames:
|
22 |
+
start = 0
|
23 |
+
else:
|
24 |
+
start = np.random.randint(0, frames.shape[0] - n_frames)
|
25 |
+
end = start + n_frames
|
26 |
+
return frames[start:end], (start, end)
|
encoder/inference.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.params_data import *
|
2 |
+
from encoder.model import SpeakerEncoder
|
3 |
+
from encoder.audio import preprocess_wav # We want to expose this function from here
|
4 |
+
from matplotlib import cm
|
5 |
+
from encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
_model = None # type: SpeakerEncoder
|
12 |
+
_device = None # type: torch.device
|
13 |
+
|
14 |
+
|
15 |
+
def load_model(weights_fpath: Path, device=None):
|
16 |
+
"""
|
17 |
+
Loads the model in memory. If this function is not explicitely called, it will be run on the
|
18 |
+
first call to embed_frames() with the default weights file.
|
19 |
+
|
20 |
+
:param weights_fpath: the path to saved model weights.
|
21 |
+
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
|
22 |
+
model will be loaded and will run on this device. Outputs will however always be on the cpu.
|
23 |
+
If None, will default to your GPU if it"s available, otherwise your CPU.
|
24 |
+
"""
|
25 |
+
# TODO: I think the slow loading of the encoder might have something to do with the device it
|
26 |
+
# was saved on. Worth investigating.
|
27 |
+
global _model, _device
|
28 |
+
if device is None:
|
29 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
elif isinstance(device, str):
|
31 |
+
_device = torch.device(device)
|
32 |
+
_model = SpeakerEncoder(_device, torch.device("cpu"))
|
33 |
+
checkpoint = torch.load(weights_fpath, _device)
|
34 |
+
_model.load_state_dict(checkpoint["model_state"])
|
35 |
+
_model.eval()
|
36 |
+
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
37 |
+
|
38 |
+
|
39 |
+
def is_loaded():
|
40 |
+
return _model is not None
|
41 |
+
|
42 |
+
|
43 |
+
def embed_frames_batch(frames_batch):
|
44 |
+
"""
|
45 |
+
Computes embeddings for a batch of mel spectrogram.
|
46 |
+
|
47 |
+
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
48 |
+
(batch_size, n_frames, n_channels)
|
49 |
+
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
50 |
+
"""
|
51 |
+
if _model is None:
|
52 |
+
raise Exception("Model was not loaded. Call load_model() before inference.")
|
53 |
+
|
54 |
+
frames = torch.from_numpy(frames_batch).to(_device)
|
55 |
+
embed = _model.forward(frames).detach().cpu().numpy()
|
56 |
+
return embed
|
57 |
+
|
58 |
+
|
59 |
+
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
60 |
+
min_pad_coverage=0.75, overlap=0.5):
|
61 |
+
"""
|
62 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
63 |
+
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
64 |
+
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
65 |
+
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
66 |
+
defined in params_data.py.
|
67 |
+
|
68 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
69 |
+
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
70 |
+
|
71 |
+
:param n_samples: the number of samples in the waveform
|
72 |
+
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
73 |
+
utterance
|
74 |
+
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
75 |
+
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
76 |
+
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
77 |
+
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
78 |
+
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
79 |
+
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
80 |
+
utterances are entirely disjoint.
|
81 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
82 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
83 |
+
utterances.
|
84 |
+
"""
|
85 |
+
assert 0 <= overlap < 1
|
86 |
+
assert 0 < min_pad_coverage <= 1
|
87 |
+
|
88 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
89 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
90 |
+
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
91 |
+
|
92 |
+
# Compute the slices
|
93 |
+
wav_slices, mel_slices = [], []
|
94 |
+
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
95 |
+
for i in range(0, steps, frame_step):
|
96 |
+
mel_range = np.array([i, i + partial_utterance_n_frames])
|
97 |
+
wav_range = mel_range * samples_per_frame
|
98 |
+
mel_slices.append(slice(*mel_range))
|
99 |
+
wav_slices.append(slice(*wav_range))
|
100 |
+
|
101 |
+
# Evaluate whether extra padding is warranted or not
|
102 |
+
last_wav_range = wav_slices[-1]
|
103 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
104 |
+
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
105 |
+
mel_slices = mel_slices[:-1]
|
106 |
+
wav_slices = wav_slices[:-1]
|
107 |
+
|
108 |
+
return wav_slices, mel_slices
|
109 |
+
|
110 |
+
|
111 |
+
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
|
112 |
+
"""
|
113 |
+
Computes an embedding for a single utterance.
|
114 |
+
|
115 |
+
# TODO: handle multiple wavs to benefit from batching on GPU
|
116 |
+
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
117 |
+
:param using_partials: if True, then the utterance is split in partial utterances of
|
118 |
+
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
119 |
+
normalized average. If False, the utterance is instead computed from feeding the entire
|
120 |
+
spectogram to the network.
|
121 |
+
:param return_partials: if True, the partial embeddings will also be returned along with the
|
122 |
+
wav slices that correspond to the partial embeddings.
|
123 |
+
:param kwargs: additional arguments to compute_partial_splits()
|
124 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
125 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
126 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
127 |
+
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
128 |
+
instead.
|
129 |
+
"""
|
130 |
+
# Process the entire utterance if not using partials
|
131 |
+
if not using_partials:
|
132 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
133 |
+
embed = embed_frames_batch(frames[None, ...])[0]
|
134 |
+
if return_partials:
|
135 |
+
return embed, None, None
|
136 |
+
return embed
|
137 |
+
|
138 |
+
# Compute where to split the utterance into partials and pad if necessary
|
139 |
+
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
|
140 |
+
max_wave_length = wave_slices[-1].stop
|
141 |
+
if max_wave_length >= len(wav):
|
142 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
143 |
+
|
144 |
+
# Split the utterance into partials
|
145 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
146 |
+
frames_batch = np.array([frames[s] for s in mel_slices])
|
147 |
+
partial_embeds = embed_frames_batch(frames_batch)
|
148 |
+
|
149 |
+
# Compute the utterance embedding from the partial embeddings
|
150 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
151 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
152 |
+
|
153 |
+
if return_partials:
|
154 |
+
return embed, partial_embeds, wave_slices
|
155 |
+
return embed
|
156 |
+
|
157 |
+
|
158 |
+
def embed_speaker(wavs, **kwargs):
|
159 |
+
raise NotImplemented()
|
160 |
+
|
161 |
+
|
162 |
+
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
|
163 |
+
if ax is None:
|
164 |
+
ax = plt.gca()
|
165 |
+
|
166 |
+
if shape is None:
|
167 |
+
height = int(np.sqrt(len(embed)))
|
168 |
+
shape = (height, -1)
|
169 |
+
embed = embed.reshape(shape)
|
170 |
+
|
171 |
+
cmap = cm.get_cmap()
|
172 |
+
mappable = ax.imshow(embed, cmap=cmap)
|
173 |
+
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
|
174 |
+
sm = cm.ScalarMappable(cmap=cmap)
|
175 |
+
sm.set_clim(*color_range)
|
176 |
+
|
177 |
+
ax.set_xticks([]), ax.set_yticks([])
|
178 |
+
ax.set_title(title)
|
encoder/model.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.params_model import *
|
2 |
+
from encoder.params_data import *
|
3 |
+
from scipy.interpolate import interp1d
|
4 |
+
from sklearn.metrics import roc_curve
|
5 |
+
from torch.nn.utils import clip_grad_norm_
|
6 |
+
from scipy.optimize import brentq
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class SpeakerEncoder(nn.Module):
|
13 |
+
def __init__(self, device, loss_device):
|
14 |
+
super().__init__()
|
15 |
+
self.loss_device = loss_device
|
16 |
+
|
17 |
+
# Network defition
|
18 |
+
self.lstm = nn.LSTM(input_size=mel_n_channels,
|
19 |
+
hidden_size=model_hidden_size,
|
20 |
+
num_layers=model_num_layers,
|
21 |
+
batch_first=True).to(device)
|
22 |
+
self.linear = nn.Linear(in_features=model_hidden_size,
|
23 |
+
out_features=model_embedding_size).to(device)
|
24 |
+
self.relu = torch.nn.ReLU().to(device)
|
25 |
+
|
26 |
+
# Cosine similarity scaling (with fixed initial parameter values)
|
27 |
+
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
28 |
+
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
29 |
+
|
30 |
+
# Loss
|
31 |
+
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
32 |
+
|
33 |
+
def do_gradient_ops(self):
|
34 |
+
# Gradient scale
|
35 |
+
self.similarity_weight.grad *= 0.01
|
36 |
+
self.similarity_bias.grad *= 0.01
|
37 |
+
|
38 |
+
# Gradient clipping
|
39 |
+
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
40 |
+
|
41 |
+
def forward(self, utterances, hidden_init=None):
|
42 |
+
"""
|
43 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
44 |
+
|
45 |
+
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
46 |
+
(batch_size, n_frames, n_channels)
|
47 |
+
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
48 |
+
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
49 |
+
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
50 |
+
"""
|
51 |
+
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
52 |
+
# and the final cell state.
|
53 |
+
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
54 |
+
|
55 |
+
# We take only the hidden state of the last layer
|
56 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
57 |
+
|
58 |
+
# L2-normalize it
|
59 |
+
embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
|
60 |
+
|
61 |
+
return embeds
|
62 |
+
|
63 |
+
def similarity_matrix(self, embeds):
|
64 |
+
"""
|
65 |
+
Computes the similarity matrix according the section 2.1 of GE2E.
|
66 |
+
|
67 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
68 |
+
utterances_per_speaker, embedding_size)
|
69 |
+
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
70 |
+
utterances_per_speaker, speakers_per_batch)
|
71 |
+
"""
|
72 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
73 |
+
|
74 |
+
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
75 |
+
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
76 |
+
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
|
77 |
+
|
78 |
+
# Exclusive centroids (1 per utterance)
|
79 |
+
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
|
80 |
+
centroids_excl /= (utterances_per_speaker - 1)
|
81 |
+
centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
|
82 |
+
|
83 |
+
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
84 |
+
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
85 |
+
# We vectorize the computation for efficiency.
|
86 |
+
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
|
87 |
+
speakers_per_batch).to(self.loss_device)
|
88 |
+
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
|
89 |
+
for j in range(speakers_per_batch):
|
90 |
+
mask = np.where(mask_matrix[j])[0]
|
91 |
+
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
92 |
+
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
93 |
+
|
94 |
+
## Even more vectorized version (slower maybe because of transpose)
|
95 |
+
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
96 |
+
# ).to(self.loss_device)
|
97 |
+
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
98 |
+
# mask = np.where(1 - eye)
|
99 |
+
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
100 |
+
# mask = np.where(eye)
|
101 |
+
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
102 |
+
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
103 |
+
|
104 |
+
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
105 |
+
return sim_matrix
|
106 |
+
|
107 |
+
def loss(self, embeds):
|
108 |
+
"""
|
109 |
+
Computes the softmax loss according the section 2.1 of GE2E.
|
110 |
+
|
111 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
112 |
+
utterances_per_speaker, embedding_size)
|
113 |
+
:return: the loss and the EER for this batch of embeddings.
|
114 |
+
"""
|
115 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
116 |
+
|
117 |
+
# Loss
|
118 |
+
sim_matrix = self.similarity_matrix(embeds)
|
119 |
+
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
|
120 |
+
speakers_per_batch))
|
121 |
+
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
122 |
+
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
123 |
+
loss = self.loss_fn(sim_matrix, target)
|
124 |
+
|
125 |
+
# EER (not backpropagated)
|
126 |
+
with torch.no_grad():
|
127 |
+
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
128 |
+
labels = np.array([inv_argmax(i) for i in ground_truth])
|
129 |
+
preds = sim_matrix.detach().cpu().numpy()
|
130 |
+
|
131 |
+
# Snippet from https://yangcha.github.io/EER-ROC/
|
132 |
+
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
133 |
+
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
134 |
+
|
135 |
+
return loss, eer
|
encoder/params_data.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Mel-filterbank
|
3 |
+
mel_window_length = 25 # In milliseconds
|
4 |
+
mel_window_step = 10 # In milliseconds
|
5 |
+
mel_n_channels = 40
|
6 |
+
|
7 |
+
|
8 |
+
## Audio
|
9 |
+
sampling_rate = 16000
|
10 |
+
# Number of spectrogram frames in a partial utterance
|
11 |
+
partials_n_frames = 160 # 1600 ms
|
12 |
+
# Number of spectrogram frames at inference
|
13 |
+
inference_n_frames = 80 # 800 ms
|
14 |
+
|
15 |
+
|
16 |
+
## Voice Activation Detection
|
17 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
18 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
19 |
+
vad_window_length = 30 # In milliseconds
|
20 |
+
# Number of frames to average together when performing the moving average smoothing.
|
21 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
22 |
+
vad_moving_average_width = 8
|
23 |
+
# Maximum number of consecutive silent frames a segment can have.
|
24 |
+
vad_max_silence_length = 6
|
25 |
+
|
26 |
+
|
27 |
+
## Audio volume normalization
|
28 |
+
audio_norm_target_dBFS = -30
|
29 |
+
|
encoder/params_model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Model parameters
|
3 |
+
model_hidden_size = 256
|
4 |
+
model_embedding_size = 256
|
5 |
+
model_num_layers = 3
|
6 |
+
|
7 |
+
|
8 |
+
## Training parameters
|
9 |
+
learning_rate_init = 1e-4
|
10 |
+
speakers_per_batch = 64
|
11 |
+
utterances_per_speaker = 10
|
encoder/preprocess.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocess.pool import ThreadPool
|
2 |
+
from encoder.params_data import *
|
3 |
+
from encoder.config import librispeech_datasets, anglophone_nationalites
|
4 |
+
from datetime import datetime
|
5 |
+
from encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
from tqdm import tqdm
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class DatasetLog:
|
12 |
+
"""
|
13 |
+
Registers metadata about the dataset in a text file.
|
14 |
+
"""
|
15 |
+
def __init__(self, root, name):
|
16 |
+
self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
|
17 |
+
self.sample_data = dict()
|
18 |
+
|
19 |
+
start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
20 |
+
self.write_line("Creating dataset %s on %s" % (name, start_time))
|
21 |
+
self.write_line("-----")
|
22 |
+
self._log_params()
|
23 |
+
|
24 |
+
def _log_params(self):
|
25 |
+
from encoder import params_data
|
26 |
+
self.write_line("Parameter values:")
|
27 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
28 |
+
value = getattr(params_data, param_name)
|
29 |
+
self.write_line("\t%s: %s" % (param_name, value))
|
30 |
+
self.write_line("-----")
|
31 |
+
|
32 |
+
def write_line(self, line):
|
33 |
+
self.text_file.write("%s\n" % line)
|
34 |
+
|
35 |
+
def add_sample(self, **kwargs):
|
36 |
+
for param_name, value in kwargs.items():
|
37 |
+
if not param_name in self.sample_data:
|
38 |
+
self.sample_data[param_name] = []
|
39 |
+
self.sample_data[param_name].append(value)
|
40 |
+
|
41 |
+
def finalize(self):
|
42 |
+
self.write_line("Statistics:")
|
43 |
+
for param_name, values in self.sample_data.items():
|
44 |
+
self.write_line("\t%s:" % param_name)
|
45 |
+
self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
|
46 |
+
self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
|
47 |
+
self.write_line("-----")
|
48 |
+
end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
49 |
+
self.write_line("Finished on %s" % end_time)
|
50 |
+
self.text_file.close()
|
51 |
+
|
52 |
+
|
53 |
+
def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
|
54 |
+
dataset_root = datasets_root.joinpath(dataset_name)
|
55 |
+
if not dataset_root.exists():
|
56 |
+
print("Couldn\'t find %s, skipping this dataset." % dataset_root)
|
57 |
+
return None, None
|
58 |
+
return dataset_root, DatasetLog(out_dir, dataset_name)
|
59 |
+
|
60 |
+
|
61 |
+
def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
|
62 |
+
skip_existing, logger):
|
63 |
+
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
64 |
+
|
65 |
+
# Function to preprocess utterances for one speaker
|
66 |
+
def preprocess_speaker(speaker_dir: Path):
|
67 |
+
# Give a name to the speaker that includes its dataset
|
68 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
69 |
+
|
70 |
+
# Create an output directory with that name, as well as a txt file containing a
|
71 |
+
# reference to each source file.
|
72 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
73 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
74 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
75 |
+
|
76 |
+
# There's a possibility that the preprocessing was interrupted earlier, check if
|
77 |
+
# there already is a sources file.
|
78 |
+
if sources_fpath.exists():
|
79 |
+
try:
|
80 |
+
with sources_fpath.open("r") as sources_file:
|
81 |
+
existing_fnames = {line.split(",")[0] for line in sources_file}
|
82 |
+
except:
|
83 |
+
existing_fnames = {}
|
84 |
+
else:
|
85 |
+
existing_fnames = {}
|
86 |
+
|
87 |
+
# Gather all audio files for that speaker recursively
|
88 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
89 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
90 |
+
# Check if the target output file already exists
|
91 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
92 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
93 |
+
if skip_existing and out_fname in existing_fnames:
|
94 |
+
continue
|
95 |
+
|
96 |
+
# Load and preprocess the waveform
|
97 |
+
wav = audio.preprocess_wav(in_fpath)
|
98 |
+
if len(wav) == 0:
|
99 |
+
continue
|
100 |
+
|
101 |
+
# Create the mel spectrogram, discard those that are too short
|
102 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
103 |
+
if len(frames) < partials_n_frames:
|
104 |
+
continue
|
105 |
+
|
106 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
107 |
+
np.save(out_fpath, frames)
|
108 |
+
logger.add_sample(duration=len(wav) / sampling_rate)
|
109 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
110 |
+
|
111 |
+
sources_file.close()
|
112 |
+
|
113 |
+
# Process the utterances for each speaker
|
114 |
+
with ThreadPool(8) as pool:
|
115 |
+
list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
|
116 |
+
unit="speakers"))
|
117 |
+
logger.finalize()
|
118 |
+
print("Done preprocessing %s.\n" % dataset_name)
|
119 |
+
|
120 |
+
|
121 |
+
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
122 |
+
for dataset_name in librispeech_datasets["train"]["other"]:
|
123 |
+
# Initialize the preprocessing
|
124 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
125 |
+
if not dataset_root:
|
126 |
+
return
|
127 |
+
|
128 |
+
# Preprocess all speakers
|
129 |
+
speaker_dirs = list(dataset_root.glob("*"))
|
130 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
|
131 |
+
skip_existing, logger)
|
132 |
+
|
133 |
+
|
134 |
+
def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
|
135 |
+
# Initialize the preprocessing
|
136 |
+
dataset_name = "VoxCeleb1"
|
137 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
138 |
+
if not dataset_root:
|
139 |
+
return
|
140 |
+
|
141 |
+
# Get the contents of the meta file
|
142 |
+
with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
|
143 |
+
metadata = [line.split("\t") for line in metafile][1:]
|
144 |
+
|
145 |
+
# Select the ID and the nationality, filter out non-anglophone speakers
|
146 |
+
nationalities = {line[0]: line[3] for line in metadata}
|
147 |
+
keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
|
148 |
+
nationality.lower() in anglophone_nationalites]
|
149 |
+
print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
|
150 |
+
(len(keep_speaker_ids), len(nationalities)))
|
151 |
+
|
152 |
+
# Get the speaker directories for anglophone speakers only
|
153 |
+
speaker_dirs = dataset_root.joinpath("wav").glob("*")
|
154 |
+
speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
|
155 |
+
speaker_dir.name in keep_speaker_ids]
|
156 |
+
print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
|
157 |
+
(len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
|
158 |
+
|
159 |
+
# Preprocess all speakers
|
160 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
161 |
+
skip_existing, logger)
|
162 |
+
|
163 |
+
|
164 |
+
def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
|
165 |
+
# Initialize the preprocessing
|
166 |
+
dataset_name = "VoxCeleb2"
|
167 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
168 |
+
if not dataset_root:
|
169 |
+
return
|
170 |
+
|
171 |
+
# Get the speaker directories
|
172 |
+
# Preprocess all speakers
|
173 |
+
speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
|
174 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
|
175 |
+
skip_existing, logger)
|
encoder/train.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.visualizations import Visualizations
|
2 |
+
from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
3 |
+
from encoder.params_model import *
|
4 |
+
from encoder.model import SpeakerEncoder
|
5 |
+
from utils.profiler import Profiler
|
6 |
+
from pathlib import Path
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def sync(device: torch.device):
|
10 |
+
# For correct profiling (cuda operations are async)
|
11 |
+
if device.type == "cuda":
|
12 |
+
torch.cuda.synchronize(device)
|
13 |
+
|
14 |
+
|
15 |
+
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
|
16 |
+
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
|
17 |
+
no_visdom: bool):
|
18 |
+
# Create a dataset and a dataloader
|
19 |
+
dataset = SpeakerVerificationDataset(clean_data_root)
|
20 |
+
loader = SpeakerVerificationDataLoader(
|
21 |
+
dataset,
|
22 |
+
speakers_per_batch,
|
23 |
+
utterances_per_speaker,
|
24 |
+
num_workers=8,
|
25 |
+
)
|
26 |
+
|
27 |
+
# Setup the device on which to run the forward pass and the loss. These can be different,
|
28 |
+
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
|
29 |
+
# hyperparameters) faster on the CPU.
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
# FIXME: currently, the gradient is None if loss_device is cuda
|
32 |
+
loss_device = torch.device("cpu")
|
33 |
+
|
34 |
+
# Create the model and the optimizer
|
35 |
+
model = SpeakerEncoder(device, loss_device)
|
36 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
|
37 |
+
init_step = 1
|
38 |
+
|
39 |
+
# Configure file path for the model
|
40 |
+
state_fpath = models_dir.joinpath(run_id + ".pt")
|
41 |
+
backup_dir = models_dir.joinpath(run_id + "_backups")
|
42 |
+
|
43 |
+
# Load any existing model
|
44 |
+
if not force_restart:
|
45 |
+
if state_fpath.exists():
|
46 |
+
print("Found existing model \"%s\", loading it and resuming training." % run_id)
|
47 |
+
checkpoint = torch.load(state_fpath)
|
48 |
+
init_step = checkpoint["step"]
|
49 |
+
model.load_state_dict(checkpoint["model_state"])
|
50 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
51 |
+
optimizer.param_groups[0]["lr"] = learning_rate_init
|
52 |
+
else:
|
53 |
+
print("No model \"%s\" found, starting training from scratch." % run_id)
|
54 |
+
else:
|
55 |
+
print("Starting the training from scratch.")
|
56 |
+
model.train()
|
57 |
+
|
58 |
+
# Initialize the visualization environment
|
59 |
+
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
|
60 |
+
vis.log_dataset(dataset)
|
61 |
+
vis.log_params()
|
62 |
+
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
|
63 |
+
vis.log_implementation({"Device": device_name})
|
64 |
+
|
65 |
+
# Training loop
|
66 |
+
profiler = Profiler(summarize_every=10, disabled=False)
|
67 |
+
for step, speaker_batch in enumerate(loader, init_step):
|
68 |
+
profiler.tick("Blocking, waiting for batch (threaded)")
|
69 |
+
|
70 |
+
# Forward pass
|
71 |
+
inputs = torch.from_numpy(speaker_batch.data).to(device)
|
72 |
+
sync(device)
|
73 |
+
profiler.tick("Data to %s" % device)
|
74 |
+
embeds = model(inputs)
|
75 |
+
sync(device)
|
76 |
+
profiler.tick("Forward pass")
|
77 |
+
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
|
78 |
+
loss, eer = model.loss(embeds_loss)
|
79 |
+
sync(loss_device)
|
80 |
+
profiler.tick("Loss")
|
81 |
+
|
82 |
+
# Backward pass
|
83 |
+
model.zero_grad()
|
84 |
+
loss.backward()
|
85 |
+
profiler.tick("Backward pass")
|
86 |
+
model.do_gradient_ops()
|
87 |
+
optimizer.step()
|
88 |
+
profiler.tick("Parameter update")
|
89 |
+
|
90 |
+
# Update visualizations
|
91 |
+
# learning_rate = optimizer.param_groups[0]["lr"]
|
92 |
+
vis.update(loss.item(), eer, step)
|
93 |
+
|
94 |
+
# Draw projections and save them to the backup folder
|
95 |
+
if umap_every != 0 and step % umap_every == 0:
|
96 |
+
print("Drawing and saving projections (step %d)" % step)
|
97 |
+
backup_dir.mkdir(exist_ok=True)
|
98 |
+
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
|
99 |
+
embeds = embeds.detach().cpu().numpy()
|
100 |
+
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
|
101 |
+
vis.save()
|
102 |
+
|
103 |
+
# Overwrite the latest version of the model
|
104 |
+
if save_every != 0 and step % save_every == 0:
|
105 |
+
print("Saving the model (step %d)" % step)
|
106 |
+
torch.save({
|
107 |
+
"step": step + 1,
|
108 |
+
"model_state": model.state_dict(),
|
109 |
+
"optimizer_state": optimizer.state_dict(),
|
110 |
+
}, state_fpath)
|
111 |
+
|
112 |
+
# Make a backup
|
113 |
+
if backup_every != 0 and step % backup_every == 0:
|
114 |
+
print("Making a backup (step %d)" % step)
|
115 |
+
backup_dir.mkdir(exist_ok=True)
|
116 |
+
backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
|
117 |
+
torch.save({
|
118 |
+
"step": step + 1,
|
119 |
+
"model_state": model.state_dict(),
|
120 |
+
"optimizer_state": optimizer.state_dict(),
|
121 |
+
}, backup_fpath)
|
122 |
+
|
123 |
+
profiler.tick("Extras (visualizations, saving)")
|
encoder/visualizations.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from datetime import datetime
|
3 |
+
from time import perf_counter as timer
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
# import webbrowser
|
7 |
+
import visdom
|
8 |
+
import umap
|
9 |
+
|
10 |
+
colormap = np.array([
|
11 |
+
[76, 255, 0],
|
12 |
+
[0, 127, 70],
|
13 |
+
[255, 0, 0],
|
14 |
+
[255, 217, 38],
|
15 |
+
[0, 135, 255],
|
16 |
+
[165, 0, 165],
|
17 |
+
[255, 167, 255],
|
18 |
+
[0, 255, 255],
|
19 |
+
[255, 96, 38],
|
20 |
+
[142, 76, 0],
|
21 |
+
[33, 0, 127],
|
22 |
+
[0, 0, 0],
|
23 |
+
[183, 183, 183],
|
24 |
+
], dtype=np.float) / 255
|
25 |
+
|
26 |
+
|
27 |
+
class Visualizations:
|
28 |
+
def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
|
29 |
+
# Tracking data
|
30 |
+
self.last_update_timestamp = timer()
|
31 |
+
self.update_every = update_every
|
32 |
+
self.step_times = []
|
33 |
+
self.losses = []
|
34 |
+
self.eers = []
|
35 |
+
print("Updating the visualizations every %d steps." % update_every)
|
36 |
+
|
37 |
+
# If visdom is disabled TODO: use a better paradigm for that
|
38 |
+
self.disabled = disabled
|
39 |
+
if self.disabled:
|
40 |
+
return
|
41 |
+
|
42 |
+
# Set the environment name
|
43 |
+
now = str(datetime.now().strftime("%d-%m %Hh%M"))
|
44 |
+
if env_name is None:
|
45 |
+
self.env_name = now
|
46 |
+
else:
|
47 |
+
self.env_name = "%s (%s)" % (env_name, now)
|
48 |
+
|
49 |
+
# Connect to visdom and open the corresponding window in the browser
|
50 |
+
try:
|
51 |
+
self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
|
52 |
+
except ConnectionError:
|
53 |
+
raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
|
54 |
+
"start it.")
|
55 |
+
# webbrowser.open("http://localhost:8097/env/" + self.env_name)
|
56 |
+
|
57 |
+
# Create the windows
|
58 |
+
self.loss_win = None
|
59 |
+
self.eer_win = None
|
60 |
+
# self.lr_win = None
|
61 |
+
self.implementation_win = None
|
62 |
+
self.projection_win = None
|
63 |
+
self.implementation_string = ""
|
64 |
+
|
65 |
+
def log_params(self):
|
66 |
+
if self.disabled:
|
67 |
+
return
|
68 |
+
from encoder import params_data
|
69 |
+
from encoder import params_model
|
70 |
+
param_string = "<b>Model parameters</b>:<br>"
|
71 |
+
for param_name in (p for p in dir(params_model) if not p.startswith("__")):
|
72 |
+
value = getattr(params_model, param_name)
|
73 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
74 |
+
param_string += "<b>Data parameters</b>:<br>"
|
75 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
76 |
+
value = getattr(params_data, param_name)
|
77 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
78 |
+
self.vis.text(param_string, opts={"title": "Parameters"})
|
79 |
+
|
80 |
+
def log_dataset(self, dataset: SpeakerVerificationDataset):
|
81 |
+
if self.disabled:
|
82 |
+
return
|
83 |
+
dataset_string = ""
|
84 |
+
dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
|
85 |
+
dataset_string += "\n" + dataset.get_logs()
|
86 |
+
dataset_string = dataset_string.replace("\n", "<br>")
|
87 |
+
self.vis.text(dataset_string, opts={"title": "Dataset"})
|
88 |
+
|
89 |
+
def log_implementation(self, params):
|
90 |
+
if self.disabled:
|
91 |
+
return
|
92 |
+
implementation_string = ""
|
93 |
+
for param, value in params.items():
|
94 |
+
implementation_string += "<b>%s</b>: %s\n" % (param, value)
|
95 |
+
implementation_string = implementation_string.replace("\n", "<br>")
|
96 |
+
self.implementation_string = implementation_string
|
97 |
+
self.implementation_win = self.vis.text(
|
98 |
+
implementation_string,
|
99 |
+
opts={"title": "Training implementation"}
|
100 |
+
)
|
101 |
+
|
102 |
+
def update(self, loss, eer, step):
|
103 |
+
# Update the tracking data
|
104 |
+
now = timer()
|
105 |
+
self.step_times.append(1000 * (now - self.last_update_timestamp))
|
106 |
+
self.last_update_timestamp = now
|
107 |
+
self.losses.append(loss)
|
108 |
+
self.eers.append(eer)
|
109 |
+
print(".", end="")
|
110 |
+
|
111 |
+
# Update the plots every <update_every> steps
|
112 |
+
if step % self.update_every != 0:
|
113 |
+
return
|
114 |
+
time_string = "Step time: mean: %5dms std: %5dms" % \
|
115 |
+
(int(np.mean(self.step_times)), int(np.std(self.step_times)))
|
116 |
+
print("\nStep %6d Loss: %.4f EER: %.4f %s" %
|
117 |
+
(step, np.mean(self.losses), np.mean(self.eers), time_string))
|
118 |
+
if not self.disabled:
|
119 |
+
self.loss_win = self.vis.line(
|
120 |
+
[np.mean(self.losses)],
|
121 |
+
[step],
|
122 |
+
win=self.loss_win,
|
123 |
+
update="append" if self.loss_win else None,
|
124 |
+
opts=dict(
|
125 |
+
legend=["Avg. loss"],
|
126 |
+
xlabel="Step",
|
127 |
+
ylabel="Loss",
|
128 |
+
title="Loss",
|
129 |
+
)
|
130 |
+
)
|
131 |
+
self.eer_win = self.vis.line(
|
132 |
+
[np.mean(self.eers)],
|
133 |
+
[step],
|
134 |
+
win=self.eer_win,
|
135 |
+
update="append" if self.eer_win else None,
|
136 |
+
opts=dict(
|
137 |
+
legend=["Avg. EER"],
|
138 |
+
xlabel="Step",
|
139 |
+
ylabel="EER",
|
140 |
+
title="Equal error rate"
|
141 |
+
)
|
142 |
+
)
|
143 |
+
if self.implementation_win is not None:
|
144 |
+
self.vis.text(
|
145 |
+
self.implementation_string + ("<b>%s</b>" % time_string),
|
146 |
+
win=self.implementation_win,
|
147 |
+
opts={"title": "Training implementation"},
|
148 |
+
)
|
149 |
+
|
150 |
+
# Reset the tracking
|
151 |
+
self.losses.clear()
|
152 |
+
self.eers.clear()
|
153 |
+
self.step_times.clear()
|
154 |
+
|
155 |
+
def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
|
156 |
+
max_speakers=10):
|
157 |
+
max_speakers = min(max_speakers, len(colormap))
|
158 |
+
embeds = embeds[:max_speakers * utterances_per_speaker]
|
159 |
+
|
160 |
+
n_speakers = len(embeds) // utterances_per_speaker
|
161 |
+
ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
|
162 |
+
colors = [colormap[i] for i in ground_truth]
|
163 |
+
|
164 |
+
reducer = umap.UMAP()
|
165 |
+
projected = reducer.fit_transform(embeds)
|
166 |
+
plt.scatter(projected[:, 0], projected[:, 1], c=colors)
|
167 |
+
plt.gca().set_aspect("equal", "datalim")
|
168 |
+
plt.title("UMAP projection (step %d)" % step)
|
169 |
+
if not self.disabled:
|
170 |
+
self.projection_win = self.vis.matplot(plt, win=self.projection_win)
|
171 |
+
if out_fpath is not None:
|
172 |
+
plt.savefig(out_fpath)
|
173 |
+
plt.clf()
|
174 |
+
|
175 |
+
def save(self):
|
176 |
+
if not self.disabled:
|
177 |
+
self.vis.save([self.env_name])
|
178 |
+
|
encoder_preprocess.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2
|
2 |
+
from utils.argutils import print_args
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
8 |
+
pass
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser(
|
11 |
+
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
|
12 |
+
"writes them to the disk. This will allow you to train the encoder. The "
|
13 |
+
"datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. "
|
14 |
+
"Ideally, you should have all three. You should extract them as they are "
|
15 |
+
"after having downloaded them and put them in a same directory, e.g.:\n"
|
16 |
+
"-[datasets_root]\n"
|
17 |
+
" -LibriSpeech\n"
|
18 |
+
" -train-other-500\n"
|
19 |
+
" -VoxCeleb1\n"
|
20 |
+
" -wav\n"
|
21 |
+
" -vox1_meta.csv\n"
|
22 |
+
" -VoxCeleb2\n"
|
23 |
+
" -dev",
|
24 |
+
formatter_class=MyFormatter
|
25 |
+
)
|
26 |
+
parser.add_argument("datasets_root", type=Path, help=\
|
27 |
+
"Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
|
28 |
+
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
29 |
+
"Path to the output directory that will contain the mel spectrograms. If left out, "
|
30 |
+
"defaults to <datasets_root>/SV2TTS/encoder/")
|
31 |
+
parser.add_argument("-d", "--datasets", type=str,
|
32 |
+
default="librispeech_other,voxceleb1,voxceleb2", help=\
|
33 |
+
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
|
34 |
+
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
|
35 |
+
"voxceleb2.")
|
36 |
+
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
37 |
+
"Whether to skip existing output files with the same name. Useful if this script was "
|
38 |
+
"interrupted.")
|
39 |
+
parser.add_argument("--no_trim", action="store_true", help=\
|
40 |
+
"Preprocess audio without trimming silences (not recommended).")
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
# Verify webrtcvad is available
|
44 |
+
if not args.no_trim:
|
45 |
+
try:
|
46 |
+
import webrtcvad
|
47 |
+
except:
|
48 |
+
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
49 |
+
"noise removal and is recommended. Please install and try again. If installation fails, "
|
50 |
+
"use --no_trim to disable this error message.")
|
51 |
+
del args.no_trim
|
52 |
+
|
53 |
+
# Process the arguments
|
54 |
+
args.datasets = args.datasets.split(",")
|
55 |
+
if not hasattr(args, "out_dir"):
|
56 |
+
args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
|
57 |
+
assert args.datasets_root.exists()
|
58 |
+
args.out_dir.mkdir(exist_ok=True, parents=True)
|
59 |
+
|
60 |
+
# Preprocess the datasets
|
61 |
+
print_args(args, parser)
|
62 |
+
preprocess_func = {
|
63 |
+
"librispeech_other": preprocess_librispeech,
|
64 |
+
"voxceleb1": preprocess_voxceleb1,
|
65 |
+
"voxceleb2": preprocess_voxceleb2,
|
66 |
+
}
|
67 |
+
args = vars(args)
|
68 |
+
for dataset in args.pop("datasets"):
|
69 |
+
print("Preprocessing %s" % dataset)
|
70 |
+
preprocess_func[dataset](**args)
|
encoder_train.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.argutils import print_args
|
2 |
+
from encoder.train import train
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
parser = argparse.ArgumentParser(
|
9 |
+
description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
|
10 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
11 |
+
)
|
12 |
+
|
13 |
+
parser.add_argument("run_id", type=str, help= \
|
14 |
+
"Name for this model instance. If a model state from the same run ID was previously "
|
15 |
+
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
16 |
+
"restart from scratch.")
|
17 |
+
parser.add_argument("clean_data_root", type=Path, help= \
|
18 |
+
"Path to the output directory of encoder_preprocess.py. If you left the default "
|
19 |
+
"output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
|
20 |
+
parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
|
21 |
+
"Path to the output directory that will contain the saved model weights, as well as "
|
22 |
+
"backups of those weights and plots generated during training.")
|
23 |
+
parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
|
24 |
+
"Number of steps between updates of the loss and the plots.")
|
25 |
+
parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
|
26 |
+
"Number of steps between updates of the umap projection. Set to 0 to never update the "
|
27 |
+
"projections.")
|
28 |
+
parser.add_argument("-s", "--save_every", type=int, default=500, help= \
|
29 |
+
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
30 |
+
"model.")
|
31 |
+
parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
|
32 |
+
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
33 |
+
"model.")
|
34 |
+
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
35 |
+
"Do not load any saved model.")
|
36 |
+
parser.add_argument("--visdom_server", type=str, default="http://localhost")
|
37 |
+
parser.add_argument("--no_visdom", action="store_true", help= \
|
38 |
+
"Disable visdom.")
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
# Process the arguments
|
42 |
+
args.models_dir.mkdir(exist_ok=True)
|
43 |
+
|
44 |
+
# Run the training
|
45 |
+
print_args(args, parser)
|
46 |
+
train(**vars(args))
|
47 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
umap-learn
|
2 |
+
visdom
|
3 |
+
librosa>=0.8.0
|
4 |
+
matplotlib>=3.3.0
|
5 |
+
numpy==1.19.3; platform_system == "Windows"
|
6 |
+
numpy==1.19.4; platform_system != "Windows"
|
7 |
+
scipy>=1.0.0
|
8 |
+
tqdm
|
9 |
+
sounddevice
|
10 |
+
SoundFile
|
11 |
+
Unidecode
|
12 |
+
inflect
|
13 |
+
PyQt5
|
14 |
+
multiprocess
|
15 |
+
numba
|
16 |
+
webrtcvad; platform_system != "Windows"
|
synthesizer/LICENSE.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
|
4 |
+
Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
|
5 |
+
Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
|
6 |
+
Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
|
7 |
+
|
8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
of this software and associated documentation files (the "Software"), to deal
|
10 |
+
in the Software without restriction, including without limitation the rights
|
11 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
copies of the Software, and to permit persons to whom the Software is
|
13 |
+
furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice shall be included in all
|
16 |
+
copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
SOFTWARE.
|
synthesizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#
|
synthesizer/audio.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.filters
|
3 |
+
import numpy as np
|
4 |
+
from scipy import signal
|
5 |
+
from scipy.io import wavfile
|
6 |
+
import soundfile as sf
|
7 |
+
|
8 |
+
|
9 |
+
def load_wav(path, sr):
|
10 |
+
return librosa.core.load(path, sr=sr)[0]
|
11 |
+
|
12 |
+
def save_wav(wav, path, sr):
|
13 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
14 |
+
#proposed by @dsmiller
|
15 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
16 |
+
|
17 |
+
def save_wavenet_wav(wav, path, sr):
|
18 |
+
sf.write(path, wav.astype(np.float32), sr)
|
19 |
+
|
20 |
+
def preemphasis(wav, k, preemphasize=True):
|
21 |
+
if preemphasize:
|
22 |
+
return signal.lfilter([1, -k], [1], wav)
|
23 |
+
return wav
|
24 |
+
|
25 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
26 |
+
if inv_preemphasize:
|
27 |
+
return signal.lfilter([1], [1, -k], wav)
|
28 |
+
return wav
|
29 |
+
|
30 |
+
#From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
|
31 |
+
def start_and_end_indices(quantized, silence_threshold=2):
|
32 |
+
for start in range(quantized.size):
|
33 |
+
if abs(quantized[start] - 127) > silence_threshold:
|
34 |
+
break
|
35 |
+
for end in range(quantized.size - 1, 1, -1):
|
36 |
+
if abs(quantized[end] - 127) > silence_threshold:
|
37 |
+
break
|
38 |
+
|
39 |
+
assert abs(quantized[start] - 127) > silence_threshold
|
40 |
+
assert abs(quantized[end] - 127) > silence_threshold
|
41 |
+
|
42 |
+
return start, end
|
43 |
+
|
44 |
+
def get_hop_size(hparams):
|
45 |
+
hop_size = hparams.hop_size
|
46 |
+
if hop_size is None:
|
47 |
+
assert hparams.frame_shift_ms is not None
|
48 |
+
hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
|
49 |
+
return hop_size
|
50 |
+
|
51 |
+
def linearspectrogram(wav, hparams):
|
52 |
+
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
|
53 |
+
S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
|
54 |
+
|
55 |
+
if hparams.signal_normalization:
|
56 |
+
return _normalize(S, hparams)
|
57 |
+
return S
|
58 |
+
|
59 |
+
def melspectrogram(wav, hparams):
|
60 |
+
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
|
61 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
|
62 |
+
|
63 |
+
if hparams.signal_normalization:
|
64 |
+
return _normalize(S, hparams)
|
65 |
+
return S
|
66 |
+
|
67 |
+
def inv_linear_spectrogram(linear_spectrogram, hparams):
|
68 |
+
"""Converts linear spectrogram to waveform using librosa"""
|
69 |
+
if hparams.signal_normalization:
|
70 |
+
D = _denormalize(linear_spectrogram, hparams)
|
71 |
+
else:
|
72 |
+
D = linear_spectrogram
|
73 |
+
|
74 |
+
S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
|
75 |
+
|
76 |
+
if hparams.use_lws:
|
77 |
+
processor = _lws_processor(hparams)
|
78 |
+
D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
|
79 |
+
y = processor.istft(D).astype(np.float32)
|
80 |
+
return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
|
81 |
+
else:
|
82 |
+
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
|
83 |
+
|
84 |
+
def inv_mel_spectrogram(mel_spectrogram, hparams):
|
85 |
+
"""Converts mel spectrogram to waveform using librosa"""
|
86 |
+
if hparams.signal_normalization:
|
87 |
+
D = _denormalize(mel_spectrogram, hparams)
|
88 |
+
else:
|
89 |
+
D = mel_spectrogram
|
90 |
+
|
91 |
+
S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
|
92 |
+
|
93 |
+
if hparams.use_lws:
|
94 |
+
processor = _lws_processor(hparams)
|
95 |
+
D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
|
96 |
+
y = processor.istft(D).astype(np.float32)
|
97 |
+
return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
|
98 |
+
else:
|
99 |
+
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
|
100 |
+
|
101 |
+
def _lws_processor(hparams):
|
102 |
+
import lws
|
103 |
+
return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
|
104 |
+
|
105 |
+
def _griffin_lim(S, hparams):
|
106 |
+
"""librosa implementation of Griffin-Lim
|
107 |
+
Based on https://github.com/librosa/librosa/issues/434
|
108 |
+
"""
|
109 |
+
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
110 |
+
S_complex = np.abs(S).astype(np.complex)
|
111 |
+
y = _istft(S_complex * angles, hparams)
|
112 |
+
for i in range(hparams.griffin_lim_iters):
|
113 |
+
angles = np.exp(1j * np.angle(_stft(y, hparams)))
|
114 |
+
y = _istft(S_complex * angles, hparams)
|
115 |
+
return y
|
116 |
+
|
117 |
+
def _stft(y, hparams):
|
118 |
+
if hparams.use_lws:
|
119 |
+
return _lws_processor(hparams).stft(y).T
|
120 |
+
else:
|
121 |
+
return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
|
122 |
+
|
123 |
+
def _istft(y, hparams):
|
124 |
+
return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
|
125 |
+
|
126 |
+
##########################################################
|
127 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
128 |
+
def num_frames(length, fsize, fshift):
|
129 |
+
"""Compute number of time frames of spectrogram
|
130 |
+
"""
|
131 |
+
pad = (fsize - fshift)
|
132 |
+
if length % fshift == 0:
|
133 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
134 |
+
else:
|
135 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
136 |
+
return M
|
137 |
+
|
138 |
+
|
139 |
+
def pad_lr(x, fsize, fshift):
|
140 |
+
"""Compute left and right padding
|
141 |
+
"""
|
142 |
+
M = num_frames(len(x), fsize, fshift)
|
143 |
+
pad = (fsize - fshift)
|
144 |
+
T = len(x) + 2 * pad
|
145 |
+
r = (M - 1) * fshift + fsize - T
|
146 |
+
return pad, pad + r
|
147 |
+
##########################################################
|
148 |
+
#Librosa correct padding
|
149 |
+
def librosa_pad_lr(x, fsize, fshift):
|
150 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
151 |
+
|
152 |
+
# Conversions
|
153 |
+
_mel_basis = None
|
154 |
+
_inv_mel_basis = None
|
155 |
+
|
156 |
+
def _linear_to_mel(spectogram, hparams):
|
157 |
+
global _mel_basis
|
158 |
+
if _mel_basis is None:
|
159 |
+
_mel_basis = _build_mel_basis(hparams)
|
160 |
+
return np.dot(_mel_basis, spectogram)
|
161 |
+
|
162 |
+
def _mel_to_linear(mel_spectrogram, hparams):
|
163 |
+
global _inv_mel_basis
|
164 |
+
if _inv_mel_basis is None:
|
165 |
+
_inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
|
166 |
+
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
|
167 |
+
|
168 |
+
def _build_mel_basis(hparams):
|
169 |
+
assert hparams.fmax <= hparams.sample_rate // 2
|
170 |
+
return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
|
171 |
+
fmin=hparams.fmin, fmax=hparams.fmax)
|
172 |
+
|
173 |
+
def _amp_to_db(x, hparams):
|
174 |
+
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
|
175 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
176 |
+
|
177 |
+
def _db_to_amp(x):
|
178 |
+
return np.power(10.0, (x) * 0.05)
|
179 |
+
|
180 |
+
def _normalize(S, hparams):
|
181 |
+
if hparams.allow_clipping_in_normalization:
|
182 |
+
if hparams.symmetric_mels:
|
183 |
+
return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
|
184 |
+
-hparams.max_abs_value, hparams.max_abs_value)
|
185 |
+
else:
|
186 |
+
return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
|
187 |
+
|
188 |
+
assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
|
189 |
+
if hparams.symmetric_mels:
|
190 |
+
return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
|
191 |
+
else:
|
192 |
+
return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
|
193 |
+
|
194 |
+
def _denormalize(D, hparams):
|
195 |
+
if hparams.allow_clipping_in_normalization:
|
196 |
+
if hparams.symmetric_mels:
|
197 |
+
return (((np.clip(D, -hparams.max_abs_value,
|
198 |
+
hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
|
199 |
+
+ hparams.min_level_db)
|
200 |
+
else:
|
201 |
+
return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
|
202 |
+
|
203 |
+
if hparams.symmetric_mels:
|
204 |
+
return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
|
205 |
+
else:
|
206 |
+
return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
|
synthesizer/hparams.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import pprint
|
3 |
+
|
4 |
+
class HParams(object):
|
5 |
+
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
6 |
+
def __setitem__(self, key, value): setattr(self, key, value)
|
7 |
+
def __getitem__(self, key): return getattr(self, key)
|
8 |
+
def __repr__(self): return pprint.pformat(self.__dict__)
|
9 |
+
|
10 |
+
def parse(self, string):
|
11 |
+
# Overrides hparams from a comma-separated string of name=value pairs
|
12 |
+
if len(string) > 0:
|
13 |
+
overrides = [s.split("=") for s in string.split(",")]
|
14 |
+
keys, values = zip(*overrides)
|
15 |
+
keys = list(map(str.strip, keys))
|
16 |
+
values = list(map(str.strip, values))
|
17 |
+
for k in keys:
|
18 |
+
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
19 |
+
return self
|
20 |
+
|
21 |
+
hparams = HParams(
|
22 |
+
### Signal Processing (used in both synthesizer and vocoder)
|
23 |
+
sample_rate = 16000,
|
24 |
+
n_fft = 800,
|
25 |
+
num_mels = 80,
|
26 |
+
hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
|
27 |
+
win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
|
28 |
+
fmin = 55,
|
29 |
+
min_level_db = -100,
|
30 |
+
ref_level_db = 20,
|
31 |
+
max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
|
32 |
+
preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
|
33 |
+
preemphasize = True,
|
34 |
+
|
35 |
+
### Tacotron Text-to-Speech (TTS)
|
36 |
+
tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
|
37 |
+
tts_encoder_dims = 256,
|
38 |
+
tts_decoder_dims = 128,
|
39 |
+
tts_postnet_dims = 512,
|
40 |
+
tts_encoder_K = 5,
|
41 |
+
tts_lstm_dims = 1024,
|
42 |
+
tts_postnet_K = 5,
|
43 |
+
tts_num_highways = 4,
|
44 |
+
tts_dropout = 0.5,
|
45 |
+
tts_cleaner_names = ["english_cleaners"],
|
46 |
+
tts_stop_threshold = -3.4, # Value below which audio generation ends.
|
47 |
+
# For example, for a range of [-4, 4], this
|
48 |
+
# will terminate the sequence at the first
|
49 |
+
# frame that has all values < -3.4
|
50 |
+
|
51 |
+
### Tacotron Training
|
52 |
+
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
53 |
+
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
54 |
+
(2, 2e-4, 80_000, 12), #
|
55 |
+
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
56 |
+
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
57 |
+
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
58 |
+
|
59 |
+
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
60 |
+
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
61 |
+
# Set to -1 to generate after completing epoch, or 0 to disable
|
62 |
+
|
63 |
+
tts_eval_num_samples = 1, # Makes this number of samples
|
64 |
+
|
65 |
+
### Data Preprocessing
|
66 |
+
max_mel_frames = 900,
|
67 |
+
rescale = True,
|
68 |
+
rescaling_max = 0.9,
|
69 |
+
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
70 |
+
|
71 |
+
### Mel Visualization and Griffin-Lim
|
72 |
+
signal_normalization = True,
|
73 |
+
power = 1.5,
|
74 |
+
griffin_lim_iters = 60,
|
75 |
+
|
76 |
+
### Audio processing options
|
77 |
+
fmax = 7600, # Should not exceed (sample_rate // 2)
|
78 |
+
allow_clipping_in_normalization = True, # Used when signal_normalization = True
|
79 |
+
clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
|
80 |
+
use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
|
81 |
+
symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
|
82 |
+
# and [0, max_abs_value] if False
|
83 |
+
trim_silence = True, # Use with sample_rate of 16000 for best results
|
84 |
+
|
85 |
+
### SV2TTS
|
86 |
+
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
87 |
+
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
88 |
+
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
89 |
+
)
|
90 |
+
|
91 |
+
def hparams_debug_string():
|
92 |
+
return str(hparams)
|
synthesizer/inference.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from synthesizer import audio
|
3 |
+
from synthesizer.hparams import hparams
|
4 |
+
from synthesizer.models.tacotron import Tacotron
|
5 |
+
from synthesizer.utils.symbols import symbols
|
6 |
+
from synthesizer.utils.text import text_to_sequence
|
7 |
+
from vocoder.display import simple_table
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Union, List
|
10 |
+
import numpy as np
|
11 |
+
import librosa
|
12 |
+
|
13 |
+
|
14 |
+
class Synthesizer:
|
15 |
+
sample_rate = hparams.sample_rate
|
16 |
+
hparams = hparams
|
17 |
+
|
18 |
+
def __init__(self, model_fpath: Path, verbose=True):
|
19 |
+
"""
|
20 |
+
The model isn't instantiated and loaded in memory until needed or until load() is called.
|
21 |
+
|
22 |
+
:param model_fpath: path to the trained model file
|
23 |
+
:param verbose: if False, prints less information when using the model
|
24 |
+
"""
|
25 |
+
self.model_fpath = model_fpath
|
26 |
+
self.verbose = verbose
|
27 |
+
|
28 |
+
# Check for GPU
|
29 |
+
if torch.cuda.is_available():
|
30 |
+
self.device = torch.device("cuda")
|
31 |
+
else:
|
32 |
+
self.device = torch.device("cpu")
|
33 |
+
if self.verbose:
|
34 |
+
print("Synthesizer using device:", self.device)
|
35 |
+
|
36 |
+
# Tacotron model will be instantiated later on first use.
|
37 |
+
self._model = None
|
38 |
+
|
39 |
+
def is_loaded(self):
|
40 |
+
"""
|
41 |
+
Whether the model is loaded in memory.
|
42 |
+
"""
|
43 |
+
return self._model is not None
|
44 |
+
|
45 |
+
def load(self):
|
46 |
+
"""
|
47 |
+
Instantiates and loads the model given the weights file that was passed in the constructor.
|
48 |
+
"""
|
49 |
+
self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
50 |
+
num_chars=len(symbols),
|
51 |
+
encoder_dims=hparams.tts_encoder_dims,
|
52 |
+
decoder_dims=hparams.tts_decoder_dims,
|
53 |
+
n_mels=hparams.num_mels,
|
54 |
+
fft_bins=hparams.num_mels,
|
55 |
+
postnet_dims=hparams.tts_postnet_dims,
|
56 |
+
encoder_K=hparams.tts_encoder_K,
|
57 |
+
lstm_dims=hparams.tts_lstm_dims,
|
58 |
+
postnet_K=hparams.tts_postnet_K,
|
59 |
+
num_highways=hparams.tts_num_highways,
|
60 |
+
dropout=hparams.tts_dropout,
|
61 |
+
stop_threshold=hparams.tts_stop_threshold,
|
62 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
63 |
+
|
64 |
+
self._model.load(self.model_fpath)
|
65 |
+
self._model.eval()
|
66 |
+
|
67 |
+
if self.verbose:
|
68 |
+
print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
|
69 |
+
|
70 |
+
def synthesize_spectrograms(self, texts: List[str],
|
71 |
+
embeddings: Union[np.ndarray, List[np.ndarray]],
|
72 |
+
return_alignments=False):
|
73 |
+
"""
|
74 |
+
Synthesizes mel spectrograms from texts and speaker embeddings.
|
75 |
+
|
76 |
+
:param texts: a list of N text prompts to be synthesized
|
77 |
+
:param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
|
78 |
+
:param return_alignments: if True, a matrix representing the alignments between the
|
79 |
+
characters
|
80 |
+
and each decoder output step will be returned for each spectrogram
|
81 |
+
:return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
|
82 |
+
sequence length of spectrogram i, and possibly the alignments.
|
83 |
+
"""
|
84 |
+
# Load the model on the first request.
|
85 |
+
if not self.is_loaded():
|
86 |
+
self.load()
|
87 |
+
|
88 |
+
# Print some info about the model when it is loaded
|
89 |
+
tts_k = self._model.get_step() // 1000
|
90 |
+
|
91 |
+
simple_table([("Tacotron", str(tts_k) + "k"),
|
92 |
+
("r", self._model.r)])
|
93 |
+
|
94 |
+
# Preprocess text inputs
|
95 |
+
inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
|
96 |
+
if not isinstance(embeddings, list):
|
97 |
+
embeddings = [embeddings]
|
98 |
+
|
99 |
+
# Batch inputs
|
100 |
+
batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
|
101 |
+
for i in range(0, len(inputs), hparams.synthesis_batch_size)]
|
102 |
+
batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
|
103 |
+
for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
|
104 |
+
|
105 |
+
specs = []
|
106 |
+
for i, batch in enumerate(batched_inputs, 1):
|
107 |
+
if self.verbose:
|
108 |
+
print(f"\n| Generating {i}/{len(batched_inputs)}")
|
109 |
+
|
110 |
+
# Pad texts so they are all the same length
|
111 |
+
text_lens = [len(text) for text in batch]
|
112 |
+
max_text_len = max(text_lens)
|
113 |
+
chars = [pad1d(text, max_text_len) for text in batch]
|
114 |
+
chars = np.stack(chars)
|
115 |
+
|
116 |
+
# Stack speaker embeddings into 2D array for batch processing
|
117 |
+
speaker_embeds = np.stack(batched_embeds[i-1])
|
118 |
+
|
119 |
+
# Convert to tensor
|
120 |
+
chars = torch.tensor(chars).long().to(self.device)
|
121 |
+
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
122 |
+
|
123 |
+
# Inference
|
124 |
+
_, mels, alignments = self._model.generate(chars, speaker_embeddings)
|
125 |
+
mels = mels.detach().cpu().numpy()
|
126 |
+
for m in mels:
|
127 |
+
# Trim silence from end of each spectrogram
|
128 |
+
while np.max(m[:, -1]) < hparams.tts_stop_threshold:
|
129 |
+
m = m[:, :-1]
|
130 |
+
specs.append(m)
|
131 |
+
|
132 |
+
if self.verbose:
|
133 |
+
print("\n\nDone.\n")
|
134 |
+
return (specs, alignments) if return_alignments else specs
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
def load_preprocess_wav(fpath):
|
138 |
+
"""
|
139 |
+
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
140 |
+
train the synthesizer.
|
141 |
+
"""
|
142 |
+
wav = librosa.load(str(fpath), hparams.sample_rate)[0]
|
143 |
+
if hparams.rescale:
|
144 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
145 |
+
return wav
|
146 |
+
|
147 |
+
@staticmethod
|
148 |
+
def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
|
149 |
+
"""
|
150 |
+
Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
|
151 |
+
were fed to the synthesizer when training.
|
152 |
+
"""
|
153 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
154 |
+
wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
|
155 |
+
else:
|
156 |
+
wav = fpath_or_wav
|
157 |
+
|
158 |
+
mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
|
159 |
+
return mel_spectrogram
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def griffin_lim(mel):
|
163 |
+
"""
|
164 |
+
Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
|
165 |
+
with the same parameters present in hparams.py.
|
166 |
+
"""
|
167 |
+
return audio.inv_mel_spectrogram(mel, hparams)
|
168 |
+
|
169 |
+
|
170 |
+
def pad1d(x, max_len, pad_value=0):
|
171 |
+
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
synthesizer/models/tacotron.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
|
10 |
+
class HighwayNetwork(nn.Module):
|
11 |
+
def __init__(self, size):
|
12 |
+
super().__init__()
|
13 |
+
self.W1 = nn.Linear(size, size)
|
14 |
+
self.W2 = nn.Linear(size, size)
|
15 |
+
self.W1.bias.data.fill_(0.)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x1 = self.W1(x)
|
19 |
+
x2 = self.W2(x)
|
20 |
+
g = torch.sigmoid(x2)
|
21 |
+
y = g * F.relu(x1) + (1. - g) * x
|
22 |
+
return y
|
23 |
+
|
24 |
+
|
25 |
+
class Encoder(nn.Module):
|
26 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
|
27 |
+
super().__init__()
|
28 |
+
prenet_dims = (encoder_dims, encoder_dims)
|
29 |
+
cbhg_channels = encoder_dims
|
30 |
+
self.embedding = nn.Embedding(num_chars, embed_dims)
|
31 |
+
self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
32 |
+
dropout=dropout)
|
33 |
+
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
34 |
+
proj_channels=[cbhg_channels, cbhg_channels],
|
35 |
+
num_highways=num_highways)
|
36 |
+
|
37 |
+
def forward(self, x, speaker_embedding=None):
|
38 |
+
x = self.embedding(x)
|
39 |
+
x = self.pre_net(x)
|
40 |
+
x.transpose_(1, 2)
|
41 |
+
x = self.cbhg(x)
|
42 |
+
if speaker_embedding is not None:
|
43 |
+
x = self.add_speaker_embedding(x, speaker_embedding)
|
44 |
+
return x
|
45 |
+
|
46 |
+
def add_speaker_embedding(self, x, speaker_embedding):
|
47 |
+
# SV2TTS
|
48 |
+
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
49 |
+
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
50 |
+
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
51 |
+
# This concats the speaker embedding for each char in the encoder output
|
52 |
+
|
53 |
+
# Save the dimensions as human-readable names
|
54 |
+
batch_size = x.size()[0]
|
55 |
+
num_chars = x.size()[1]
|
56 |
+
|
57 |
+
if speaker_embedding.dim() == 1:
|
58 |
+
idx = 0
|
59 |
+
else:
|
60 |
+
idx = 1
|
61 |
+
|
62 |
+
# Start by making a copy of each speaker embedding to match the input text length
|
63 |
+
# The output of this has size (batch_size, num_chars * tts_embed_dims)
|
64 |
+
speaker_embedding_size = speaker_embedding.size()[idx]
|
65 |
+
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
66 |
+
|
67 |
+
# Reshape it and transpose
|
68 |
+
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
69 |
+
e = e.transpose(1, 2)
|
70 |
+
|
71 |
+
# Concatenate the tiled speaker embedding with the encoder output
|
72 |
+
x = torch.cat((x, e), 2)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class BatchNormConv(nn.Module):
|
77 |
+
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
78 |
+
super().__init__()
|
79 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
80 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
81 |
+
self.relu = relu
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x = self.conv(x)
|
85 |
+
x = F.relu(x) if self.relu is True else x
|
86 |
+
return self.bnorm(x)
|
87 |
+
|
88 |
+
|
89 |
+
class CBHG(nn.Module):
|
90 |
+
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
# List of all rnns to call `flatten_parameters()` on
|
94 |
+
self._to_flatten = []
|
95 |
+
|
96 |
+
self.bank_kernels = [i for i in range(1, K + 1)]
|
97 |
+
self.conv1d_bank = nn.ModuleList()
|
98 |
+
for k in self.bank_kernels:
|
99 |
+
conv = BatchNormConv(in_channels, channels, k)
|
100 |
+
self.conv1d_bank.append(conv)
|
101 |
+
|
102 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
103 |
+
|
104 |
+
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
105 |
+
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
106 |
+
|
107 |
+
# Fix the highway input if necessary
|
108 |
+
if proj_channels[-1] != channels:
|
109 |
+
self.highway_mismatch = True
|
110 |
+
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
111 |
+
else:
|
112 |
+
self.highway_mismatch = False
|
113 |
+
|
114 |
+
self.highways = nn.ModuleList()
|
115 |
+
for i in range(num_highways):
|
116 |
+
hn = HighwayNetwork(channels)
|
117 |
+
self.highways.append(hn)
|
118 |
+
|
119 |
+
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
120 |
+
self._to_flatten.append(self.rnn)
|
121 |
+
|
122 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
123 |
+
self._flatten_parameters()
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
127 |
+
# the model gets replicated, making it no longer guaranteed that the
|
128 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
129 |
+
self._flatten_parameters()
|
130 |
+
|
131 |
+
# Save these for later
|
132 |
+
residual = x
|
133 |
+
seq_len = x.size(-1)
|
134 |
+
conv_bank = []
|
135 |
+
|
136 |
+
# Convolution Bank
|
137 |
+
for conv in self.conv1d_bank:
|
138 |
+
c = conv(x) # Convolution
|
139 |
+
conv_bank.append(c[:, :, :seq_len])
|
140 |
+
|
141 |
+
# Stack along the channel axis
|
142 |
+
conv_bank = torch.cat(conv_bank, dim=1)
|
143 |
+
|
144 |
+
# dump the last padding to fit residual
|
145 |
+
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
146 |
+
|
147 |
+
# Conv1d projections
|
148 |
+
x = self.conv_project1(x)
|
149 |
+
x = self.conv_project2(x)
|
150 |
+
|
151 |
+
# Residual Connect
|
152 |
+
x = x + residual
|
153 |
+
|
154 |
+
# Through the highways
|
155 |
+
x = x.transpose(1, 2)
|
156 |
+
if self.highway_mismatch is True:
|
157 |
+
x = self.pre_highway(x)
|
158 |
+
for h in self.highways: x = h(x)
|
159 |
+
|
160 |
+
# And then the RNN
|
161 |
+
x, _ = self.rnn(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
def _flatten_parameters(self):
|
165 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
166 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
167 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
168 |
+
|
169 |
+
class PreNet(nn.Module):
|
170 |
+
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
171 |
+
super().__init__()
|
172 |
+
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
173 |
+
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
174 |
+
self.p = dropout
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
x = self.fc1(x)
|
178 |
+
x = F.relu(x)
|
179 |
+
x = F.dropout(x, self.p, training=True)
|
180 |
+
x = self.fc2(x)
|
181 |
+
x = F.relu(x)
|
182 |
+
x = F.dropout(x, self.p, training=True)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class Attention(nn.Module):
|
187 |
+
def __init__(self, attn_dims):
|
188 |
+
super().__init__()
|
189 |
+
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
|
190 |
+
self.v = nn.Linear(attn_dims, 1, bias=False)
|
191 |
+
|
192 |
+
def forward(self, encoder_seq_proj, query, t):
|
193 |
+
|
194 |
+
# print(encoder_seq_proj.shape)
|
195 |
+
# Transform the query vector
|
196 |
+
query_proj = self.W(query).unsqueeze(1)
|
197 |
+
|
198 |
+
# Compute the scores
|
199 |
+
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
|
200 |
+
scores = F.softmax(u, dim=1)
|
201 |
+
|
202 |
+
return scores.transpose(1, 2)
|
203 |
+
|
204 |
+
|
205 |
+
class LSA(nn.Module):
|
206 |
+
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
207 |
+
super().__init__()
|
208 |
+
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
209 |
+
self.L = nn.Linear(filters, attn_dim, bias=False)
|
210 |
+
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
211 |
+
self.v = nn.Linear(attn_dim, 1, bias=False)
|
212 |
+
self.cumulative = None
|
213 |
+
self.attention = None
|
214 |
+
|
215 |
+
def init_attention(self, encoder_seq_proj):
|
216 |
+
device = next(self.parameters()).device # use same device as parameters
|
217 |
+
b, t, c = encoder_seq_proj.size()
|
218 |
+
self.cumulative = torch.zeros(b, t, device=device)
|
219 |
+
self.attention = torch.zeros(b, t, device=device)
|
220 |
+
|
221 |
+
def forward(self, encoder_seq_proj, query, t, chars):
|
222 |
+
|
223 |
+
if t == 0: self.init_attention(encoder_seq_proj)
|
224 |
+
|
225 |
+
processed_query = self.W(query).unsqueeze(1)
|
226 |
+
|
227 |
+
location = self.cumulative.unsqueeze(1)
|
228 |
+
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
229 |
+
|
230 |
+
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
231 |
+
u = u.squeeze(-1)
|
232 |
+
|
233 |
+
# Mask zero padding chars
|
234 |
+
u = u * (chars != 0).float()
|
235 |
+
|
236 |
+
# Smooth Attention
|
237 |
+
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
238 |
+
scores = F.softmax(u, dim=1)
|
239 |
+
self.attention = scores
|
240 |
+
self.cumulative = self.cumulative + self.attention
|
241 |
+
|
242 |
+
return scores.unsqueeze(-1).transpose(1, 2)
|
243 |
+
|
244 |
+
|
245 |
+
class Decoder(nn.Module):
|
246 |
+
# Class variable because its value doesn't change between classes
|
247 |
+
# yet ought to be scoped by class because its a property of a Decoder
|
248 |
+
max_r = 20
|
249 |
+
def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
|
250 |
+
dropout, speaker_embedding_size):
|
251 |
+
super().__init__()
|
252 |
+
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
|
253 |
+
self.n_mels = n_mels
|
254 |
+
prenet_dims = (decoder_dims * 2, decoder_dims * 2)
|
255 |
+
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
256 |
+
dropout=dropout)
|
257 |
+
self.attn_net = LSA(decoder_dims)
|
258 |
+
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
259 |
+
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
260 |
+
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
261 |
+
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
262 |
+
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
263 |
+
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
264 |
+
|
265 |
+
def zoneout(self, prev, current, p=0.1):
|
266 |
+
device = next(self.parameters()).device # Use same device as parameters
|
267 |
+
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
268 |
+
return prev * mask + current * (1 - mask)
|
269 |
+
|
270 |
+
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
271 |
+
hidden_states, cell_states, context_vec, t, chars):
|
272 |
+
|
273 |
+
# Need this for reshaping mels
|
274 |
+
batch_size = encoder_seq.size(0)
|
275 |
+
|
276 |
+
# Unpack the hidden and cell states
|
277 |
+
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
278 |
+
rnn1_cell, rnn2_cell = cell_states
|
279 |
+
|
280 |
+
# PreNet for the Attention RNN
|
281 |
+
prenet_out = self.prenet(prenet_in)
|
282 |
+
|
283 |
+
# Compute the Attention RNN hidden state
|
284 |
+
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
|
285 |
+
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
|
286 |
+
|
287 |
+
# Compute the attention scores
|
288 |
+
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
|
289 |
+
|
290 |
+
# Dot product to create the context vector
|
291 |
+
context_vec = scores @ encoder_seq
|
292 |
+
context_vec = context_vec.squeeze(1)
|
293 |
+
|
294 |
+
# Concat Attention RNN output w. Context Vector & project
|
295 |
+
x = torch.cat([context_vec, attn_hidden], dim=1)
|
296 |
+
x = self.rnn_input(x)
|
297 |
+
|
298 |
+
# Compute first Residual RNN
|
299 |
+
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
300 |
+
if self.training:
|
301 |
+
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
|
302 |
+
else:
|
303 |
+
rnn1_hidden = rnn1_hidden_next
|
304 |
+
x = x + rnn1_hidden
|
305 |
+
|
306 |
+
# Compute second Residual RNN
|
307 |
+
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
308 |
+
if self.training:
|
309 |
+
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
|
310 |
+
else:
|
311 |
+
rnn2_hidden = rnn2_hidden_next
|
312 |
+
x = x + rnn2_hidden
|
313 |
+
|
314 |
+
# Project Mels
|
315 |
+
mels = self.mel_proj(x)
|
316 |
+
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
|
317 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
318 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
319 |
+
|
320 |
+
# Stop token prediction
|
321 |
+
s = torch.cat((x, context_vec), dim=1)
|
322 |
+
s = self.stop_proj(s)
|
323 |
+
stop_tokens = torch.sigmoid(s)
|
324 |
+
|
325 |
+
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
326 |
+
|
327 |
+
|
328 |
+
class Tacotron(nn.Module):
|
329 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
|
330 |
+
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
|
331 |
+
dropout, stop_threshold, speaker_embedding_size):
|
332 |
+
super().__init__()
|
333 |
+
self.n_mels = n_mels
|
334 |
+
self.lstm_dims = lstm_dims
|
335 |
+
self.encoder_dims = encoder_dims
|
336 |
+
self.decoder_dims = decoder_dims
|
337 |
+
self.speaker_embedding_size = speaker_embedding_size
|
338 |
+
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
339 |
+
encoder_K, num_highways, dropout)
|
340 |
+
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
341 |
+
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
342 |
+
dropout, speaker_embedding_size)
|
343 |
+
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
344 |
+
[postnet_dims, fft_bins], num_highways)
|
345 |
+
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
346 |
+
|
347 |
+
self.init_model()
|
348 |
+
self.num_params()
|
349 |
+
|
350 |
+
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
351 |
+
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
352 |
+
|
353 |
+
@property
|
354 |
+
def r(self):
|
355 |
+
return self.decoder.r.item()
|
356 |
+
|
357 |
+
@r.setter
|
358 |
+
def r(self, value):
|
359 |
+
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
360 |
+
|
361 |
+
def forward(self, x, m, speaker_embedding):
|
362 |
+
device = next(self.parameters()).device # use same device as parameters
|
363 |
+
|
364 |
+
self.step += 1
|
365 |
+
batch_size, _, steps = m.size()
|
366 |
+
|
367 |
+
# Initialise all hidden states and pack into tuple
|
368 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
369 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
370 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
371 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
372 |
+
|
373 |
+
# Initialise all lstm cell states and pack into tuple
|
374 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
375 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
376 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
377 |
+
|
378 |
+
# <GO> Frame for start of decoder loop
|
379 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
380 |
+
|
381 |
+
# Need an initial context vector
|
382 |
+
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
383 |
+
|
384 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
385 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
386 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
387 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
388 |
+
|
389 |
+
# Need a couple of lists for outputs
|
390 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
391 |
+
|
392 |
+
# Run the decoder loop
|
393 |
+
for t in range(0, steps, self.r):
|
394 |
+
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
|
395 |
+
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
396 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
397 |
+
hidden_states, cell_states, context_vec, t, x)
|
398 |
+
mel_outputs.append(mel_frames)
|
399 |
+
attn_scores.append(scores)
|
400 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
401 |
+
|
402 |
+
# Concat the mel outputs into sequence
|
403 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
404 |
+
|
405 |
+
# Post-Process for Linear Spectrograms
|
406 |
+
postnet_out = self.postnet(mel_outputs)
|
407 |
+
linear = self.post_proj(postnet_out)
|
408 |
+
linear = linear.transpose(1, 2)
|
409 |
+
|
410 |
+
# For easy visualisation
|
411 |
+
attn_scores = torch.cat(attn_scores, 1)
|
412 |
+
# attn_scores = attn_scores.cpu().data.numpy()
|
413 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
414 |
+
|
415 |
+
return mel_outputs, linear, attn_scores, stop_outputs
|
416 |
+
|
417 |
+
def generate(self, x, speaker_embedding=None, steps=2000):
|
418 |
+
self.eval()
|
419 |
+
device = next(self.parameters()).device # use same device as parameters
|
420 |
+
|
421 |
+
batch_size, _ = x.size()
|
422 |
+
|
423 |
+
# Need to initialise all hidden states and pack into tuple for tidyness
|
424 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
425 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
426 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
427 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
428 |
+
|
429 |
+
# Need to initialise all lstm cell states and pack into tuple for tidyness
|
430 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
431 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
432 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
433 |
+
|
434 |
+
# Need a <GO> Frame for start of decoder loop
|
435 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
436 |
+
|
437 |
+
# Need an initial context vector
|
438 |
+
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
439 |
+
|
440 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
441 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
442 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
443 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
444 |
+
|
445 |
+
# Need a couple of lists for outputs
|
446 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
447 |
+
|
448 |
+
# Run the decoder loop
|
449 |
+
for t in range(0, steps, self.r):
|
450 |
+
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
451 |
+
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
452 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
453 |
+
hidden_states, cell_states, context_vec, t, x)
|
454 |
+
mel_outputs.append(mel_frames)
|
455 |
+
attn_scores.append(scores)
|
456 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
457 |
+
# Stop the loop when all stop tokens in batch exceed threshold
|
458 |
+
if (stop_tokens > 0.5).all() and t > 10: break
|
459 |
+
|
460 |
+
# Concat the mel outputs into sequence
|
461 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
462 |
+
|
463 |
+
# Post-Process for Linear Spectrograms
|
464 |
+
postnet_out = self.postnet(mel_outputs)
|
465 |
+
linear = self.post_proj(postnet_out)
|
466 |
+
|
467 |
+
|
468 |
+
linear = linear.transpose(1, 2)
|
469 |
+
|
470 |
+
# For easy visualisation
|
471 |
+
attn_scores = torch.cat(attn_scores, 1)
|
472 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
473 |
+
|
474 |
+
self.train()
|
475 |
+
|
476 |
+
return mel_outputs, linear, attn_scores
|
477 |
+
|
478 |
+
def init_model(self):
|
479 |
+
for p in self.parameters():
|
480 |
+
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
481 |
+
|
482 |
+
def get_step(self):
|
483 |
+
return self.step.data.item()
|
484 |
+
|
485 |
+
def reset_step(self):
|
486 |
+
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
487 |
+
self.step = self.step.data.new_tensor(1)
|
488 |
+
|
489 |
+
def log(self, path, msg):
|
490 |
+
with open(path, "a") as f:
|
491 |
+
print(msg, file=f)
|
492 |
+
|
493 |
+
def load(self, path, optimizer=None):
|
494 |
+
# Use device of model params as location for loaded state
|
495 |
+
device = next(self.parameters()).device
|
496 |
+
checkpoint = torch.load(str(path), map_location=device)
|
497 |
+
self.load_state_dict(checkpoint["model_state"])
|
498 |
+
|
499 |
+
if "optimizer_state" in checkpoint and optimizer is not None:
|
500 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
501 |
+
|
502 |
+
def save(self, path, optimizer=None):
|
503 |
+
if optimizer is not None:
|
504 |
+
torch.save({
|
505 |
+
"model_state": self.state_dict(),
|
506 |
+
"optimizer_state": optimizer.state_dict(),
|
507 |
+
}, str(path))
|
508 |
+
else:
|
509 |
+
torch.save({
|
510 |
+
"model_state": self.state_dict(),
|
511 |
+
}, str(path))
|
512 |
+
|
513 |
+
|
514 |
+
def num_params(self, print_out=True):
|
515 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
516 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
517 |
+
if print_out:
|
518 |
+
print("Trainable Parameters: %.3fM" % parameters)
|
519 |
+
return parameters
|
synthesizer/preprocess.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing.pool import Pool
|
2 |
+
from synthesizer import audio
|
3 |
+
from functools import partial
|
4 |
+
from itertools import chain
|
5 |
+
from encoder import inference as encoder
|
6 |
+
from pathlib import Path
|
7 |
+
from utils import logmmse
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import librosa
|
11 |
+
|
12 |
+
|
13 |
+
def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int,
|
14 |
+
skip_existing: bool, hparams, no_alignments: bool,
|
15 |
+
datasets_name: str, subfolders: str):
|
16 |
+
# Gather the input directories
|
17 |
+
dataset_root = datasets_root.joinpath(datasets_name)
|
18 |
+
input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
|
19 |
+
print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
|
20 |
+
assert all(input_dir.exists() for input_dir in input_dirs)
|
21 |
+
|
22 |
+
# Create the output directories for each output file type
|
23 |
+
out_dir.joinpath("mels").mkdir(exist_ok=True)
|
24 |
+
out_dir.joinpath("audio").mkdir(exist_ok=True)
|
25 |
+
|
26 |
+
# Create a metadata file
|
27 |
+
metadata_fpath = out_dir.joinpath("train.txt")
|
28 |
+
metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
|
29 |
+
|
30 |
+
# Preprocess the dataset
|
31 |
+
speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
|
32 |
+
func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
|
33 |
+
hparams=hparams, no_alignments=no_alignments)
|
34 |
+
job = Pool(n_processes).imap(func, speaker_dirs)
|
35 |
+
for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
|
36 |
+
for metadatum in speaker_metadata:
|
37 |
+
metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
|
38 |
+
metadata_file.close()
|
39 |
+
|
40 |
+
# Verify the contents of the metadata file
|
41 |
+
with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
|
42 |
+
metadata = [line.split("|") for line in metadata_file]
|
43 |
+
mel_frames = sum([int(m[4]) for m in metadata])
|
44 |
+
timesteps = sum([int(m[3]) for m in metadata])
|
45 |
+
sample_rate = hparams.sample_rate
|
46 |
+
hours = (timesteps / sample_rate) / 3600
|
47 |
+
print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
|
48 |
+
(len(metadata), mel_frames, timesteps, hours))
|
49 |
+
print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
|
50 |
+
print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
|
51 |
+
print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
|
52 |
+
|
53 |
+
|
54 |
+
def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
|
55 |
+
metadata = []
|
56 |
+
for book_dir in speaker_dir.glob("*"):
|
57 |
+
if no_alignments:
|
58 |
+
# Gather the utterance audios and texts
|
59 |
+
# LibriTTS uses .wav but we will include extensions for compatibility with other datasets
|
60 |
+
extensions = ["*.wav", "*.flac", "*.mp3"]
|
61 |
+
for extension in extensions:
|
62 |
+
wav_fpaths = book_dir.glob(extension)
|
63 |
+
|
64 |
+
for wav_fpath in wav_fpaths:
|
65 |
+
# Load the audio waveform
|
66 |
+
wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
|
67 |
+
if hparams.rescale:
|
68 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
69 |
+
|
70 |
+
# Get the corresponding text
|
71 |
+
# Check for .txt (for compatibility with other datasets)
|
72 |
+
text_fpath = wav_fpath.with_suffix(".txt")
|
73 |
+
if not text_fpath.exists():
|
74 |
+
# Check for .normalized.txt (LibriTTS)
|
75 |
+
text_fpath = wav_fpath.with_suffix(".normalized.txt")
|
76 |
+
assert text_fpath.exists()
|
77 |
+
with text_fpath.open("r") as text_file:
|
78 |
+
text = "".join([line for line in text_file])
|
79 |
+
text = text.replace("\"", "")
|
80 |
+
text = text.strip()
|
81 |
+
|
82 |
+
# Process the utterance
|
83 |
+
metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
|
84 |
+
skip_existing, hparams))
|
85 |
+
else:
|
86 |
+
# Process alignment file (LibriSpeech support)
|
87 |
+
# Gather the utterance audios and texts
|
88 |
+
try:
|
89 |
+
alignments_fpath = next(book_dir.glob("*.alignment.txt"))
|
90 |
+
with alignments_fpath.open("r") as alignments_file:
|
91 |
+
alignments = [line.rstrip().split(" ") for line in alignments_file]
|
92 |
+
except StopIteration:
|
93 |
+
# A few alignment files will be missing
|
94 |
+
continue
|
95 |
+
|
96 |
+
# Iterate over each entry in the alignments file
|
97 |
+
for wav_fname, words, end_times in alignments:
|
98 |
+
wav_fpath = book_dir.joinpath(wav_fname + ".flac")
|
99 |
+
assert wav_fpath.exists()
|
100 |
+
words = words.replace("\"", "").split(",")
|
101 |
+
end_times = list(map(float, end_times.replace("\"", "").split(",")))
|
102 |
+
|
103 |
+
# Process each sub-utterance
|
104 |
+
wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
|
105 |
+
for i, (wav, text) in enumerate(zip(wavs, texts)):
|
106 |
+
sub_basename = "%s_%02d" % (wav_fname, i)
|
107 |
+
metadata.append(process_utterance(wav, text, out_dir, sub_basename,
|
108 |
+
skip_existing, hparams))
|
109 |
+
|
110 |
+
return [m for m in metadata if m is not None]
|
111 |
+
|
112 |
+
|
113 |
+
def split_on_silences(wav_fpath, words, end_times, hparams):
|
114 |
+
# Load the audio waveform
|
115 |
+
wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
|
116 |
+
if hparams.rescale:
|
117 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
118 |
+
|
119 |
+
words = np.array(words)
|
120 |
+
start_times = np.array([0.0] + end_times[:-1])
|
121 |
+
end_times = np.array(end_times)
|
122 |
+
assert len(words) == len(end_times) == len(start_times)
|
123 |
+
assert words[0] == "" and words[-1] == ""
|
124 |
+
|
125 |
+
# Find pauses that are too long
|
126 |
+
mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
|
127 |
+
mask[0] = mask[-1] = True
|
128 |
+
breaks = np.where(mask)[0]
|
129 |
+
|
130 |
+
# Profile the noise from the silences and perform noise reduction on the waveform
|
131 |
+
silence_times = [[start_times[i], end_times[i]] for i in breaks]
|
132 |
+
silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
|
133 |
+
noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
|
134 |
+
if len(noisy_wav) > hparams.sample_rate * 0.02:
|
135 |
+
profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
|
136 |
+
wav = logmmse.denoise(wav, profile, eta=0)
|
137 |
+
|
138 |
+
# Re-attach segments that are too short
|
139 |
+
segments = list(zip(breaks[:-1], breaks[1:]))
|
140 |
+
segment_durations = [start_times[end] - end_times[start] for start, end in segments]
|
141 |
+
i = 0
|
142 |
+
while i < len(segments) and len(segments) > 1:
|
143 |
+
if segment_durations[i] < hparams.utterance_min_duration:
|
144 |
+
# See if the segment can be re-attached with the right or the left segment
|
145 |
+
left_duration = float("inf") if i == 0 else segment_durations[i - 1]
|
146 |
+
right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
|
147 |
+
joined_duration = segment_durations[i] + min(left_duration, right_duration)
|
148 |
+
|
149 |
+
# Do not re-attach if it causes the joined utterance to be too long
|
150 |
+
if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
|
151 |
+
i += 1
|
152 |
+
continue
|
153 |
+
|
154 |
+
# Re-attach the segment with the neighbour of shortest duration
|
155 |
+
j = i - 1 if left_duration <= right_duration else i
|
156 |
+
segments[j] = (segments[j][0], segments[j + 1][1])
|
157 |
+
segment_durations[j] = joined_duration
|
158 |
+
del segments[j + 1], segment_durations[j + 1]
|
159 |
+
else:
|
160 |
+
i += 1
|
161 |
+
|
162 |
+
# Split the utterance
|
163 |
+
segment_times = [[end_times[start], start_times[end]] for start, end in segments]
|
164 |
+
segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
|
165 |
+
wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
|
166 |
+
texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
|
167 |
+
|
168 |
+
# # DEBUG: play the audio segments (run with -n=1)
|
169 |
+
# import sounddevice as sd
|
170 |
+
# if len(wavs) > 1:
|
171 |
+
# print("This sentence was split in %d segments:" % len(wavs))
|
172 |
+
# else:
|
173 |
+
# print("There are no silences long enough for this sentence to be split:")
|
174 |
+
# for wav, text in zip(wavs, texts):
|
175 |
+
# # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
|
176 |
+
# # when playing them. You shouldn't need to do that in your parsers.
|
177 |
+
# wav = np.concatenate((wav, [0] * 16000))
|
178 |
+
# print("\t%s" % text)
|
179 |
+
# sd.play(wav, 16000, blocking=True)
|
180 |
+
# print("")
|
181 |
+
|
182 |
+
return wavs, texts
|
183 |
+
|
184 |
+
|
185 |
+
def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
186 |
+
skip_existing: bool, hparams):
|
187 |
+
## FOR REFERENCE:
|
188 |
+
# For you not to lose your head if you ever wish to change things here or implement your own
|
189 |
+
# synthesizer.
|
190 |
+
# - Both the audios and the mel spectrograms are saved as numpy arrays
|
191 |
+
# - There is no processing done to the audios that will be saved to disk beyond volume
|
192 |
+
# normalization (in split_on_silences)
|
193 |
+
# - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
|
194 |
+
# is why we re-apply it on the audio on the side of the vocoder.
|
195 |
+
# - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
|
196 |
+
# without extra padding. This means that you won't have an exact relation between the length
|
197 |
+
# of the wav and of the mel spectrogram. See the vocoder data loader.
|
198 |
+
|
199 |
+
|
200 |
+
# Skip existing utterances if needed
|
201 |
+
mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
|
202 |
+
wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
|
203 |
+
if skip_existing and mel_fpath.exists() and wav_fpath.exists():
|
204 |
+
return None
|
205 |
+
|
206 |
+
# Trim silence
|
207 |
+
if hparams.trim_silence:
|
208 |
+
wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
|
209 |
+
|
210 |
+
# Skip utterances that are too short
|
211 |
+
if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
|
212 |
+
return None
|
213 |
+
|
214 |
+
# Compute the mel spectrogram
|
215 |
+
mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
|
216 |
+
mel_frames = mel_spectrogram.shape[1]
|
217 |
+
|
218 |
+
# Skip utterances that are too long
|
219 |
+
if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
|
220 |
+
return None
|
221 |
+
|
222 |
+
# Write the spectrogram, embed and audio to disk
|
223 |
+
np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
|
224 |
+
np.save(wav_fpath, wav, allow_pickle=False)
|
225 |
+
|
226 |
+
# Return a tuple describing this training example
|
227 |
+
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
|
228 |
+
|
229 |
+
|
230 |
+
def embed_utterance(fpaths, encoder_model_fpath):
|
231 |
+
if not encoder.is_loaded():
|
232 |
+
encoder.load_model(encoder_model_fpath)
|
233 |
+
|
234 |
+
# Compute the speaker embedding of the utterance
|
235 |
+
wav_fpath, embed_fpath = fpaths
|
236 |
+
wav = np.load(wav_fpath)
|
237 |
+
wav = encoder.preprocess_wav(wav)
|
238 |
+
embed = encoder.embed_utterance(wav)
|
239 |
+
np.save(embed_fpath, embed, allow_pickle=False)
|
240 |
+
|
241 |
+
|
242 |
+
def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
|
243 |
+
wav_dir = synthesizer_root.joinpath("audio")
|
244 |
+
metadata_fpath = synthesizer_root.joinpath("train.txt")
|
245 |
+
assert wav_dir.exists() and metadata_fpath.exists()
|
246 |
+
embed_dir = synthesizer_root.joinpath("embeds")
|
247 |
+
embed_dir.mkdir(exist_ok=True)
|
248 |
+
|
249 |
+
# Gather the input wave filepath and the target output embed filepath
|
250 |
+
with metadata_fpath.open("r") as metadata_file:
|
251 |
+
metadata = [line.split("|") for line in metadata_file]
|
252 |
+
fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
|
253 |
+
|
254 |
+
# TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
|
255 |
+
# Embed the utterances in separate threads
|
256 |
+
func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
|
257 |
+
job = Pool(n_processes).imap(func, fpaths)
|
258 |
+
list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
|
259 |
+
|
synthesizer/synthesize.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from synthesizer.hparams import hparams_debug_string
|
4 |
+
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
5 |
+
from synthesizer.models.tacotron import Tacotron
|
6 |
+
from synthesizer.utils.text import text_to_sequence
|
7 |
+
from synthesizer.utils.symbols import symbols
|
8 |
+
import numpy as np
|
9 |
+
from pathlib import Path
|
10 |
+
from tqdm import tqdm
|
11 |
+
import platform
|
12 |
+
|
13 |
+
def run_synthesis(in_dir, out_dir, model_dir, hparams):
|
14 |
+
# This generates ground truth-aligned mels for vocoder training
|
15 |
+
synth_dir = Path(out_dir).joinpath("mels_gta")
|
16 |
+
synth_dir.mkdir(exist_ok=True)
|
17 |
+
print(hparams_debug_string())
|
18 |
+
|
19 |
+
# Check for GPU
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
device = torch.device("cuda")
|
22 |
+
if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
|
23 |
+
raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
|
24 |
+
else:
|
25 |
+
device = torch.device("cpu")
|
26 |
+
print("Synthesizer using device:", device)
|
27 |
+
|
28 |
+
# Instantiate Tacotron model
|
29 |
+
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
30 |
+
num_chars=len(symbols),
|
31 |
+
encoder_dims=hparams.tts_encoder_dims,
|
32 |
+
decoder_dims=hparams.tts_decoder_dims,
|
33 |
+
n_mels=hparams.num_mels,
|
34 |
+
fft_bins=hparams.num_mels,
|
35 |
+
postnet_dims=hparams.tts_postnet_dims,
|
36 |
+
encoder_K=hparams.tts_encoder_K,
|
37 |
+
lstm_dims=hparams.tts_lstm_dims,
|
38 |
+
postnet_K=hparams.tts_postnet_K,
|
39 |
+
num_highways=hparams.tts_num_highways,
|
40 |
+
dropout=0., # Use zero dropout for gta mels
|
41 |
+
stop_threshold=hparams.tts_stop_threshold,
|
42 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
43 |
+
|
44 |
+
# Load the weights
|
45 |
+
model_dir = Path(model_dir)
|
46 |
+
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
|
47 |
+
print("\nLoading weights at %s" % model_fpath)
|
48 |
+
model.load(model_fpath)
|
49 |
+
print("Tacotron weights loaded from step %d" % model.step)
|
50 |
+
|
51 |
+
# Synthesize using same reduction factor as the model is currently trained
|
52 |
+
r = np.int32(model.r)
|
53 |
+
|
54 |
+
# Set model to eval mode (disable gradient and zoneout)
|
55 |
+
model.eval()
|
56 |
+
|
57 |
+
# Initialize the dataset
|
58 |
+
in_dir = Path(in_dir)
|
59 |
+
metadata_fpath = in_dir.joinpath("train.txt")
|
60 |
+
mel_dir = in_dir.joinpath("mels")
|
61 |
+
embed_dir = in_dir.joinpath("embeds")
|
62 |
+
|
63 |
+
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
64 |
+
data_loader = DataLoader(dataset,
|
65 |
+
collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
|
66 |
+
batch_size=hparams.synthesis_batch_size,
|
67 |
+
num_workers=2 if platform.system() != "Windows" else 0,
|
68 |
+
shuffle=False,
|
69 |
+
pin_memory=True)
|
70 |
+
|
71 |
+
# Generate GTA mels
|
72 |
+
meta_out_fpath = Path(out_dir).joinpath("synthesized.txt")
|
73 |
+
with open(meta_out_fpath, "w") as file:
|
74 |
+
for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
|
75 |
+
texts = texts.to(device)
|
76 |
+
mels = mels.to(device)
|
77 |
+
embeds = embeds.to(device)
|
78 |
+
|
79 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
80 |
+
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
81 |
+
_, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
|
82 |
+
else:
|
83 |
+
_, mels_out, _, _ = model(texts, mels, embeds)
|
84 |
+
|
85 |
+
for j, k in enumerate(idx):
|
86 |
+
# Note: outputs mel-spectrogram files and target ones have same names, just different folders
|
87 |
+
mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
|
88 |
+
mel_out = mels_out[j].detach().cpu().numpy().T
|
89 |
+
|
90 |
+
# Use the length of the ground truth mel to remove padding from the generated mels
|
91 |
+
mel_out = mel_out[:int(dataset.metadata[k][4])]
|
92 |
+
|
93 |
+
# Write the spectrogram to disk
|
94 |
+
np.save(mel_filename, mel_out, allow_pickle=False)
|
95 |
+
|
96 |
+
# Write metadata into the synthesized file
|
97 |
+
file.write("|".join(dataset.metadata[k]))
|
synthesizer/synthesizer_dataset.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
from synthesizer.utils.text import text_to_sequence
|
6 |
+
|
7 |
+
|
8 |
+
class SynthesizerDataset(Dataset):
|
9 |
+
def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
|
10 |
+
print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
|
11 |
+
|
12 |
+
with metadata_fpath.open("r") as metadata_file:
|
13 |
+
metadata = [line.split("|") for line in metadata_file]
|
14 |
+
|
15 |
+
mel_fnames = [x[1] for x in metadata if int(x[4])]
|
16 |
+
mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
|
17 |
+
embed_fnames = [x[2] for x in metadata if int(x[4])]
|
18 |
+
embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
|
19 |
+
self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
|
20 |
+
self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
|
21 |
+
self.metadata = metadata
|
22 |
+
self.hparams = hparams
|
23 |
+
|
24 |
+
print("Found %d samples" % len(self.samples_fpaths))
|
25 |
+
|
26 |
+
def __getitem__(self, index):
|
27 |
+
# Sometimes index may be a list of 2 (not sure why this happens)
|
28 |
+
# If that is the case, return a single item corresponding to first element in index
|
29 |
+
if index is list:
|
30 |
+
index = index[0]
|
31 |
+
|
32 |
+
mel_path, embed_path = self.samples_fpaths[index]
|
33 |
+
mel = np.load(mel_path).T.astype(np.float32)
|
34 |
+
|
35 |
+
# Load the embed
|
36 |
+
embed = np.load(embed_path)
|
37 |
+
|
38 |
+
# Get the text and clean it
|
39 |
+
text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
|
40 |
+
|
41 |
+
# Convert the list returned by text_to_sequence to a numpy array
|
42 |
+
text = np.asarray(text).astype(np.int32)
|
43 |
+
|
44 |
+
return text, mel.astype(np.float32), embed.astype(np.float32), index
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.samples_fpaths)
|
48 |
+
|
49 |
+
|
50 |
+
def collate_synthesizer(batch, r, hparams):
|
51 |
+
# Text
|
52 |
+
x_lens = [len(x[0]) for x in batch]
|
53 |
+
max_x_len = max(x_lens)
|
54 |
+
|
55 |
+
chars = [pad1d(x[0], max_x_len) for x in batch]
|
56 |
+
chars = np.stack(chars)
|
57 |
+
|
58 |
+
# Mel spectrogram
|
59 |
+
spec_lens = [x[1].shape[-1] for x in batch]
|
60 |
+
max_spec_len = max(spec_lens) + 1
|
61 |
+
if max_spec_len % r != 0:
|
62 |
+
max_spec_len += r - max_spec_len % r
|
63 |
+
|
64 |
+
# WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
|
65 |
+
# By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
|
66 |
+
if hparams.symmetric_mels:
|
67 |
+
mel_pad_value = -1 * hparams.max_abs_value
|
68 |
+
else:
|
69 |
+
mel_pad_value = 0
|
70 |
+
|
71 |
+
mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
|
72 |
+
mel = np.stack(mel)
|
73 |
+
|
74 |
+
# Speaker embedding (SV2TTS)
|
75 |
+
embeds = [x[2] for x in batch]
|
76 |
+
|
77 |
+
# Index (for vocoder preprocessing)
|
78 |
+
indices = [x[3] for x in batch]
|
79 |
+
|
80 |
+
|
81 |
+
# Convert all to tensor
|
82 |
+
chars = torch.tensor(chars).long()
|
83 |
+
mel = torch.tensor(mel)
|
84 |
+
embeds = torch.tensor(embeds)
|
85 |
+
|
86 |
+
return chars, mel, embeds, indices
|
87 |
+
|
88 |
+
def pad1d(x, max_len, pad_value=0):
|
89 |
+
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
90 |
+
|
91 |
+
def pad2d(x, max_len, pad_value=0):
|
92 |
+
return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
|
synthesizer/train.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import optim
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from synthesizer import audio
|
6 |
+
from synthesizer.models.tacotron import Tacotron
|
7 |
+
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
8 |
+
from synthesizer.utils import ValueWindow, data_parallel_workaround
|
9 |
+
from synthesizer.utils.plot import plot_spectrogram
|
10 |
+
from synthesizer.utils.symbols import symbols
|
11 |
+
from synthesizer.utils.text import sequence_to_text
|
12 |
+
from vocoder.display import *
|
13 |
+
from datetime import datetime
|
14 |
+
import numpy as np
|
15 |
+
from pathlib import Path
|
16 |
+
import sys
|
17 |
+
import time
|
18 |
+
import platform
|
19 |
+
|
20 |
+
|
21 |
+
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
22 |
+
|
23 |
+
def time_string():
|
24 |
+
return datetime.now().strftime("%Y-%m-%d %H:%M")
|
25 |
+
|
26 |
+
def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
27 |
+
backup_every: int, force_restart:bool, hparams):
|
28 |
+
|
29 |
+
syn_dir = Path(syn_dir)
|
30 |
+
models_dir = Path(models_dir)
|
31 |
+
models_dir.mkdir(exist_ok=True)
|
32 |
+
|
33 |
+
model_dir = models_dir.joinpath(run_id)
|
34 |
+
plot_dir = model_dir.joinpath("plots")
|
35 |
+
wav_dir = model_dir.joinpath("wavs")
|
36 |
+
mel_output_dir = model_dir.joinpath("mel-spectrograms")
|
37 |
+
meta_folder = model_dir.joinpath("metas")
|
38 |
+
model_dir.mkdir(exist_ok=True)
|
39 |
+
plot_dir.mkdir(exist_ok=True)
|
40 |
+
wav_dir.mkdir(exist_ok=True)
|
41 |
+
mel_output_dir.mkdir(exist_ok=True)
|
42 |
+
meta_folder.mkdir(exist_ok=True)
|
43 |
+
|
44 |
+
weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt")
|
45 |
+
metadata_fpath = syn_dir.joinpath("train.txt")
|
46 |
+
|
47 |
+
print("Checkpoint path: {}".format(weights_fpath))
|
48 |
+
print("Loading training data from: {}".format(metadata_fpath))
|
49 |
+
print("Using model: Tacotron")
|
50 |
+
|
51 |
+
# Book keeping
|
52 |
+
step = 0
|
53 |
+
time_window = ValueWindow(100)
|
54 |
+
loss_window = ValueWindow(100)
|
55 |
+
|
56 |
+
|
57 |
+
# From WaveRNN/train_tacotron.py
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
device = torch.device("cuda")
|
60 |
+
|
61 |
+
for session in hparams.tts_schedule:
|
62 |
+
_, _, _, batch_size = session
|
63 |
+
if batch_size % torch.cuda.device_count() != 0:
|
64 |
+
raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
|
65 |
+
else:
|
66 |
+
device = torch.device("cpu")
|
67 |
+
print("Using device:", device)
|
68 |
+
|
69 |
+
# Instantiate Tacotron Model
|
70 |
+
print("\nInitialising Tacotron Model...\n")
|
71 |
+
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
72 |
+
num_chars=len(symbols),
|
73 |
+
encoder_dims=hparams.tts_encoder_dims,
|
74 |
+
decoder_dims=hparams.tts_decoder_dims,
|
75 |
+
n_mels=hparams.num_mels,
|
76 |
+
fft_bins=hparams.num_mels,
|
77 |
+
postnet_dims=hparams.tts_postnet_dims,
|
78 |
+
encoder_K=hparams.tts_encoder_K,
|
79 |
+
lstm_dims=hparams.tts_lstm_dims,
|
80 |
+
postnet_K=hparams.tts_postnet_K,
|
81 |
+
num_highways=hparams.tts_num_highways,
|
82 |
+
dropout=hparams.tts_dropout,
|
83 |
+
stop_threshold=hparams.tts_stop_threshold,
|
84 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
85 |
+
|
86 |
+
# Initialize the optimizer
|
87 |
+
optimizer = optim.Adam(model.parameters())
|
88 |
+
|
89 |
+
# Load the weights
|
90 |
+
if force_restart or not weights_fpath.exists():
|
91 |
+
print("\nStarting the training of Tacotron from scratch\n")
|
92 |
+
model.save(weights_fpath)
|
93 |
+
|
94 |
+
# Embeddings metadata
|
95 |
+
char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
|
96 |
+
with open(char_embedding_fpath, "w", encoding="utf-8") as f:
|
97 |
+
for symbol in symbols:
|
98 |
+
if symbol == " ":
|
99 |
+
symbol = "\\s" # For visual purposes, swap space with \s
|
100 |
+
|
101 |
+
f.write("{}\n".format(symbol))
|
102 |
+
|
103 |
+
else:
|
104 |
+
print("\nLoading weights at %s" % weights_fpath)
|
105 |
+
model.load(weights_fpath, optimizer)
|
106 |
+
print("Tacotron weights loaded from step %d" % model.step)
|
107 |
+
|
108 |
+
# Initialize the dataset
|
109 |
+
metadata_fpath = syn_dir.joinpath("train.txt")
|
110 |
+
mel_dir = syn_dir.joinpath("mels")
|
111 |
+
embed_dir = syn_dir.joinpath("embeds")
|
112 |
+
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
113 |
+
test_loader = DataLoader(dataset,
|
114 |
+
batch_size=1,
|
115 |
+
shuffle=True,
|
116 |
+
pin_memory=True)
|
117 |
+
|
118 |
+
for i, session in enumerate(hparams.tts_schedule):
|
119 |
+
current_step = model.get_step()
|
120 |
+
|
121 |
+
r, lr, max_step, batch_size = session
|
122 |
+
|
123 |
+
training_steps = max_step - current_step
|
124 |
+
|
125 |
+
# Do we need to change to the next session?
|
126 |
+
if current_step >= max_step:
|
127 |
+
# Are there no further sessions than the current one?
|
128 |
+
if i == len(hparams.tts_schedule) - 1:
|
129 |
+
# We have completed training. Save the model and exit
|
130 |
+
model.save(weights_fpath, optimizer)
|
131 |
+
break
|
132 |
+
else:
|
133 |
+
# There is a following session, go to it
|
134 |
+
continue
|
135 |
+
|
136 |
+
model.r = r
|
137 |
+
|
138 |
+
# Begin the training
|
139 |
+
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
140 |
+
("Batch Size", batch_size),
|
141 |
+
("Learning Rate", lr),
|
142 |
+
("Outputs/Step (r)", model.r)])
|
143 |
+
|
144 |
+
for p in optimizer.param_groups:
|
145 |
+
p["lr"] = lr
|
146 |
+
|
147 |
+
data_loader = DataLoader(dataset,
|
148 |
+
collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
|
149 |
+
batch_size=batch_size,
|
150 |
+
num_workers=2 if platform.system() != "Windows" else 0,
|
151 |
+
shuffle=True,
|
152 |
+
pin_memory=True)
|
153 |
+
|
154 |
+
total_iters = len(dataset)
|
155 |
+
steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
|
156 |
+
epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
|
157 |
+
|
158 |
+
for epoch in range(1, epochs+1):
|
159 |
+
for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
|
160 |
+
start_time = time.time()
|
161 |
+
|
162 |
+
# Generate stop tokens for training
|
163 |
+
stop = torch.ones(mels.shape[0], mels.shape[2])
|
164 |
+
for j, k in enumerate(idx):
|
165 |
+
stop[j, :int(dataset.metadata[k][4])-1] = 0
|
166 |
+
|
167 |
+
texts = texts.to(device)
|
168 |
+
mels = mels.to(device)
|
169 |
+
embeds = embeds.to(device)
|
170 |
+
stop = stop.to(device)
|
171 |
+
|
172 |
+
# Forward pass
|
173 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
174 |
+
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
175 |
+
m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts,
|
176 |
+
mels, embeds)
|
177 |
+
else:
|
178 |
+
m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
|
179 |
+
|
180 |
+
# Backward pass
|
181 |
+
m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
|
182 |
+
m2_loss = F.mse_loss(m2_hat, mels)
|
183 |
+
stop_loss = F.binary_cross_entropy(stop_pred, stop)
|
184 |
+
|
185 |
+
loss = m1_loss + m2_loss + stop_loss
|
186 |
+
|
187 |
+
optimizer.zero_grad()
|
188 |
+
loss.backward()
|
189 |
+
|
190 |
+
if hparams.tts_clip_grad_norm is not None:
|
191 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
|
192 |
+
if np.isnan(grad_norm.cpu()):
|
193 |
+
print("grad_norm was NaN!")
|
194 |
+
|
195 |
+
optimizer.step()
|
196 |
+
|
197 |
+
time_window.append(time.time() - start_time)
|
198 |
+
loss_window.append(loss.item())
|
199 |
+
|
200 |
+
step = model.get_step()
|
201 |
+
k = step // 1000
|
202 |
+
|
203 |
+
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
|
204 |
+
stream(msg)
|
205 |
+
|
206 |
+
# Backup or save model as appropriate
|
207 |
+
if backup_every != 0 and step % backup_every == 0 :
|
208 |
+
backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
|
209 |
+
model.save(backup_fpath, optimizer)
|
210 |
+
|
211 |
+
if save_every != 0 and step % save_every == 0 :
|
212 |
+
# Must save latest optimizer state to ensure that resuming training
|
213 |
+
# doesn't produce artifacts
|
214 |
+
model.save(weights_fpath, optimizer)
|
215 |
+
|
216 |
+
# Evaluate model to generate samples
|
217 |
+
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
|
218 |
+
step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
|
219 |
+
if epoch_eval or step_eval:
|
220 |
+
for sample_idx in range(hparams.tts_eval_num_samples):
|
221 |
+
# At most, generate samples equal to number in the batch
|
222 |
+
if sample_idx + 1 <= len(texts):
|
223 |
+
# Remove padding from mels using frame length in metadata
|
224 |
+
mel_length = int(dataset.metadata[idx[sample_idx]][4])
|
225 |
+
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
|
226 |
+
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
|
227 |
+
attention_len = mel_length // model.r
|
228 |
+
|
229 |
+
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
|
230 |
+
mel_prediction=mel_prediction,
|
231 |
+
target_spectrogram=target_spectrogram,
|
232 |
+
input_seq=np_now(texts[sample_idx]),
|
233 |
+
step=step,
|
234 |
+
plot_dir=plot_dir,
|
235 |
+
mel_output_dir=mel_output_dir,
|
236 |
+
wav_dir=wav_dir,
|
237 |
+
sample_num=sample_idx + 1,
|
238 |
+
loss=loss,
|
239 |
+
hparams=hparams)
|
240 |
+
|
241 |
+
# Break out of loop to update training schedule
|
242 |
+
if step >= max_step:
|
243 |
+
break
|
244 |
+
|
245 |
+
# Add line break after every epoch
|
246 |
+
print("")
|
247 |
+
|
248 |
+
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
249 |
+
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
|
250 |
+
# Save some results for evaluation
|
251 |
+
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
|
252 |
+
save_attention(attention, attention_path)
|
253 |
+
|
254 |
+
# save predicted mel spectrogram to disk (debug)
|
255 |
+
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
|
256 |
+
np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
|
257 |
+
|
258 |
+
# save griffin lim inverted wav for debug (mel -> wav)
|
259 |
+
wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
|
260 |
+
wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
|
261 |
+
audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
|
262 |
+
|
263 |
+
# save real and predicted mel-spectrogram plot to disk (control purposes)
|
264 |
+
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
|
265 |
+
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
|
266 |
+
plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
267 |
+
target_spectrogram=target_spectrogram,
|
268 |
+
max_len=target_spectrogram.size // hparams.num_mels)
|
269 |
+
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
|
synthesizer/utils/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
_output_ref = None
|
5 |
+
_replicas_ref = None
|
6 |
+
|
7 |
+
def data_parallel_workaround(model, *input):
|
8 |
+
global _output_ref
|
9 |
+
global _replicas_ref
|
10 |
+
device_ids = list(range(torch.cuda.device_count()))
|
11 |
+
output_device = device_ids[0]
|
12 |
+
replicas = torch.nn.parallel.replicate(model, device_ids)
|
13 |
+
# input.shape = (num_args, batch, ...)
|
14 |
+
inputs = torch.nn.parallel.scatter(input, device_ids)
|
15 |
+
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
|
16 |
+
replicas = replicas[:len(inputs)]
|
17 |
+
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
|
18 |
+
y_hat = torch.nn.parallel.gather(outputs, output_device)
|
19 |
+
_output_ref = outputs
|
20 |
+
_replicas_ref = replicas
|
21 |
+
return y_hat
|
22 |
+
|
23 |
+
|
24 |
+
class ValueWindow():
|
25 |
+
def __init__(self, window_size=100):
|
26 |
+
self._window_size = window_size
|
27 |
+
self._values = []
|
28 |
+
|
29 |
+
def append(self, x):
|
30 |
+
self._values = self._values[-(self._window_size - 1):] + [x]
|
31 |
+
|
32 |
+
@property
|
33 |
+
def sum(self):
|
34 |
+
return sum(self._values)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def count(self):
|
38 |
+
return len(self._values)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def average(self):
|
42 |
+
return self.sum / max(1, self.count)
|
43 |
+
|
44 |
+
def reset(self):
|
45 |
+
self._values = []
|
synthesizer/utils/_cmudict.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
valid_symbols = [
|
4 |
+
"AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
|
5 |
+
"AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
|
6 |
+
"B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
|
7 |
+
"EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
|
8 |
+
"IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
|
9 |
+
"OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
|
10 |
+
"UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
|
11 |
+
]
|
12 |
+
|
13 |
+
_valid_symbol_set = set(valid_symbols)
|
14 |
+
|
15 |
+
|
16 |
+
class CMUDict:
|
17 |
+
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
|
18 |
+
def __init__(self, file_or_path, keep_ambiguous=True):
|
19 |
+
if isinstance(file_or_path, str):
|
20 |
+
with open(file_or_path, encoding="latin-1") as f:
|
21 |
+
entries = _parse_cmudict(f)
|
22 |
+
else:
|
23 |
+
entries = _parse_cmudict(file_or_path)
|
24 |
+
if not keep_ambiguous:
|
25 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
26 |
+
self._entries = entries
|
27 |
+
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self._entries)
|
31 |
+
|
32 |
+
|
33 |
+
def lookup(self, word):
|
34 |
+
"""Returns list of ARPAbet pronunciations of the given word."""
|
35 |
+
return self._entries.get(word.upper())
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
_alt_re = re.compile(r"\([0-9]+\)")
|
40 |
+
|
41 |
+
|
42 |
+
def _parse_cmudict(file):
|
43 |
+
cmudict = {}
|
44 |
+
for line in file:
|
45 |
+
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
46 |
+
parts = line.split(" ")
|
47 |
+
word = re.sub(_alt_re, "", parts[0])
|
48 |
+
pronunciation = _get_pronunciation(parts[1])
|
49 |
+
if pronunciation:
|
50 |
+
if word in cmudict:
|
51 |
+
cmudict[word].append(pronunciation)
|
52 |
+
else:
|
53 |
+
cmudict[word] = [pronunciation]
|
54 |
+
return cmudict
|
55 |
+
|
56 |
+
|
57 |
+
def _get_pronunciation(s):
|
58 |
+
parts = s.strip().split(" ")
|
59 |
+
for part in parts:
|
60 |
+
if part not in _valid_symbol_set:
|
61 |
+
return None
|
62 |
+
return " ".join(parts)
|
synthesizer/utils/cleaners.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
3 |
+
|
4 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
5 |
+
hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
|
6 |
+
1. "english_cleaners" for English text
|
7 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
8 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
9 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
10 |
+
the symbols in symbols.py to match your data).
|
11 |
+
"""
|
12 |
+
|
13 |
+
import re
|
14 |
+
from unidecode import unidecode
|
15 |
+
from .numbers import normalize_numbers
|
16 |
+
|
17 |
+
# Regular expression matching whitespace:
|
18 |
+
_whitespace_re = re.compile(r"\s+")
|
19 |
+
|
20 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
21 |
+
_abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
|
22 |
+
("mrs", "misess"),
|
23 |
+
("mr", "mister"),
|
24 |
+
("dr", "doctor"),
|
25 |
+
("st", "saint"),
|
26 |
+
("co", "company"),
|
27 |
+
("jr", "junior"),
|
28 |
+
("maj", "major"),
|
29 |
+
("gen", "general"),
|
30 |
+
("drs", "doctors"),
|
31 |
+
("rev", "reverend"),
|
32 |
+
("lt", "lieutenant"),
|
33 |
+
("hon", "honorable"),
|
34 |
+
("sgt", "sergeant"),
|
35 |
+
("capt", "captain"),
|
36 |
+
("esq", "esquire"),
|
37 |
+
("ltd", "limited"),
|
38 |
+
("col", "colonel"),
|
39 |
+
("ft", "fort"),
|
40 |
+
]]
|
41 |
+
|
42 |
+
|
43 |
+
def expand_abbreviations(text):
|
44 |
+
for regex, replacement in _abbreviations:
|
45 |
+
text = re.sub(regex, replacement, text)
|
46 |
+
return text
|
47 |
+
|
48 |
+
|
49 |
+
def expand_numbers(text):
|
50 |
+
return normalize_numbers(text)
|
51 |
+
|
52 |
+
|
53 |
+
def lowercase(text):
|
54 |
+
"""lowercase input tokens."""
|
55 |
+
return text.lower()
|
56 |
+
|
57 |
+
|
58 |
+
def collapse_whitespace(text):
|
59 |
+
return re.sub(_whitespace_re, " ", text)
|
60 |
+
|
61 |
+
|
62 |
+
def convert_to_ascii(text):
|
63 |
+
return unidecode(text)
|
64 |
+
|
65 |
+
|
66 |
+
def basic_cleaners(text):
|
67 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
68 |
+
text = lowercase(text)
|
69 |
+
text = collapse_whitespace(text)
|
70 |
+
return text
|
71 |
+
|
72 |
+
|
73 |
+
def transliteration_cleaners(text):
|
74 |
+
"""Pipeline for non-English text that transliterates to ASCII."""
|
75 |
+
text = convert_to_ascii(text)
|
76 |
+
text = lowercase(text)
|
77 |
+
text = collapse_whitespace(text)
|
78 |
+
return text
|
79 |
+
|
80 |
+
|
81 |
+
def english_cleaners(text):
|
82 |
+
"""Pipeline for English text, including number and abbreviation expansion."""
|
83 |
+
text = convert_to_ascii(text)
|
84 |
+
text = lowercase(text)
|
85 |
+
text = expand_numbers(text)
|
86 |
+
text = expand_abbreviations(text)
|
87 |
+
text = collapse_whitespace(text)
|
88 |
+
return text
|
synthesizer/utils/numbers.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import inflect
|
3 |
+
|
4 |
+
_inflect = inflect.engine()
|
5 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
6 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
7 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
8 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
9 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
10 |
+
_number_re = re.compile(r"[0-9]+")
|
11 |
+
|
12 |
+
|
13 |
+
def _remove_commas(m):
|
14 |
+
return m.group(1).replace(",", "")
|
15 |
+
|
16 |
+
|
17 |
+
def _expand_decimal_point(m):
|
18 |
+
return m.group(1).replace(".", " point ")
|
19 |
+
|
20 |
+
|
21 |
+
def _expand_dollars(m):
|
22 |
+
match = m.group(1)
|
23 |
+
parts = match.split(".")
|
24 |
+
if len(parts) > 2:
|
25 |
+
return match + " dollars" # Unexpected format
|
26 |
+
dollars = int(parts[0]) if parts[0] else 0
|
27 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
28 |
+
if dollars and cents:
|
29 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
30 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
31 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
32 |
+
elif dollars:
|
33 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
34 |
+
return "%s %s" % (dollars, dollar_unit)
|
35 |
+
elif cents:
|
36 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
37 |
+
return "%s %s" % (cents, cent_unit)
|
38 |
+
else:
|
39 |
+
return "zero dollars"
|
40 |
+
|
41 |
+
|
42 |
+
def _expand_ordinal(m):
|
43 |
+
return _inflect.number_to_words(m.group(0))
|
44 |
+
|
45 |
+
|
46 |
+
def _expand_number(m):
|
47 |
+
num = int(m.group(0))
|
48 |
+
if num > 1000 and num < 3000:
|
49 |
+
if num == 2000:
|
50 |
+
return "two thousand"
|
51 |
+
elif num > 2000 and num < 2010:
|
52 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
53 |
+
elif num % 100 == 0:
|
54 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
55 |
+
else:
|
56 |
+
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
57 |
+
else:
|
58 |
+
return _inflect.number_to_words(num, andword="")
|
59 |
+
|
60 |
+
|
61 |
+
def normalize_numbers(text):
|
62 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
63 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
64 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
65 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
66 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
67 |
+
text = re.sub(_number_re, _expand_number, text)
|
68 |
+
return text
|
synthesizer/utils/plot.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
matplotlib.use("Agg")
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def split_title_line(title_text, max_words=5):
|
8 |
+
"""
|
9 |
+
A function that splits any string based on specific character
|
10 |
+
(returning it with the string), with maximum number of words on it
|
11 |
+
"""
|
12 |
+
seq = title_text.split()
|
13 |
+
return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
|
14 |
+
|
15 |
+
def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
|
16 |
+
if max_len is not None:
|
17 |
+
alignment = alignment[:, :max_len]
|
18 |
+
|
19 |
+
fig = plt.figure(figsize=(8, 6))
|
20 |
+
ax = fig.add_subplot(111)
|
21 |
+
|
22 |
+
im = ax.imshow(
|
23 |
+
alignment,
|
24 |
+
aspect="auto",
|
25 |
+
origin="lower",
|
26 |
+
interpolation="none")
|
27 |
+
fig.colorbar(im, ax=ax)
|
28 |
+
xlabel = "Decoder timestep"
|
29 |
+
|
30 |
+
if split_title:
|
31 |
+
title = split_title_line(title)
|
32 |
+
|
33 |
+
plt.xlabel(xlabel)
|
34 |
+
plt.title(title)
|
35 |
+
plt.ylabel("Encoder timestep")
|
36 |
+
plt.tight_layout()
|
37 |
+
plt.savefig(path, format="png")
|
38 |
+
plt.close()
|
39 |
+
|
40 |
+
|
41 |
+
def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
|
42 |
+
if max_len is not None:
|
43 |
+
target_spectrogram = target_spectrogram[:max_len]
|
44 |
+
pred_spectrogram = pred_spectrogram[:max_len]
|
45 |
+
|
46 |
+
if split_title:
|
47 |
+
title = split_title_line(title)
|
48 |
+
|
49 |
+
fig = plt.figure(figsize=(10, 8))
|
50 |
+
# Set common labels
|
51 |
+
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
|
52 |
+
|
53 |
+
#target spectrogram subplot
|
54 |
+
if target_spectrogram is not None:
|
55 |
+
ax1 = fig.add_subplot(311)
|
56 |
+
ax2 = fig.add_subplot(312)
|
57 |
+
|
58 |
+
if auto_aspect:
|
59 |
+
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
|
60 |
+
else:
|
61 |
+
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
|
62 |
+
ax1.set_title("Target Mel-Spectrogram")
|
63 |
+
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
|
64 |
+
ax2.set_title("Predicted Mel-Spectrogram")
|
65 |
+
else:
|
66 |
+
ax2 = fig.add_subplot(211)
|
67 |
+
|
68 |
+
if auto_aspect:
|
69 |
+
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
|
70 |
+
else:
|
71 |
+
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
|
72 |
+
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
|
73 |
+
|
74 |
+
plt.tight_layout()
|
75 |
+
plt.savefig(path, format="png")
|
76 |
+
plt.close()
|
synthesizer/utils/symbols.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines the set of symbols used in text input to the model.
|
3 |
+
|
4 |
+
The default is a set of ASCII characters that works well for English or text that has been run
|
5 |
+
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
6 |
+
"""
|
7 |
+
# from . import cmudict
|
8 |
+
|
9 |
+
_pad = "_"
|
10 |
+
_eos = "~"
|
11 |
+
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "
|
12 |
+
|
13 |
+
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
14 |
+
#_arpabet = ["@' + s for s in cmudict.valid_symbols]
|
15 |
+
|
16 |
+
# Export all symbols:
|
17 |
+
symbols = [_pad, _eos] + list(_characters) #+ _arpabet
|
synthesizer/utils/text.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .symbols import symbols
|
2 |
+
from . import cleaners
|
3 |
+
import re
|
4 |
+
|
5 |
+
# Mappings from symbol to numeric ID and vice versa:
|
6 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
7 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
8 |
+
|
9 |
+
# Regular expression matching text enclosed in curly braces:
|
10 |
+
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
11 |
+
|
12 |
+
|
13 |
+
def text_to_sequence(text, cleaner_names):
|
14 |
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
15 |
+
|
16 |
+
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
17 |
+
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
18 |
+
|
19 |
+
Args:
|
20 |
+
text: string to convert to a sequence
|
21 |
+
cleaner_names: names of the cleaner functions to run the text through
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
List of integers corresponding to the symbols in the text
|
25 |
+
"""
|
26 |
+
sequence = []
|
27 |
+
|
28 |
+
# Check for curly braces and treat their contents as ARPAbet:
|
29 |
+
while len(text):
|
30 |
+
m = _curly_re.match(text)
|
31 |
+
if not m:
|
32 |
+
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
33 |
+
break
|
34 |
+
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
35 |
+
sequence += _arpabet_to_sequence(m.group(2))
|
36 |
+
text = m.group(3)
|
37 |
+
|
38 |
+
# Append EOS token
|
39 |
+
sequence.append(_symbol_to_id["~"])
|
40 |
+
return sequence
|
41 |
+
|
42 |
+
|
43 |
+
def sequence_to_text(sequence):
|
44 |
+
"""Converts a sequence of IDs back to a string"""
|
45 |
+
result = ""
|
46 |
+
for symbol_id in sequence:
|
47 |
+
if symbol_id in _id_to_symbol:
|
48 |
+
s = _id_to_symbol[symbol_id]
|
49 |
+
# Enclose ARPAbet back in curly braces:
|
50 |
+
if len(s) > 1 and s[0] == "@":
|
51 |
+
s = "{%s}" % s[1:]
|
52 |
+
result += s
|
53 |
+
return result.replace("}{", " ")
|
54 |
+
|
55 |
+
|
56 |
+
def _clean_text(text, cleaner_names):
|
57 |
+
for name in cleaner_names:
|
58 |
+
cleaner = getattr(cleaners, name)
|
59 |
+
if not cleaner:
|
60 |
+
raise Exception("Unknown cleaner: %s" % name)
|
61 |
+
text = cleaner(text)
|
62 |
+
return text
|
63 |
+
|
64 |
+
|
65 |
+
def _symbols_to_sequence(symbols):
|
66 |
+
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
67 |
+
|
68 |
+
|
69 |
+
def _arpabet_to_sequence(text):
|
70 |
+
return _symbols_to_sequence(["@" + s for s in text.split()])
|
71 |
+
|
72 |
+
|
73 |
+
def _should_keep_symbol(s):
|
74 |
+
return s in _symbol_to_id and s not in ("_", "~")
|
synthesizer_preprocess_audio.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from synthesizer.preprocess import preprocess_dataset
|
2 |
+
from synthesizer.hparams import hparams
|
3 |
+
from utils.argutils import print_args
|
4 |
+
from pathlib import Path
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
parser = argparse.ArgumentParser(
|
10 |
+
description="Preprocesses audio files from datasets, encodes them as mel spectrograms "
|
11 |
+
"and writes them to the disk. Audio files are also saved, to be used by the "
|
12 |
+
"vocoder for training.",
|
13 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
14 |
+
)
|
15 |
+
parser.add_argument("datasets_root", type=Path, help=\
|
16 |
+
"Path to the directory containing your LibriSpeech/TTS datasets.")
|
17 |
+
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
18 |
+
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
19 |
+
"embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/")
|
20 |
+
parser.add_argument("-n", "--n_processes", type=int, default=None, help=\
|
21 |
+
"Number of processes in parallel.")
|
22 |
+
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
23 |
+
"Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
24 |
+
"interrupted.")
|
25 |
+
parser.add_argument("--hparams", type=str, default="", help=\
|
26 |
+
"Hyperparameter overrides as a comma-separated list of name-value pairs")
|
27 |
+
parser.add_argument("--no_trim", action="store_true", help=\
|
28 |
+
"Preprocess audio without trimming silences (not recommended).")
|
29 |
+
parser.add_argument("--no_alignments", action="store_true", help=\
|
30 |
+
"Use this option when dataset does not include alignments\
|
31 |
+
(these are used to split long audio files into sub-utterances.)")
|
32 |
+
parser.add_argument("--datasets_name", type=str, default="LibriSpeech", help=\
|
33 |
+
"Name of the dataset directory to process.")
|
34 |
+
parser.add_argument("--subfolders", type=str, default="train-clean-100, train-clean-360", help=\
|
35 |
+
"Comma-separated list of subfolders to process inside your dataset directory")
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
# Process the arguments
|
39 |
+
if not hasattr(args, "out_dir"):
|
40 |
+
args.out_dir = args.datasets_root.joinpath("SV2TTS", "synthesizer")
|
41 |
+
|
42 |
+
# Create directories
|
43 |
+
assert args.datasets_root.exists()
|
44 |
+
args.out_dir.mkdir(exist_ok=True, parents=True)
|
45 |
+
|
46 |
+
# Verify webrtcvad is available
|
47 |
+
if not args.no_trim:
|
48 |
+
try:
|
49 |
+
import webrtcvad
|
50 |
+
except:
|
51 |
+
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
52 |
+
"noise removal and is recommended. Please install and try again. If installation fails, "
|
53 |
+
"use --no_trim to disable this error message.")
|
54 |
+
del args.no_trim
|
55 |
+
|
56 |
+
# Preprocess the dataset
|
57 |
+
print_args(args, parser)
|
58 |
+
args.hparams = hparams.parse(args.hparams)
|
59 |
+
preprocess_dataset(**vars(args))
|
synthesizer_preprocess_embeds.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from synthesizer.preprocess import create_embeddings
|
2 |
+
from utils.argutils import print_args
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
parser = argparse.ArgumentParser(
|
9 |
+
description="Creates embeddings for the synthesizer from the LibriSpeech utterances.",
|
10 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
11 |
+
)
|
12 |
+
parser.add_argument("synthesizer_root", type=Path, help=\
|
13 |
+
"Path to the synthesizer training data that contains the audios and the train.txt file. "
|
14 |
+
"If you let everything as default, it should be <datasets_root>/SV2TTS/synthesizer/.")
|
15 |
+
parser.add_argument("-e", "--encoder_model_fpath", type=Path,
|
16 |
+
default="encoder/saved_models/pretrained.pt", help=\
|
17 |
+
"Path your trained encoder model.")
|
18 |
+
parser.add_argument("-n", "--n_processes", type=int, default=4, help= \
|
19 |
+
"Number of parallel processes. An encoder is created for each, so you may need to lower "
|
20 |
+
"this value on GPUs with low memory. Set it to 1 if CUDA is unhappy.")
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
# Preprocess the dataset
|
24 |
+
print_args(args, parser)
|
25 |
+
create_embeddings(**vars(args))
|
synthesizer_train.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from synthesizer.hparams import hparams
|
2 |
+
from synthesizer.train import train
|
3 |
+
from utils.argutils import print_args
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("run_id", type=str, help= \
|
10 |
+
"Name for this model instance. If a model state from the same run ID was previously "
|
11 |
+
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
12 |
+
"restart from scratch.")
|
13 |
+
parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
14 |
+
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
15 |
+
"the wavs and the embeds.")
|
16 |
+
parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\
|
17 |
+
"Path to the output directory that will contain the saved model weights and the logs.")
|
18 |
+
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
|
19 |
+
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
20 |
+
"model.")
|
21 |
+
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
22 |
+
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
23 |
+
"model.")
|
24 |
+
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
25 |
+
"Do not load any saved model and restart from scratch.")
|
26 |
+
parser.add_argument("--hparams", default="",
|
27 |
+
help="Hyperparameter overrides as a comma-separated list of name=value "
|
28 |
+
"pairs")
|
29 |
+
args = parser.parse_args()
|
30 |
+
print_args(args, parser)
|
31 |
+
|
32 |
+
args.hparams = hparams.parse(args.hparams)
|
33 |
+
|
34 |
+
# Run the training
|
35 |
+
train(**vars(args))
|
toolbox/__init__.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from toolbox.ui import UI
|
2 |
+
from encoder import inference as encoder
|
3 |
+
from synthesizer.inference import Synthesizer
|
4 |
+
from vocoder import inference as vocoder
|
5 |
+
from pathlib import Path
|
6 |
+
from time import perf_counter as timer
|
7 |
+
from toolbox.utterance import Utterance
|
8 |
+
import numpy as np
|
9 |
+
import traceback
|
10 |
+
import sys
|
11 |
+
import torch
|
12 |
+
import librosa
|
13 |
+
from audioread.exceptions import NoBackendError
|
14 |
+
|
15 |
+
# Use this directory structure for your datasets, or modify it to fit your needs
|
16 |
+
recognized_datasets = [
|
17 |
+
"LibriSpeech/dev-clean",
|
18 |
+
"LibriSpeech/dev-other",
|
19 |
+
"LibriSpeech/test-clean",
|
20 |
+
"LibriSpeech/test-other",
|
21 |
+
"LibriSpeech/train-clean-100",
|
22 |
+
"LibriSpeech/train-clean-360",
|
23 |
+
"LibriSpeech/train-other-500",
|
24 |
+
"LibriTTS/dev-clean",
|
25 |
+
"LibriTTS/dev-other",
|
26 |
+
"LibriTTS/test-clean",
|
27 |
+
"LibriTTS/test-other",
|
28 |
+
"LibriTTS/train-clean-100",
|
29 |
+
"LibriTTS/train-clean-360",
|
30 |
+
"LibriTTS/train-other-500",
|
31 |
+
"LJSpeech-1.1",
|
32 |
+
"VoxCeleb1/wav",
|
33 |
+
"VoxCeleb1/test_wav",
|
34 |
+
"VoxCeleb2/dev/aac",
|
35 |
+
"VoxCeleb2/test/aac",
|
36 |
+
"VCTK-Corpus/wav48",
|
37 |
+
]
|
38 |
+
|
39 |
+
#Maximum of generated wavs to keep on memory
|
40 |
+
MAX_WAVES = 15
|
41 |
+
|
42 |
+
class Toolbox:
|
43 |
+
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support):
|
44 |
+
if not no_mp3_support:
|
45 |
+
try:
|
46 |
+
librosa.load("samples/6829_00000.mp3")
|
47 |
+
except NoBackendError:
|
48 |
+
print("Librosa will be unable to open mp3 files if additional software is not installed.\n"
|
49 |
+
"Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.")
|
50 |
+
exit(-1)
|
51 |
+
self.no_mp3_support = no_mp3_support
|
52 |
+
sys.excepthook = self.excepthook
|
53 |
+
self.datasets_root = datasets_root
|
54 |
+
self.utterances = set()
|
55 |
+
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
|
56 |
+
|
57 |
+
self.synthesizer = None # type: Synthesizer
|
58 |
+
self.current_wav = None
|
59 |
+
self.waves_list = []
|
60 |
+
self.waves_count = 0
|
61 |
+
self.waves_namelist = []
|
62 |
+
|
63 |
+
# Check for webrtcvad (enables removal of silences in vocoder output)
|
64 |
+
try:
|
65 |
+
import webrtcvad
|
66 |
+
self.trim_silences = True
|
67 |
+
except:
|
68 |
+
self.trim_silences = False
|
69 |
+
|
70 |
+
# Initialize the events and the interface
|
71 |
+
self.ui = UI()
|
72 |
+
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
|
73 |
+
self.setup_events()
|
74 |
+
self.ui.start()
|
75 |
+
|
76 |
+
def excepthook(self, exc_type, exc_value, exc_tb):
|
77 |
+
traceback.print_exception(exc_type, exc_value, exc_tb)
|
78 |
+
self.ui.log("Exception: %s" % exc_value)
|
79 |
+
|
80 |
+
def setup_events(self):
|
81 |
+
# Dataset, speaker and utterance selection
|
82 |
+
self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
|
83 |
+
random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
|
84 |
+
recognized_datasets,
|
85 |
+
level)
|
86 |
+
self.ui.random_dataset_button.clicked.connect(random_func(0))
|
87 |
+
self.ui.random_speaker_button.clicked.connect(random_func(1))
|
88 |
+
self.ui.random_utterance_button.clicked.connect(random_func(2))
|
89 |
+
self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
|
90 |
+
self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
|
91 |
+
|
92 |
+
# Model selection
|
93 |
+
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
|
94 |
+
def func():
|
95 |
+
self.synthesizer = None
|
96 |
+
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
97 |
+
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
|
98 |
+
|
99 |
+
# Utterance selection
|
100 |
+
func = lambda: self.load_from_browser(self.ui.browse_file())
|
101 |
+
self.ui.browser_browse_button.clicked.connect(func)
|
102 |
+
func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
|
103 |
+
self.ui.utterance_history.currentIndexChanged.connect(func)
|
104 |
+
func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
|
105 |
+
self.ui.play_button.clicked.connect(func)
|
106 |
+
self.ui.stop_button.clicked.connect(self.ui.stop)
|
107 |
+
self.ui.record_button.clicked.connect(self.record)
|
108 |
+
|
109 |
+
#Audio
|
110 |
+
self.ui.setup_audio_devices(Synthesizer.sample_rate)
|
111 |
+
|
112 |
+
#Wav playback & save
|
113 |
+
func = lambda: self.replay_last_wav()
|
114 |
+
self.ui.replay_wav_button.clicked.connect(func)
|
115 |
+
func = lambda: self.export_current_wave()
|
116 |
+
self.ui.export_wav_button.clicked.connect(func)
|
117 |
+
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
118 |
+
|
119 |
+
# Generation
|
120 |
+
func = lambda: self.synthesize() or self.vocode()
|
121 |
+
self.ui.generate_button.clicked.connect(func)
|
122 |
+
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
123 |
+
self.ui.vocode_button.clicked.connect(self.vocode)
|
124 |
+
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
125 |
+
|
126 |
+
# UMAP legend
|
127 |
+
self.ui.clear_button.clicked.connect(self.clear_utterances)
|
128 |
+
|
129 |
+
def set_current_wav(self, index):
|
130 |
+
self.current_wav = self.waves_list[index]
|
131 |
+
|
132 |
+
def export_current_wave(self):
|
133 |
+
self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
|
134 |
+
|
135 |
+
def replay_last_wav(self):
|
136 |
+
self.ui.play(self.current_wav, Synthesizer.sample_rate)
|
137 |
+
|
138 |
+
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed):
|
139 |
+
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
|
140 |
+
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
|
141 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
142 |
+
|
143 |
+
def load_from_browser(self, fpath=None):
|
144 |
+
if fpath is None:
|
145 |
+
fpath = Path(self.datasets_root,
|
146 |
+
self.ui.current_dataset_name,
|
147 |
+
self.ui.current_speaker_name,
|
148 |
+
self.ui.current_utterance_name)
|
149 |
+
name = str(fpath.relative_to(self.datasets_root))
|
150 |
+
speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
|
151 |
+
|
152 |
+
# Select the next utterance
|
153 |
+
if self.ui.auto_next_checkbox.isChecked():
|
154 |
+
self.ui.browser_select_next()
|
155 |
+
elif fpath == "":
|
156 |
+
return
|
157 |
+
else:
|
158 |
+
name = fpath.name
|
159 |
+
speaker_name = fpath.parent.name
|
160 |
+
|
161 |
+
if fpath.suffix.lower() == ".mp3" and self.no_mp3_support:
|
162 |
+
self.ui.log("Error: No mp3 file argument was passed but an mp3 file was used")
|
163 |
+
return
|
164 |
+
|
165 |
+
# Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
|
166 |
+
# playback, so as to have a fair comparison with the generated audio
|
167 |
+
wav = Synthesizer.load_preprocess_wav(fpath)
|
168 |
+
self.ui.log("Loaded %s" % name)
|
169 |
+
|
170 |
+
self.add_real_utterance(wav, name, speaker_name)
|
171 |
+
|
172 |
+
def record(self):
|
173 |
+
wav = self.ui.record_one(encoder.sampling_rate, 5)
|
174 |
+
if wav is None:
|
175 |
+
return
|
176 |
+
self.ui.play(wav, encoder.sampling_rate)
|
177 |
+
|
178 |
+
speaker_name = "user01"
|
179 |
+
name = speaker_name + "_rec_%05d" % np.random.randint(100000)
|
180 |
+
self.add_real_utterance(wav, name, speaker_name)
|
181 |
+
|
182 |
+
def add_real_utterance(self, wav, name, speaker_name):
|
183 |
+
# Compute the mel spectrogram
|
184 |
+
spec = Synthesizer.make_spectrogram(wav)
|
185 |
+
self.ui.draw_spec(spec, "current")
|
186 |
+
|
187 |
+
# Compute the embedding
|
188 |
+
if not encoder.is_loaded():
|
189 |
+
self.init_encoder()
|
190 |
+
encoder_wav = encoder.preprocess_wav(wav)
|
191 |
+
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
192 |
+
|
193 |
+
# Add the utterance
|
194 |
+
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
|
195 |
+
self.utterances.add(utterance)
|
196 |
+
self.ui.register_utterance(utterance)
|
197 |
+
|
198 |
+
# Plot it
|
199 |
+
self.ui.draw_embed(embed, name, "current")
|
200 |
+
self.ui.draw_umap_projections(self.utterances)
|
201 |
+
|
202 |
+
def clear_utterances(self):
|
203 |
+
self.utterances.clear()
|
204 |
+
self.ui.draw_umap_projections(self.utterances)
|
205 |
+
|
206 |
+
def synthesize(self):
|
207 |
+
self.ui.log("Generating the mel spectrogram...")
|
208 |
+
self.ui.set_loading(1)
|
209 |
+
|
210 |
+
# Update the synthesizer random seed
|
211 |
+
if self.ui.random_seed_checkbox.isChecked():
|
212 |
+
seed = int(self.ui.seed_textbox.text())
|
213 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
214 |
+
else:
|
215 |
+
seed = None
|
216 |
+
|
217 |
+
if seed is not None:
|
218 |
+
torch.manual_seed(seed)
|
219 |
+
|
220 |
+
# Synthesize the spectrogram
|
221 |
+
if self.synthesizer is None or seed is not None:
|
222 |
+
self.init_synthesizer()
|
223 |
+
|
224 |
+
texts = self.ui.text_prompt.toPlainText().split("\n")
|
225 |
+
embed = self.ui.selected_utterance.embed
|
226 |
+
embeds = [embed] * len(texts)
|
227 |
+
specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
|
228 |
+
breaks = [spec.shape[1] for spec in specs]
|
229 |
+
spec = np.concatenate(specs, axis=1)
|
230 |
+
|
231 |
+
self.ui.draw_spec(spec, "generated")
|
232 |
+
self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
|
233 |
+
self.ui.set_loading(0)
|
234 |
+
|
235 |
+
def vocode(self):
|
236 |
+
speaker_name, spec, breaks, _ = self.current_generated
|
237 |
+
assert spec is not None
|
238 |
+
|
239 |
+
# Initialize the vocoder model and make it determinstic, if user provides a seed
|
240 |
+
if self.ui.random_seed_checkbox.isChecked():
|
241 |
+
seed = int(self.ui.seed_textbox.text())
|
242 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
243 |
+
else:
|
244 |
+
seed = None
|
245 |
+
|
246 |
+
if seed is not None:
|
247 |
+
torch.manual_seed(seed)
|
248 |
+
|
249 |
+
# Synthesize the waveform
|
250 |
+
if not vocoder.is_loaded() or seed is not None:
|
251 |
+
self.init_vocoder()
|
252 |
+
|
253 |
+
def vocoder_progress(i, seq_len, b_size, gen_rate):
|
254 |
+
real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
|
255 |
+
line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
|
256 |
+
% (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
|
257 |
+
self.ui.log(line, "overwrite")
|
258 |
+
self.ui.set_loading(i, seq_len)
|
259 |
+
if self.ui.current_vocoder_fpath is not None:
|
260 |
+
self.ui.log("")
|
261 |
+
wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
262 |
+
else:
|
263 |
+
self.ui.log("Waveform generation with Griffin-Lim... ")
|
264 |
+
wav = Synthesizer.griffin_lim(spec)
|
265 |
+
self.ui.set_loading(0)
|
266 |
+
self.ui.log(" Done!", "append")
|
267 |
+
|
268 |
+
# Add breaks
|
269 |
+
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
|
270 |
+
b_starts = np.concatenate(([0], b_ends[:-1]))
|
271 |
+
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
272 |
+
breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
|
273 |
+
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
274 |
+
|
275 |
+
# Trim excessive silences
|
276 |
+
if self.ui.trim_silences_checkbox.isChecked():
|
277 |
+
wav = encoder.preprocess_wav(wav)
|
278 |
+
|
279 |
+
# Play it
|
280 |
+
wav = wav / np.abs(wav).max() * 0.97
|
281 |
+
self.ui.play(wav, Synthesizer.sample_rate)
|
282 |
+
|
283 |
+
# Name it (history displayed in combobox)
|
284 |
+
# TODO better naming for the combobox items?
|
285 |
+
wav_name = str(self.waves_count + 1)
|
286 |
+
|
287 |
+
#Update waves combobox
|
288 |
+
self.waves_count += 1
|
289 |
+
if self.waves_count > MAX_WAVES:
|
290 |
+
self.waves_list.pop()
|
291 |
+
self.waves_namelist.pop()
|
292 |
+
self.waves_list.insert(0, wav)
|
293 |
+
self.waves_namelist.insert(0, wav_name)
|
294 |
+
|
295 |
+
self.ui.waves_cb.disconnect()
|
296 |
+
self.ui.waves_cb_model.setStringList(self.waves_namelist)
|
297 |
+
self.ui.waves_cb.setCurrentIndex(0)
|
298 |
+
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
299 |
+
|
300 |
+
# Update current wav
|
301 |
+
self.set_current_wav(0)
|
302 |
+
|
303 |
+
#Enable replay and save buttons:
|
304 |
+
self.ui.replay_wav_button.setDisabled(False)
|
305 |
+
self.ui.export_wav_button.setDisabled(False)
|
306 |
+
|
307 |
+
# Compute the embedding
|
308 |
+
# TODO: this is problematic with different sampling rates, gotta fix it
|
309 |
+
if not encoder.is_loaded():
|
310 |
+
self.init_encoder()
|
311 |
+
encoder_wav = encoder.preprocess_wav(wav)
|
312 |
+
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
313 |
+
|
314 |
+
# Add the utterance
|
315 |
+
name = speaker_name + "_gen_%05d" % np.random.randint(100000)
|
316 |
+
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
|
317 |
+
self.utterances.add(utterance)
|
318 |
+
|
319 |
+
# Plot it
|
320 |
+
self.ui.draw_embed(embed, name, "generated")
|
321 |
+
self.ui.draw_umap_projections(self.utterances)
|
322 |
+
|
323 |
+
def init_encoder(self):
|
324 |
+
model_fpath = self.ui.current_encoder_fpath
|
325 |
+
|
326 |
+
self.ui.log("Loading the encoder %s... " % model_fpath)
|
327 |
+
self.ui.set_loading(1)
|
328 |
+
start = timer()
|
329 |
+
encoder.load_model(model_fpath)
|
330 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
331 |
+
self.ui.set_loading(0)
|
332 |
+
|
333 |
+
def init_synthesizer(self):
|
334 |
+
model_fpath = self.ui.current_synthesizer_fpath
|
335 |
+
|
336 |
+
self.ui.log("Loading the synthesizer %s... " % model_fpath)
|
337 |
+
self.ui.set_loading(1)
|
338 |
+
start = timer()
|
339 |
+
self.synthesizer = Synthesizer(model_fpath)
|
340 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
341 |
+
self.ui.set_loading(0)
|
342 |
+
|
343 |
+
def init_vocoder(self):
|
344 |
+
model_fpath = self.ui.current_vocoder_fpath
|
345 |
+
# Case of Griffin-lim
|
346 |
+
if model_fpath is None:
|
347 |
+
return
|
348 |
+
|
349 |
+
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
350 |
+
self.ui.set_loading(1)
|
351 |
+
start = timer()
|
352 |
+
vocoder.load_model(model_fpath)
|
353 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
354 |
+
self.ui.set_loading(0)
|
355 |
+
|
356 |
+
def update_seed_textbox(self):
|
357 |
+
self.ui.update_seed_textbox()
|
toolbox/ui.py
ADDED
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
3 |
+
from matplotlib.figure import Figure
|
4 |
+
from PyQt5.QtCore import Qt, QStringListModel
|
5 |
+
from PyQt5.QtWidgets import *
|
6 |
+
from encoder.inference import plot_embedding_as_heatmap
|
7 |
+
from toolbox.utterance import Utterance
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Set
|
10 |
+
import sounddevice as sd
|
11 |
+
import soundfile as sf
|
12 |
+
import numpy as np
|
13 |
+
# from sklearn.manifold import TSNE # You can try with TSNE if you like, I prefer UMAP
|
14 |
+
from time import sleep
|
15 |
+
import umap
|
16 |
+
import sys
|
17 |
+
from warnings import filterwarnings, warn
|
18 |
+
filterwarnings("ignore")
|
19 |
+
|
20 |
+
|
21 |
+
colormap = np.array([
|
22 |
+
[0, 127, 70],
|
23 |
+
[255, 0, 0],
|
24 |
+
[255, 217, 38],
|
25 |
+
[0, 135, 255],
|
26 |
+
[165, 0, 165],
|
27 |
+
[255, 167, 255],
|
28 |
+
[97, 142, 151],
|
29 |
+
[0, 255, 255],
|
30 |
+
[255, 96, 38],
|
31 |
+
[142, 76, 0],
|
32 |
+
[33, 0, 127],
|
33 |
+
[0, 0, 0],
|
34 |
+
[183, 183, 183],
|
35 |
+
[76, 255, 0],
|
36 |
+
], dtype=np.float) / 255
|
37 |
+
|
38 |
+
default_text = \
|
39 |
+
"Welcome to the toolbox! To begin, load an utterance from your datasets or record one " \
|
40 |
+
"yourself.\nOnce its embedding has been created, you can synthesize any text written here.\n" \
|
41 |
+
"The synthesizer expects to generate " \
|
42 |
+
"outputs that are somewhere between 5 and 12 seconds.\nTo mark breaks, write a new line. " \
|
43 |
+
"Each line will be treated separately.\nThen, they are joined together to make the final " \
|
44 |
+
"spectrogram. Use the vocoder to generate audio.\nThe vocoder generates almost in constant " \
|
45 |
+
"time, so it will be more time efficient for longer inputs like this one.\nOn the left you " \
|
46 |
+
"have the embedding projections. Load or record more utterances to see them.\nIf you have " \
|
47 |
+
"at least 2 or 3 utterances from a same speaker, a cluster should form.\nSynthesized " \
|
48 |
+
"utterances are of the same color as the speaker whose voice was used, but they're " \
|
49 |
+
"represented with a cross."
|
50 |
+
|
51 |
+
|
52 |
+
class UI(QDialog):
|
53 |
+
min_umap_points = 4
|
54 |
+
max_log_lines = 5
|
55 |
+
max_saved_utterances = 20
|
56 |
+
|
57 |
+
def draw_utterance(self, utterance: Utterance, which):
|
58 |
+
self.draw_spec(utterance.spec, which)
|
59 |
+
self.draw_embed(utterance.embed, utterance.name, which)
|
60 |
+
|
61 |
+
def draw_embed(self, embed, name, which):
|
62 |
+
embed_ax, _ = self.current_ax if which == "current" else self.gen_ax
|
63 |
+
embed_ax.figure.suptitle("" if embed is None else name)
|
64 |
+
|
65 |
+
## Embedding
|
66 |
+
# Clear the plot
|
67 |
+
if len(embed_ax.images) > 0:
|
68 |
+
embed_ax.images[0].colorbar.remove()
|
69 |
+
embed_ax.clear()
|
70 |
+
|
71 |
+
# Draw the embed
|
72 |
+
if embed is not None:
|
73 |
+
plot_embedding_as_heatmap(embed, embed_ax)
|
74 |
+
embed_ax.set_title("embedding")
|
75 |
+
embed_ax.set_aspect("equal", "datalim")
|
76 |
+
embed_ax.set_xticks([])
|
77 |
+
embed_ax.set_yticks([])
|
78 |
+
embed_ax.figure.canvas.draw()
|
79 |
+
|
80 |
+
def draw_spec(self, spec, which):
|
81 |
+
_, spec_ax = self.current_ax if which == "current" else self.gen_ax
|
82 |
+
|
83 |
+
## Spectrogram
|
84 |
+
# Draw the spectrogram
|
85 |
+
spec_ax.clear()
|
86 |
+
if spec is not None:
|
87 |
+
im = spec_ax.imshow(spec, aspect="auto", interpolation="none")
|
88 |
+
# spec_ax.figure.colorbar(mappable=im, shrink=0.65, orientation="horizontal",
|
89 |
+
# spec_ax=spec_ax)
|
90 |
+
spec_ax.set_title("mel spectrogram")
|
91 |
+
|
92 |
+
spec_ax.set_xticks([])
|
93 |
+
spec_ax.set_yticks([])
|
94 |
+
spec_ax.figure.canvas.draw()
|
95 |
+
if which != "current":
|
96 |
+
self.vocode_button.setDisabled(spec is None)
|
97 |
+
|
98 |
+
def draw_umap_projections(self, utterances: Set[Utterance]):
|
99 |
+
self.umap_ax.clear()
|
100 |
+
|
101 |
+
speakers = np.unique([u.speaker_name for u in utterances])
|
102 |
+
colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)}
|
103 |
+
embeds = [u.embed for u in utterances]
|
104 |
+
|
105 |
+
# Display a message if there aren't enough points
|
106 |
+
if len(utterances) < self.min_umap_points:
|
107 |
+
self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" %
|
108 |
+
(self.min_umap_points - len(utterances)),
|
109 |
+
horizontalalignment='center', fontsize=15)
|
110 |
+
self.umap_ax.set_title("")
|
111 |
+
|
112 |
+
# Compute the projections
|
113 |
+
else:
|
114 |
+
if not self.umap_hot:
|
115 |
+
self.log(
|
116 |
+
"Drawing UMAP projections for the first time, this will take a few seconds.")
|
117 |
+
self.umap_hot = True
|
118 |
+
|
119 |
+
reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine")
|
120 |
+
# reducer = TSNE()
|
121 |
+
projections = reducer.fit_transform(embeds)
|
122 |
+
|
123 |
+
speakers_done = set()
|
124 |
+
for projection, utterance in zip(projections, utterances):
|
125 |
+
color = colors[utterance.speaker_name]
|
126 |
+
mark = "x" if "_gen_" in utterance.name else "o"
|
127 |
+
label = None if utterance.speaker_name in speakers_done else utterance.speaker_name
|
128 |
+
speakers_done.add(utterance.speaker_name)
|
129 |
+
self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark,
|
130 |
+
label=label)
|
131 |
+
# self.umap_ax.set_title("UMAP projections")
|
132 |
+
self.umap_ax.legend(prop={'size': 10})
|
133 |
+
|
134 |
+
# Draw the plot
|
135 |
+
self.umap_ax.set_aspect("equal", "datalim")
|
136 |
+
self.umap_ax.set_xticks([])
|
137 |
+
self.umap_ax.set_yticks([])
|
138 |
+
self.umap_ax.figure.canvas.draw()
|
139 |
+
|
140 |
+
def save_audio_file(self, wav, sample_rate):
|
141 |
+
dialog = QFileDialog()
|
142 |
+
dialog.setDefaultSuffix(".wav")
|
143 |
+
fpath, _ = dialog.getSaveFileName(
|
144 |
+
parent=self,
|
145 |
+
caption="Select a path to save the audio file",
|
146 |
+
filter="Audio Files (*.flac *.wav)"
|
147 |
+
)
|
148 |
+
if fpath:
|
149 |
+
#Default format is wav
|
150 |
+
if Path(fpath).suffix == "":
|
151 |
+
fpath += ".wav"
|
152 |
+
sf.write(fpath, wav, sample_rate)
|
153 |
+
|
154 |
+
def setup_audio_devices(self, sample_rate):
|
155 |
+
input_devices = []
|
156 |
+
output_devices = []
|
157 |
+
for device in sd.query_devices():
|
158 |
+
# Check if valid input
|
159 |
+
try:
|
160 |
+
sd.check_input_settings(device=device["name"], samplerate=sample_rate)
|
161 |
+
input_devices.append(device["name"])
|
162 |
+
except:
|
163 |
+
pass
|
164 |
+
|
165 |
+
# Check if valid output
|
166 |
+
try:
|
167 |
+
sd.check_output_settings(device=device["name"], samplerate=sample_rate)
|
168 |
+
output_devices.append(device["name"])
|
169 |
+
except Exception as e:
|
170 |
+
# Log a warning only if the device is not an input
|
171 |
+
if not device["name"] in input_devices:
|
172 |
+
warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e)))
|
173 |
+
|
174 |
+
if len(input_devices) == 0:
|
175 |
+
self.log("No audio input device detected. Recording may not work.")
|
176 |
+
self.audio_in_device = None
|
177 |
+
else:
|
178 |
+
self.audio_in_device = input_devices[0]
|
179 |
+
|
180 |
+
if len(output_devices) == 0:
|
181 |
+
self.log("No supported output audio devices were found! Audio output may not work.")
|
182 |
+
self.audio_out_devices_cb.addItems(["None"])
|
183 |
+
self.audio_out_devices_cb.setDisabled(True)
|
184 |
+
else:
|
185 |
+
self.audio_out_devices_cb.clear()
|
186 |
+
self.audio_out_devices_cb.addItems(output_devices)
|
187 |
+
self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device)
|
188 |
+
|
189 |
+
self.set_audio_device()
|
190 |
+
|
191 |
+
def set_audio_device(self):
|
192 |
+
|
193 |
+
output_device = self.audio_out_devices_cb.currentText()
|
194 |
+
if output_device == "None":
|
195 |
+
output_device = None
|
196 |
+
|
197 |
+
# If None, sounddevice queries portaudio
|
198 |
+
sd.default.device = (self.audio_in_device, output_device)
|
199 |
+
|
200 |
+
def play(self, wav, sample_rate):
|
201 |
+
try:
|
202 |
+
sd.stop()
|
203 |
+
sd.play(wav, sample_rate)
|
204 |
+
except Exception as e:
|
205 |
+
print(e)
|
206 |
+
self.log("Error in audio playback. Try selecting a different audio output device.")
|
207 |
+
self.log("Your device must be connected before you start the toolbox.")
|
208 |
+
|
209 |
+
def stop(self):
|
210 |
+
sd.stop()
|
211 |
+
|
212 |
+
def record_one(self, sample_rate, duration):
|
213 |
+
self.record_button.setText("Recording...")
|
214 |
+
self.record_button.setDisabled(True)
|
215 |
+
|
216 |
+
self.log("Recording %d seconds of audio" % duration)
|
217 |
+
sd.stop()
|
218 |
+
try:
|
219 |
+
wav = sd.rec(duration * sample_rate, sample_rate, 1)
|
220 |
+
except Exception as e:
|
221 |
+
print(e)
|
222 |
+
self.log("Could not record anything. Is your recording device enabled?")
|
223 |
+
self.log("Your device must be connected before you start the toolbox.")
|
224 |
+
return None
|
225 |
+
|
226 |
+
for i in np.arange(0, duration, 0.1):
|
227 |
+
self.set_loading(i, duration)
|
228 |
+
sleep(0.1)
|
229 |
+
self.set_loading(duration, duration)
|
230 |
+
sd.wait()
|
231 |
+
|
232 |
+
self.log("Done recording.")
|
233 |
+
self.record_button.setText("Record")
|
234 |
+
self.record_button.setDisabled(False)
|
235 |
+
|
236 |
+
return wav.squeeze()
|
237 |
+
|
238 |
+
@property
|
239 |
+
def current_dataset_name(self):
|
240 |
+
return self.dataset_box.currentText()
|
241 |
+
|
242 |
+
@property
|
243 |
+
def current_speaker_name(self):
|
244 |
+
return self.speaker_box.currentText()
|
245 |
+
|
246 |
+
@property
|
247 |
+
def current_utterance_name(self):
|
248 |
+
return self.utterance_box.currentText()
|
249 |
+
|
250 |
+
def browse_file(self):
|
251 |
+
fpath = QFileDialog().getOpenFileName(
|
252 |
+
parent=self,
|
253 |
+
caption="Select an audio file",
|
254 |
+
filter="Audio Files (*.mp3 *.flac *.wav *.m4a)"
|
255 |
+
)
|
256 |
+
return Path(fpath[0]) if fpath[0] != "" else ""
|
257 |
+
|
258 |
+
@staticmethod
|
259 |
+
def repopulate_box(box, items, random=False):
|
260 |
+
"""
|
261 |
+
Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join
|
262 |
+
data to the items
|
263 |
+
"""
|
264 |
+
box.blockSignals(True)
|
265 |
+
box.clear()
|
266 |
+
for item in items:
|
267 |
+
item = list(item) if isinstance(item, tuple) else [item]
|
268 |
+
box.addItem(str(item[0]), *item[1:])
|
269 |
+
if len(items) > 0:
|
270 |
+
box.setCurrentIndex(np.random.randint(len(items)) if random else 0)
|
271 |
+
box.setDisabled(len(items) == 0)
|
272 |
+
box.blockSignals(False)
|
273 |
+
|
274 |
+
def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int,
|
275 |
+
random=True):
|
276 |
+
# Select a random dataset
|
277 |
+
if level <= 0:
|
278 |
+
if datasets_root is not None:
|
279 |
+
datasets = [datasets_root.joinpath(d) for d in recognized_datasets]
|
280 |
+
datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()]
|
281 |
+
self.browser_load_button.setDisabled(len(datasets) == 0)
|
282 |
+
if datasets_root is None or len(datasets) == 0:
|
283 |
+
msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \
|
284 |
+
if datasets_root is None else "o not have any of the recognized datasets" \
|
285 |
+
" in %s" % datasets_root)
|
286 |
+
self.log(msg)
|
287 |
+
msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \
|
288 |
+
"can still use the toolbox by recording samples yourself." % \
|
289 |
+
("\n\t".join(recognized_datasets))
|
290 |
+
print(msg, file=sys.stderr)
|
291 |
+
|
292 |
+
self.random_utterance_button.setDisabled(True)
|
293 |
+
self.random_speaker_button.setDisabled(True)
|
294 |
+
self.random_dataset_button.setDisabled(True)
|
295 |
+
self.utterance_box.setDisabled(True)
|
296 |
+
self.speaker_box.setDisabled(True)
|
297 |
+
self.dataset_box.setDisabled(True)
|
298 |
+
self.browser_load_button.setDisabled(True)
|
299 |
+
self.auto_next_checkbox.setDisabled(True)
|
300 |
+
return
|
301 |
+
self.repopulate_box(self.dataset_box, datasets, random)
|
302 |
+
|
303 |
+
# Select a random speaker
|
304 |
+
if level <= 1:
|
305 |
+
speakers_root = datasets_root.joinpath(self.current_dataset_name)
|
306 |
+
speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()]
|
307 |
+
self.repopulate_box(self.speaker_box, speaker_names, random)
|
308 |
+
|
309 |
+
# Select a random utterance
|
310 |
+
if level <= 2:
|
311 |
+
utterances_root = datasets_root.joinpath(
|
312 |
+
self.current_dataset_name,
|
313 |
+
self.current_speaker_name
|
314 |
+
)
|
315 |
+
utterances = []
|
316 |
+
for extension in ['mp3', 'flac', 'wav', 'm4a']:
|
317 |
+
utterances.extend(Path(utterances_root).glob("**/*.%s" % extension))
|
318 |
+
utterances = [fpath.relative_to(utterances_root) for fpath in utterances]
|
319 |
+
self.repopulate_box(self.utterance_box, utterances, random)
|
320 |
+
|
321 |
+
def browser_select_next(self):
|
322 |
+
index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box)
|
323 |
+
self.utterance_box.setCurrentIndex(index)
|
324 |
+
|
325 |
+
@property
|
326 |
+
def current_encoder_fpath(self):
|
327 |
+
return self.encoder_box.itemData(self.encoder_box.currentIndex())
|
328 |
+
|
329 |
+
@property
|
330 |
+
def current_synthesizer_fpath(self):
|
331 |
+
return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex())
|
332 |
+
|
333 |
+
@property
|
334 |
+
def current_vocoder_fpath(self):
|
335 |
+
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
|
336 |
+
|
337 |
+
def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
|
338 |
+
vocoder_models_dir: Path):
|
339 |
+
# Encoder
|
340 |
+
encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
|
341 |
+
if len(encoder_fpaths) == 0:
|
342 |
+
raise Exception("No encoder models found in %s" % encoder_models_dir)
|
343 |
+
self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
|
344 |
+
|
345 |
+
# Synthesizer
|
346 |
+
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
|
347 |
+
if len(synthesizer_fpaths) == 0:
|
348 |
+
raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
|
349 |
+
self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
|
350 |
+
|
351 |
+
# Vocoder
|
352 |
+
vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt"))
|
353 |
+
vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
|
354 |
+
self.repopulate_box(self.vocoder_box, vocoder_items)
|
355 |
+
|
356 |
+
@property
|
357 |
+
def selected_utterance(self):
|
358 |
+
return self.utterance_history.itemData(self.utterance_history.currentIndex())
|
359 |
+
|
360 |
+
def register_utterance(self, utterance: Utterance):
|
361 |
+
self.utterance_history.blockSignals(True)
|
362 |
+
self.utterance_history.insertItem(0, utterance.name, utterance)
|
363 |
+
self.utterance_history.setCurrentIndex(0)
|
364 |
+
self.utterance_history.blockSignals(False)
|
365 |
+
|
366 |
+
if len(self.utterance_history) > self.max_saved_utterances:
|
367 |
+
self.utterance_history.removeItem(self.max_saved_utterances)
|
368 |
+
|
369 |
+
self.play_button.setDisabled(False)
|
370 |
+
self.generate_button.setDisabled(False)
|
371 |
+
self.synthesize_button.setDisabled(False)
|
372 |
+
|
373 |
+
def log(self, line, mode="newline"):
|
374 |
+
if mode == "newline":
|
375 |
+
self.logs.append(line)
|
376 |
+
if len(self.logs) > self.max_log_lines:
|
377 |
+
del self.logs[0]
|
378 |
+
elif mode == "append":
|
379 |
+
self.logs[-1] += line
|
380 |
+
elif mode == "overwrite":
|
381 |
+
self.logs[-1] = line
|
382 |
+
log_text = '\n'.join(self.logs)
|
383 |
+
|
384 |
+
self.log_window.setText(log_text)
|
385 |
+
self.app.processEvents()
|
386 |
+
|
387 |
+
def set_loading(self, value, maximum=1):
|
388 |
+
self.loading_bar.setValue(value * 100)
|
389 |
+
self.loading_bar.setMaximum(maximum * 100)
|
390 |
+
self.loading_bar.setTextVisible(value != 0)
|
391 |
+
self.app.processEvents()
|
392 |
+
|
393 |
+
def populate_gen_options(self, seed, trim_silences):
|
394 |
+
if seed is not None:
|
395 |
+
self.random_seed_checkbox.setChecked(True)
|
396 |
+
self.seed_textbox.setText(str(seed))
|
397 |
+
self.seed_textbox.setEnabled(True)
|
398 |
+
else:
|
399 |
+
self.random_seed_checkbox.setChecked(False)
|
400 |
+
self.seed_textbox.setText(str(0))
|
401 |
+
self.seed_textbox.setEnabled(False)
|
402 |
+
|
403 |
+
if not trim_silences:
|
404 |
+
self.trim_silences_checkbox.setChecked(False)
|
405 |
+
self.trim_silences_checkbox.setDisabled(True)
|
406 |
+
|
407 |
+
def update_seed_textbox(self):
|
408 |
+
if self.random_seed_checkbox.isChecked():
|
409 |
+
self.seed_textbox.setEnabled(True)
|
410 |
+
else:
|
411 |
+
self.seed_textbox.setEnabled(False)
|
412 |
+
|
413 |
+
def reset_interface(self):
|
414 |
+
self.draw_embed(None, None, "current")
|
415 |
+
self.draw_embed(None, None, "generated")
|
416 |
+
self.draw_spec(None, "current")
|
417 |
+
self.draw_spec(None, "generated")
|
418 |
+
self.draw_umap_projections(set())
|
419 |
+
self.set_loading(0)
|
420 |
+
self.play_button.setDisabled(True)
|
421 |
+
self.generate_button.setDisabled(True)
|
422 |
+
self.synthesize_button.setDisabled(True)
|
423 |
+
self.vocode_button.setDisabled(True)
|
424 |
+
self.replay_wav_button.setDisabled(True)
|
425 |
+
self.export_wav_button.setDisabled(True)
|
426 |
+
[self.log("") for _ in range(self.max_log_lines)]
|
427 |
+
|
428 |
+
def __init__(self):
|
429 |
+
## Initialize the application
|
430 |
+
self.app = QApplication(sys.argv)
|
431 |
+
super().__init__(None)
|
432 |
+
self.setWindowTitle("SV2TTS toolbox")
|
433 |
+
|
434 |
+
|
435 |
+
## Main layouts
|
436 |
+
# Root
|
437 |
+
root_layout = QGridLayout()
|
438 |
+
self.setLayout(root_layout)
|
439 |
+
|
440 |
+
# Browser
|
441 |
+
browser_layout = QGridLayout()
|
442 |
+
root_layout.addLayout(browser_layout, 0, 0, 1, 2)
|
443 |
+
|
444 |
+
# Generation
|
445 |
+
gen_layout = QVBoxLayout()
|
446 |
+
root_layout.addLayout(gen_layout, 0, 2, 1, 2)
|
447 |
+
|
448 |
+
# Projections
|
449 |
+
self.projections_layout = QVBoxLayout()
|
450 |
+
root_layout.addLayout(self.projections_layout, 1, 0, 1, 1)
|
451 |
+
|
452 |
+
# Visualizations
|
453 |
+
vis_layout = QVBoxLayout()
|
454 |
+
root_layout.addLayout(vis_layout, 1, 1, 1, 3)
|
455 |
+
|
456 |
+
|
457 |
+
## Projections
|
458 |
+
# UMap
|
459 |
+
fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
|
460 |
+
fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
|
461 |
+
self.projections_layout.addWidget(FigureCanvas(fig))
|
462 |
+
self.umap_hot = False
|
463 |
+
self.clear_button = QPushButton("Clear")
|
464 |
+
self.projections_layout.addWidget(self.clear_button)
|
465 |
+
|
466 |
+
|
467 |
+
## Browser
|
468 |
+
# Dataset, speaker and utterance selection
|
469 |
+
i = 0
|
470 |
+
self.dataset_box = QComboBox()
|
471 |
+
browser_layout.addWidget(QLabel("<b>Dataset</b>"), i, 0)
|
472 |
+
browser_layout.addWidget(self.dataset_box, i + 1, 0)
|
473 |
+
self.speaker_box = QComboBox()
|
474 |
+
browser_layout.addWidget(QLabel("<b>Speaker</b>"), i, 1)
|
475 |
+
browser_layout.addWidget(self.speaker_box, i + 1, 1)
|
476 |
+
self.utterance_box = QComboBox()
|
477 |
+
browser_layout.addWidget(QLabel("<b>Utterance</b>"), i, 2)
|
478 |
+
browser_layout.addWidget(self.utterance_box, i + 1, 2)
|
479 |
+
self.browser_load_button = QPushButton("Load")
|
480 |
+
browser_layout.addWidget(self.browser_load_button, i + 1, 3)
|
481 |
+
i += 2
|
482 |
+
|
483 |
+
# Random buttons
|
484 |
+
self.random_dataset_button = QPushButton("Random")
|
485 |
+
browser_layout.addWidget(self.random_dataset_button, i, 0)
|
486 |
+
self.random_speaker_button = QPushButton("Random")
|
487 |
+
browser_layout.addWidget(self.random_speaker_button, i, 1)
|
488 |
+
self.random_utterance_button = QPushButton("Random")
|
489 |
+
browser_layout.addWidget(self.random_utterance_button, i, 2)
|
490 |
+
self.auto_next_checkbox = QCheckBox("Auto select next")
|
491 |
+
self.auto_next_checkbox.setChecked(True)
|
492 |
+
browser_layout.addWidget(self.auto_next_checkbox, i, 3)
|
493 |
+
i += 1
|
494 |
+
|
495 |
+
# Utterance box
|
496 |
+
browser_layout.addWidget(QLabel("<b>Use embedding from:</b>"), i, 0)
|
497 |
+
self.utterance_history = QComboBox()
|
498 |
+
browser_layout.addWidget(self.utterance_history, i, 1, 1, 3)
|
499 |
+
i += 1
|
500 |
+
|
501 |
+
# Random & next utterance buttons
|
502 |
+
self.browser_browse_button = QPushButton("Browse")
|
503 |
+
browser_layout.addWidget(self.browser_browse_button, i, 0)
|
504 |
+
self.record_button = QPushButton("Record")
|
505 |
+
browser_layout.addWidget(self.record_button, i, 1)
|
506 |
+
self.play_button = QPushButton("Play")
|
507 |
+
browser_layout.addWidget(self.play_button, i, 2)
|
508 |
+
self.stop_button = QPushButton("Stop")
|
509 |
+
browser_layout.addWidget(self.stop_button, i, 3)
|
510 |
+
i += 1
|
511 |
+
|
512 |
+
|
513 |
+
# Model and audio output selection
|
514 |
+
self.encoder_box = QComboBox()
|
515 |
+
browser_layout.addWidget(QLabel("<b>Encoder</b>"), i, 0)
|
516 |
+
browser_layout.addWidget(self.encoder_box, i + 1, 0)
|
517 |
+
self.synthesizer_box = QComboBox()
|
518 |
+
browser_layout.addWidget(QLabel("<b>Synthesizer</b>"), i, 1)
|
519 |
+
browser_layout.addWidget(self.synthesizer_box, i + 1, 1)
|
520 |
+
self.vocoder_box = QComboBox()
|
521 |
+
browser_layout.addWidget(QLabel("<b>Vocoder</b>"), i, 2)
|
522 |
+
browser_layout.addWidget(self.vocoder_box, i + 1, 2)
|
523 |
+
|
524 |
+
self.audio_out_devices_cb=QComboBox()
|
525 |
+
browser_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 3)
|
526 |
+
browser_layout.addWidget(self.audio_out_devices_cb, i + 1, 3)
|
527 |
+
i += 2
|
528 |
+
|
529 |
+
#Replay & Save Audio
|
530 |
+
browser_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
|
531 |
+
self.waves_cb = QComboBox()
|
532 |
+
self.waves_cb_model = QStringListModel()
|
533 |
+
self.waves_cb.setModel(self.waves_cb_model)
|
534 |
+
self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
|
535 |
+
browser_layout.addWidget(self.waves_cb, i, 1)
|
536 |
+
self.replay_wav_button = QPushButton("Replay")
|
537 |
+
self.replay_wav_button.setToolTip("Replay last generated vocoder")
|
538 |
+
browser_layout.addWidget(self.replay_wav_button, i, 2)
|
539 |
+
self.export_wav_button = QPushButton("Export")
|
540 |
+
self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
|
541 |
+
browser_layout.addWidget(self.export_wav_button, i, 3)
|
542 |
+
i += 1
|
543 |
+
|
544 |
+
|
545 |
+
## Embed & spectrograms
|
546 |
+
vis_layout.addStretch()
|
547 |
+
|
548 |
+
gridspec_kw = {"width_ratios": [1, 4]}
|
549 |
+
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
550 |
+
gridspec_kw=gridspec_kw)
|
551 |
+
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
|
552 |
+
vis_layout.addWidget(FigureCanvas(fig))
|
553 |
+
|
554 |
+
fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
555 |
+
gridspec_kw=gridspec_kw)
|
556 |
+
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
|
557 |
+
vis_layout.addWidget(FigureCanvas(fig))
|
558 |
+
|
559 |
+
for ax in self.current_ax.tolist() + self.gen_ax.tolist():
|
560 |
+
ax.set_facecolor("#F0F0F0")
|
561 |
+
for side in ["top", "right", "bottom", "left"]:
|
562 |
+
ax.spines[side].set_visible(False)
|
563 |
+
|
564 |
+
|
565 |
+
## Generation
|
566 |
+
self.text_prompt = QPlainTextEdit(default_text)
|
567 |
+
gen_layout.addWidget(self.text_prompt, stretch=1)
|
568 |
+
|
569 |
+
self.generate_button = QPushButton("Synthesize and vocode")
|
570 |
+
gen_layout.addWidget(self.generate_button)
|
571 |
+
|
572 |
+
layout = QHBoxLayout()
|
573 |
+
self.synthesize_button = QPushButton("Synthesize only")
|
574 |
+
layout.addWidget(self.synthesize_button)
|
575 |
+
self.vocode_button = QPushButton("Vocode only")
|
576 |
+
layout.addWidget(self.vocode_button)
|
577 |
+
gen_layout.addLayout(layout)
|
578 |
+
|
579 |
+
layout_seed = QGridLayout()
|
580 |
+
self.random_seed_checkbox = QCheckBox("Random seed:")
|
581 |
+
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
|
582 |
+
layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
|
583 |
+
self.seed_textbox = QLineEdit()
|
584 |
+
self.seed_textbox.setMaximumWidth(80)
|
585 |
+
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
586 |
+
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
587 |
+
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
588 |
+
" This feature requires `webrtcvad` to be installed.")
|
589 |
+
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
590 |
+
gen_layout.addLayout(layout_seed)
|
591 |
+
|
592 |
+
self.loading_bar = QProgressBar()
|
593 |
+
gen_layout.addWidget(self.loading_bar)
|
594 |
+
|
595 |
+
self.log_window = QLabel()
|
596 |
+
self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft)
|
597 |
+
gen_layout.addWidget(self.log_window)
|
598 |
+
self.logs = []
|
599 |
+
gen_layout.addStretch()
|
600 |
+
|
601 |
+
|
602 |
+
## Set the size of the window and of the elements
|
603 |
+
max_size = QDesktopWidget().availableGeometry(self).size() * 0.8
|
604 |
+
self.resize(max_size)
|
605 |
+
|
606 |
+
## Finalize the display
|
607 |
+
self.reset_interface()
|
608 |
+
self.show()
|
609 |
+
|
610 |
+
def start(self):
|
611 |
+
self.app.exec_()
|
toolbox/utterance.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth")
|
4 |
+
Utterance.__eq__ = lambda x, y: x.name == y.name
|
5 |
+
Utterance.__hash__ = lambda x: hash(x.name)
|
utils/__init__.py
ADDED
File without changes
|
utils/argutils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
_type_priorities = [ # In decreasing order
|
6 |
+
Path,
|
7 |
+
str,
|
8 |
+
int,
|
9 |
+
float,
|
10 |
+
bool,
|
11 |
+
]
|
12 |
+
|
13 |
+
def _priority(o):
|
14 |
+
p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None)
|
15 |
+
if p is not None:
|
16 |
+
return p
|
17 |
+
p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None)
|
18 |
+
if p is not None:
|
19 |
+
return p
|
20 |
+
return len(_type_priorities)
|
21 |
+
|
22 |
+
def print_args(args: argparse.Namespace, parser=None):
|
23 |
+
args = vars(args)
|
24 |
+
if parser is None:
|
25 |
+
priorities = list(map(_priority, args.values()))
|
26 |
+
else:
|
27 |
+
all_params = [a.dest for g in parser._action_groups for a in g._group_actions ]
|
28 |
+
priority = lambda p: all_params.index(p) if p in all_params else len(all_params)
|
29 |
+
priorities = list(map(priority, args.keys()))
|
30 |
+
|
31 |
+
pad = max(map(len, args.keys())) + 3
|
32 |
+
indices = np.lexsort((list(args.keys()), priorities))
|
33 |
+
items = list(args.items())
|
34 |
+
|
35 |
+
print("Arguments:")
|
36 |
+
for i in indices:
|
37 |
+
param, value = items[i]
|
38 |
+
print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value))
|
39 |
+
print("")
|
40 |
+
|
utils/logmmse.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The MIT License (MIT)
|
2 |
+
#
|
3 |
+
# Copyright (c) 2015 braindead
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
#
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
#
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
#
|
23 |
+
#
|
24 |
+
# This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
|
25 |
+
# simply modified the interface to meet my needs.
|
26 |
+
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
import math
|
30 |
+
from scipy.special import expn
|
31 |
+
from collections import namedtuple
|
32 |
+
|
33 |
+
NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
|
34 |
+
|
35 |
+
|
36 |
+
def profile_noise(noise, sampling_rate, window_size=0):
|
37 |
+
"""
|
38 |
+
Creates a profile of the noise in a given waveform.
|
39 |
+
|
40 |
+
:param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
|
41 |
+
:param sampling_rate: the sampling rate of the audio
|
42 |
+
:param window_size: the size of the window the logmmse algorithm operates on. A default value
|
43 |
+
will be picked if left as 0.
|
44 |
+
:return: a NoiseProfile object
|
45 |
+
"""
|
46 |
+
noise, dtype = to_float(noise)
|
47 |
+
noise += np.finfo(np.float64).eps
|
48 |
+
|
49 |
+
if window_size == 0:
|
50 |
+
window_size = int(math.floor(0.02 * sampling_rate))
|
51 |
+
|
52 |
+
if window_size % 2 == 1:
|
53 |
+
window_size = window_size + 1
|
54 |
+
|
55 |
+
perc = 50
|
56 |
+
len1 = int(math.floor(window_size * perc / 100))
|
57 |
+
len2 = int(window_size - len1)
|
58 |
+
|
59 |
+
win = np.hanning(window_size)
|
60 |
+
win = win * len2 / np.sum(win)
|
61 |
+
n_fft = 2 * window_size
|
62 |
+
|
63 |
+
noise_mean = np.zeros(n_fft)
|
64 |
+
n_frames = len(noise) // window_size
|
65 |
+
for j in range(0, window_size * n_frames, window_size):
|
66 |
+
noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
|
67 |
+
noise_mu2 = (noise_mean / n_frames) ** 2
|
68 |
+
|
69 |
+
return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
|
70 |
+
|
71 |
+
|
72 |
+
def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
|
73 |
+
"""
|
74 |
+
Cleans the noise from a speech waveform given a noise profile. The waveform must have the
|
75 |
+
same sampling rate as the one used to create the noise profile.
|
76 |
+
|
77 |
+
:param wav: a speech waveform as a numpy array of floats or ints.
|
78 |
+
:param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
|
79 |
+
the same) waveform.
|
80 |
+
:param eta: voice threshold for noise update. While the voice activation detection value is
|
81 |
+
below this threshold, the noise profile will be continuously updated throughout the audio.
|
82 |
+
Set to 0 to disable updating the noise profile.
|
83 |
+
:return: the clean wav as a numpy array of floats or ints of the same length.
|
84 |
+
"""
|
85 |
+
wav, dtype = to_float(wav)
|
86 |
+
wav += np.finfo(np.float64).eps
|
87 |
+
p = noise_profile
|
88 |
+
|
89 |
+
nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
|
90 |
+
x_final = np.zeros(nframes * p.len2)
|
91 |
+
|
92 |
+
aa = 0.98
|
93 |
+
mu = 0.98
|
94 |
+
ksi_min = 10 ** (-25 / 10)
|
95 |
+
|
96 |
+
x_old = np.zeros(p.len1)
|
97 |
+
xk_prev = np.zeros(p.len1)
|
98 |
+
noise_mu2 = p.noise_mu2
|
99 |
+
for k in range(0, nframes * p.len2, p.len2):
|
100 |
+
insign = p.win * wav[k:k + p.window_size]
|
101 |
+
|
102 |
+
spec = np.fft.fft(insign, p.n_fft, axis=0)
|
103 |
+
sig = np.absolute(spec)
|
104 |
+
sig2 = sig ** 2
|
105 |
+
|
106 |
+
gammak = np.minimum(sig2 / noise_mu2, 40)
|
107 |
+
|
108 |
+
if xk_prev.all() == 0:
|
109 |
+
ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
|
110 |
+
else:
|
111 |
+
ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
|
112 |
+
ksi = np.maximum(ksi_min, ksi)
|
113 |
+
|
114 |
+
log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
|
115 |
+
vad_decision = np.sum(log_sigma_k) / p.window_size
|
116 |
+
if vad_decision < eta:
|
117 |
+
noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
|
118 |
+
|
119 |
+
a = ksi / (1 + ksi)
|
120 |
+
vk = a * gammak
|
121 |
+
ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
|
122 |
+
hw = a * np.exp(ei_vk)
|
123 |
+
sig = sig * hw
|
124 |
+
xk_prev = sig ** 2
|
125 |
+
xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
|
126 |
+
xi_w = np.real(xi_w)
|
127 |
+
|
128 |
+
x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
|
129 |
+
x_old = xi_w[p.len1:p.window_size]
|
130 |
+
|
131 |
+
output = from_float(x_final, dtype)
|
132 |
+
output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
|
133 |
+
return output
|
134 |
+
|
135 |
+
|
136 |
+
## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
|
137 |
+
## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
|
138 |
+
## webrctvad
|
139 |
+
# def vad(wav, sampling_rate, eta=0.15, window_size=0):
|
140 |
+
# """
|
141 |
+
# TODO: fix doc
|
142 |
+
# Creates a profile of the noise in a given waveform.
|
143 |
+
#
|
144 |
+
# :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
|
145 |
+
# :param sampling_rate: the sampling rate of the audio
|
146 |
+
# :param window_size: the size of the window the logmmse algorithm operates on. A default value
|
147 |
+
# will be picked if left as 0.
|
148 |
+
# :param eta: voice threshold for noise update. While the voice activation detection value is
|
149 |
+
# below this threshold, the noise profile will be continuously updated throughout the audio.
|
150 |
+
# Set to 0 to disable updating the noise profile.
|
151 |
+
# """
|
152 |
+
# wav, dtype = to_float(wav)
|
153 |
+
# wav += np.finfo(np.float64).eps
|
154 |
+
#
|
155 |
+
# if window_size == 0:
|
156 |
+
# window_size = int(math.floor(0.02 * sampling_rate))
|
157 |
+
#
|
158 |
+
# if window_size % 2 == 1:
|
159 |
+
# window_size = window_size + 1
|
160 |
+
#
|
161 |
+
# perc = 50
|
162 |
+
# len1 = int(math.floor(window_size * perc / 100))
|
163 |
+
# len2 = int(window_size - len1)
|
164 |
+
#
|
165 |
+
# win = np.hanning(window_size)
|
166 |
+
# win = win * len2 / np.sum(win)
|
167 |
+
# n_fft = 2 * window_size
|
168 |
+
#
|
169 |
+
# wav_mean = np.zeros(n_fft)
|
170 |
+
# n_frames = len(wav) // window_size
|
171 |
+
# for j in range(0, window_size * n_frames, window_size):
|
172 |
+
# wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
|
173 |
+
# noise_mu2 = (wav_mean / n_frames) ** 2
|
174 |
+
#
|
175 |
+
# wav, dtype = to_float(wav)
|
176 |
+
# wav += np.finfo(np.float64).eps
|
177 |
+
#
|
178 |
+
# nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
|
179 |
+
# vad = np.zeros(nframes * len2, dtype=np.bool)
|
180 |
+
#
|
181 |
+
# aa = 0.98
|
182 |
+
# mu = 0.98
|
183 |
+
# ksi_min = 10 ** (-25 / 10)
|
184 |
+
#
|
185 |
+
# xk_prev = np.zeros(len1)
|
186 |
+
# noise_mu2 = noise_mu2
|
187 |
+
# for k in range(0, nframes * len2, len2):
|
188 |
+
# insign = win * wav[k:k + window_size]
|
189 |
+
#
|
190 |
+
# spec = np.fft.fft(insign, n_fft, axis=0)
|
191 |
+
# sig = np.absolute(spec)
|
192 |
+
# sig2 = sig ** 2
|
193 |
+
#
|
194 |
+
# gammak = np.minimum(sig2 / noise_mu2, 40)
|
195 |
+
#
|
196 |
+
# if xk_prev.all() == 0:
|
197 |
+
# ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
|
198 |
+
# else:
|
199 |
+
# ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
|
200 |
+
# ksi = np.maximum(ksi_min, ksi)
|
201 |
+
#
|
202 |
+
# log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
|
203 |
+
# vad_decision = np.sum(log_sigma_k) / window_size
|
204 |
+
# if vad_decision < eta:
|
205 |
+
# noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
|
206 |
+
# print(vad_decision)
|
207 |
+
#
|
208 |
+
# a = ksi / (1 + ksi)
|
209 |
+
# vk = a * gammak
|
210 |
+
# ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
|
211 |
+
# hw = a * np.exp(ei_vk)
|
212 |
+
# sig = sig * hw
|
213 |
+
# xk_prev = sig ** 2
|
214 |
+
#
|
215 |
+
# vad[k:k + len2] = vad_decision >= eta
|
216 |
+
#
|
217 |
+
# vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
|
218 |
+
# return vad
|
219 |
+
|
220 |
+
|
221 |
+
def to_float(_input):
|
222 |
+
if _input.dtype == np.float64:
|
223 |
+
return _input, _input.dtype
|
224 |
+
elif _input.dtype == np.float32:
|
225 |
+
return _input.astype(np.float64), _input.dtype
|
226 |
+
elif _input.dtype == np.uint8:
|
227 |
+
return (_input - 128) / 128., _input.dtype
|
228 |
+
elif _input.dtype == np.int16:
|
229 |
+
return _input / 32768., _input.dtype
|
230 |
+
elif _input.dtype == np.int32:
|
231 |
+
return _input / 2147483648., _input.dtype
|
232 |
+
raise ValueError('Unsupported wave file format')
|
233 |
+
|
234 |
+
|
235 |
+
def from_float(_input, dtype):
|
236 |
+
if dtype == np.float64:
|
237 |
+
return _input, np.float64
|
238 |
+
elif dtype == np.float32:
|
239 |
+
return _input.astype(np.float32)
|
240 |
+
elif dtype == np.uint8:
|
241 |
+
return ((_input * 128) + 128).astype(np.uint8)
|
242 |
+
elif dtype == np.int16:
|
243 |
+
return (_input * 32768).astype(np.int16)
|
244 |
+
elif dtype == np.int32:
|
245 |
+
print(_input)
|
246 |
+
return (_input * 2147483648).astype(np.int32)
|
247 |
+
raise ValueError('Unsupported wave file format')
|
utils/modelutils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
def check_model_paths(encoder_path: Path, synthesizer_path: Path, vocoder_path: Path):
|
4 |
+
# This function tests the model paths and makes sure at least one is valid.
|
5 |
+
if encoder_path.is_file() or encoder_path.is_dir():
|
6 |
+
return
|
7 |
+
if synthesizer_path.is_file() or synthesizer_path.is_dir():
|
8 |
+
return
|
9 |
+
if vocoder_path.is_file() or vocoder_path.is_dir():
|
10 |
+
return
|
11 |
+
|
12 |
+
# If none of the paths exist, remind the user to download models if needed
|
13 |
+
print("********************************************************************************")
|
14 |
+
print("Error: Model files not found. Follow these instructions to get and install the models:")
|
15 |
+
print("https://github.com/CorentinJ/Real-Time-Voice-Cloning/wiki/Pretrained-models")
|
16 |
+
print("********************************************************************************\n")
|
17 |
+
quit(-1)
|
utils/profiler.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import perf_counter as timer
|
2 |
+
from collections import OrderedDict
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class Profiler:
|
7 |
+
def __init__(self, summarize_every=5, disabled=False):
|
8 |
+
self.last_tick = timer()
|
9 |
+
self.logs = OrderedDict()
|
10 |
+
self.summarize_every = summarize_every
|
11 |
+
self.disabled = disabled
|
12 |
+
|
13 |
+
def tick(self, name):
|
14 |
+
if self.disabled:
|
15 |
+
return
|
16 |
+
|
17 |
+
# Log the time needed to execute that function
|
18 |
+
if not name in self.logs:
|
19 |
+
self.logs[name] = []
|
20 |
+
if len(self.logs[name]) >= self.summarize_every:
|
21 |
+
self.summarize()
|
22 |
+
self.purge_logs()
|
23 |
+
self.logs[name].append(timer() - self.last_tick)
|
24 |
+
|
25 |
+
self.reset_timer()
|
26 |
+
|
27 |
+
def purge_logs(self):
|
28 |
+
for name in self.logs:
|
29 |
+
self.logs[name].clear()
|
30 |
+
|
31 |
+
def reset_timer(self):
|
32 |
+
self.last_tick = timer()
|
33 |
+
|
34 |
+
def summarize(self):
|
35 |
+
n = max(map(len, self.logs.values()))
|
36 |
+
assert n == self.summarize_every
|
37 |
+
print("\nAverage execution time over %d steps:" % n)
|
38 |
+
|
39 |
+
name_msgs = ["%s (%d/%d):" % (name, len(deltas), n) for name, deltas in self.logs.items()]
|
40 |
+
pad = max(map(len, name_msgs))
|
41 |
+
for name_msg, deltas in zip(name_msgs, self.logs.values()):
|
42 |
+
print(" %s mean: %4.0fms std: %4.0fms" %
|
43 |
+
(name_msg.ljust(pad), np.mean(deltas) * 1000, np.std(deltas) * 1000))
|
44 |
+
print("", flush=True)
|
45 |
+
|