WillHeld commited on
Commit
a5c279e
β€’
1 Parent(s): e21d2a8
Files changed (1) hide show
  1. app.py +33 -275
app.py CHANGED
@@ -1,259 +1,82 @@
1
  import copy
2
  import os
3
  import random
4
- import shutil
5
  import sys
6
- from pathlib import Path
7
 
 
8
  import gradio as gr
9
  import librosa
10
  import numpy as np
11
  import soundfile as sf
12
- import spaces
13
  import torch
14
  import torch.nn.functional as F
15
  from accelerate import infer_auto_device_map
16
  from datasets import Audio
17
- from huggingface_hub import CommitScheduler, delete_file, hf_hub_download
18
  from safetensors.torch import load, load_model
19
- from tinydb import TinyDB
20
  from torch import nn
21
  from transformers import (
22
- AutoModel,
23
  AutoModelForCausalLM,
24
  AutoProcessor,
25
  AutoTokenizer,
26
  LlamaForCausalLM,
27
  TextIteratorStreamer,
28
  WhisperForConditionalGeneration,
 
 
29
  )
30
  from transformers.generation import GenerationConfig
31
- import spaces
32
-
33
- # Set an environment variable
34
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
35
 
36
- model_id = "meta-llama/Meta-Llama-3-8B"
37
- # Load the tokenizer and model
38
- AutoTokenizer.from_pretrained(model_id)
39
- AutoModelForCausalLM.from_pretrained(model_id)
40
-
41
- #from models.salmonn import SALMONN
42
-
43
- DB_PATH = "user_study.json"
44
- DB_DATASET_ID = "WillHeld/DiVAVotes"
45
-
46
- # Download existing DB
47
- if not os.path.isfile(DB_PATH):
48
- print("Downloading DB...")
49
- try:
50
- cache_path = hf_hub_download(
51
- repo_id=DB_DATASET_ID, repo_type="dataset", filename=DB_PATH
52
- )
53
- shutil.copyfile(cache_path, DB_PATH)
54
- print("Downloaded DB")
55
- except Exception as e:
56
- print("Error while downloading DB:", e)
57
-
58
- db = TinyDB(DB_PATH)
59
-
60
- # Sync local DB with remote repo every 5 minute (only if a change is detected)
61
- scheduler = CommitScheduler(
62
- repo_id=DB_DATASET_ID,
63
- repo_type="dataset",
64
- folder_path=Path(DB_PATH).parent,
65
- every=5,
66
- allow_patterns=DB_PATH,
67
- )
68
-
69
- tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
70
- prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to("cuda")
71
- pre_user_suffix = torch.tensor([271]).to("cuda")
72
- final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to("cuda")
73
- cache = None
74
  anonymous = False
75
 
