WillHeld commited on
Commit
2430bc8
β€’
1 Parent(s): f3ba0cb

Tmp demo file

Browse files
Files changed (1) hide show
  1. demo.py +401 -0
demo.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ import sys
5
+
6
+ import gradio as gr
7
+ import librosa
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from accelerate import infer_auto_device_map
13
+ from datasets import Audio
14
+ from models.salmonn import SALMONN
15
+ from safetensors.torch import load, load_model
16
+ from tinydb import TinyDB
17
+ from torch import nn
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoProcessor,
21
+ AutoModel,
22
+ AutoTokenizer,
23
+ LlamaForCausalLM,
24
+ TextIteratorStreamer,
25
+ WhisperForConditionalGeneration,
26
+ )
27
+ from transformers.generation import GenerationConfig
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
30
+ prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to("cuda:0")
31
+ pre_user_suffix = torch.tensor([271]).to("cuda:0")
32
+ final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to("cuda:0")
33
+ cache = None
34
+ anonymous = False
35
+
36
+ resampler = Audio(sampling_rate=16_000)
37
+
38
+
39
+ qwen_tokenizer = AutoTokenizer.from_pretrained(
40
+ "Qwen/Qwen-Audio-Chat", trust_remote_code=True
41
+ )
42
+ qwen_model = AutoModelForCausalLM.from_pretrained(
43
+ "Qwen/Qwen-Audio-Chat",
44
+ device_map="auto",
45
+ trust_remote_code=True,
46
+ torch_dtype=torch.float16,
47
+ ).eval()
48
+
49
+ qwen_model.generation_config = GenerationConfig.from_pretrained(
50
+ "Qwen/Qwen-Audio-Chat",
51
+ trust_remote_code=True,
52
+ do_sample=False,
53
+ top_k=50,
54
+ top_p=1.0,
55
+ )
56
+
57
+
58
+ salmonn_model = SALMONN(
59
+ ckpt="./SALMONN_PATHS/salmonn_v1.pth",
60
+ whisper_path="./SALMONN_PATHS/whisper-large-v2",
61
+ beats_path="./SALMONN_PATHS/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt",
62
+ vicuna_path="./SALMONN_PATHS/vicuna-13b-v1.1",
63
+ low_resource=False,
64
+ device="cuda:0",
65
+ )
66
+ salmonn_tokenizer = salmonn_model.llama_tokenizer
67
+
68
+
69
+ diva = AutoModel.from_pretrained("WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True)
70
+
71
+
72
+ @torch.no_grad
73
+ def salmonn_fwd(audio_input, prompt, do_sample=False, temperature=0.001):
74
+ if audio_input == None:
75
+ return ""
76
+ sr, y = audio_input
77
+ y = y.astype(np.float32)
78
+ y /= np.max(np.abs(y))
79
+ a = resampler.decode_example(
80
+ resampler.encode_example({"array": y, "sampling_rate": sr})
81
+ )
82
+ sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav")
83
+ streamer = TextIteratorStreamer(salmonn_tokenizer)
84
+ with torch.cuda.amp.autocast(dtype=torch.float16):
85
+ llm_message = salmonn_model.generate(
86
+ wav_path="tmp.wav",
87
+ prompt=prompt,
88
+ do_sample=False,
89
+ top_p=1.0,
90
+ temperature=0.0,
91
+ device="cuda:0",
92
+ streamer=streamer,
93
+ )
94
+
95
+ response = ""
96
+ for new_tokens in streamer:
97
+ response += new_tokens
98
+ yield response.replace("</s>", "")
99
+
100
+
101
+ @torch.no_grad
102
+ def qwen_audio(audio_input, prompt, do_sample=False, temperature=0.001):
103
+ if audio_input == None:
104
+ return ""
105
+ sr, y = audio_input
106
+ y = y.astype(np.float32)
107
+ y /= np.max(np.abs(y))
108
+ a = resampler.decode_example(
109
+ resampler.encode_example({"array": y, "sampling_rate": sr})
110
+ )
111
+ sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav")
112
+ query = qwen_tokenizer.from_list_format([{"audio": "tmp.wav"}, {"text": prompt}])
113
+
114
+ response, history = qwen_model.chat(
115
+ qwen_tokenizer,
116
+ query=query,
117
+ system="You are a helpful assistant.",
118
+ history=None,
119
+ )
120
+ return response
121
+
122
+
123
+ @torch.no_grad
124
+ def via(audio_input, prompt, do_sample=False, temperature=0.001):
125
+ if audio_input == None:
126
+ return ""
127
+ sr, y = audio_input
128
+ y = y.astype(np.float32)
129
+ y /= np.max(np.abs(y))
130
+ a = resampler.decode_example(
131
+ resampler.encode_example({"array": y, "sampling_rate": sr})
132
+ )
133
+
134
+ audio = a["array"]
135
+
136
+ yield from diva.generate_stream(audio, prompt)
137
+
138
+
139
+ def transcribe(audio_input, text_prompt, state, model_order):
140
+ yield (
141
+ gr.Button(
142
+ value="Waiting in queue for GPU time...",
143
+ interactive=False,
144
+ variant="primary",
145
+ ),
146
+ "",
147
+ "",
148
+ "",
149
+ gr.Button(visible=False),
150
+ gr.Button(visible=False),
151
+ gr.Button(visible=False),
152
+ state,
153
+ )
154
+ if audio_input == None:
155
+ return (
156
+ "",
157
+ "",
158
+ "",
159
+ gr.Button(visible=False),
160
+ gr.Button(visible=False),
161
+ gr.Button(visible=False),
162
+ state,
163
+ )
164
+
165
+ def gen_from_via():
166
+ via_resp = via(audio_input, text_prompt)
167
+ for resp in via_resp:
168
+ v_resp = gr.Textbox(
169
+ value=resp,
170
+ visible=True,
171
+ label=model_names[0] if not anonymous else f"Model {order}",
172
+ )
173
+ yield (v_resp, s_resp, q_resp)
174
+
175
+ def gen_from_salmonn():
176
+ salmonn_resp = salmonn_fwd(audio_input, text_prompt)
177
+ for resp in salmonn_resp:
178
+ s_resp = gr.Textbox(
179
+ value=resp,
180
+ visible=True,
181
+ label=model_names[1] if not anonymous else f"Model {order}",
182
+ )
183
+ yield (v_resp, s_resp, q_resp)
184
+
185
+ def gen_from_qwen():
186
+ qwen_resp = qwen_audio(audio_input, text_prompt)
187
+ q_resp = gr.Textbox(
188
+ value=qwen_resp,
189
+ visible=True,
190
+ label=model_names[2] if not anonymous else f"Model {order}",
191
+ )
192
+ yield (v_resp, s_resp, q_resp)
193
+
194
+ spinner_id = 0
195
+ spinners = ["◐ ", "β—“ ", "β—‘", "β—’"]
196
+ initial_responses = [("", "", "")]
197
+ resp_generators = [
198
+ gen_from_via(),
199
+ gen_from_salmonn(),
200
+ gen_from_qwen(),
201
+ ]
202
+ order = -1
203
+ resp_generators = [
204
+ resp_generators[model_order[0]],
205
+ resp_generators[model_order[1]],
206
+ resp_generators[model_order[2]],
207
+ ]
208
+ for generator in [initial_responses, *resp_generators]:
209
+ order += 1
210
+ for resps in generator:
211
+ v_resp, s_resp, q_resp = resps
212
+ resp_1 = resps[model_order[0]]
213
+ resp_2 = resps[model_order[1]]
214
+ resp_3 = resps[model_order[2]]
215
+ spinner = spinners[spinner_id]
216
+ spinner_id = (spinner_id + 1) % 4
217
+ yield (
218
+ gr.Button(
219
+ value=spinner + " Generating Responses " + spinner,
220
+ interactive=False,
221
+ variant="primary",
222
+ ),
223
+ resp_1,
224
+ resp_2,
225
+ resp_3,
226
+ gr.Button(visible=False),
227
+ gr.Button(visible=False),
228
+ gr.Button(visible=False),
229
+ state,
230
+ )
231
+ yield (
232
+ gr.Button(
233
+ value="Click to compare models!", interactive=True, variant="primary"
234
+ ),
235
+ resp_1,
236
+ resp_2,
237
+ resp_3,
238
+ gr.Button(visible=True),
239
+ gr.Button(visible=True),
240
+ gr.Button(visible=True),
241
+ responses_complete(state),
242
+ )
243
+
244
+
245
+ def on_page_load(state, model_order):
246
+ if state == 0:
247
+ gr.Info(
248
+ "Record what you want to say to your AI Assistant! All Audio recordings are stored only temporarily and will be erased as soon as you exit this page."
249
+ )
250
+ state = 1
251
+ if anonymous:
252
+ random.shuffle(model_order)
253
+ return state, model_order
254
+
255
+
256
+ def recording_complete(state):
257
+ if state == 1:
258
+ gr.Info(
259
+ "Submit your recording to get responses from all three models! You can also influence the model responses with an optional prompt."
260
+ )
261
+ state = 2
262
+ return (
263
+ gr.Button(
264
+ value="Click to compare models!", interactive=True, variant="primary"
265
+ ),
266
+ state,
267
+ )
268
+
269
+
270
+ def responses_complete(state):
271
+ if state == 2:
272
+ gr.Info(
273
+ "Give us your feedback! Mark which model gave you the best response so we can understand the quality of these different voice assistant models."
274
+ )
275
+ state = 3
276
+ return state
277
+
278
+
279
+ def clear_factory(button_id):
280
+ def clear(audio_input, text_prompt, model_order):
281
+ if button_id != None:
282
+ sr, y = audio_input
283
+ db.insert(
284
+ {
285
+ "audio_hash": hash(str(y)),
286
+ "text_prompt": text_prompt,
287
+ "best": model_shorthand[model_order[button_id]],
288
+ }
289
+ )
290
+ if anonymous:
291
+ random.shuffle(model_order)
292
+ return (
293
+ model_order,
294
+ gr.Button(
295
+ value="Record Audio to Submit!",
296
+ interactive=False,
297
+ ),
298
+ gr.Button(visible=False),
299
+ gr.Button(visible=False),
300
+ gr.Button(visible=False),
301
+ None,
302
+ gr.Textbox(visible=False),
303
+ gr.Textbox(visible=False),
304
+ gr.Textbox(visible=False),
305
+ )
306
+
307
+ return clear
308
+
309
+
310
+ theme = gr.themes.Soft(
311
+ primary_hue=gr.themes.Color(
312
+ c100="#82000019",
313
+ c200="#82000033",
314
+ c300="#8200004c",
315
+ c400="#82000066",
316
+ c50="#8200007f",
317
+ c500="#8200007f",
318
+ c600="#82000099",
319
+ c700="#820000b2",
320
+ c800="#820000cc",
321
+ c900="#820000e5",
322
+ c950="#820000f2",
323
+ ),
324
+ secondary_hue="rose",
325
+ neutral_hue="stone",
326
+ )
327
+
328
+ db = TinyDB("user_study.json")
329
+
330
+ model_names = ["Llama 3 DiVA", "SALMONN", "Qwen Audio"]
331
+ model_shorthand = ["via", "salmonn", "qwen"]
332
+ with gr.Blocks(theme=theme) as demo:
333
+ state = gr.State(0)
334
+ model_order = gr.State([0, 1, 2])
335
+ with gr.Row():
336
+ audio_input = gr.Audio(
337
+ sources=["microphone"], streaming=False, label="Audio Input"
338
+ )
339
+ with gr.Row():
340
+ prompt = gr.Textbox(
341
+ value="",
342
+ label="Text Prompt",
343
+ placeholder="Optional: Additional text prompt to influence how the model responds to your speech. e.g. 'Respond in a Haiku style.'",
344
+ )
345
+
346
+ with gr.Row():
347
+ btn = gr.Button(value="Record Audio to Submit!", interactive=False)
348
+
349
+ with gr.Row():
350
+ with gr.Column(scale=1):
351
+ out1 = gr.Textbox(visible=False)
352
+ best1 = gr.Button(value="This response is best", visible=False)
353
+ with gr.Column(scale=1):
354
+ out2 = gr.Textbox(visible=False)
355
+ best2 = gr.Button(value="This response is best", visible=False)
356
+ with gr.Column(scale=1):
357
+ out3 = gr.Textbox(visible=False)
358
+ best3 = gr.Button(value="This response is best", visible=False)
359
+
360
+ audio_input.stop_recording(
361
+ recording_complete,
362
+ [state],
363
+ [btn, state],
364
+ )
365
+ audio_input.start_recording(
366
+ lambda: gr.Button(
367
+ value="Uploading Audio to Cloud", interactive=False, variant="primary"
368
+ ),
369
+ None,
370
+ btn,
371
+ )
372
+ btn.click(
373
+ fn=transcribe,
374
+ inputs=[audio_input, prompt, state, model_order],
375
+ outputs=[btn, out1, out2, out3, best1, best2, best3, state],
376
+ )
377
+ best1.click(
378
+ fn=clear_factory(0),
379
+ inputs=[audio_input, prompt, model_order],
380
+ outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
381
+ )
382
+ best2.click(
383
+ fn=clear_factory(1),
384
+ inputs=[audio_input, prompt, model_order],
385
+ outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
386
+ )
387
+ best3.click(
388
+ fn=clear_factory(2),
389
+ inputs=[audio_input, prompt, model_order],
390
+ outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
391
+ )
392
+ audio_input.clear(
393
+ clear_factory(None),
394
+ [audio_input, prompt, model_order],
395
+ [model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
396
+ )
397
+ demo.load(
398
+ fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order]
399
+ )
400
+
401
+ demo.launch(share=True)