indic / TTS /vocoder /models /univnet_generator.py
azamat's picture
Init
6127b48
raw
history blame
5.5 kB
from typing import List
import numpy as np
import torch
import torch.nn.functional as F
from TTS.vocoder.layers.lvc_block import LVCBlock
LRELU_SLOPE = 0.1
class UnivnetGenerator(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
cond_channels: int,
upsample_factors: List[int],
lvc_layers_each_block: int,
lvc_kernel_size: int,
kpnet_hidden_channels: int,
kpnet_conv_size: int,
dropout: float,
use_weight_norm=True,
):
"""Univnet Generator network.
Paper: https://arxiv.org/pdf/2106.07889.pdf
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of channels of the output tensor.
hidden_channels (int): Number of hidden network channels.
cond_channels (int): Number of channels of the conditioning tensors.
upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
lvc_layers_each_block (int): Number of LVC layers in each block.
lvc_kernel_size (int): Kernel size of the LVC layers.
kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
kpnet_conv_size (int): Number of convolution channels in the key-point network.
dropout (float): Dropout rate.
use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.cond_channels = cond_channels
self.upsample_scale = np.prod(upsample_factors)
self.lvc_block_nums = len(upsample_factors)
# define first convolution
self.first_conv = torch.nn.Conv1d(
in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
)
# define residual blocks
self.lvc_blocks = torch.nn.ModuleList()
cond_hop_length = 1
for n in range(self.lvc_block_nums):
cond_hop_length = cond_hop_length * upsample_factors[n]
lvcb = LVCBlock(
in_channels=hidden_channels,
cond_channels=cond_channels,
upsample_ratio=upsample_factors[n],
conv_layers=lvc_layers_each_block,
conv_kernel_size=lvc_kernel_size,
cond_hop_length=cond_hop_length,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=dropout,
)
self.lvc_blocks += [lvcb]
# define output layers
self.last_conv_layers = torch.nn.ModuleList(
[
torch.nn.Conv1d(
hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
),
]
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, c):
"""Calculate forward propagation.
Args:
c (Tensor): Local conditioning auxiliary features (B, C ,T').
Returns:
Tensor: Output tensor (B, out_channels, T)
"""
# random noise
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
x = x.to(self.first_conv.bias.device)
x = self.first_conv(x)
for n in range(self.lvc_block_nums):
x = self.lvc_blocks[n](x, c)
# apply final layers
for f in self.last_conv_layers:
x = F.leaky_relu(x, LRELU_SLOPE)
x = f(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
"""Return receptive field size."""
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
@torch.no_grad()
def inference(self, c):
"""Perform inference.
Args:
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
Returns:
Tensor: Output tensor (T, out_channels)
"""
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
x = x.to(self.first_conv.bias.device)
c = c.to(next(self.parameters()))
return self.forward(c)