|
import torch |
|
from nsf_hifigan.nvSTFT import STFT |
|
from nsf_hifigan.models import load_model |
|
from torchaudio.transforms import Resample |
|
|
|
|
|
class Vocoder: |
|
def __init__(self, vocoder_type, vocoder_ckpt, device = None): |
|
if device is None: |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.device = device |
|
|
|
if vocoder_type == 'nsf-hifigan': |
|
self.vocoder = NsfHifiGAN(vocoder_ckpt, device = device) |
|
elif vocoder_type == 'nsf-hifigan-log10': |
|
self.vocoder = NsfHifiGANLog10(vocoder_ckpt, device = device) |
|
else: |
|
raise ValueError(f" [x] Unknown vocoder: {vocoder_type}") |
|
|
|
self.resample_kernel = {} |
|
self.vocoder_sample_rate = self.vocoder.sample_rate() |
|
self.vocoder_hop_size = self.vocoder.hop_size() |
|
self.dimension = self.vocoder.dimension() |
|
|
|
def extract(self, audio, sample_rate, keyshift=0): |
|
|
|
|
|
if sample_rate == self.vocoder_sample_rate: |
|
audio_res = audio |
|
else: |
|
key_str = str(sample_rate) |
|
if key_str not in self.resample_kernel: |
|
self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device) |
|
audio_res = self.resample_kernel[key_str](audio) |
|
|
|
|
|
mel = self.vocoder.extract(audio_res, keyshift=keyshift) |
|
return mel |
|
|
|
def infer(self, mel, f0): |
|
f0 = f0[:,:mel.size(1),0] |
|
audio = self.vocoder(mel, f0) |
|
return audio |
|
|
|
|
|
class NsfHifiGAN(torch.nn.Module): |
|
def __init__(self, model_path, device=None): |
|
super().__init__() |
|
if device is None: |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.device = device |
|
print('| Load HifiGAN: ', model_path) |
|
self.model, self.h = load_model(model_path, device=self.device) |
|
self.stft = STFT( |
|
self.h.sampling_rate, |
|
self.h.num_mels, |
|
self.h.n_fft, |
|
self.h.win_size, |
|
self.h.hop_size, |
|
self.h.fmin, |
|
self.h.fmax) |
|
|
|
def sample_rate(self): |
|
return self.h.sampling_rate |
|
|
|
def hop_size(self): |
|
return self.h.hop_size |
|
|
|
def dimension(self): |
|
return self.h.num_mels |
|
|
|
def extract(self, audio, keyshift=0): |
|
mel = self.stft.get_mel(audio, keyshift=keyshift).transpose(1, 2) |
|
return mel |
|
|
|
def forward(self, mel, f0): |
|
with torch.no_grad(): |
|
c = mel.transpose(1, 2) |
|
audio = self.model(c, f0) |
|
return audio |
|
|
|
class NsfHifiGANLog10(NsfHifiGAN): |
|
def forward(self, mel, f0): |
|
with torch.no_grad(): |
|
c = 0.434294 * mel.transpose(1, 2) |
|
audio = self.model(c, f0) |
|
return audio |