integrate new autoregressive model and fix new diffusion bug
Browse files- api.py +4 -3
- api_new_autoregressive.py +245 -0
- do_tts.py +2 -2
- models/diffusion_decoder.py +5 -5
- models/new_autoregressive.py +293 -0
api.py
CHANGED
@@ -117,13 +117,14 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_
|
|
117 |
cond_mels.append(cond_mel)
|
118 |
cond_mels = torch.stack(cond_mels, dim=1)
|
119 |
|
120 |
-
|
121 |
-
|
|
|
122 |
|
123 |
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
|
124 |
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
125 |
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
126 |
-
return denormalize_tacotron_mel(mel)[:,:,:
|
127 |
|
128 |
|
129 |
class TextToSpeech:
|
|
|
117 |
cond_mels.append(cond_mel)
|
118 |
cond_mels = torch.stack(cond_mels, dim=1)
|
119 |
|
120 |
+
output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
121 |
+
output_shape = (mel_codes.shape[0], 100, output_seq_len)
|
122 |
+
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
|
123 |
|
124 |
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
|
125 |
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
126 |
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
127 |
+
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
128 |
|
129 |
|
130 |
class TextToSpeech:
|
api_new_autoregressive.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from urllib import request
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
import progressbar
|
10 |
+
import ocotillo
|
11 |
+
|
12 |
+
from models.diffusion_decoder import DiffusionTts
|
13 |
+
from models.autoregressive import UnifiedVoice
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from models.arch_util import TorchMelSpectrogram
|
17 |
+
from models.new_autoregressive import AutoregressiveCodegen
|
18 |
+
from models.text_voice_clip import VoiceCLIP
|
19 |
+
from models.vocoder import UnivNetGenerator
|
20 |
+
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
|
21 |
+
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
22 |
+
from utils.tokenizer import VoiceBpeTokenizer, lev_distance
|
23 |
+
|
24 |
+
|
25 |
+
pbar = None
|
26 |
+
def download_models():
|
27 |
+
MODELS = {
|
28 |
+
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
|
29 |
+
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
|
30 |
+
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
|
31 |
+
}
|
32 |
+
os.makedirs('.models', exist_ok=True)
|
33 |
+
def show_progress(block_num, block_size, total_size):
|
34 |
+
global pbar
|
35 |
+
if pbar is None:
|
36 |
+
pbar = progressbar.ProgressBar(maxval=total_size)
|
37 |
+
pbar.start()
|
38 |
+
|
39 |
+
downloaded = block_num * block_size
|
40 |
+
if downloaded < total_size:
|
41 |
+
pbar.update(downloaded)
|
42 |
+
else:
|
43 |
+
pbar.finish()
|
44 |
+
pbar = None
|
45 |
+
for model_name, url in MODELS.items():
|
46 |
+
if os.path.exists(f'.models/{model_name}'):
|
47 |
+
continue
|
48 |
+
print(f'Downloading {model_name} from {url}...')
|
49 |
+
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
50 |
+
print('Done.')
|
51 |
+
|
52 |
+
|
53 |
+
def pad_or_truncate(t, length):
|
54 |
+
if t.shape[-1] == length:
|
55 |
+
return t
|
56 |
+
elif t.shape[-1] < length:
|
57 |
+
return F.pad(t, (0, length-t.shape[-1]))
|
58 |
+
else:
|
59 |
+
return t[..., :length]
|
60 |
+
|
61 |
+
|
62 |
+
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
|
63 |
+
"""
|
64 |
+
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
65 |
+
"""
|
66 |
+
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
67 |
+
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
68 |
+
conditioning_free=cond_free, conditioning_free_k=cond_free_k)
|
69 |
+
|
70 |
+
|
71 |
+
def load_conditioning(clip, cond_length=132300):
|
72 |
+
gap = clip.shape[-1] - cond_length
|
73 |
+
if gap < 0:
|
74 |
+
clip = F.pad(clip, pad=(0, abs(gap)))
|
75 |
+
elif gap > 0:
|
76 |
+
rand_start = random.randint(0, gap)
|
77 |
+
clip = clip[:, rand_start:rand_start + cond_length]
|
78 |
+
mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0)
|
79 |
+
return mel_clip.unsqueeze(0).cuda()
|
80 |
+
|
81 |
+
|
82 |
+
def fix_autoregressive_output(codes, stop_token):
|
83 |
+
"""
|
84 |
+
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
|
85 |
+
trained on and what the autoregressive code generator creates (which has no padding or end).
|
86 |
+
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
|
87 |
+
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
|
88 |
+
and copying out the last few codes.
|
89 |
+
|
90 |
+
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
|
91 |
+
"""
|
92 |
+
# Strip off the autoregressive stop token and add padding.
|
93 |
+
stop_token_indices = (codes == stop_token).nonzero()
|
94 |
+
if len(stop_token_indices) == 0:
|
95 |
+
print("No stop tokens found, enjoy that output of yours!")
|
96 |
+
return codes
|
97 |
+
else:
|
98 |
+
codes[stop_token_indices] = 83
|
99 |
+
stm = stop_token_indices.min().item()
|
100 |
+
codes[stm:] = 83
|
101 |
+
if stm - 3 < codes.shape[0]:
|
102 |
+
codes[-3] = 45
|
103 |
+
codes[-2] = 45
|
104 |
+
codes[-1] = 248
|
105 |
+
|
106 |
+
return codes
|
107 |
+
|
108 |
+
|
109 |
+
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1):
|
110 |
+
"""
|
111 |
+
Uses the specified diffusion model to convert discrete codes into a spectrogram.
|
112 |
+
"""
|
113 |
+
with torch.no_grad():
|
114 |
+
cond_mels = []
|
115 |
+
for sample in conditioning_samples:
|
116 |
+
sample = pad_or_truncate(sample, 102400)
|
117 |
+
cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False)
|
118 |
+
cond_mels.append(cond_mel)
|
119 |
+
cond_mels = torch.stack(cond_mels, dim=1)
|
120 |
+
|
121 |
+
output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
122 |
+
output_shape = (mel_codes.shape[0], 100, output_seq_len)
|
123 |
+
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False)
|
124 |
+
|
125 |
+
noise = torch.randn(output_shape, device=mel_codes.device) * temperature
|
126 |
+
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
|
127 |
+
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
|
128 |
+
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
129 |
+
|
130 |
+
|
131 |
+
class TextToSpeech:
|
132 |
+
def __init__(self, autoregressive_batch_size=32):
|
133 |
+
self.autoregressive_batch_size = autoregressive_batch_size
|
134 |
+
self.tokenizer = VoiceBpeTokenizer()
|
135 |
+
download_models()
|
136 |
+
|
137 |
+
self.autoregressive = AutoregressiveCodegen(512, 12).cpu().eval()
|
138 |
+
self.autoregressive.load_state_dict(torch.load('D:\\dlas\\experiments\\train_autoregressive_codegen\\models\\23000_codegen_ema.pth'))
|
139 |
+
|
140 |
+
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
|
141 |
+
text_seq_len=350, text_heads=8,
|
142 |
+
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
|
143 |
+
use_xformers=True).cpu().eval()
|
144 |
+
self.clip.load_state_dict(torch.load('.models/clip.pth'))
|
145 |
+
|
146 |
+
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
147 |
+
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
148 |
+
layer_drop=0, unconditioned_percentage=0).cpu().eval()
|
149 |
+
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
|
150 |
+
|
151 |
+
self.vocoder = UnivNetGenerator().cpu()
|
152 |
+
self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
153 |
+
self.vocoder.eval(inference=True)
|
154 |
+
|
155 |
+
def tts(self, text, voice_samples, k=1,
|
156 |
+
# autoregressive generation parameters follow
|
157 |
+
num_autoregressive_samples=512, temperature=.5, length_penalty=2, repetition_penalty=2.0, top_p=.5,
|
158 |
+
typical_sampling=False, typical_mass=.9,
|
159 |
+
# diffusion generation parameters follow
|
160 |
+
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=.7,):
|
161 |
+
text = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
162 |
+
text = F.pad(text, (0, 1)) # This may not be necessary.
|
163 |
+
|
164 |
+
conds = []
|
165 |
+
if not isinstance(voice_samples, list):
|
166 |
+
voice_samples = [voice_samples]
|
167 |
+
for vs in voice_samples:
|
168 |
+
conds.append(load_conditioning(vs))
|
169 |
+
conds = torch.stack(conds, dim=1)
|
170 |
+
|
171 |
+
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
|
172 |
+
|
173 |
+
with torch.no_grad():
|
174 |
+
samples = []
|
175 |
+
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
|
176 |
+
stop_mel_token = self.autoregressive.STOP_TOKEN
|
177 |
+
self.autoregressive = self.autoregressive.cuda()
|
178 |
+
for _ in tqdm(range(num_batches)):
|
179 |
+
codes = self.autoregressive.generate(conds, text,
|
180 |
+
do_sample=True,
|
181 |
+
top_p=top_p,
|
182 |
+
temperature=temperature,
|
183 |
+
num_return_sequences=self.autoregressive_batch_size,
|
184 |
+
length_penalty=length_penalty,
|
185 |
+
repetition_penalty=repetition_penalty,
|
186 |
+
typical_sampling=typical_sampling,
|
187 |
+
typical_mass=typical_mass)
|
188 |
+
padding_needed = 250 - codes.shape[1]
|
189 |
+
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
190 |
+
samples.append(codes)
|
191 |
+
#self.autoregressive = self.autoregressive.cpu()
|
192 |
+
|
193 |
+
clip_results = []
|
194 |
+
self.clip = self.clip.cuda()
|
195 |
+
for batch in samples:
|
196 |
+
for i in range(batch.shape[0]):
|
197 |
+
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
198 |
+
bad_toks = batch >= 8192
|
199 |
+
batch = batch * bad_toks.logical_not()
|
200 |
+
clip_results.append(self.clip(text.repeat(batch.shape[0], 1), batch, return_loss=False))
|
201 |
+
clip_results = torch.cat(clip_results, dim=0)
|
202 |
+
samples = torch.cat(samples, dim=0)
|
203 |
+
best_results = samples[torch.topk(clip_results, k=k).indices]
|
204 |
+
self.clip = self.clip.cpu()
|
205 |
+
del samples
|
206 |
+
|
207 |
+
print("Performing vocoding..")
|
208 |
+
wav_candidates = []
|
209 |
+
self.diffusion = self.diffusion.cuda()
|
210 |
+
self.vocoder = self.vocoder.cuda()
|
211 |
+
for b in range(best_results.shape[0]):
|
212 |
+
code = best_results[b].unsqueeze(0)
|
213 |
+
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, voice_samples, temperature=diffusion_temperature)
|
214 |
+
wav = self.vocoder.inference(mel)
|
215 |
+
wav_candidates.append(wav.cpu())
|
216 |
+
self.diffusion = self.diffusion.cpu()
|
217 |
+
self.vocoder = self.vocoder.cpu()
|
218 |
+
|
219 |
+
if len(wav_candidates) > 1:
|
220 |
+
return wav_candidates
|
221 |
+
return wav_candidates[0]
|
222 |
+
|
223 |
+
def refine_for_intellibility(self, wav_candidates, corresponding_codes, output_path):
|
224 |
+
"""
|
225 |
+
Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
|
226 |
+
TODO: finish this function
|
227 |
+
:param wav_candidates:
|
228 |
+
:return:
|
229 |
+
"""
|
230 |
+
transcriber = ocotillo.Transcriber(on_cuda=True)
|
231 |
+
transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
|
232 |
+
best = 99999999
|
233 |
+
for i, transcription in enumerate(transcriptions):
|
234 |
+
dist = lev_distance(transcription, args.text.lower())
|
235 |
+
if dist < best:
|
236 |
+
best = dist
|
237 |
+
best_codes = corresponding_codes[i].unsqueeze(0)
|
238 |
+
best_wav = wav_candidates[i]
|
239 |
+
del transcriber
|
240 |
+
torchaudio.save(os.path.join(output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
|
241 |
+
|
242 |
+
# Perform diffusion again with the high-quality diffuser.
|
243 |
+
mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
|
244 |
+
wav = vocoder.inference(mel)
|
245 |
+
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)
|
do_tts.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import torch.nn.functional as F
|
6 |
import torchaudio
|
7 |
|
8 |
-
from
|
9 |
from utils.audio import load_audio
|
10 |
from utils.tokenizer import VoiceBpeTokenizer
|
11 |
|
@@ -28,7 +28,7 @@ if __name__ == '__main__':
|
|
28 |
parser = argparse.ArgumentParser()
|
29 |
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
30 |
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
|
31 |
-
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=
|
32 |
parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
|
33 |
parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
|
34 |
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
|
|
|
5 |
import torch.nn.functional as F
|
6 |
import torchaudio
|
7 |
|
8 |
+
from api_new_autoregressive import TextToSpeech, load_conditioning
|
9 |
from utils.audio import load_audio
|
10 |
from utils.tokenizer import VoiceBpeTokenizer
|
11 |
|
|
|
28 |
parser = argparse.ArgumentParser()
|
29 |
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
30 |
parser.add_argument('-voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dotrice,harris,lescault,otto,atkins,grace,kennard,mol')
|
31 |
+
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=32)
|
32 |
parser.add_argument('-batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16)
|
33 |
parser.add_argument('-num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16)
|
34 |
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
|
models/diffusion_decoder.py
CHANGED
@@ -212,7 +212,7 @@ class DiffusionTts(nn.Module):
|
|
212 |
}
|
213 |
return groups
|
214 |
|
215 |
-
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
|
216 |
# Shuffle aligned_latent to BxCxS format
|
217 |
if is_latent(aligned_conditioning):
|
218 |
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
@@ -227,7 +227,7 @@ class DiffusionTts(nn.Module):
|
|
227 |
cond_emb = conds.mean(dim=-1)
|
228 |
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
229 |
if is_latent(aligned_conditioning):
|
230 |
-
code_emb = self.
|
231 |
else:
|
232 |
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
233 |
code_emb = self.code_converter(code_emb)
|
@@ -240,7 +240,7 @@ class DiffusionTts(nn.Module):
|
|
240 |
device=code_emb.device) < self.unconditioned_percentage
|
241 |
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
242 |
code_emb)
|
243 |
-
expanded_code_emb = F.interpolate(code_emb, size=
|
244 |
|
245 |
if not return_code_pred:
|
246 |
return expanded_code_emb
|
@@ -250,7 +250,6 @@ class DiffusionTts(nn.Module):
|
|
250 |
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
251 |
return expanded_code_emb, mel_pred
|
252 |
|
253 |
-
|
254 |
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
255 |
"""
|
256 |
Apply the model to an input batch.
|
@@ -275,11 +274,12 @@ class DiffusionTts(nn.Module):
|
|
275 |
if precomputed_aligned_embeddings is not None:
|
276 |
code_emb = precomputed_aligned_embeddings
|
277 |
else:
|
278 |
-
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
|
279 |
if is_latent(aligned_conditioning):
|
280 |
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
281 |
else:
|
282 |
unused_params.extend(list(self.latent_converter.parameters()))
|
|
|
283 |
unused_params.append(self.unconditioned_embedding)
|
284 |
|
285 |
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
|
|
212 |
}
|
213 |
return groups
|
214 |
|
215 |
+
def timestep_independent(self, aligned_conditioning, conditioning_input, expected_seq_len, return_code_pred):
|
216 |
# Shuffle aligned_latent to BxCxS format
|
217 |
if is_latent(aligned_conditioning):
|
218 |
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
|
|
227 |
cond_emb = conds.mean(dim=-1)
|
228 |
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
229 |
if is_latent(aligned_conditioning):
|
230 |
+
code_emb = self.autoregressive_latent_converter(aligned_conditioning)
|
231 |
else:
|
232 |
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
233 |
code_emb = self.code_converter(code_emb)
|
|
|
240 |
device=code_emb.device) < self.unconditioned_percentage
|
241 |
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
242 |
code_emb)
|
243 |
+
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest')
|
244 |
|
245 |
if not return_code_pred:
|
246 |
return expanded_code_emb
|
|
|
250 |
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
251 |
return expanded_code_emb, mel_pred
|
252 |
|
|
|
253 |
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
254 |
"""
|
255 |
Apply the model to an input batch.
|
|
|
274 |
if precomputed_aligned_embeddings is not None:
|
275 |
code_emb = precomputed_aligned_embeddings
|
276 |
else:
|
277 |
+
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True)
|
278 |
if is_latent(aligned_conditioning):
|
279 |
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
280 |
else:
|
281 |
unused_params.extend(list(self.latent_converter.parameters()))
|
282 |
+
|
283 |
unused_params.append(self.unconditioned_embedding)
|
284 |
|
285 |
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
models/new_autoregressive.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from transformers import GPT2PreTrainedModel, GPT2Config
|
7 |
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
8 |
+
from x_transformers import TransformerWrapper, Encoder, Decoder
|
9 |
+
|
10 |
+
from models.arch_util import AttentionBlock
|
11 |
+
|
12 |
+
|
13 |
+
class InferenceModel(GPT2PreTrainedModel):
|
14 |
+
"""
|
15 |
+
Implementation of GPT2PreTrainedModel from transformers, which allows us to use their generation library with
|
16 |
+
this transformer.
|
17 |
+
"""
|
18 |
+
def __init__(self, model):
|
19 |
+
super().__init__(GPT2Config())
|
20 |
+
self.transformer = model
|
21 |
+
self.context = None
|
22 |
+
|
23 |
+
def parallelize(self, device_map=None):
|
24 |
+
# Not implemented.
|
25 |
+
pass
|
26 |
+
|
27 |
+
def deparallelize(self):
|
28 |
+
# Not implemented.
|
29 |
+
pass
|
30 |
+
|
31 |
+
def get_output_embeddings(self):
|
32 |
+
assert False, "Unsupported operation."
|
33 |
+
|
34 |
+
def set_output_embeddings(self, new_embeddings):
|
35 |
+
assert False, "Unsupported operation."
|
36 |
+
|
37 |
+
def store_context(self, context):
|
38 |
+
self.context = context
|
39 |
+
|
40 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
41 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
42 |
+
# only last token for inputs_ids if past is defined in kwargs
|
43 |
+
if past:
|
44 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
45 |
+
if token_type_ids is not None:
|
46 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
47 |
+
|
48 |
+
attention_mask = kwargs.get("attention_mask", None)
|
49 |
+
position_ids = kwargs.get("position_ids", None)
|
50 |
+
|
51 |
+
if attention_mask is not None and position_ids is None:
|
52 |
+
# create position_ids on the fly for batch generation
|
53 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
54 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
55 |
+
if past:
|
56 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
57 |
+
else:
|
58 |
+
position_ids = None
|
59 |
+
return {
|
60 |
+
"input_ids": input_ids,
|
61 |
+
"past_key_values": past,
|
62 |
+
"use_cache": kwargs.get("use_cache"),
|
63 |
+
"position_ids": position_ids,
|
64 |
+
"attention_mask": attention_mask,
|
65 |
+
"token_type_ids": token_type_ids,
|
66 |
+
}
|
67 |
+
|
68 |
+
def forward(
|
69 |
+
self,
|
70 |
+
input_ids=None,
|
71 |
+
past_key_values=None,
|
72 |
+
attention_mask=None,
|
73 |
+
token_type_ids=None,
|
74 |
+
position_ids=None,
|
75 |
+
head_mask=None,
|
76 |
+
inputs_embeds=None,
|
77 |
+
encoder_hidden_states=None,
|
78 |
+
encoder_attention_mask=None,
|
79 |
+
labels=None,
|
80 |
+
use_cache=None,
|
81 |
+
output_attentions=None,
|
82 |
+
output_hidden_states=None,
|
83 |
+
return_dict=None,
|
84 |
+
):
|
85 |
+
assert self.context is not None
|
86 |
+
assert inputs_embeds is None # Not supported by this inference model.
|
87 |
+
assert labels is None # Training not supported by this inference model.
|
88 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
89 |
+
|
90 |
+
hidden_states = self.transformer.decoder(input_ids, context=self.context, return_embeddings=True)
|
91 |
+
logits = self.transformer.decoder.transformer.to_logits(hidden_states)
|
92 |
+
|
93 |
+
if not return_dict:
|
94 |
+
return (logits, )
|
95 |
+
|
96 |
+
return CausalLMOutputWithCrossAttentions(
|
97 |
+
loss=None,
|
98 |
+
logits=logits,
|
99 |
+
past_key_values=None,
|
100 |
+
hidden_states=hidden_states,
|
101 |
+
attentions=None,
|
102 |
+
cross_attentions=None,
|
103 |
+
)
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def _reorder_cache(past, beam_idx):
|
107 |
+
"""
|
108 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
109 |
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
110 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
111 |
+
"""
|
112 |
+
return tuple(
|
113 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
114 |
+
for layer_past in past
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
class ResBlock(nn.Module):
|
119 |
+
"""
|
120 |
+
Basic residual convolutional block that uses GroupNorm.
|
121 |
+
"""
|
122 |
+
def __init__(self, chan):
|
123 |
+
super().__init__()
|
124 |
+
self.net = nn.Sequential(
|
125 |
+
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
126 |
+
nn.GroupNorm(chan//8, chan),
|
127 |
+
nn.ReLU(),
|
128 |
+
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
129 |
+
nn.GroupNorm(chan//8, chan)
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
return F.relu(self.net(x) + x)
|
134 |
+
|
135 |
+
|
136 |
+
class ConditioningEncoder(nn.Module):
|
137 |
+
def __init__(self,
|
138 |
+
spec_dim,
|
139 |
+
embedding_dim,
|
140 |
+
attn_blocks=6,
|
141 |
+
num_attn_heads=4,
|
142 |
+
do_checkpointing=False):
|
143 |
+
super().__init__()
|
144 |
+
attn = []
|
145 |
+
self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//4, kernel_size=5, padding=2),
|
146 |
+
nn.Conv1d(embedding_dim//4, embedding_dim//2, kernel_size=3, padding=1, stride=2),
|
147 |
+
ResBlock(embedding_dim//2),
|
148 |
+
nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2))
|
149 |
+
for a in range(attn_blocks):
|
150 |
+
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
|
151 |
+
self.attn = nn.Sequential(*attn)
|
152 |
+
self.dim = embedding_dim
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
h = self.init(x)
|
156 |
+
h = self.attn(h)
|
157 |
+
return h.mean(dim=2)
|
158 |
+
|
159 |
+
|
160 |
+
class CheckpointedLayer(nn.Module):
|
161 |
+
"""
|
162 |
+
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
163 |
+
checkpoint for all other args.
|
164 |
+
"""
|
165 |
+
def __init__(self, wrap):
|
166 |
+
super().__init__()
|
167 |
+
self.wrap = wrap
|
168 |
+
|
169 |
+
def forward(self, x, *args, **kwargs):
|
170 |
+
for k, v in kwargs.items():
|
171 |
+
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
172 |
+
partial = functools.partial(self.wrap, **kwargs)
|
173 |
+
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
174 |
+
|
175 |
+
|
176 |
+
class CheckpointedXTransformerWrapper(nn.Module):
|
177 |
+
"""
|
178 |
+
Wraps a TransformerWrapper and applies CheckpointedLayer to each layer.
|
179 |
+
"""
|
180 |
+
def __init__(self, checkpoint=True, **xtransformer_kwargs):
|
181 |
+
super().__init__()
|
182 |
+
self.transformer = TransformerWrapper(**xtransformer_kwargs)
|
183 |
+
|
184 |
+
if not checkpoint:
|
185 |
+
return
|
186 |
+
for i in range(len(self.transformer.attn_layers.layers)):
|
187 |
+
n, b, r = self.transformer.attn_layers.layers[i]
|
188 |
+
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
189 |
+
|
190 |
+
def forward(self, x, **kwargs):
|
191 |
+
return self.transformer(x, **kwargs)
|
192 |
+
|
193 |
+
|
194 |
+
class AutoregressiveCodegen(nn.Module):
|
195 |
+
def __init__(self, model_dim, depth, num_text_tokens=256, num_mel_tokens=8194, max_text_tokens=4000,
|
196 |
+
max_mel_tokens=4000, dropout=.1):
|
197 |
+
super().__init__()
|
198 |
+
|
199 |
+
self.START_TOKEN=8192
|
200 |
+
self.STOP_TOKEN=8193
|
201 |
+
self.max_mel_tokens = max_mel_tokens
|
202 |
+
self.minicoder = ConditioningEncoder(80, model_dim, do_checkpointing=False)
|
203 |
+
self.encoder = CheckpointedXTransformerWrapper(
|
204 |
+
num_tokens=num_text_tokens,
|
205 |
+
max_seq_len=max_text_tokens,
|
206 |
+
attn_layers = Encoder(
|
207 |
+
depth=depth//2,
|
208 |
+
heads=model_dim//64,
|
209 |
+
dim=model_dim,
|
210 |
+
attn_dropout=dropout,
|
211 |
+
ff_dropout=dropout,
|
212 |
+
use_rmsnorm=True,
|
213 |
+
ff_glu=True,
|
214 |
+
ff_mult=1,
|
215 |
+
rotary_pos_emb=True,
|
216 |
+
rel_pos_bias=True,
|
217 |
+
))
|
218 |
+
self.decoder = CheckpointedXTransformerWrapper(
|
219 |
+
num_tokens=num_mel_tokens,
|
220 |
+
max_seq_len=max_mel_tokens,
|
221 |
+
attn_layers=Decoder(
|
222 |
+
depth=depth,
|
223 |
+
heads=model_dim//64,
|
224 |
+
dim=model_dim,
|
225 |
+
attn_dropout=dropout,
|
226 |
+
ff_dropout=dropout,
|
227 |
+
use_rmsnorm=True,
|
228 |
+
ff_glu=True,
|
229 |
+
ff_mult=1,
|
230 |
+
rotary_pos_emb=True,
|
231 |
+
rel_pos_bias=True,
|
232 |
+
cross_attend=True,
|
233 |
+
))
|
234 |
+
|
235 |
+
def get_grad_norm_parameter_groups(self):
|
236 |
+
return {
|
237 |
+
'encoder': list(self.encoder.parameters()),
|
238 |
+
'decoder': list(self.decoder.parameters()),
|
239 |
+
'minicoder': list(self.minicoder.parameters()),
|
240 |
+
}
|
241 |
+
|
242 |
+
def forward(self, text_codes, conditioning_signal, mel_codes, wav_lengths, return_loss=True):
|
243 |
+
# Format mel_codes with a stop token on the end.
|
244 |
+
mel_lengths = wav_lengths // 1024 + 1
|
245 |
+
for b in range(mel_codes.shape[0]):
|
246 |
+
mel_codes[b, mel_lengths[b]:] = self.STOP_TOKEN
|
247 |
+
mel_codes = F.pad(mel_codes, (0, 1), value=self.STOP_TOKEN)
|
248 |
+
|
249 |
+
# Build the context
|
250 |
+
if len(conditioning_signal.shape) != 4:
|
251 |
+
conditioning_signal = conditioning_signal.unsqueeze(1)
|
252 |
+
cond_embs = []
|
253 |
+
for i in range(conditioning_signal.shape[1]):
|
254 |
+
cond_embs.append(self.minicoder(conditioning_signal[:, i]))
|
255 |
+
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
256 |
+
enc_text = self.encoder(text_codes, return_embeddings=True)
|
257 |
+
context = torch.cat([cond_emb, enc_text], dim=1)
|
258 |
+
|
259 |
+
# Execute the decoder
|
260 |
+
dec_inputs = F.pad(mel_codes, (1,0), value=self.START_TOKEN)[:, :-1]
|
261 |
+
dec = self.decoder(dec_inputs, context=context)
|
262 |
+
if not return_loss:
|
263 |
+
return dec
|
264 |
+
loss_mel = F.cross_entropy(dec.permute(0,2,1), mel_codes)
|
265 |
+
return loss_mel
|
266 |
+
|
267 |
+
def generate(self, conditioning_signal, text_codes, **hf_generate_kwargs):
|
268 |
+
if not hasattr(self, 'inference_model'):
|
269 |
+
self.inference_model = InferenceModel(self)
|
270 |
+
|
271 |
+
if len(conditioning_signal.shape) != 4:
|
272 |
+
conditioning_signal = conditioning_signal.unsqueeze(1)
|
273 |
+
cond_embs = []
|
274 |
+
for i in range(conditioning_signal.shape[1]):
|
275 |
+
cond_embs.append(self.minicoder(conditioning_signal[:, i]))
|
276 |
+
cond_emb = torch.stack(cond_embs, dim=1).mean(dim=1, keepdim=True)
|
277 |
+
enc_text = self.encoder(text_codes, return_embeddings=True)
|
278 |
+
context = torch.cat([cond_emb, enc_text], dim=1)
|
279 |
+
self.inference_model.store_context(context)
|
280 |
+
|
281 |
+
gen = self.inference_model.generate(bos_token_id=self.START_TOKEN, pad_token_id=self.STOP_TOKEN, eos_token_id=self.STOP_TOKEN,
|
282 |
+
max_length=250, output_attentions=False, return_dict_in_generate=True,
|
283 |
+
**hf_generate_kwargs)
|
284 |
+
return gen.sequences
|
285 |
+
|
286 |
+
|
287 |
+
if __name__ == '__main__':
|
288 |
+
codegen = AutoregressiveCodegen(1024, 20)
|
289 |
+
codegen.generate(torch.randn((1,80,120)), torch.randint(0,256,(1,200)))
|
290 |
+
codegen(torch.randint(0,256, (2,200)),
|
291 |
+
torch.randn(2,80,120),
|
292 |
+
torch.randint(0,8192, (2,350)),
|
293 |
+
torch.tensor([192,350]))
|