76
- resampler = Audio(sampling_rate=16_000)
77
-
78
-
79
- qwen_tokenizer = AutoTokenizer.from_pretrained(
80
- "Qwen/Qwen-Audio-Chat", trust_remote_code=True
81
- )
82
- qwen_model = AutoModelForCausalLM.from_pretrained(
83
- "Qwen/Qwen-Audio-Chat",
84
- device_map="auto",
85
- trust_remote_code=True,
86
- torch_dtype=torch.float16,
87
- ).eval()
88
-
89
- qwen_model.generation_config = GenerationConfig.from_pretrained(
90
- "Qwen/Qwen-Audio-Chat",
91
- trust_remote_code=True,
92
- do_sample=False,
93
- top_k=50,
94
- top_p=1.0,
95
- )
96
-
97
-
98
- # salmonn_model = SALMONN(
99
- # ckpt="./SALMONN_PATHS/salmonn_v1.pth",
100
- # whisper_path="./SALMONN_PATHS/whisper-large-v2",
101
- # beats_path="./SALMONN_PATHS/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt",
102
- # vicuna_path="./SALMONN_PATHS/vicuna-13b-v1.1",
103
- # low_resource=False,
104
- # device="cuda",
105
- # )
106
- # salmonn_tokenizer = salmonn_model.llama_tokenizer
107
-
108
-
109
- diva = AutoModel.from_pretrained("WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True, speech_encoder_device="cuda")
110
-
111
- # @spaces.GPU
112
- # @torch.no_grad
113
- # def salmonn_fwd(audio_input, prompt, do_sample=False, temperature=0.001):
114
- # if audio_input == None:
115
- # return ""
116
- # sr, y = audio_input
117
- # y = y.astype(np.float32)
118
- # y /= np.max(np.abs(y))
119
- # a = resampler.decode_example(
120
- # resampler.encode_example({"array": y, "sampling_rate": sr})
121
- # )
122
- # sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav")
123
- # streamer = TextIteratorStreamer(salmonn_tokenizer)
124
- # with torch.cuda.amp.autocast(dtype=torch.float16):
125
- # llm_message = salmonn_model.generate(
126
- # wav_path="tmp.wav",
127
- # prompt=prompt,
128
- # do_sample=False,
129
- # top_p=1.0,
130
- # temperature=0.0,
131
- # device="cuda:0",
132
- # streamer=streamer,
133
- # )
134
-
135
- # response = ""
136
- # for new_tokens in streamer:
137
- # response += new_tokens
138
- # yield response.replace("</s>", "")
139
 
 
140
 
141
  @spaces.GPU
142
  @torch.no_grad
143
- def qwen_audio(audio_input, prompt, do_sample=False, temperature=0.001):
144
- if audio_input == None:
145
- return ""
146
  sr, y = audio_input
 
147
  y = y.astype(np.float32)
148
  y /= np.max(np.abs(y))
149
  a = resampler.decode_example(
150
  resampler.encode_example({"array": y, "sampling_rate": sr})
151
  )
152
- sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav")
153
- query = qwen_tokenizer.from_list_format([{"audio": "tmp.wav"}, {"text": prompt}])
154
-
155
- response, history = qwen_model.chat(
156
- qwen_tokenizer,
157
- query=query,
158
- system="You are a helpful assistant.",
159
- history=None,
160
- )
161
- return response
162
-
163
 
164
  @spaces.GPU
165
- @torch.no_grad
166
- def via(audio_input, prompt, do_sample=False, temperature=0.001):
167
- if audio_input == None:
168
- return ""
169
- sr, y = audio_input
170
- y = y.astype(np.float32)
171
- y /= np.max(np.abs(y))
172
- a = resampler.decode_example(
173
- resampler.encode_example({"array": y, "sampling_rate": sr})
174
- )
175
-
176
- audio = a["array"]
177
-
178
- yield from diva.generate_stream(audio, prompt)
179
-
180
-
181
- def transcribe(audio_input, text_prompt, state, model_order):
182
- yield (
183
- gr.Button(
184
- value="Waiting in queue for GPU time...",
185
- interactive=False,
186
- variant="primary",
187
- ),
188
- "",
189
- "",
190
- "",
191
- gr.Button(visible=False),
192
- gr.Button(visible=False),
193
- gr.Button(visible=False),
194
- state,
195
- )
196
  if audio_input == None:
197
  return (
198
  "",
199
- "",
200
- "",
201
- gr.Button(visible=False),
202
- gr.Button(visible=False),
203
- gr.Button(visible=False),
204
  state,
205
  )
206
 
207
- def gen_from_via():
208
- via_resp = via(audio_input, text_prompt)
209
- for resp in via_resp:
210
- v_resp = gr.Textbox(
211
  value=resp,
212
  visible=True,
213
  label=model_names[0] if not anonymous else f"Model {order}",
214
  )
215
- yield (v_resp, s_resp, q_resp)
216
-
217
- # def gen_from_salmonn():
218
- # salmonn_resp = salmonn_fwd(audio_input, text_prompt)
219
- # for resp in salmonn_resp:
220
- # s_resp = gr.Textbox(
221
- # value=resp,
222
- # visible=True,
223
- # label=model_names[1] if not anonymous else f"Model {order}",
224
- # )
225
- # yield (v_resp, s_resp, q_resp)
226
-
227
- def gen_from_qwen():
228
- qwen_resp = qwen_audio(audio_input, text_prompt)
229
- q_resp = gr.Textbox(
230
- value=qwen_resp,
231
- visible=True,
232
- label=model_names[2] if not anonymous else f"Model {order}",
233
- )
234
- yield (v_resp, s_resp, q_resp)
235
 
236
  spinner_id = 0
237
  spinners = ["◐ ", "β—“ ", "β—‘", "β—’"]
238
- initial_responses = [("", "", "")]
239
  resp_generators = [
240
- gen_from_via(),
241
- # gen_from_salmonn(),
242
- gen_from_qwen(),
243
  ]
244
  order = -1
245
- resp_generators = [
246
- resp_generators[model_order[0]],
247
- #resp_generators[model_order[1]],
248
- resp_generators[model_order[1]],
249
- ]
250
  for generator in [initial_responses, *resp_generators]:
251
  order += 1
252
  for resps in generator:
253
- v_resp, s_resp, q_resp = resps
254
  resp_1 = resps[model_order[0]]
255
- resp_2 = s_resp #resps[model_order[1]]
256
- resp_3 = resps[model_order[1]]
257
  spinner = spinners[spinner_id]
258
  spinner_id = (spinner_id + 1) % 4
259
  yield (
@@ -263,11 +86,6 @@ def transcribe(audio_input, text_prompt, state, model_order):
263
  variant="primary",
264
  ),
265
  resp_1,
266
- resp_2,
267
- resp_3,
268
- gr.Button(visible=False),
269
- gr.Button(visible=False),
270
- gr.Button(visible=False),
271
  state,
272
  )
