Mylo commited on
Commit
62895ea
1 Parent(s): c1a6347

Bug fix, (why did it take this long to get a bug report?) #2

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -24,17 +24,21 @@ encodec_model = EncodecModel.encodec_model_24khz()
24
 
25
  def clone(audio, *args):
26
  sr, wav = audio
 
 
 
 
 
 
27
  if wav.shape[0] == 2: # Stereo to mono if needed
28
  wav = wav.mean(0, keepdim=True)
 
 
29
 
30
  wav = wav[-int(sr*20):] # Take only the last 20 seconds
31
 
32
- duration = wav.shape[0]
33
-
34
  wav = wav.reshape(1, -1) # Reshape from gradio style to HuBERT shape. (N, 1) to (1, N)
35
 
36
- wav = torch.tensor(wav, dtype=torch.float32)
37
-
38
  semantic_vectors = hubert_model.forward(wav, input_sample_hz=sr)
39
  semantic_tokens = tokenizer_model.get_token(semantic_vectors)
40
 
 
24
 
25
  def clone(audio, *args):
26
  sr, wav = audio
27
+
28
+ wav = torch.tensor(wav)
29
+
30
+ if wav.dtype == torch.int16:
31
+ wav = wav.float() / 32767.0
32
+
33
  if wav.shape[0] == 2: # Stereo to mono if needed
34
  wav = wav.mean(0, keepdim=True)
35
+ if wav.shape[1] == 2:
36
+ wav = wav.mean(1, keepdim=False).unsqueeze(-1)
37
 
38
  wav = wav[-int(sr*20):] # Take only the last 20 seconds
39
 
 
 
40
  wav = wav.reshape(1, -1) # Reshape from gradio style to HuBERT shape. (N, 1) to (1, N)
41
 
 
 
42
  semantic_vectors = hubert_model.forward(wav, input_sample_hz=sr)
43
  semantic_tokens = tokenizer_model.get_token(semantic_vectors)
44