ahuang11's picture
Update app.py
a7d2870 verified
raw
history blame
No virus
5.39 kB
import asyncio
import random
import sqlite3
import os
import panel as pn
import pandas as pd
from litellm import acompletion
pn.extension("perspective")
MODELS = [
"mistral/mistral-tiny",
"mistral/mistral-small",
"mistral/mistral-medium",
"mistral/mistral-large-latest",
]
VOTING_LABELS = [
"πŸ‘ˆ A is better",
"πŸ€— About the same",
"πŸ˜“ Both not good",
"πŸ‘‰ B is better",
]
def set_api_key(api_key):
os.environ["MISTRAL_API_KEY"] = api_key
async def respond(content, user, instance):
"""
Respond to the user in the chat interface.
"""
try:
instance.disabled = True
chat_label = instance.name
if chat_model := chat_models.get(chat_label):
model = chat_model
else:
# remove past history up to new message
instance.objects = instance.objects[-1:]
header_a.object = f"## Model: A"
header_b.object = f"## Model: B"
model = chat_models[chat_label] = random.choice(MODELS)
messages = instance.serialize()
messages.append({"role": "user", "content": content})
response = await acompletion(
model=model, messages=messages, stream=True, max_tokens=128
)
message = None
async for chunk in response:
if not chunk.choices[0].delta["content"]:
continue
message = instance.stream(
chunk.choices[0].delta["content"], user="Assistant", message=message
)
finally:
instance.disabled = False
async def forward_message(content, user, instance):
"""
Send the message to the other chat interface and respond to the user in both.
"""
if instance is chat_interface_a:
other_instance = chat_interface_b
else:
other_instance = chat_interface_a
other_instance.append(pn.chat.ChatMessage(content, user=user))
coroutines = [
respond(content, user, chat_interface)
for chat_interface in (chat_interface_a, chat_interface_b)
]
await asyncio.gather(*coroutines)
def click_vote(event):
"""
Count the votes and update the voting results.
"""
if len(chat_models) == 0:
return
voting_label = event.obj.name
if voting_label == VOTING_LABELS[0]:
chat_model = chat_models[chat_interface_a.name]
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
elif voting_label == VOTING_LABELS[3]:
chat_model = chat_models[chat_interface_b.name]
voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
elif voting_label == VOTING_LABELS[1]:
chat_model_a = chat_models[chat_interface_a.name]
chat_model_b = chat_models[chat_interface_b.name]
if chat_model_a == chat_model_b:
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
else:
voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
voting_counts[chat_model_b] = voting_counts.get(chat_model_b, 0) + 1
header_a.object = f"## Model: {chat_models[chat_interface_a.name]}"
header_b.object = f"## Model: {chat_models[chat_interface_b.name]}"
for chat_label in set(chat_models.keys()):
chat_models.pop(chat_label)
perspective.object = (
pd.DataFrame(voting_counts, index=["Votes"])
.melt(var_name="Model", value_name="Votes")
.set_index("Model")
)
with sqlite3.connect("voting_counts.db") as conn:
pd.DataFrame(voting_counts.items(), columns=["Model", "Votes"]).to_sql(
"voting_counts", conn, if_exists="replace", index=False
)
# initialize
chat_models = {}
with sqlite3.connect("voting_counts.db") as conn:
conn.execute(
"CREATE TABLE IF NOT EXISTS voting_counts (Model TEXT PRIMARY KEY, Votes INTEGER)"
)
voting_counts = (
pd.read_sql("SELECT * FROM voting_counts", conn)
.set_index("Model")["Votes"]
.to_dict()
)
# header
api_key_input = pn.widgets.PasswordInput(placeholder="Mistral API Key")
pn.bind(set_api_key, api_key_input)
# main
tabs = pn.Tabs()
# tab 1
chat_interface_kwargs = dict(
callback=forward_message,
show_undo=False,
show_rerun=False,
show_clear=False,
show_stop=False,
show_button_name=False,
callback_exception="verbose",
)
header_a = pn.pane.Markdown("## Model: A")
chat_interface_a = pn.chat.ChatInterface(
name="A", header=header_a, **chat_interface_kwargs
)
header_b = pn.pane.Markdown("## Model: B")
chat_interface_b = pn.chat.ChatInterface(
name="B", header=header_b, **chat_interface_kwargs
)
button_kwargs = dict(sizing_mode="stretch_width")
button_row = pn.Row()
for voting_label in VOTING_LABELS:
button = pn.widgets.Button(name=voting_label, **button_kwargs)
button.on_click(click_vote)
button_row.append(button)
tabs.append(("Chat", pn.Column(pn.Row(chat_interface_a, chat_interface_b), button_row)))
# tab 2
perspective = pn.pane.Perspective(
pd.DataFrame(voting_counts, index=["Votes"])
.melt(var_name="Model", value_name="Votes")
.set_index("Model"),
sizing_mode="stretch_both",
editable=False,
)
tabs.append(("Voting Results", perspective))
# layout
pn.template.FastListTemplate(
title="Mistral Chat Arena", header=[api_key_input], main=[tabs]
).servable()