Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,202 Bytes
6faeba1 6a79837 6faeba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import numpy
import soundfile as sf
import torch
from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface
from Modules.Aligner.Aligner import Aligner
from Modules.ToucanTTS.DurationCalculator import DurationCalculator
from Modules.ToucanTTS.EnergyCalculator import EnergyCalculator
from Modules.ToucanTTS.PitchCalculator import Parselmouth
from Preprocessing.AudioPreprocessor import AudioPreprocessor
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
from Preprocessing.articulatory_features import get_feature_to_index_lookup
from Utility.storage_config import MODELS_DIR
from Utility.utils import float2pcm
class UtteranceCloner:
"""
Clone the prosody of an utterance, but exchange the speaker (or don't)
Useful for Privacy Applications
"""
def __init__(self, model_id, device, language="eng"):
self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id)
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False)
self.tf = ArticulatoryCombinedTextFrontend(language=language, device=device)
self.device = device
acoustic_checkpoint_path = os.path.join(MODELS_DIR, "Aligner", "aligner.pt")
self.aligner_weights = torch.load(acoustic_checkpoint_path, map_location=device)["asr_model"]
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
(self.get_speech_timestamps, _, _, _, _) = utils
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
# this to false globally during model loading rather than using inference_mode or no_grad
self.acoustic_model = Aligner()
self.acoustic_model = self.acoustic_model.to(self.device)
self.acoustic_model.load_state_dict(self.aligner_weights)
self.acoustic_model.eval()
self.parsel = Parselmouth(reduction_factor=1, fs=16000)
self.energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
self.dc = DurationCalculator(reduction_factor=1)
def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=True):
if on_line_fine_tune:
self.acoustic_model.load_state_dict(self.aligner_weights)
self.acoustic_model.eval()
wave, sr = sf.read(ref_audio_path)
if self.tf.language != lang:
self.tf = ArticulatoryCombinedTextFrontend(language=lang, device=self.device)
if self.ap.input_sr != sr:
self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=False)
try:
norm_wave = self.ap.normalize_audio(audio=wave)
except ValueError:
print('Something went wrong, the reference wave might be too short.')
raise RuntimeError
with torch.inference_mode():
speech_timestamps = self.get_speech_timestamps(norm_wave, self.silero_model, sampling_rate=16000)
if len(speech_timestamps) == 0:
speech_timestamps = [{'start': 0, 'end': len(norm_wave)}]
start_silence = speech_timestamps[0]['start']
end_silence = len(norm_wave) - speech_timestamps[-1]['end']
norm_wave = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
norm_wave_length = torch.LongTensor([len(norm_wave)])
text = self.tf.string_to_tensor(transcript, handle_missing=False).squeeze(0)
features = self.ap.audio_to_mel_spec_tensor(audio=norm_wave, explicit_sampling_rate=16000).transpose(0, 1)
feature_length = torch.LongTensor([len(features)]).numpy()
if on_line_fine_tune:
# we fine-tune the aligner for a couple steps using SGD. This makes cloning pretty slow, but the results are greatly improved.
steps = 4
tokens = self.tf.text_vectors_to_id_sequence(text_vector=text) # we need an ID sequence for training rather than a sequence of phonological features
tokens = torch.LongTensor(tokens).squeeze().to(self.device)
tokens_len = torch.LongTensor([len(tokens)]).to(self.device)
mel = features.unsqueeze(0).to(self.device)
mel_len = torch.LongTensor([len(mel[0])]).to(self.device)
# actual fine-tuning starts here
optim_asr = torch.optim.Adam(self.acoustic_model.parameters(), lr=0.00001)
self.acoustic_model.train()
for _ in range(steps):
pred = self.acoustic_model(mel.clone())
loss = self.acoustic_model.ctc_loss(pred.transpose(0, 1).log_softmax(2), tokens, mel_len, tokens_len)
print(loss.item())
optim_asr.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.acoustic_model.parameters(), 1.0)
optim_asr.step()
self.acoustic_model.eval()
# We deal with the word boundaries by having 2 versions of text: with and without word boundaries.
# We note the index of word boundaries and insert durations of 0 afterwards
text_without_word_boundaries = list()
indexes_of_word_boundaries = list()
for phoneme_index, vector in enumerate(text):
if vector[get_feature_to_index_lookup()["word-boundary"]] == 0:
text_without_word_boundaries.append(vector.numpy().tolist())
else:
indexes_of_word_boundaries.append(phoneme_index)
matrix_without_word_boundaries = torch.Tensor(text_without_word_boundaries)
alignment_path = self.acoustic_model.inference(features=features.to(self.device),
tokens=matrix_without_word_boundaries.to(self.device),
return_ctc=False)
duration = self.dc(torch.LongTensor(alignment_path), vis=None).cpu()
for index_of_word_boundary in indexes_of_word_boundaries:
duration = torch.cat([duration[:index_of_word_boundary],
torch.LongTensor([0]), # insert a 0 duration wherever there is a word boundary
duration[index_of_word_boundary:]])
energy = self.energy_calc(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=feature_length,
text=text,
durations=duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
pitch = self.parsel(input_waves=norm_wave.unsqueeze(0),
input_waves_lengths=norm_wave_length,
feats_lengths=feature_length,
text=text,
durations=duration.unsqueeze(0),
durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
return duration, pitch, energy, start_silence, end_silence
def clone_utterance(self,
path_to_reference_audio_for_intonation,
path_to_reference_audio_for_voice,
transcription_of_intonation_reference,
filename_of_result=None,
lang="eng"):
"""
What is said in path_to_reference_audio_for_intonation has to match the text in the reference_transcription exactly!
"""
self.tts.set_utterance_embedding(path_to_reference_audio=path_to_reference_audio_for_voice)
duration, pitch, energy, silence_frames_start, silence_frames_end = self.extract_prosody(transcription_of_intonation_reference,
path_to_reference_audio_for_intonation,
lang=lang)
self.tts.set_language(lang)
start_sil = numpy.zeros([int(silence_frames_start * 1.5)]) # timestamps are from 16kHz, but now we're using 24000Hz, so upsampling required
end_sil = numpy.zeros([int(silence_frames_end * 1.5)]) # timestamps are from 16kHz, but now we're using 24000Hz, so upsampling required
cloned_speech, sr = self.tts(transcription_of_intonation_reference, view=False, durations=duration, pitch=pitch, energy=energy)
cloned_utt = numpy.concatenate([start_sil, cloned_speech, end_sil], axis=0)
if filename_of_result is not None:
sf.write(file=filename_of_result, data=float2pcm(cloned_utt), samplerate=sr, subtype="PCM_16")
return cloned_utt, sr
|