273
  yield (
@@ -276,10 +94,6 @@ def transcribe(audio_input, text_prompt, state, model_order):
276
  ),
277
  resp_1,
278
  resp_2,
279
- resp_3,
280
- gr.Button(visible=True),
281
- gr.Button(visible=False),
282
- gr.Button(visible=True),
283
  responses_complete(state),
284
  )
285
 
@@ -287,7 +101,7 @@ def transcribe(audio_input, text_prompt, state, model_order):
287
  def on_page_load(state, model_order):
288
  if state == 0:
289
  gr.Info(
290
- "Record what you want to say to your AI Assistant! All Audio recordings are stored only temporarily and will be erased as soon as you exit this page."
291
  )
292
  state = 1
293
  if anonymous:
@@ -298,53 +112,27 @@ def on_page_load(state, model_order):
298
  def recording_complete(state):
299
  if state == 1:
300
  gr.Info(
301
- "Submit your recording to get responses from all three models! You can also influence the model responses with an optional prompt."
302
  )
303
  state = 2
304
  return (
305
  gr.Button(
306
- value="Click to compare models!", interactive=True, variant="primary"
307
  ),
308
  state,
309
  )
310
 
311
 
312
- def responses_complete(state):
313
- if state == 2:
314
- gr.Info(
315
- "Give us your feedback! Mark which model gave you the best response so we can understand the quality of these different voice assistant models. NOTE: This will save an (irreversible) hash of your inputs to deduplicate any repeated votes."
316
- )
317
- state = 3
318
- return state
319
-
320
-
321
  def clear_factory(button_id):
322
- def clear(audio_input, text_prompt, model_order):
323
- if button_id != None:
324
- sr, y = audio_input
325
- with scheduler.lock:
326
- db.insert(
327
- {
328
- "audio_hash": hash(str(y)),
329
- "text_prompt": hash(text_prompt),
330
- "best": model_shorthand[model_order[button_id]],
331
- }
332
- )
333
- if anonymous:
334
- random.shuffle(model_order)
335
  return (
336
  model_order,
337
  gr.Button(
338
  value="Record Audio to Submit!",
339
  interactive=False,
340
  ),
341
- gr.Button(visible=False),
342
- gr.Button(visible=False),
343
- gr.Button(visible=False),
344
  None,
345
- gr.Textbox(visible=False),
346
- gr.Textbox(visible=False),
347
- gr.Textbox(visible=False),
348
  )
