Spaces:
Running
Running
import torch | |
from torch import nn | |
from so_vits_svc_fork.modules import attentions as attentions | |
class F0Decoder(nn.Module): | |
def __init__( | |
self, | |
out_channels, | |
hidden_channels, | |
filter_channels, | |
n_heads, | |
n_layers, | |
kernel_size, | |
p_dropout, | |
spk_channels=0, | |
): | |
super().__init__() | |
self.out_channels = out_channels | |
self.hidden_channels = hidden_channels | |
self.filter_channels = filter_channels | |
self.n_heads = n_heads | |
self.n_layers = n_layers | |
self.kernel_size = kernel_size | |
self.p_dropout = p_dropout | |
self.spk_channels = spk_channels | |
self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1) | |
self.decoder = attentions.FFT( | |
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout | |
) | |
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) | |
self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1) | |
self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) | |
def forward(self, x, norm_f0, x_mask, spk_emb=None): | |
x = torch.detach(x) | |
if spk_emb is not None: | |
spk_emb = torch.detach(spk_emb) | |
x = x + self.cond(spk_emb) | |
x += self.f0_prenet(norm_f0) | |
x = self.prenet(x) * x_mask | |
x = self.decoder(x * x_mask, x_mask) | |
x = self.proj(x) * x_mask | |
return x | |