Vexa's picture
Duplicate from pivich/sovits-new
d5d7329
raw
history blame
No virus
1.46 kB
import torch
from torch import nn
from so_vits_svc_fork.modules import attentions as attentions
class F0Decoder(nn.Module):
def __init__(
self,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
spk_channels=0,
):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.spk_channels = spk_channels
self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1)
self.decoder = attentions.FFT(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1)
self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
def forward(self, x, norm_f0, x_mask, spk_emb=None):
x = torch.detach(x)
if spk_emb is not None:
spk_emb = torch.detach(spk_emb)
x = x + self.cond(spk_emb)
x += self.f0_prenet(norm_f0)
x = self.prenet(x) * x_mask
x = self.decoder(x * x_mask, x_mask)
x = self.proj(x) * x_mask
return x