349
 
350
  return clear
@@ -368,9 +156,8 @@ theme = gr.themes.Soft(
368
  neutral_hue="stone",
369
  )
370
 
371
-
372
- model_names = ["Llama 3 DiVA", "SALMONN", "Qwen Audio"]
373
- model_shorthand = ["via", "salmonn", "qwen"]
374
  with gr.Blocks(theme=theme) as demo:
375
  state = gr.State(0)
376
  model_order = gr.State([0, 1])
@@ -378,26 +165,12 @@ with gr.Blocks(theme=theme) as demo:
378
  audio_input = gr.Audio(
379
  sources=["microphone"], streaming=False, label="Audio Input"
380
  )
381
- with gr.Row():
382
- prompt = gr.Textbox(
383
- value="",
384
- label="Text Prompt",
385
- placeholder="Optional: Additional text prompt to influence how the model responds to your speech. e.g. 'Respond in a Haiku style.' or 'Translate the input to Arabic'",
386
- )
387
 
388
  with gr.Row():
389
  btn = gr.Button(value="Record Audio to Submit!", interactive=False)
390
 
391
  with gr.Row():
392
- with gr.Column(scale=1):
393
- out1 = gr.Textbox(visible=False)
394
- best1 = gr.Button(value="This response is best", visible=False)
395
- with gr.Column(scale=1):
396
- out2 = gr.Textbox(visible=False)
397
- best2 = gr.Button(value="This response is best", visible=False)
398
- with gr.Column(scale=1):
399
- out3 = gr.Textbox(visible=False)
400
- best3 = gr.Button(value="This response is best", visible=False)
401
 
402
  audio_input.stop_recording(
403
  recording_complete,
@@ -413,31 +186,16 @@ with gr.Blocks(theme=theme) as demo:
413
  )
414
  btn.click(
415
  fn=transcribe,
416
- inputs=[audio_input, prompt, state, model_order],
417
- outputs=[btn, out1, out2, out3, best1, best2, best3, state],
418
- )
419
- best1.click(
420
- fn=clear_factory(0),
421
- inputs=[audio_input, prompt, model_order],
422
- outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
423
- )
424
- best2.click(
425
- fn=clear_factory(1),
426
- inputs=[audio_input, prompt, model_order],
427
- outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
428
- )
429
- best3.click(
430
- fn=clear_factory(2),
431
- inputs=[audio_input, prompt, model_order],
432
- outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
433
  )
434
  audio_input.clear(
435
  clear_factory(None),
436
- [audio_input, prompt, model_order],
437
- [model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
438
  )
439
  demo.load(
440
  fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order]
441
  )
442
 
443
- demo.launch(share=True)
 
1
  import copy
2
  import os
3
  import random
 
4
  import sys
 
5
 
6
+ import xxhash
7
  import gradio as gr
8
  import librosa
9
  import numpy as np
10
  import soundfile as sf
 
11
  import torch
12
  import torch.nn.functional as F
13
  from accelerate import infer_auto_device_map
14
  from datasets import Audio
15
+ from models.salmonn import SALMONN
16
  from safetensors.torch import load, load_model
17
+ import spaces
18
  from torch import nn
19
  from transformers import (
 
20
  AutoModelForCausalLM,
21
  AutoProcessor,
22
  AutoTokenizer,
23
  LlamaForCausalLM,
24
  TextIteratorStreamer,
25
  WhisperForConditionalGeneration,
26
+ AutoProcessor,
27
+ AutoModel
28
  )
29
  from transformers.generation import GenerationConfig
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  anonymous = False
32
 
