File size: 2,996 Bytes
9c9ed59
 
 
 
7f7d37c
9c9ed59
 
 
 
 
 
 
 
 
 
 
 
ca677a9
9c9ed59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca677a9
9c9ed59
 
 
 
 
 
 
 
 
 
ca677a9
 
 
 
 
9c9ed59
 
d2304af
9c9ed59
 
 
 
c8c7772
9c9ed59
 
 
5d623bb
9c9ed59
28d0e79
9c9ed59
 
c8c7772
9c9ed59
 
 
d2304af
9c9ed59
 
 
 
c8c7772
9c9ed59
 
 
 
 
 
 
 
c8c7772
9c9ed59
 
 
619b3ea
 
11332ca
619b3ea
cac98cc
9c9ed59
4579d7a
 
cac98cc
4579d7a
 
eda969f
619b3ea
eda969f
e95e8e1
 
 
2891dae
59c032c
1afe06d
4579d7a
1afe06d
e95e8e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)


def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate(
    prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Textbox(
        label="System Prompt",
        max_lines=1,
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.5,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Более высокое значение, даёт более разнообразные результаты.",
    ),
    gr.Slider(
        label="Max new tokens",
        value=20480,
        minimum=0,
        maximum=32768,
        step=64,
        interactive=True,
        info="Максимальное количество токенов",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.75,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Более высокое значение, даёт большее разнообразие ",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Степень наказания за повторение токенов",
    )
]

examples=[
         ["", "Always answer completely in English", 0.5, 20480, 0.75, 1.2],
         ["", "Répondez toujours complètement en Français", 0.5, 20480, 0.75, 1.2],
         ["", "Отвечай всегда полностью на русском языке", 0.5, 20480, 0.75, 1.2],
]


description = r"""

"""




gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mixtral-8x7B-Chat",
    examples=examples,
    description=description,
    concurrency_limit=20,
).launch(show_api=False)