indic / TTS /tts /layers /tacotron /capacitron_layers.py
azamat's picture
Init
6127b48
raw
history blame
9.32 kB
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
from torch.nn import functional as F
class CapacitronVAE(nn.Module):
"""Effective Use of Variational Embedding Capacity for prosody transfer.
See https://arxiv.org/abs/1906.03402"""
def __init__(
self,
num_mel,
capacitron_VAE_embedding_dim,
encoder_output_dim=256,
reference_encoder_out_dim=128,
speaker_embedding_dim=None,
text_summary_embedding_dim=None,
):
super().__init__()
# Init distributions
self.prior_distribution = MVN(
torch.zeros(capacitron_VAE_embedding_dim), torch.eye(capacitron_VAE_embedding_dim)
)
self.approximate_posterior_distribution = None
# define output ReferenceEncoder dim to the capacitron_VAE_embedding_dim
self.encoder = ReferenceEncoder(num_mel, out_dim=reference_encoder_out_dim)
# Init beta, the lagrange-like term for the KL distribution
self.beta = torch.nn.Parameter(torch.log(torch.exp(torch.Tensor([1.0])) - 1), requires_grad=True)
mlp_input_dimension = reference_encoder_out_dim
if text_summary_embedding_dim is not None:
self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim)
mlp_input_dimension += text_summary_embedding_dim
if speaker_embedding_dim is not None:
# TODO: Test a multispeaker model!
mlp_input_dimension += speaker_embedding_dim
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim)
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None):
# Use reference
if reference_mel_info is not None:
reference_mels = reference_mel_info[0] # [batch_size, num_frames, num_mels]
mel_lengths = reference_mel_info[1] # [batch_size]
enc_out = self.encoder(reference_mels, mel_lengths)
# concat speaker_embedding and/or text summary embedding
if text_info is not None:
text_inputs = text_info[0] # [batch_size, num_characters, num_embedding]
input_lengths = text_info[1]
text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device)
enc_out = torch.cat([enc_out, text_summary_out], dim=-1)
if speaker_embedding is not None:
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
# Feed the output of the ref encoder and information about text/speaker into
# an MLP to produce the parameteres for the approximate poterior distributions
mu, sigma = self.post_encoder_mlp(enc_out)
# convert to cpu because prior_distribution was created on cpu
mu = mu.cpu()
sigma = sigma.cpu()
# Sample from the posterior: z ~ q(z|x)
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma))
VAE_embedding = self.approximate_posterior_distribution.rsample()
# Infer from the model, bypasses encoding
else:
# Sample from the prior: z ~ p(z)
VAE_embedding = self.prior_distribution.sample().unsqueeze(0)
# reshape to [batch_size, 1, capacitron_VAE_embedding_dim]
return VAE_embedding.unsqueeze(1), self.approximate_posterior_distribution, self.prior_distribution, self.beta
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, num_mel, out_dim):
super().__init__()
self.num_mel = num_mel
filters = [1] + [32, 32, 64, 64, 128, 128]
num_layers = len(filters) - 1
convs = [
nn.Conv2d(
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
)
for i in range(num_layers)
]
self.convs = nn.ModuleList(convs)
self.training = False
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
self.recurrence = nn.LSTM(
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
)
def forward(self, inputs, input_lengths):
batch_size = inputs.size(0)
x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel]
valid_lengths = input_lengths.float() # [batch_size]
for conv, bn in zip(self.convs, self.bns):
x = conv(x)
x = bn(x)
x = F.relu(x)
# Create the post conv width mask based on the valid lengths of the output of the convolution.
# The valid lengths for the output of a convolution on varying length inputs is
# ceil(input_length/stride) + 1 for stride=3 and padding=2
# For example (kernel_size=3, stride=2, padding=2):
# 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d
# _____
# x _____
# x _____
# x ____
# x
# x x x x -> Output valid length = 4
# Since every example in te batch is zero padded and therefore have separate valid_lengths,
# we need to mask off all the values AFTER the valid length for each example in the batch.
# Otherwise, the convolutions create noise and a lot of not real information
valid_lengths = (valid_lengths / 2).float()
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size]
post_conv_max_width = x.size(2)
mask = torch.arange(post_conv_max_width).to(inputs.device).expand(
len(valid_lengths), post_conv_max_width
) < valid_lengths.unsqueeze(1)
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1]
x = x * mask
x = x.transpose(1, 2)
# x: 4D tensor [batch_size, post_conv_width,
# num_channels==128, post_conv_height]
post_conv_width = x.size(1)
x = x.contiguous().view(batch_size, post_conv_width, -1)
# x: 3D tensor [batch_size, post_conv_width,
# num_channels*post_conv_height]
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
post_conv_input_lengths = valid_lengths
packed_seqs = nn.utils.rnn.pack_padded_sequence(
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.recurrence.flatten_parameters()
_, (ht, _) = self.recurrence(packed_seqs)
last_output = ht[-1]
return last_output.to(inputs.device) # [B, 128]
@staticmethod
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
for _ in range(n_convs):
height = (height - kernel_size + 2 * pad) // stride + 1
return height
class TextSummary(nn.Module):
def __init__(self, embedding_dim, encoder_output_dim):
super().__init__()
self.lstm = nn.LSTM(
encoder_output_dim, # text embedding dimension from the text encoder
embedding_dim, # fixed length output summary the lstm creates from the input
batch_first=True,
bidirectional=False,
)
def forward(self, inputs, input_lengths):
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
packed_seqs = nn.utils.rnn.pack_padded_sequence(
inputs, input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.lstm.flatten_parameters()
_, (ht, _) = self.lstm(packed_seqs)
last_output = ht[-1]
return last_output
class PostEncoderMLP(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
modules = [
nn.Linear(input_size, hidden_size), # Hidden Layer
nn.Tanh(),
nn.Linear(hidden_size, hidden_size * 2),
] # Output layer twice the size for mean and variance
self.net = nn.Sequential(*modules)
self.softplus = nn.Softplus()
def forward(self, _input):
mlp_output = self.net(_input)
# The mean parameter is unconstrained
mu = mlp_output[:, : self.hidden_size]
# The standard deviation must be positive. Parameterise with a softplus
sigma = self.softplus(mlp_output[:, self.hidden_size :])
return mu, sigma