jbetker commited on
Commit
76c30fe
1 Parent(s): 713281e

Update autoregressive to support type inputs

Browse files
Files changed (1) hide show
  1. models/autoregressive.py +12 -7
models/autoregressive.py CHANGED
@@ -278,9 +278,10 @@ class MelEncoder(nn.Module):
278
  class UnifiedVoice(nn.Module):
279
  def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
280
  mel_length_compression=1024, number_text_tokens=256,
281
- start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
282
  stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
283
- checkpointing=True, average_conditioning_embeddings=False):
 
284
  """
285
  Args:
286
  layers: Number of layers in transformer stack.
@@ -304,8 +305,8 @@ class UnifiedVoice(nn.Module):
304
  super().__init__()
305
 
306
  self.number_text_tokens = number_text_tokens
307
- self.start_text_token = start_text_token
308
- self.stop_text_token = stop_text_token
309
  self.number_mel_codes = number_mel_codes
310
  self.start_mel_token = start_mel_token
311
  self.stop_mel_token = stop_mel_token
@@ -318,7 +319,7 @@ class UnifiedVoice(nn.Module):
318
  self.mel_length_compression = mel_length_compression
319
  self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
320
  self.average_conditioning_embeddings = average_conditioning_embeddings
321
- self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
322
  if use_mel_codes_as_input:
323
  self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
324
  else:
@@ -333,7 +334,7 @@ class UnifiedVoice(nn.Module):
333
  self.text_solo_embedding = 0
334
 
335
  self.final_norm = nn.LayerNorm(model_dim)
336
- self.text_head = nn.Linear(model_dim, self.number_text_tokens)
337
  self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
338
 
339
  # Initialize the embeddings per the GPT-2 scheme
@@ -389,7 +390,7 @@ class UnifiedVoice(nn.Module):
389
  else:
390
  return first_logits
391
 
392
- def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
393
  return_latent=False, clip_inputs=True):
394
  """
395
  Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
@@ -406,6 +407,10 @@ class UnifiedVoice(nn.Module):
406
  If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
407
  If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
408
  """
 
 
 
 
409
  if clip_inputs:
410
  # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
411
  # chopping the inputs by the maximum actual length.
 
278
  class UnifiedVoice(nn.Module):
279
  def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
280
  mel_length_compression=1024, number_text_tokens=256,
281
+ start_text_token=None, number_mel_codes=8194, start_mel_token=8192,
282
  stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
283
+ checkpointing=True, average_conditioning_embeddings=False,
284
+ types=1):
285
  """
286
  Args:
287
  layers: Number of layers in transformer stack.
 
305
  super().__init__()
306
 
307
  self.number_text_tokens = number_text_tokens
308
+ self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
309
+ self.stop_text_token = 0
310
  self.number_mel_codes = number_mel_codes
311
  self.start_mel_token = start_mel_token
312
  self.stop_mel_token = stop_mel_token
 
319
  self.mel_length_compression = mel_length_compression
320
  self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
321
  self.average_conditioning_embeddings = average_conditioning_embeddings
322
+ self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim)
323
  if use_mel_codes_as_input:
324
  self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
325
  else:
 
334
  self.text_solo_embedding = 0
335
 
336
  self.final_norm = nn.LayerNorm(model_dim)
337
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1)
338
  self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
339
 
340
  # Initialize the embeddings per the GPT-2 scheme
 
390
  else:
391
  return first_logits
392
 
393
+ def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
394
  return_latent=False, clip_inputs=True):
395
  """
396
  Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
 
407
  If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
408
  If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
409
  """
410
+ # Types are expressed by expanding the text embedding space.
411
+ if types is not None:
412
+ text_inputs = text_inputs * (1+types).unsqueeze(-1)
413
+
414
  if clip_inputs:
415
  # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
416
  # chopping the inputs by the maximum actual length.