Update autoregressive to support type inputs
Browse files- models/autoregressive.py +12 -7
models/autoregressive.py
CHANGED
@@ -278,9 +278,10 @@ class MelEncoder(nn.Module):
|
|
278 |
class UnifiedVoice(nn.Module):
|
279 |
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
280 |
mel_length_compression=1024, number_text_tokens=256,
|
281 |
-
start_text_token=
|
282 |
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
283 |
-
checkpointing=True, average_conditioning_embeddings=False
|
|
|
284 |
"""
|
285 |
Args:
|
286 |
layers: Number of layers in transformer stack.
|
@@ -304,8 +305,8 @@ class UnifiedVoice(nn.Module):
|
|
304 |
super().__init__()
|
305 |
|
306 |
self.number_text_tokens = number_text_tokens
|
307 |
-
self.start_text_token = start_text_token
|
308 |
-
self.stop_text_token =
|
309 |
self.number_mel_codes = number_mel_codes
|
310 |
self.start_mel_token = start_mel_token
|
311 |
self.stop_mel_token = stop_mel_token
|
@@ -318,7 +319,7 @@ class UnifiedVoice(nn.Module):
|
|
318 |
self.mel_length_compression = mel_length_compression
|
319 |
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
320 |
self.average_conditioning_embeddings = average_conditioning_embeddings
|
321 |
-
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
322 |
if use_mel_codes_as_input:
|
323 |
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
324 |
else:
|
@@ -333,7 +334,7 @@ class UnifiedVoice(nn.Module):
|
|
333 |
self.text_solo_embedding = 0
|
334 |
|
335 |
self.final_norm = nn.LayerNorm(model_dim)
|
336 |
-
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
337 |
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
338 |
|
339 |
# Initialize the embeddings per the GPT-2 scheme
|
@@ -389,7 +390,7 @@ class UnifiedVoice(nn.Module):
|
|
389 |
else:
|
390 |
return first_logits
|
391 |
|
392 |
-
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
|
393 |
return_latent=False, clip_inputs=True):
|
394 |
"""
|
395 |
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
@@ -406,6 +407,10 @@ class UnifiedVoice(nn.Module):
|
|
406 |
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
407 |
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
408 |
"""
|
|
|
|
|
|
|
|
|
409 |
if clip_inputs:
|
410 |
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
411 |
# chopping the inputs by the maximum actual length.
|
|
|
278 |
class UnifiedVoice(nn.Module):
|
279 |
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
280 |
mel_length_compression=1024, number_text_tokens=256,
|
281 |
+
start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
|
282 |
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
283 |
+
checkpointing=True, average_conditioning_embeddings=False,
|
284 |
+
types=1):
|
285 |
"""
|
286 |
Args:
|
287 |
layers: Number of layers in transformer stack.
|
|
|
305 |
super().__init__()
|
306 |
|
307 |
self.number_text_tokens = number_text_tokens
|
308 |
+
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
309 |
+
self.stop_text_token = 0
|
310 |
self.number_mel_codes = number_mel_codes
|
311 |
self.start_mel_token = start_mel_token
|
312 |
self.stop_mel_token = stop_mel_token
|
|
|
319 |
self.mel_length_compression = mel_length_compression
|
320 |
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
321 |
self.average_conditioning_embeddings = average_conditioning_embeddings
|
322 |
+
self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
|
323 |
if use_mel_codes_as_input:
|
324 |
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
325 |
else:
|
|
|
334 |
self.text_solo_embedding = 0
|
335 |
|
336 |
self.final_norm = nn.LayerNorm(model_dim)
|
337 |
+
self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
|
338 |
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
339 |
|
340 |
# Initialize the embeddings per the GPT-2 scheme
|
|
|
390 |
else:
|
391 |
return first_logits
|
392 |
|
393 |
+
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
394 |
return_latent=False, clip_inputs=True):
|
395 |
"""
|
396 |
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
|
|
407 |
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
408 |
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
409 |
"""
|
410 |
+
# Types are expressed by expanding the text embedding space.
|
411 |
+
if types is not None:
|
412 |
+
text_inputs = text_inputs * (1+types).unsqueeze(-1)
|
413 |
+
|
414 |
if clip_inputs:
|
415 |
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
416 |
# chopping the inputs by the maximum actual length.
|