Spaces:
Runtime error
Runtime error
import gradio as gr | |
from conversation import Conversation | |
def get_tab_playground(download_bot_config, get_bot_profile, model_mapping): | |
gr.Markdown(""" | |
# 🎢 Playground 🎢 | |
## Rules | |
* Chat with any model you would like with any bot from the Chai app. | |
* Click “Clear” to start a new conversation. | |
""") | |
default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca" | |
bot_config = download_bot_config(default_bot_id) | |
user_state = gr.State( | |
bot_config | |
) | |
with gr.Row(): | |
bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True) | |
reload_bot_button = gr.Button("Reload bot") | |
bot_profile = gr.HTML(get_bot_profile(bot_config)) | |
with gr.Accordion("Bot config:", open=False): | |
bot_config_text = gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}") | |
first_message = (None, bot_config["firstMessage"]) | |
chatbot = gr.Chatbot([first_message]) | |
msg = gr.Textbox(label="Message", value="Hi there!") | |
with gr.Row(): | |
send = gr.Button("Send") | |
regenerate = gr.Button("Regenerate") | |
clear = gr.Button("Clear") | |
values = list(model_mapping.keys()) | |
model_tag = gr.Dropdown(values, value=values[0], label="Model version") | |
model = model_mapping[model_tag.value] | |
with gr.Accordion("Generation parameters", open=False): | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"], | |
interactive=True, label="Temperature") | |
repetition_penalty = gr.Slider(minimum=0.0, maximum=2.0, | |
value=model.generation_params["repetition_penalty"], | |
interactive=True, label="Repetition penalty") | |
max_new_tokens = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"], | |
interactive=True, label="Max new tokens") | |
top_k = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"], | |
interactive=True, label="Top-K") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"], | |
interactive=True, label="Top-P") | |
def respond(message, chat_history, user_state, model_tag, | |
temperature, repetition_penalty, max_new_tokens, top_k, top_p): | |
custom_generation_params = { | |
'temperature': temperature, | |
'repetition_penalty': repetition_penalty, | |
'max_new_tokens': max_new_tokens, | |
'top_k': top_k, | |
'top_p': top_p, | |
} | |
conv = Conversation(user_state) | |
conv.set_chat_history(chat_history) | |
conv.add_user_message(message) | |
model = model_mapping[model_tag] | |
bot_message = model.generate_response(conv, custom_generation_params) | |
chat_history.append( | |
(message, bot_message) | |
) | |
return "", chat_history | |
def clear_chat(chat_history, user_state): | |
chat_history = [(None, user_state["firstMessage"])] | |
return chat_history | |
def regenerate_response(chat_history, user_state, model_tag, | |
temperature, repetition_penalty, max_new_tokens, top_k, top_p): | |
custom_generation_params = { | |
'temperature': temperature, | |
'repetition_penalty': repetition_penalty, | |
'max_new_tokens': max_new_tokens, | |
'top_k': top_k, | |
'top_p': top_p, | |
} | |
last_row = chat_history.pop(-1) | |
chat_history.append((last_row[0], None)) | |
model = model_mapping[model_tag] | |
conv = Conversation(user_state) | |
conv.set_chat_history(chat_history) | |
bot_message = model.generate_response(conv, custom_generation_params) | |
chat_history[-1] = (last_row[0], bot_message) | |
return chat_history | |
def reload_bot(bot_id, bot_profile, chat_history): | |
bot_config = download_bot_config(bot_id) | |
bot_profile = get_bot_profile(bot_config) | |
return bot_profile, [(None, bot_config[ | |
"firstMessage"])], bot_config, f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}" | |
def get_generation_args(model_tag): | |
model = model_mapping[model_tag] | |
return ( | |
model.generation_params["temperature"], | |
model.generation_params["repetition_penalty"], | |
model.generation_params["max_new_tokens"], | |
model.generation_params["top_k"], | |
model.generation_params["top_p"], | |
) | |
model_tag.change(get_generation_args, [model_tag], [temperature, repetition_penalty, max_new_tokens, top_k, | |
top_p], queue=False) | |
send.click(respond, | |
[msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k, | |
top_p], [msg, chatbot], | |
queue=False) | |
msg.submit(respond, | |
[msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k, | |
top_p], [msg, chatbot], | |
queue=False) | |
clear.click(clear_chat, [chatbot, user_state], [chatbot], queue=False) | |
regenerate.click(regenerate_response, | |
[chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k, | |
top_p], [chatbot], queue=False) | |
reload_bot_button.click(reload_bot, [bot_id, bot_profile, chatbot], | |
[bot_profile, chatbot, user_state, bot_config_text], | |
queue=False) | |