import torch from torch import nn from so_vits_svc_fork.modules import attentions as attentions class F0Decoder(nn.Module): def __init__( self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, spk_channels=0, ): super().__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels self.filter_channels = filter_channels self.n_heads = n_heads self.n_layers = n_layers self.kernel_size = kernel_size self.p_dropout = p_dropout self.spk_channels = spk_channels self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1) self.decoder = attentions.FFT( hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout ) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1) self.cond = nn.Conv1d(spk_channels, hidden_channels, 1) def forward(self, x, norm_f0, x_mask, spk_emb=None): x = torch.detach(x) if spk_emb is not None: spk_emb = torch.detach(spk_emb) x = x + self.cond(spk_emb) x += self.f0_prenet(norm_f0) x = self.prenet(x) * x_mask x = self.decoder(x * x_mask, x_mask) x = self.proj(x) * x_mask return x