Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- inference.py +19 -21
inference.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
import lightning as L
|
3 |
import torch
|
4 |
import time
|
5 |
-
import spaces
|
6 |
from snac import SNAC
|
7 |
from litgpt import Tokenizer
|
8 |
from litgpt.utils import (
|
@@ -147,8 +146,8 @@ def load_audio(path):
|
|
147 |
|
148 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
149 |
snacmodel, out_dir=None):
|
150 |
-
|
151 |
-
|
152 |
tokenlist = generate_TA_BATCH(
|
153 |
model,
|
154 |
audio_feature,
|
@@ -191,8 +190,8 @@ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, s
|
|
191 |
|
192 |
|
193 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
194 |
-
|
195 |
-
|
196 |
tokenlist = generate_AT(
|
197 |
model,
|
198 |
audio_feature,
|
@@ -214,8 +213,8 @@ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
214 |
|
215 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
216 |
snacmodel, out_dir=None):
|
217 |
-
|
218 |
-
|
219 |
tokenlist = generate_AA(
|
220 |
model,
|
221 |
audio_feature,
|
@@ -256,8 +255,8 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
|
256 |
|
257 |
|
258 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
259 |
-
|
260 |
-
|
261 |
tokenlist = generate_ASR(
|
262 |
model,
|
263 |
audio_feature,
|
@@ -280,8 +279,8 @@ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
280 |
|
281 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
282 |
snacmodel, out_dir=None):
|
283 |
-
|
284 |
-
|
285 |
tokenlist = generate_TA(
|
286 |
model,
|
287 |
None,
|
@@ -325,8 +324,8 @@ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
|
325 |
|
326 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
327 |
|
328 |
-
|
329 |
-
|
330 |
tokenlist = generate_TT(
|
331 |
model,
|
332 |
None,
|
@@ -356,13 +355,12 @@ def load_model(ckpt_dir, device):
|
|
356 |
config.post_adapter = False
|
357 |
|
358 |
with fabric.init_module(empty_init=False):
|
359 |
-
model = GPT(config
|
360 |
|
361 |
-
|
362 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
363 |
model.load_state_dict(state_dict, strict=True)
|
364 |
-
model
|
365 |
-
model.eval()
|
366 |
|
367 |
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
368 |
|
@@ -385,8 +383,7 @@ class OmniInference:
|
|
385 |
for _ in self.run_AT_batch_stream(sample):
|
386 |
pass
|
387 |
|
388 |
-
|
389 |
-
@spaces.GPU
|
390 |
def run_AT_batch_stream(self,
|
391 |
audio_path,
|
392 |
stream_stride=4,
|
@@ -401,7 +398,8 @@ class OmniInference:
|
|
401 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
402 |
model = self.model
|
403 |
|
404 |
-
|
|
|
405 |
|
406 |
mel, leng = load_audio(audio_path)
|
407 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
@@ -419,7 +417,7 @@ class OmniInference:
|
|
419 |
list_output = [[] for i in range(8)]
|
420 |
tokens_A, token_T = next_token_batch(
|
421 |
model,
|
422 |
-
audio_feature.to(torch.float32).to(device),
|
423 |
input_ids,
|
424 |
[T - 3, T - 3],
|
425 |
["A1T2", "A1T2"],
|
|
|
2 |
import lightning as L
|
3 |
import torch
|
4 |
import time
|
|
|
5 |
from snac import SNAC
|
6 |
from litgpt import Tokenizer
|
7 |
from litgpt.utils import (
|
|
|
146 |
|
147 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
148 |
snacmodel, out_dir=None):
|
149 |
+
with fabric.init_tensor():
|
150 |
+
model.set_kv_cache(batch_size=2)
|
151 |
tokenlist = generate_TA_BATCH(
|
152 |
model,
|
153 |
audio_feature,
|
|
|
190 |
|
191 |
|
192 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
193 |
+
with fabric.init_tensor():
|
194 |
+
model.set_kv_cache(batch_size=1)
|
195 |
tokenlist = generate_AT(
|
196 |
model,
|
197 |
audio_feature,
|
|
|
213 |
|
214 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
215 |
snacmodel, out_dir=None):
|
216 |
+
with fabric.init_tensor():
|
217 |
+
model.set_kv_cache(batch_size=1)
|
218 |
tokenlist = generate_AA(
|
219 |
model,
|
220 |
audio_feature,
|
|
|
255 |
|
256 |
|
257 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
258 |
+
with fabric.init_tensor():
|
259 |
+
model.set_kv_cache(batch_size=1)
|
260 |
tokenlist = generate_ASR(
|
261 |
model,
|
262 |
audio_feature,
|
|
|
279 |
|
280 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
281 |
snacmodel, out_dir=None):
|
282 |
+
with fabric.init_tensor():
|
283 |
+
model.set_kv_cache(batch_size=1)
|
284 |
tokenlist = generate_TA(
|
285 |
model,
|
286 |
None,
|
|
|
324 |
|
325 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
326 |
|
327 |
+
with fabric.init_tensor():
|
328 |
+
model.set_kv_cache(batch_size=1)
|
329 |
tokenlist = generate_TT(
|
330 |
model,
|
331 |
None,
|
|
|
355 |
config.post_adapter = False
|
356 |
|
357 |
with fabric.init_module(empty_init=False):
|
358 |
+
model = GPT(config)
|
359 |
|
360 |
+
model = fabric.setup(model)
|
361 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
362 |
model.load_state_dict(state_dict, strict=True)
|
363 |
+
model.to(device).eval()
|
|
|
364 |
|
365 |
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
366 |
|
|
|
383 |
for _ in self.run_AT_batch_stream(sample):
|
384 |
pass
|
385 |
|
386 |
+
@torch.inference_mode()
|
|
|
387 |
def run_AT_batch_stream(self,
|
388 |
audio_path,
|
389 |
stream_stride=4,
|
|
|
398 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
399 |
model = self.model
|
400 |
|
401 |
+
with self.fabric.init_tensor():
|
402 |
+
model.set_kv_cache(batch_size=2)
|
403 |
|
404 |
mel, leng = load_audio(audio_path)
|
405 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
417 |
list_output = [[] for i in range(8)]
|
418 |
tokens_A, token_T = next_token_batch(
|
419 |
model,
|
420 |
+
audio_feature.to(torch.float32).to(model.device),
|
421 |
input_ids,
|
422 |
[T - 3, T - 3],
|
423 |
["A1T2", "A1T2"],
|