azamat commited on
Commit
73e61ac
1 Parent(s): 04d9b94
Files changed (1) hide show
  1. app.py +9 -13
app.py CHANGED
@@ -15,11 +15,15 @@ np.random.seed(0)
15
  from util import print_size, sampling
16
  from network import CleanUNet
17
  import torchaudio
 
 
 
18
 
19
  def load_simple(filename):
20
- print(filename)
21
- audio, _ = torchaudio.load(filename)
22
- return audio
 
23
 
24
  CONFIG = "configs/DNS-large-full.json"
25
  CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl"
@@ -65,24 +69,16 @@ def denoise(filename, ckpt_path = CHECKPOINT, out = "out.wav"):
65
  net.eval()
66
 
67
  # inference
68
- batch_size = 1000000
69
  noisy_audio = load_simple(filename)
70
- LENGTH = len(noisy_audio[0].squeeze())
71
- noisy_audio = torch.chunk(noisy_audio, LENGTH // batch_size + 1, dim=1)
72
- all_audio = []
73
 
74
  for batch in tqdm(noisy_audio):
75
  with torch.no_grad():
76
  generated_audio = sampling(net, batch)
77
- generated_audio = generated_audio.cpu().numpy().squeeze()
78
- all_audio.append(generated_audio)
79
-
80
- all_audio = np.concatenate(all_audio, axis=0)
81
- sf.write(out, np.ravel(all_audio.squeeze()), 32000)
82
 
83
  return out
84
 
85
-
86
  audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath')
87
  inputs = [audio]
88
  outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath')
 
15
  from util import print_size, sampling
16
  from network import CleanUNet
17
  import torchaudio
18
+ import torchaudio.transforms as T
19
+
20
+ SAMPLE_RATE = 22050
21
 
22
  def load_simple(filename):
23
+ wav, sr = torchaudio.load(filename)
24
+ resampler = T.Resample(sr, SAMPLE_RATE, dtype=wav.dtype)
25
+ resampled_wav = resampler(audio)
26
+ return resampled_wav
27
 
28
  CONFIG = "configs/DNS-large-full.json"
29
  CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl"
 
69
  net.eval()
70
 
71
  # inference
 
72
  noisy_audio = load_simple(filename)
 
 
 
73
 
74
  for batch in tqdm(noisy_audio):
75
  with torch.no_grad():
76
  generated_audio = sampling(net, batch)
77
+ generated_audio = generated_audio.cpu()
78
+ sf.write(out, np.ravel(generated_audio.squeeze()), SAMPLE_RATE)
 
 
 
79
 
80
  return out
81
 
 
82
  audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath')
83
  inputs = [audio]
84
  outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath')