Spaces:
Runtime error
Runtime error
File size: 3,146 Bytes
33e3a91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import os
import json
from tqdm import tqdm
from copy import deepcopy
import numpy as np
import gradio as gr
import torch
import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
from scipy.io.wavfile import write as wavwrite
from util import print_size, sampling
from network import CleanUNet
import torchaudio
def load_simple(filename):
audio, _ = torchaudio.load(filename)
return audio
CONFIG = "configs/DNS-large-full.json"
CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl"
# Parse configs. Globals nicer in this case
with open(CONFIG) as f:
data = f.read()
config = json.loads(data)
gen_config = config["gen_config"]
global network_config
network_config = config["network_config"] # to define wavenet
global train_config
train_config = config["train_config"] # train config
global trainset_config
trainset_config = config["trainset_config"] # to read trainset configurations
def denoise(files, ckpt_path):
"""
Denoise audio
Parameters:
output_directory (str): save generated speeches to this path
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded;
automitically selects the maximum iteration if 'max' is selected
subset (str): training, testing, validation
dump (bool): whether save enhanced (denoised) audio
"""
# setup local experiment path
exp_path = train_config["exp_path"]
print('exp_path:', exp_path)
# load data
loader_config = deepcopy(trainset_config)
loader_config["crop_length_sec"] = 0
# predefine model
net = CleanUNet(**network_config)
print_size(net)
# load checkpoint
checkpoint = torch.load(ckpt_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()
# inference
batch_size = 1000000
for file_path in tqdm(files):
file_name = os.path.basename(file_path)
file_dir = os.path.dirname(file_name)
new_file_name = file_name + "_denoised.wav"
noisy_audio = load_simple(file_path)
LENGTH = len(noisy_audio[0].squeeze())
noisy_audio = torch.chunk(noisy_audio, LENGTH // batch_size + 1, dim=1)
all_audio = []
for batch in tqdm(noisy_audio):
with torch.no_grad():
generated_audio = sampling(net, batch)
generated_audio = generated_audio.cpu().numpy().squeeze()
all_audio.append(generated_audio)
all_audio = np.concatenate(all_audio, axis=0)
save_file = os.path.join(file_dir, new_file_name)
print("saved to:", save_file)
wavwrite(save_file, 32000, all_audio.squeeze())
audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath')
inputs = [audio, CHECKPOINT]
outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath')
title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia"
gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch() |