Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import os | |
import random | |
import sys | |
import xxhash | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import torch | |
import torch.nn.functional as F | |
from accelerate import infer_auto_device_map | |
from datasets import Audio | |
from models.salmonn import SALMONN | |
from safetensors.torch import load, load_model | |
import spaces | |
from torch import nn | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoProcessor, | |
AutoTokenizer, | |
LlamaForCausalLM, | |
TextIteratorStreamer, | |
WhisperForConditionalGeneration, | |
AutoProcessor, | |
AutoModel, | |
) | |
from transformers.generation import GenerationConfig | |
anonymous = False | |
diva_model = AutoModel.from_pretrained( | |
"WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True | |
) | |
resampler = Audio(sampling_rate=16_000) | |
def diva_audio(audio_input, do_sample=False, temperature=0.001): | |
sr, y = audio_input | |
x = xxhash.xxh32(bytes(y)).hexdigest() | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
a = resampler.decode_example( | |
resampler.encode_example({"array": y, "sampling_rate": sr}) | |
) | |
yield from diva_model.generate_stream( | |
a["array"], None, do_sample=do_sample, max_new_tokens=256 | |
) | |
def transcribe(audio_input, state, model_order): | |
if audio_input == None: | |
return ( | |
"Click to run inference!", | |
"", | |
state, | |
) | |
def gen_from_diva(): | |
diva_resp = diva_audio(audio_input) | |
for resp in diva_resp: | |
d_resp = gr.Textbox( | |
value=resp, | |
visible=True, | |
label=model_names[0] if not anonymous else f"Model {order}", | |
) | |
yield d_resp | |
spinner_id = 0 | |
spinners = ["β ", "β ", "β", "β"] | |
for response in gen_from_diva(): | |
spinner = spinners[spinner_id] | |
spinner_id = (spinner_id + 1) % 4 | |
yield ( | |
gr.Button( | |
value=spinner + " Generating Responses " + spinner, | |
interactive=False, | |
variant="primary", | |
), | |
response, | |
state, | |
) | |
yield ( | |
gr.Button(value="Click to run inference!", interactive=True, variant="primary"), | |
response, | |
state, | |
) | |
def on_page_load(state, model_order): | |
if state == 0: | |
gr.Info( | |
"Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant, or ChatGPT for." | |
) | |
state = 1 | |
if anonymous: | |
random.shuffle(model_order) | |
return state, model_order | |
def recording_complete(state): | |
if state == 1: | |
gr.Info( | |
"Once you submit your recording, DiVA will stream back a response! This might take a second." | |
) | |
state = 2 | |
return ( | |
gr.Button(value="Click to run inference!", interactive=True, variant="primary"), | |
state, | |
) | |
def clear_factory(button_id): | |
def clear(audio_input, model_order): | |
return ( | |
model_order, | |
gr.Button( | |
value="Record Audio to Submit!", | |
interactive=False, | |
), | |
None, | |
None, | |
) | |
return clear | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.Color( | |
c100="#82000019", | |
c200="#82000033", | |
c300="#8200004c", | |
c400="#82000066", | |
c50="#8200007f", | |
c500="#8200007f", | |
c600="#82000099", | |
c700="#820000b2", | |
c800="#820000cc", | |
c900="#820000e5", | |
c950="#820000f2", | |
), | |
secondary_hue="rose", | |
neutral_hue="stone", | |
) | |
model_names = ["DiVA Llama 3 8B"] | |
model_shorthand = ["diva"] | |
with gr.Blocks(theme=theme) as demo: | |
state = gr.State(0) | |
model_order = gr.State([0, 1]) | |
with gr.Row(): | |
audio_input = gr.Audio( | |
sources=["microphone"], streaming=False, label="Audio Input" | |
) | |
with gr.Row(): | |
btn = gr.Button(value="Record Audio to Submit!", interactive=False) | |
with gr.Row(): | |
out1 = gr.Textbox(visible=False) | |
audio_input.stop_recording( | |
recording_complete, | |
[state], | |
[btn, state], | |
) | |
audio_input.start_recording( | |
lambda: gr.Button( | |
value="Uploading Audio to Cloud", interactive=False, variant="primary" | |
), | |
None, | |
btn, | |
) | |
btn.click( | |
fn=transcribe, | |
inputs=[audio_input, state, model_order], | |
outputs=[btn, out1, state], | |
) | |
audio_input.clear( | |
clear_factory(None), | |
[audio_input, model_order], | |
[model_order, btn, audio_input, out1], | |
) | |
demo.load( | |
fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order] | |
) | |
demo.launch(share=True) | |