Spaces:
Paused
Paused
import gradio as gr | |
from vid2persona import init | |
from vid2persona.pipeline import vlm | |
from vid2persona.pipeline import llm | |
init.auth_gcp() | |
init.get_env_vars() | |
prompt_tpl_path = "vid2persona/prompts" | |
async def extract_traits(video_path): | |
traits = await vlm.get_traits( | |
init.gcp_project_id, | |
init.gcp_project_location, | |
video_path, | |
prompt_tpl_path | |
) | |
if 'characters' in traits: | |
traits = traits['characters'][0] | |
return [ | |
traits, [], | |
gr.Textbox("", interactive=True), | |
gr.Button(interactive=True), | |
gr.Button(interactive=True), | |
gr.Button(interactive=True) | |
] | |
async def conversation( | |
message: str, messages: list, traits: dict, | |
model_id: str, max_input_token_length: int, | |
max_new_tokens: int, temperature: float, | |
top_p: float, top_k: float, repetition_penalty: float, | |
): | |
messages = messages + [[message, ""]] | |
yield [messages, message, gr.Button(interactive=False), gr.Button(interactive=False)] | |
async for partial_response in llm.chat( | |
message, messages, traits, | |
prompt_tpl_path, model_id, | |
max_input_token_length, max_new_tokens, | |
temperature, top_p, top_k, | |
repetition_penalty, hf_token=init.hf_access_token | |
): | |
last_message = messages[-1] | |
last_message[1] = last_message[1] + partial_response | |
messages[-1] = last_message | |
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] | |
yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)] | |
async def regen_conversation( | |
messages: list, traits: dict, | |
model_id: str, max_input_token_length: int, | |
max_new_tokens: int, temperature: float, | |
top_p: float, top_k: float, repetition_penalty: float, | |
): | |
if len(messages) > 0: | |
message = messages[-1][0] | |
messages = messages[:-1] | |
messages = messages + [[message, ""]] | |
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] | |
async for partial_response in llm.chat( | |
message, messages, traits, | |
prompt_tpl_path, model_id, | |
max_input_token_length, max_new_tokens, | |
temperature, top_p, top_k, | |
repetition_penalty, hf_token=init.hf_access_token | |
): | |
last_message = messages[-1] | |
last_message[1] = last_message[1] + partial_response | |
messages[-1] = last_message | |
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)] | |
yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)] | |
with gr.Blocks(css="styles.css", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("Vid2Persona", elem_classes=["md-center", "h1-font"]) | |
gr.Markdown("This project breathes life into video characters by using AI to describe their personality and then chat with you as them.") | |
with gr.Column(elem_classes=["group"]): | |
with gr.Row(): | |
video = gr.Video(label="upload short video clip") | |
traits = gr.Json(label="extracted traits") | |
with gr.Row(): | |
trait_gen = gr.Button("generate traits") | |
with gr.Column(elem_classes=["group"]): | |
chatbot = gr.Chatbot([], label="chatbot", elem_id="chatbot", elem_classes=["chatbot-no-label"]) | |
with gr.Row(): | |
clear = gr.Button("clear conversation", interactive=False) | |
regen = gr.Button("regenerate the last", interactive=False) | |
stop = gr.Button("stop", interactive=False) | |
user_input = gr.Textbox(placeholder="ask anything", interactive=False, elem_classes=["textbox-no-label", "textbox-no-top-bottom-borders"]) | |
with gr.Accordion("parameters' control pane", open=False): | |
model_id = gr.Dropdown(choices=init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS, value="HuggingFaceH4/zephyr-7b-beta", label="Model ID") | |
with gr.Row(): | |
max_input_token_length = gr.Slider(minimum=1024, maximum=4096, value=4096, label="max-input-tokens") | |
max_new_tokens = gr.Slider(minimum=128, maximum=2048, value=256, label="max-new-tokens") | |
with gr.Row(): | |
temperature = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="temperature") | |
top_p = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.9, label="top-p") | |
top_k = gr.Slider(minimum=0, maximum=2, step=0.1, value=50, label="top-k") | |
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, value=1.2, label="repetition-penalty") | |
with gr.Row(): | |
gr.Markdown( | |
"[![GitHub Repo](https://img.shields.io/badge/GitHub%20Repo-gray?style=for-the-badge&logo=github&link=https://github.com/deep-diver/Vid2Persona)](https://github.com/deep-diver/Vid2Persona) " | |
"[![Chansung](https://img.shields.io/badge/Chansung-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/algo_diver)](https://twitter.com/algo_diver) " | |
"[![Sayak](https://img.shields.io/badge/Sayak-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/RisingSayak)](https://twitter.com/RisingSayak )", | |
elem_id="bottom-md" | |
) | |
trait_gen.click( | |
extract_traits, | |
[video], | |
[traits, chatbot, user_input, clear, regen, stop] | |
) | |
conv = user_input.submit( | |
conversation, | |
[ | |
user_input, chatbot, traits, | |
model_id, max_input_token_length, | |
max_new_tokens, temperature, | |
top_p, top_k, repetition_penalty, | |
], | |
[chatbot, user_input, clear, regen] | |
) | |
clear.click( | |
lambda: [ | |
gr.Chatbot([]), | |
gr.Button(interactive=False), | |
gr.Button(interactive=False), | |
], | |
None, [chatbot, clear, regen] | |
) | |
conv_regen = regen.click( | |
regen_conversation, | |
[ | |
chatbot, traits, | |
model_id, max_input_token_length, | |
max_new_tokens, temperature, | |
top_p, top_k, repetition_penalty, | |
], | |
[chatbot, user_input, clear, regen] | |
) | |
stop.click( | |
None, None, None, | |
cancels=[conv, conv_regen] | |
) | |
demo.launch() |