|
import torch |
|
from torch import nn |
|
|
|
from TTS.tts.layers.generic.normalization import ActNorm |
|
from TTS.tts.layers.glow_tts.glow import CouplingBlock, InvConvNear |
|
|
|
|
|
def squeeze(x, x_mask=None, num_sqz=2): |
|
"""GlowTTS squeeze operation |
|
Increase number of channels and reduce number of time steps |
|
by the same factor. |
|
|
|
Note: |
|
each 's' is a n-dimensional vector. |
|
``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]`` |
|
""" |
|
b, c, t = x.size() |
|
|
|
t = (t // num_sqz) * num_sqz |
|
x = x[:, :, :t] |
|
x_sqz = x.view(b, c, t // num_sqz, num_sqz) |
|
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz) |
|
|
|
if x_mask is not None: |
|
x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz] |
|
else: |
|
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype) |
|
return x_sqz * x_mask, x_mask |
|
|
|
|
|
def unsqueeze(x, x_mask=None, num_sqz=2): |
|
"""GlowTTS unsqueeze operation |
|
|
|
Note: |
|
each 's' is a n-dimensional vector. |
|
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]`` |
|
""" |
|
b, c, t = x.size() |
|
|
|
x_unsqz = x.view(b, num_sqz, c // num_sqz, t) |
|
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz) |
|
|
|
if x_mask is not None: |
|
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz) |
|
else: |
|
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype) |
|
return x_unsqz * x_mask, x_mask |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Stack of Glow Decoder Modules. |
|
|
|
:: |
|
|
|
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze |
|
|
|
Args: |
|
in_channels (int): channels of input tensor. |
|
hidden_channels (int): hidden decoder channels. |
|
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) |
|
dilation_rate (int): rate to increase dilation by each layer in a decoder block. |
|
num_flow_blocks (int): number of decoder blocks. |
|
num_coupling_layers (int): number coupling layers. (number of wavenet layers.) |
|
dropout_p (float): wavenet dropout rate. |
|
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
hidden_channels, |
|
kernel_size, |
|
dilation_rate, |
|
num_flow_blocks, |
|
num_coupling_layers, |
|
dropout_p=0.0, |
|
num_splits=4, |
|
num_squeeze=2, |
|
sigmoid_scale=False, |
|
c_in_channels=0, |
|
): |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
self.hidden_channels = hidden_channels |
|
self.kernel_size = kernel_size |
|
self.dilation_rate = dilation_rate |
|
self.num_flow_blocks = num_flow_blocks |
|
self.num_coupling_layers = num_coupling_layers |
|
self.dropout_p = dropout_p |
|
self.num_splits = num_splits |
|
self.num_squeeze = num_squeeze |
|
self.sigmoid_scale = sigmoid_scale |
|
self.c_in_channels = c_in_channels |
|
|
|
self.flows = nn.ModuleList() |
|
for _ in range(num_flow_blocks): |
|
self.flows.append(ActNorm(channels=in_channels * num_squeeze)) |
|
self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits)) |
|
self.flows.append( |
|
CouplingBlock( |
|
in_channels * num_squeeze, |
|
hidden_channels, |
|
kernel_size=kernel_size, |
|
dilation_rate=dilation_rate, |
|
num_layers=num_coupling_layers, |
|
c_in_channels=c_in_channels, |
|
dropout_p=dropout_p, |
|
sigmoid_scale=sigmoid_scale, |
|
) |
|
) |
|
|
|
def forward(self, x, x_mask, g=None, reverse=False): |
|
""" |
|
Shapes: |
|
- x: :math:`[B, C, T]` |
|
- x_mask: :math:`[B, 1 ,T]` |
|
- g: :math:`[B, C]` |
|
""" |
|
if not reverse: |
|
flows = self.flows |
|
logdet_tot = 0 |
|
else: |
|
flows = reversed(self.flows) |
|
logdet_tot = None |
|
|
|
if self.num_squeeze > 1: |
|
x, x_mask = squeeze(x, x_mask, self.num_squeeze) |
|
for f in flows: |
|
if not reverse: |
|
x, logdet = f(x, x_mask, g=g, reverse=reverse) |
|
logdet_tot += logdet |
|
else: |
|
x, logdet = f(x, x_mask, g=g, reverse=reverse) |
|
if self.num_squeeze > 1: |
|
x, x_mask = unsqueeze(x, x_mask, self.num_squeeze) |
|
return x, logdet_tot |
|
|
|
def store_inverse(self): |
|
for f in self.flows: |
|
f.store_inverse() |
|
|