import torch from BigVGAN.meldataset import get_mel_spectrogram from voice_restore import VoiceRestore class OptimizedAudioRestorationModel(torch.nn.Module): def __init__(self, target_sample_rate=24000, device=None, bigvgan_model=None): super().__init__() # Initialize VoiceRestore self.voice_restore = VoiceRestore( sigma=0.0, transformer={ 'dim': 768, 'depth': 20, 'heads': 16, 'dim_head': 64, 'skip_connect_type': 'concat', 'max_seq_len': 2000, }, num_channels=100 ) self.device = device if self.device == 'cuda': self.voice_restore.bfloat16() self.voice_restore.eval() self.voice_restore.to(self.device) self.target_sample_rate = target_sample_rate self.bigvgan_model = bigvgan_model def forward(self, audio, steps=32, cfg_strength=0.5): # Convert to Mel-spectrogram if self.bigvgan_model is None: raise ValueError("BigVGAN model is not provided. Please provide the BigVGAN model.") if self.device is None: raise ValueError("Device is not provided. Please provide the device (cuda, cpu or mps).") processed_mel = get_mel_spectrogram(audio, self.bigvgan_model.h).to(self.device) # Restore audio restored_mel = self.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength) restored_mel = restored_mel.squeeze(0).transpose(0, 1) # Convert restored mel-spectrogram to waveform restored_wav = self.bigvgan_model(restored_mel.unsqueeze(0)) return restored_wav