File size: 5,387 Bytes
a7d2870
 
 
 
 
ab13803
a7d2870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e333df
a7d2870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ebfb41
a7d2870
 
 
 
3ebfb41
a7d2870
 
 
 
 
 
 
 
 
 
 
ab13803
 
a7d2870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab13803
a7d2870
 
 
 
 
 
 
 
 
ab13803
a7d2870
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import asyncio
import random
import sqlite3

import os
import panel as pn
import pandas as pd
from litellm import acompletion

pn.extension("perspective")


MODELS = [
    "mistral/mistral-tiny",
    "mistral/mistral-small",
    "mistral/mistral-medium",
    "mistral/mistral-large-latest",
]

VOTING_LABELS = [
    "👈 A is better",
    "🤗 About the same",
    "😓 Both not good",
    "👉 B is better",
]


def set_api_key(api_key):
    os.environ["MISTRAL_API_KEY"] = api_key


async def respond(content, user, instance):
    """
    Respond to the user in the chat interface.
    """
    try:
        instance.disabled = True
        chat_label = instance.name
        if chat_model := chat_models.get(chat_label):
            model = chat_model
        else:
            # remove past history up to new message
            instance.objects = instance.objects[-1:]
            header_a.object = f"## Model: A"
            header_b.object = f"## Model: B"
            model = chat_models[chat_label] = random.choice(MODELS)

        messages = instance.serialize()
        messages.append({"role": "user", "content": content})

        response = await acompletion(
            model=model, messages=messages, stream=True, max_tokens=128
        )

        message = None
        async for chunk in response:
            if not chunk.choices[0].delta["content"]:
                continue
            message = instance.stream(
                chunk.choices[0].delta["content"], user="Assistant", message=message
            )
    finally:
        instance.disabled = False


async def forward_message(content, user, instance):
    """
    Send the message to the other chat interface and respond to the user in both.
    """
    if instance is chat_interface_a:
        other_instance = chat_interface_b
    else:
        other_instance = chat_interface_a
    other_instance.append(pn.chat.ChatMessage(content, user=user))

    coroutines = [
        respond(content, user, chat_interface)
        for chat_interface in (chat_interface_a, chat_interface_b)
    ]
    await asyncio.gather(*coroutines)


def click_vote(event):
    """
    Count the votes and update the voting results.
    """
    if len(chat_models) == 0:
        return

    voting_label = event.obj.name
    if voting_label == VOTING_LABELS[0]:
        chat_model = chat_models[chat_interface_a.name]
        voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
    elif voting_label == VOTING_LABELS[3]:
        chat_model = chat_models[chat_interface_b.name]
        voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1
    elif voting_label == VOTING_LABELS[1]:
        chat_model_a = chat_models[chat_interface_a.name]
        chat_model_b = chat_models[chat_interface_b.name]
        if chat_model_a == chat_model_b:
            voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
        else:
            voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1
            voting_counts[chat_model_b] = voting_counts.get(chat_model_b, 0) + 1

    header_a.object = f"## Model: {chat_models[chat_interface_a.name]}"
    header_b.object = f"## Model: {chat_models[chat_interface_b.name]}"
    for chat_label in set(chat_models.keys()):
        chat_models.pop(chat_label)

    perspective.object = (
        pd.DataFrame(voting_counts, index=["Votes"])
        .melt(var_name="Model", value_name="Votes")
        .set_index("Model")
    )
    with sqlite3.connect("voting_counts.db") as conn:
        pd.DataFrame(voting_counts.items(), columns=["Model", "Votes"]).to_sql(
            "voting_counts", conn, if_exists="replace", index=False
        )


# initialize
chat_models = {}
with sqlite3.connect("voting_counts.db") as conn:
    conn.execute(
        "CREATE TABLE IF NOT EXISTS voting_counts (Model TEXT PRIMARY KEY, Votes INTEGER)"
    )
    voting_counts = (
        pd.read_sql("SELECT * FROM voting_counts", conn)
        .set_index("Model")["Votes"]
        .to_dict()
    )

# header
api_key_input = pn.widgets.PasswordInput(placeholder="Mistral API Key")
pn.bind(set_api_key, api_key_input)

# main
tabs = pn.Tabs()

# tab 1
chat_interface_kwargs = dict(
    callback=forward_message,
    show_undo=False,
    show_rerun=False,
    show_clear=False,
    show_stop=False,
    show_button_name=False,
    callback_exception="verbose",
)
header_a = pn.pane.Markdown("## Model: A")
chat_interface_a = pn.chat.ChatInterface(
    name="A", header=header_a, **chat_interface_kwargs
)
header_b = pn.pane.Markdown("## Model: B")
chat_interface_b = pn.chat.ChatInterface(
    name="B", header=header_b, **chat_interface_kwargs
)

button_kwargs = dict(sizing_mode="stretch_width")
button_row = pn.Row()
for voting_label in VOTING_LABELS:
    button = pn.widgets.Button(name=voting_label, **button_kwargs)
    button.on_click(click_vote)
    button_row.append(button)

tabs.append(("Chat", pn.Column(pn.Row(chat_interface_a, chat_interface_b), button_row)))

# tab 2
perspective = pn.pane.Perspective(
    pd.DataFrame(voting_counts, index=["Votes"])
    .melt(var_name="Model", value_name="Votes")
    .set_index("Model"),
    sizing_mode="stretch_both",
    editable=False,
)
tabs.append(("Voting Results", perspective))

# layout
pn.template.FastListTemplate(
    title="Mistral Chat Arena", header=[api_key_input], main=[tabs]
).servable()