|
import torch |
|
from torch import nn |
|
from torch.nn.utils import weight_norm |
|
|
|
from TTS.utils.io import load_fsspec |
|
from TTS.vocoder.layers.melgan import ResidualStack |
|
|
|
|
|
class MelganGenerator(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels=80, |
|
out_channels=1, |
|
proj_kernel=7, |
|
base_channels=512, |
|
upsample_factors=(8, 8, 2, 2), |
|
res_kernel=3, |
|
num_res_blocks=3, |
|
): |
|
super().__init__() |
|
|
|
|
|
assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." |
|
|
|
|
|
base_padding = (proj_kernel - 1) // 2 |
|
act_slope = 0.2 |
|
self.inference_padding = 2 |
|
|
|
|
|
layers = [] |
|
layers += [ |
|
nn.ReflectionPad1d(base_padding), |
|
weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), |
|
] |
|
|
|
|
|
for idx, upsample_factor in enumerate(upsample_factors): |
|
layer_in_channels = base_channels // (2**idx) |
|
layer_out_channels = base_channels // (2 ** (idx + 1)) |
|
layer_filter_size = upsample_factor * 2 |
|
layer_stride = upsample_factor |
|
layer_output_padding = upsample_factor % 2 |
|
layer_padding = upsample_factor // 2 + layer_output_padding |
|
layers += [ |
|
nn.LeakyReLU(act_slope), |
|
weight_norm( |
|
nn.ConvTranspose1d( |
|
layer_in_channels, |
|
layer_out_channels, |
|
layer_filter_size, |
|
stride=layer_stride, |
|
padding=layer_padding, |
|
output_padding=layer_output_padding, |
|
bias=True, |
|
) |
|
), |
|
ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), |
|
] |
|
|
|
layers += [nn.LeakyReLU(act_slope)] |
|
|
|
|
|
layers += [ |
|
nn.ReflectionPad1d(base_padding), |
|
weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), |
|
nn.Tanh(), |
|
] |
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, c): |
|
return self.layers(c) |
|
|
|
def inference(self, c): |
|
c = c.to(self.layers[1].weight.device) |
|
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") |
|
return self.layers(c) |
|
|
|
def remove_weight_norm(self): |
|
for _, layer in enumerate(self.layers): |
|
if len(layer.state_dict()) != 0: |
|
try: |
|
nn.utils.remove_weight_norm(layer) |
|
except ValueError: |
|
layer.remove_weight_norm() |
|
|
|
def load_checkpoint( |
|
self, config, checkpoint_path, eval=False |
|
): |
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) |
|
self.load_state_dict(state["model"]) |
|
if eval: |
|
self.eval() |
|
assert not self.training |
|
self.remove_weight_norm() |
|
|