File size: 5,681 Bytes
cb80e0b
 
 
 
 
d7f914d
7f27042
d7f914d
 
 
 
cb80e0b
 
 
 
 
 
 
 
 
 
 
d7f914d
cb80e0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7f914d
 
cb80e0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7f914d
 
cb80e0b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from conversation import Conversation


def get_tab_playground(download_bot_config, get_bot_profile, model_mapping):
    gr.Markdown("""
    # 🎢 Playground 🎢
    ## Rules
    * Chat with any model you would like with any bot from the Chai app.
    * Click “Clear” to start a new conversation.
    """)
    default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca"
    bot_config = download_bot_config(default_bot_id)
    user_state = gr.State(
        bot_config
    )
    with gr.Row():
        bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
        reload_bot_button = gr.Button("Reload bot")

    bot_profile = gr.HTML(get_bot_profile(bot_config))
    with gr.Accordion("Bot config:", open=False):
        bot_config_text = gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}")

    first_message = (None, bot_config["firstMessage"])
    chatbot = gr.Chatbot([first_message])

    msg = gr.Textbox(label="Message", value="Hi there!")
    with gr.Row():
        send = gr.Button("Send")
        regenerate = gr.Button("Regenerate")
        clear = gr.Button("Clear")
    values = list(model_mapping.keys())
    model_tag = gr.Dropdown(values, value=values[0], label="Model version")
    model = model_mapping[model_tag.value]

    with gr.Accordion("Generation parameters", open=False):
        temperature = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
                                interactive=True, label="Temperature")
        repetition_penalty = gr.Slider(minimum=0.0, maximum=2.0,
                                       value=model.generation_params["repetition_penalty"],
                                       interactive=True, label="Repetition penalty")
        max_new_tokens = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
                                   interactive=True, label="Max new tokens")
        top_k = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
                          interactive=True, label="Top-K")
        top_p = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
                          interactive=True, label="Top-P")

    def respond(message, chat_history, user_state, model_tag,
                temperature, repetition_penalty, max_new_tokens, top_k, top_p):
        custom_generation_params = {
            'temperature': temperature,
            'repetition_penalty': repetition_penalty,
            'max_new_tokens': max_new_tokens,
            'top_k': top_k,
            'top_p': top_p,
        }
        conv = Conversation(user_state)
        conv.set_chat_history(chat_history)
        conv.add_user_message(message)
        model = model_mapping[model_tag]
        bot_message = model.generate_response(conv, custom_generation_params)
        chat_history.append(
            (message, bot_message)
        )
        return "", chat_history

    def clear_chat(chat_history, user_state):
        chat_history = [(None, user_state["firstMessage"])]
        return chat_history

    def regenerate_response(chat_history, user_state, model_tag,
                            temperature, repetition_penalty, max_new_tokens, top_k, top_p):
        custom_generation_params = {
            'temperature': temperature,
            'repetition_penalty': repetition_penalty,
            'max_new_tokens': max_new_tokens,
            'top_k': top_k,
            'top_p': top_p,
        }
        last_row = chat_history.pop(-1)
        chat_history.append((last_row[0], None))
        model = model_mapping[model_tag]
        conv = Conversation(user_state)
        conv.set_chat_history(chat_history)
        bot_message = model.generate_response(conv, custom_generation_params)
        chat_history[-1] = (last_row[0], bot_message)
        return chat_history

    def reload_bot(bot_id, bot_profile, chat_history):
        bot_config = download_bot_config(bot_id)
        bot_profile = get_bot_profile(bot_config)
        return bot_profile, [(None, bot_config[
            "firstMessage"])], bot_config, f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}"

    def get_generation_args(model_tag):
        model = model_mapping[model_tag]
        return (
            model.generation_params["temperature"],
            model.generation_params["repetition_penalty"],
            model.generation_params["max_new_tokens"],
            model.generation_params["top_k"],
            model.generation_params["top_p"],
        )

    model_tag.change(get_generation_args, [model_tag], [temperature, repetition_penalty, max_new_tokens, top_k,
                                                        top_p], queue=False)
    send.click(respond,
               [msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
                top_p], [msg, chatbot],
               queue=False)
    msg.submit(respond,
               [msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
                top_p], [msg, chatbot],
               queue=False)
    clear.click(clear_chat, [chatbot, user_state], [chatbot], queue=False)
    regenerate.click(regenerate_response,
                     [chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
                      top_p], [chatbot], queue=False)
    reload_bot_button.click(reload_bot, [bot_id, bot_profile, chatbot],
                            [bot_profile, chatbot, user_state, bot_config_text],
                            queue=False)