File size: 6,573 Bytes
6127b48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
from torch import nn
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
class WN(torch.nn.Module):
"""Wavenet layers with weight norm and no input conditioning.
|-----------------------------------------------------------------------------|
| |-> tanh -| |
res -|- conv1d(dilation) -> dropout -> + -| * -> conv1d1x1 -> split -|- + -> res
g -------------------------------------| |-> sigmoid -| |
o --------------------------------------------------------------------------- + --------- o
Args:
in_channels (int): number of input channels.
hidden_channes (int): number of hidden channels.
kernel_size (int): filter kernel size for the first conv layer.
dilation_rate (int): dilations rate to increase dilation per layer.
If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers.
num_layers (int): number of wavenet layers.
c_in_channels (int): number of channels of conditioning input.
dropout_p (float): dropout rate.
weight_norm (bool): enable/disable weight norm for convolution layers.
"""
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0,
weight_norm=True,
):
super().__init__()
assert kernel_size % 2 == 1
assert hidden_channels % 2 == 0
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.c_in_channels = c_in_channels
self.dropout_p = dropout_p
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.dropout = nn.Dropout(dropout_p)
# init conditioning layer
if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
# intermediate layers
for i in range(num_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
if i < num_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
# setup weight norm
if not weight_norm:
self.remove_weight_norm()
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
x_mask = 1.0 if x_mask is None else x_mask
if g is not None:
g = self.cond_layer(g)
for i in range(self.num_layers):
x_in = self.in_layers[i](x)
x_in = self.dropout(x_in)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.num_layers - 1:
x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.c_in_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class WNBlocks(nn.Module):
"""Wavenet blocks.
Note: After each block dilation resets to 1 and it increases in each block
along the dilation rate.
Args:
in_channels (int): number of input channels.
hidden_channes (int): number of hidden channels.
kernel_size (int): filter kernel size for the first conv layer.
dilation_rate (int): dilations rate to increase dilation per layer.
If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers.
num_blocks (int): number of wavenet blocks.
num_layers (int): number of wavenet layers.
c_in_channels (int): number of channels of conditioning input.
dropout_p (float): dropout rate.
weight_norm (bool): enable/disable weight norm for convolution layers.
"""
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_blocks,
num_layers,
c_in_channels=0,
dropout_p=0,
weight_norm=True,
):
super().__init__()
self.wn_blocks = nn.ModuleList()
for idx in range(num_blocks):
layer = WN(
in_channels=in_channels if idx == 0 else hidden_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
num_layers=num_layers,
c_in_channels=c_in_channels,
dropout_p=dropout_p,
weight_norm=weight_norm,
)
self.wn_blocks.append(layer)
def forward(self, x, x_mask=None, g=None):
o = x
for layer in self.wn_blocks:
o = layer(o, x_mask, g)
return o
|