gpt-omni
commited on
Commit
•
6eacc63
1
Parent(s):
008f4a1
fix snac
Browse files- utils/snac_utils.py +2 -2
utils/snac_utils.py
CHANGED
@@ -21,8 +21,8 @@ def layershift(input_id, layer, stride=4160, shift=152000):
|
|
21 |
return input_id + shift + layer * stride
|
22 |
|
23 |
|
24 |
-
def generate_audio_data(snac_tokens, snacmodel):
|
25 |
-
audio = reconstruct_tensors(snac_tokens)
|
26 |
with torch.inference_mode():
|
27 |
audio_hat = snacmodel.decode(audio)
|
28 |
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
|
|
|
21 |
return input_id + shift + layer * stride
|
22 |
|
23 |
|
24 |
+
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
25 |
+
audio = reconstruct_tensors(snac_tokens, device)
|
26 |
with torch.inference_mode():
|
27 |
audio_hat = snacmodel.decode(audio)
|
28 |
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
|