33
+ diva_model = AutoModel.from_pretrained("WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ resampler = Audio(sampling_rate=16_000)
36
 
37
  @spaces.GPU
38
  @torch.no_grad
39
+ def diva_audio(audio_input, do_sample=False, temperature=0.001):
 
 
40
  sr, y = audio_input
41
+ x = xxhash.xxh32(bytes(y)).hexdigest()
42
  y = y.astype(np.float32)
43
  y /= np.max(np.abs(y))
44
  a = resampler.decode_example(
45
  resampler.encode_example({"array": y, "sampling_rate": sr})
46
  )
47
+ yield from diva_model.generate_stream(a["array"], None, do_sample=do_sample, max_new_tokens = 256)
 
 
 
 
 
 
 
 
 
 
48
 
49
  @spaces.GPU
50
+ def transcribe(audio_input, state, model_order):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if audio_input == None:
52
  return (
53
  "",
 
 
 
 
 
54
  state,
55
  )
56
 
57
+ def gen_from_diva():
58
+ diva_resp = diva_audio(audio_input)
59
+ for resp in diva_resp:
60
+ d_resp = gr.Textbox(
61
  value=resp,
62
  visible=True,
63
  label=model_names[0] if not anonymous else f"Model {order}",
64
  )
65
+ yield (d_resp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  spinner_id = 0
68
  spinners = ["◐ ", "β—“ ", "β—‘", "β—’"]
69
+ initial_responses = [("")]
70
  resp_generators = [
71
+ gen_from_diva(),
 
 
72
  ]
73
  order = -1
74
+ resp_generators = [resp_generators[model_order[0]]]
 
 
 
 
75
  for generator in [initial_responses, *resp_generators]:
76
  order += 1
77
  for resps in generator:
78
+ s_resp, q_resp = resps
79
  resp_1 = resps[model_order[0]]
 
 
80
  spinner = spinners[spinner_id]
81
  spinner_id = (spinner_id + 1) % 4
82
  yield (
 
86
  variant="primary",
87
  ),
88
  resp_1,
 
 
 
 
 
89
  state,
90
  )
91
  yield (
 
94
  ),
95
  resp_1,
96
  resp_2,
 
 
 
 
97
  responses_complete(state),
98
  )
99
 
 
101
  def on_page_load(state, model_order):
102
  if state == 0:
103
  gr.Info(
104
+ "Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant, or ChatGPT for."
105
  )
106
  state = 1
107
  if anonymous:
 
112
  def recording_complete(state):
113
  if state == 1:
114
  gr.Info(
115
+ "Once you submit your recording, DiVA will stream back a response! This might take a second."
116
  )
117
  state = 2
118
  return (
119
  gr.Button(
120
+ value="Click to run inference!", interactive=True, variant="primary"
121
  ),
122
  state,
123
  )
124
 
125
 
 
 
 
 
 
 
 
 
 
126
  def clear_factory(button_id):
127
+ def clear(audio_input, model_order):
 
 
 
 
 
 
 
 
 
 
 
 
128
  return (
129
  model_order,
130
  gr.Button(
131
  value="Record Audio to Submit!",
132
  interactive=False,
133
  ),
 
 
 
134
  None,
135
+ None
 
 
136
  )
137
 
138
  return clear
 
156
  neutral_hue="stone",
157
  )
158
 
159
+ model_names = ["DiVA Llama 3 8B"]
160
+ model_shorthand = ["diva"]
 
161
  with gr.Blocks(theme=theme) as demo:
162
  state = gr.State(0)
163
  model_order = gr.State([0, 1])
 
165
  audio_input = gr.Audio(
166
  sources=["microphone"], streaming=False, label="Audio Input"
167
  )
 
 
 
 
 
 
168
 
169
  with gr.Row():
170
  btn = gr.Button(value="Record Audio to Submit!", interactive=False)
171
 
172
  with gr.Row():
173
+ out1 = gr.Textbox(visible=False)
 
 
 
 
 
 
 
 
174
 
175
  audio_input.stop_recording(
176
  recording_complete,
 
186
  )
187
  btn.click(
188
  fn=transcribe,
189
+ inputs=[audio_input, state, model_order],
190
+ outputs=[btn, out1, state],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  )
192
  audio_input.clear(
193
  clear_factory(None),
194
+ [audio_input, model_order],
195
+ [model_order, btn, audio_input, out1],
196
  )
197
  demo.load(
198
  fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order]
199
  )
200
 
201
+ demo.launch(share=True)