model-evaluation / tabs /playground.py
AlekseyKorshuk's picture
updates
d7f914d
raw
history blame
5.69 kB
import gradio as gr
from conversation import Conversation
def get_tab_playground(download_bot_config, get_bot_profile, model_mapping):
gr.Markdown("""
# 🎢.️ Playground 🎢.️
## Rules
* Chat with any model you would like with any bot from the Chai app.
* Click “Clear” to start a new conversation.
""")
default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca"
bot_config = download_bot_config(default_bot_id)
user_state = gr.State(
bot_config
)
with gr.Row():
bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
reload_bot_button = gr.Button("Reload bot")
bot_profile = gr.HTML(get_bot_profile(bot_config))
with gr.Accordion("Bot config:", open=False):
bot_config_text = gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}")
first_message = (None, bot_config["firstMessage"])
chatbot = gr.Chatbot([first_message])
msg = gr.Textbox(label="Message", value="Hi there!")
with gr.Row():
send = gr.Button("Send")
regenerate = gr.Button("Regenerate")
clear = gr.Button("Clear")
values = list(model_mapping.keys())
model_tag = gr.Dropdown(values, value=values[0], label="Model version")
model = model_mapping[model_tag.value]
with gr.Accordion("Generation parameters", open=False):
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
interactive=True, label="Temperature")
repetition_penalty = gr.Slider(minimum=0.0, maximum=2.0,
value=model.generation_params["repetition_penalty"],
interactive=True, label="Repetition penalty")
max_new_tokens = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
interactive=True, label="Max new tokens")
top_k = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
interactive=True, label="Top-K")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
interactive=True, label="Top-P")
def respond(message, chat_history, user_state, model_tag,
temperature, repetition_penalty, max_new_tokens, top_k, top_p):
custom_generation_params = {
'temperature': temperature,
'repetition_penalty': repetition_penalty,
'max_new_tokens': max_new_tokens,
'top_k': top_k,
'top_p': top_p,
}
conv = Conversation(user_state)
conv.set_chat_history(chat_history)
conv.add_user_message(message)
model = model_mapping[model_tag]
bot_message = model.generate_response(conv, custom_generation_params)
chat_history.append(
(message, bot_message)
)
return "", chat_history
def clear_chat(chat_history, user_state):
chat_history = [(None, user_state["firstMessage"])]
return chat_history
def regenerate_response(chat_history, user_state, model_tag,
temperature, repetition_penalty, max_new_tokens, top_k, top_p):
custom_generation_params = {
'temperature': temperature,
'repetition_penalty': repetition_penalty,
'max_new_tokens': max_new_tokens,
'top_k': top_k,
'top_p': top_p,
}
last_row = chat_history.pop(-1)
chat_history.append((last_row[0], None))
model = model_mapping[model_tag]
conv = Conversation(user_state)
conv.set_chat_history(chat_history)
bot_message = model.generate_response(conv, custom_generation_params)
chat_history[-1] = (last_row[0], bot_message)
return chat_history
def reload_bot(bot_id, bot_profile, chat_history):
bot_config = download_bot_config(bot_id)
bot_profile = get_bot_profile(bot_config)
return bot_profile, [(None, bot_config[
"firstMessage"])], bot_config, f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}"
def get_generation_args(model_tag):
model = model_mapping[model_tag]
return (
model.generation_params["temperature"],
model.generation_params["repetition_penalty"],
model.generation_params["max_new_tokens"],
model.generation_params["top_k"],
model.generation_params["top_p"],
)
model_tag.change(get_generation_args, [model_tag], [temperature, repetition_penalty, max_new_tokens, top_k,
top_p], queue=False)
send.click(respond,
[msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
top_p], [msg, chatbot],
queue=False)
msg.submit(respond,
[msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
top_p], [msg, chatbot],
queue=False)
clear.click(clear_chat, [chatbot, user_state], [chatbot], queue=False)
regenerate.click(regenerate_response,
[chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
top_p], [chatbot], queue=False)
reload_bot_button.click(reload_bot, [bot_id, bot_profile, chatbot],
[bot_profile, chatbot, user_state, bot_config_text],
queue=False)