Spaces:
Runtime error
Runtime error
fix conflic
Browse files- inference.py +18 -19
inference.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
import lightning as L
|
3 |
import torch
|
4 |
import time
|
5 |
-
import spaces
|
6 |
from snac import SNAC
|
7 |
from litgpt import Tokenizer
|
8 |
from litgpt.utils import (
|
@@ -148,8 +147,8 @@ def load_audio(path):
|
|
148 |
|
149 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
150 |
snacmodel, out_dir=None):
|
151 |
-
|
152 |
-
|
153 |
tokenlist = generate_TA_BATCH(
|
154 |
model,
|
155 |
audio_feature,
|
@@ -192,8 +191,8 @@ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, s
|
|
192 |
|
193 |
|
194 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
195 |
-
|
196 |
-
|
197 |
tokenlist = generate_AT(
|
198 |
model,
|
199 |
audio_feature,
|
@@ -215,8 +214,8 @@ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
215 |
|
216 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
217 |
snacmodel, out_dir=None):
|
218 |
-
|
219 |
-
|
220 |
tokenlist = generate_AA(
|
221 |
model,
|
222 |
audio_feature,
|
@@ -257,8 +256,8 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
|
257 |
|
258 |
|
259 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
260 |
-
|
261 |
-
|
262 |
tokenlist = generate_ASR(
|
263 |
model,
|
264 |
audio_feature,
|
@@ -281,8 +280,8 @@ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
|
281 |
|
282 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
283 |
snacmodel, out_dir=None):
|
284 |
-
|
285 |
-
|
286 |
tokenlist = generate_TA(
|
287 |
model,
|
288 |
None,
|
@@ -326,8 +325,8 @@ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
|
326 |
|
327 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
328 |
|
329 |
-
|
330 |
-
|
331 |
tokenlist = generate_TT(
|
332 |
model,
|
333 |
None,
|
@@ -357,13 +356,12 @@ def load_model(ckpt_dir, device):
|
|
357 |
config.post_adapter = False
|
358 |
|
359 |
with fabric.init_module(empty_init=False):
|
360 |
-
model = GPT(config
|
361 |
|
362 |
-
|
363 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
364 |
model.load_state_dict(state_dict, strict=True)
|
365 |
-
model
|
366 |
-
model.eval()
|
367 |
|
368 |
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
369 |
|
@@ -401,7 +399,8 @@ class OmniInference:
|
|
401 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
402 |
model = self.model
|
403 |
|
404 |
-
|
|
|
405 |
|
406 |
mel, leng = load_audio(audio_path)
|
407 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
@@ -419,7 +418,7 @@ class OmniInference:
|
|
419 |
list_output = [[] for i in range(8)]
|
420 |
tokens_A, token_T = next_token_batch(
|
421 |
model,
|
422 |
-
audio_feature.to(torch.float32).to(device),
|
423 |
input_ids,
|
424 |
[T - 3, T - 3],
|
425 |
["A1T2", "A1T2"],
|
|
|
2 |
import lightning as L
|
3 |
import torch
|
4 |
import time
|
|
|
5 |
from snac import SNAC
|
6 |
from litgpt import Tokenizer
|
7 |
from litgpt.utils import (
|
|
|
147 |
|
148 |
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
149 |
snacmodel, out_dir=None):
|
150 |
+
with fabric.init_tensor():
|
151 |
+
model.set_kv_cache(batch_size=2)
|
152 |
tokenlist = generate_TA_BATCH(
|
153 |
model,
|
154 |
audio_feature,
|
|
|
191 |
|
192 |
|
193 |
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
194 |
+
with fabric.init_tensor():
|
195 |
+
model.set_kv_cache(batch_size=1)
|
196 |
tokenlist = generate_AT(
|
197 |
model,
|
198 |
audio_feature,
|
|
|
214 |
|
215 |
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
216 |
snacmodel, out_dir=None):
|
217 |
+
with fabric.init_tensor():
|
218 |
+
model.set_kv_cache(batch_size=1)
|
219 |
tokenlist = generate_AA(
|
220 |
model,
|
221 |
audio_feature,
|
|
|
256 |
|
257 |
|
258 |
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
259 |
+
with fabric.init_tensor():
|
260 |
+
model.set_kv_cache(batch_size=1)
|
261 |
tokenlist = generate_ASR(
|
262 |
model,
|
263 |
audio_feature,
|
|
|
280 |
|
281 |
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
282 |
snacmodel, out_dir=None):
|
283 |
+
with fabric.init_tensor():
|
284 |
+
model.set_kv_cache(batch_size=1)
|
285 |
tokenlist = generate_TA(
|
286 |
model,
|
287 |
None,
|
|
|
325 |
|
326 |
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
327 |
|
328 |
+
with fabric.init_tensor():
|
329 |
+
model.set_kv_cache(batch_size=1)
|
330 |
tokenlist = generate_TT(
|
331 |
model,
|
332 |
None,
|
|
|
356 |
config.post_adapter = False
|
357 |
|
358 |
with fabric.init_module(empty_init=False):
|
359 |
+
model = GPT(config)
|
360 |
|
361 |
+
model = fabric.setup(model)
|
362 |
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
363 |
model.load_state_dict(state_dict, strict=True)
|
364 |
+
model.to(device).eval()
|
|
|
365 |
|
366 |
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
367 |
|
|
|
399 |
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
400 |
model = self.model
|
401 |
|
402 |
+
with self.fabric.init_tensor():
|
403 |
+
model.set_kv_cache(batch_size=2)
|
404 |
|
405 |
mel, leng = load_audio(audio_path)
|
406 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
418 |
list_output = [[] for i in range(8)]
|
419 |
tokens_A, token_T = next_token_batch(
|
420 |
model,
|
421 |
+
audio_feature.to(torch.float32).to(model.device),
|
422 |
input_ids,
|
423 |
[T - 3, T - 3],
|
424 |
["A1T2", "A1T2"],
|