File size: 1,305 Bytes
6127b48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from torch import nn
from .normalization import LayerNorm
class GatedConvBlock(nn.Module):
"""Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf
Args:
in_out_channels (int): number of input/output channels.
kernel_size (int): convolution kernel size.
dropout_p (float): dropout rate.
"""
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers):
super().__init__()
# class arguments
self.dropout_p = dropout_p
self.num_layers = num_layers
# define layers
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)]
self.norm_layers += [LayerNorm(2 * in_out_channels)]
def forward(self, x, x_mask):
o = x
res = x
for idx in range(self.num_layers):
o = nn.functional.dropout(o, p=self.dropout_p, training=self.training)
o = self.conv_layers[idx](o * x_mask)
o = self.norm_layers[idx](o)
o = nn.functional.glu(o, dim=1)
o = res + o
res = o
return o
|