gpt-omni commited on
Commit
31c2333
2 Parent(s): 7d577d3 399ac1f

fix conflic

Browse files
Files changed (1) hide show
  1. inference.py +18 -19
inference.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import lightning as L
3
  import torch
4
  import time
5
- import spaces
6
  from snac import SNAC
7
  from litgpt import Tokenizer
8
  from litgpt.utils import (
@@ -148,8 +147,8 @@ def load_audio(path):
148
 
149
  def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
150
  snacmodel, out_dir=None):
151
-
152
- model.set_kv_cache(batch_size=2)
153
  tokenlist = generate_TA_BATCH(
154
  model,
155
  audio_feature,
@@ -192,8 +191,8 @@ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, s
192
 
193
 
194
  def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
195
-
196
- model.set_kv_cache(batch_size=1)
197
  tokenlist = generate_AT(
198
  model,
199
  audio_feature,
@@ -215,8 +214,8 @@ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
215
 
216
  def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
217
  snacmodel, out_dir=None):
218
-
219
- model.set_kv_cache(batch_size=1)
220
  tokenlist = generate_AA(
221
  model,
222
  audio_feature,
@@ -257,8 +256,8 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
257
 
258
 
259
  def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
260
-
261
- model.set_kv_cache(batch_size=1)
262
  tokenlist = generate_ASR(
263
  model,
264
  audio_feature,
@@ -281,8 +280,8 @@ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
281
 
282
  def T1_A2(fabric, input_ids, model, text_tokenizer, step,
283
  snacmodel, out_dir=None):
284
-
285
- model.set_kv_cache(batch_size=1)
286
  tokenlist = generate_TA(
287
  model,
288
  None,
@@ -326,8 +325,8 @@ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
326
 
327
  def T1_T2(fabric, input_ids, model, text_tokenizer, step):
328
 
329
-
330
- model.set_kv_cache(batch_size=1)
331
  tokenlist = generate_TT(
332
  model,
333
  None,
@@ -357,13 +356,12 @@ def load_model(ckpt_dir, device):
357
  config.post_adapter = False
358
 
359
  with fabric.init_module(empty_init=False):
360
- model = GPT(config, device=device)
361
 
362
- # model = fabric.setup(model)
363
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
364
  model.load_state_dict(state_dict, strict=True)
365
- model = model.to(device)
366
- model.eval()
367
 
368
  return fabric, model, text_tokenizer, snacmodel, whispermodel
369
 
@@ -401,7 +399,8 @@ class OmniInference:
401
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
402
  model = self.model
403
 
404
- model.set_kv_cache(batch_size=2)
 
405
 
406
  mel, leng = load_audio(audio_path)
407
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
@@ -419,7 +418,7 @@ class OmniInference:
419
  list_output = [[] for i in range(8)]
420
  tokens_A, token_T = next_token_batch(
421
  model,
422
- audio_feature.to(torch.float32).to(device),
423
  input_ids,
424
  [T - 3, T - 3],
425
  ["A1T2", "A1T2"],
 
2
  import lightning as L
3
  import torch
4
  import time
 
5
  from snac import SNAC
6
  from litgpt import Tokenizer
7
  from litgpt.utils import (
 
147
 
148
  def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
149
  snacmodel, out_dir=None):
150
+ with fabric.init_tensor():
151
+ model.set_kv_cache(batch_size=2)
152
  tokenlist = generate_TA_BATCH(
153
  model,
154
  audio_feature,
 
191
 
192
 
193
  def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
194
+ with fabric.init_tensor():
195
+ model.set_kv_cache(batch_size=1)
196
  tokenlist = generate_AT(
197
  model,
198
  audio_feature,
 
214
 
215
  def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
216
  snacmodel, out_dir=None):
217
+ with fabric.init_tensor():
218
+ model.set_kv_cache(batch_size=1)
219
  tokenlist = generate_AA(
220
  model,
221
  audio_feature,
 
256
 
257
 
258
  def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
259
+ with fabric.init_tensor():
260
+ model.set_kv_cache(batch_size=1)
261
  tokenlist = generate_ASR(
262
  model,
263
  audio_feature,
 
280
 
281
  def T1_A2(fabric, input_ids, model, text_tokenizer, step,
282
  snacmodel, out_dir=None):
283
+ with fabric.init_tensor():
284
+ model.set_kv_cache(batch_size=1)
285
  tokenlist = generate_TA(
286
  model,
287
  None,
 
325
 
326
  def T1_T2(fabric, input_ids, model, text_tokenizer, step):
327
 
328
+ with fabric.init_tensor():
329
+ model.set_kv_cache(batch_size=1)
330
  tokenlist = generate_TT(
331
  model,
332
  None,
 
356
  config.post_adapter = False
357
 
358
  with fabric.init_module(empty_init=False):
359
+ model = GPT(config)
360
 
361
+ model = fabric.setup(model)
362
  state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
363
  model.load_state_dict(state_dict, strict=True)
364
+ model.to(device).eval()
 
365
 
366
  return fabric, model, text_tokenizer, snacmodel, whispermodel
367
 
 
399
  assert os.path.exists(audio_path), f"audio file {audio_path} not found"
400
  model = self.model
401
 
402
+ with self.fabric.init_tensor():
403
+ model.set_kv_cache(batch_size=2)
404
 
405
  mel, leng = load_audio(audio_path)
406
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
 
418
  list_output = [[] for i in range(8)]
419
  tokens_A, token_T = next_token_batch(
420
  model,
421
+ audio_feature.to(torch.float32).to(model.device),
422
  input_ids,
423
  [T - 3, T - 3],
424
  ["A1T2", "A1T2"],