Spaces:
Runtime error
Runtime error
import os | |
import json | |
from tqdm import tqdm | |
from copy import deepcopy | |
import numpy as np | |
import gradio as gr | |
import torch | |
import random | |
random.seed(0) | |
torch.manual_seed(0) | |
np.random.seed(0) | |
from scipy.io.wavfile import write as wavwrite | |
from util import print_size, sampling | |
from network import CleanUNet | |
import torchaudio | |
def load_simple(filename): | |
audio, _ = torchaudio.load(filename) | |
return audio | |
CONFIG = "configs/DNS-large-full.json" | |
CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl" | |
# Parse configs. Globals nicer in this case | |
with open(CONFIG) as f: | |
data = f.read() | |
config = json.loads(data) | |
gen_config = config["gen_config"] | |
global network_config | |
network_config = config["network_config"] # to define wavenet | |
global train_config | |
train_config = config["train_config"] # train config | |
global trainset_config | |
trainset_config = config["trainset_config"] # to read trainset configurations | |
def denoise(files, ckpt_path): | |
""" | |
Denoise audio | |
Parameters: | |
output_directory (str): save generated speeches to this path | |
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; | |
automitically selects the maximum iteration if 'max' is selected | |
subset (str): training, testing, validation | |
dump (bool): whether save enhanced (denoised) audio | |
""" | |
# setup local experiment path | |
exp_path = train_config["exp_path"] | |
print('exp_path:', exp_path) | |
# load data | |
loader_config = deepcopy(trainset_config) | |
loader_config["crop_length_sec"] = 0 | |
# predefine model | |
net = CleanUNet(**network_config) | |
print_size(net) | |
# load checkpoint | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
net.load_state_dict(checkpoint['model_state_dict']) | |
net.eval() | |
# inference | |
batch_size = 1000000 | |
for file_path in tqdm(files): | |
file_name = os.path.basename(file_path) | |
file_dir = os.path.dirname(file_name) | |
new_file_name = file_name + "_denoised.wav" | |
noisy_audio = load_simple(file_path) | |
LENGTH = len(noisy_audio[0].squeeze()) | |
noisy_audio = torch.chunk(noisy_audio, LENGTH // batch_size + 1, dim=1) | |
all_audio = [] | |
for batch in tqdm(noisy_audio): | |
with torch.no_grad(): | |
generated_audio = sampling(net, batch) | |
generated_audio = generated_audio.cpu().numpy().squeeze() | |
all_audio.append(generated_audio) | |
all_audio = np.concatenate(all_audio, axis=0) | |
save_file = os.path.join(file_dir, new_file_name) | |
print("saved to:", save_file) | |
wavwrite(save_file, 32000, all_audio.squeeze()) | |
audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath') | |
inputs = [audio, CHECKPOINT] | |
outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath') | |
title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia" | |
gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch() |