|
import math |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock |
|
|
|
|
|
class ParallelWaveganDiscriminator(nn.Module): |
|
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480. |
|
It classifies each audio window real/fake and returns a sequence |
|
of predictions. |
|
It is a stack of convolutional blocks with dilation. |
|
""" |
|
|
|
|
|
def __init__( |
|
self, |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_size=3, |
|
num_layers=10, |
|
conv_channels=64, |
|
dilation_factor=1, |
|
nonlinear_activation="LeakyReLU", |
|
nonlinear_activation_params={"negative_slope": 0.2}, |
|
bias=True, |
|
): |
|
super().__init__() |
|
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size." |
|
assert dilation_factor > 0, " [!] dilation factor must be > 0." |
|
self.conv_layers = nn.ModuleList() |
|
conv_in_channels = in_channels |
|
for i in range(num_layers - 1): |
|
if i == 0: |
|
dilation = 1 |
|
else: |
|
dilation = i if dilation_factor == 1 else dilation_factor**i |
|
conv_in_channels = conv_channels |
|
padding = (kernel_size - 1) // 2 * dilation |
|
conv_layer = [ |
|
nn.Conv1d( |
|
conv_in_channels, |
|
conv_channels, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias, |
|
), |
|
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), |
|
] |
|
self.conv_layers += conv_layer |
|
padding = (kernel_size - 1) // 2 |
|
last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) |
|
self.conv_layers += [last_conv_layer] |
|
self.apply_weight_norm() |
|
|
|
def forward(self, x): |
|
""" |
|
x : (B, 1, T). |
|
Returns: |
|
Tensor: (B, 1, T) |
|
""" |
|
for f in self.conv_layers: |
|
x = f(x) |
|
return x |
|
|
|
def apply_weight_norm(self): |
|
def _apply_weight_norm(m): |
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): |
|
torch.nn.utils.weight_norm(m) |
|
|
|
self.apply(_apply_weight_norm) |
|
|
|
def remove_weight_norm(self): |
|
def _remove_weight_norm(m): |
|
try: |
|
|
|
nn.utils.remove_weight_norm(m) |
|
except ValueError: |
|
return |
|
|
|
self.apply(_remove_weight_norm) |
|
|
|
|
|
class ResidualParallelWaveganDiscriminator(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_size=3, |
|
num_layers=30, |
|
stacks=3, |
|
res_channels=64, |
|
gate_channels=128, |
|
skip_channels=64, |
|
dropout=0.0, |
|
bias=True, |
|
nonlinear_activation="LeakyReLU", |
|
nonlinear_activation_params={"negative_slope": 0.2}, |
|
): |
|
super().__init__() |
|
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.num_layers = num_layers |
|
self.stacks = stacks |
|
self.kernel_size = kernel_size |
|
self.res_factor = math.sqrt(1.0 / num_layers) |
|
|
|
|
|
assert num_layers % stacks == 0 |
|
layers_per_stack = num_layers // stacks |
|
|
|
|
|
self.first_conv = nn.Sequential( |
|
nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True), |
|
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), |
|
) |
|
|
|
|
|
self.conv_layers = nn.ModuleList() |
|
for layer in range(num_layers): |
|
dilation = 2 ** (layer % layers_per_stack) |
|
conv = ResidualBlock( |
|
kernel_size=kernel_size, |
|
res_channels=res_channels, |
|
gate_channels=gate_channels, |
|
skip_channels=skip_channels, |
|
aux_channels=-1, |
|
dilation=dilation, |
|
dropout=dropout, |
|
bias=bias, |
|
use_causal_conv=False, |
|
) |
|
self.conv_layers += [conv] |
|
|
|
|
|
self.last_conv_layers = nn.ModuleList( |
|
[ |
|
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), |
|
nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True), |
|
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), |
|
nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True), |
|
] |
|
) |
|
|
|
|
|
self.apply_weight_norm() |
|
|
|
def forward(self, x): |
|
""" |
|
x: (B, 1, T). |
|
""" |
|
x = self.first_conv(x) |
|
|
|
skips = 0 |
|
for f in self.conv_layers: |
|
x, h = f(x, None) |
|
skips += h |
|
skips *= self.res_factor |
|
|
|
|
|
x = skips |
|
for f in self.last_conv_layers: |
|
x = f(x) |
|
return x |
|
|
|
def apply_weight_norm(self): |
|
def _apply_weight_norm(m): |
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): |
|
torch.nn.utils.weight_norm(m) |
|
|
|
self.apply(_apply_weight_norm) |
|
|
|
def remove_weight_norm(self): |
|
def _remove_weight_norm(m): |
|
try: |
|
print(f"Weight norm is removed from {m}.") |
|
nn.utils.remove_weight_norm(m) |
|
except ValueError: |
|
return |
|
|
|
self.apply(_remove_weight_norm) |
|
|