File size: 2,996 Bytes
9c9ed59 7f7d37c 9c9ed59 ca677a9 9c9ed59 ca677a9 9c9ed59 ca677a9 9c9ed59 d2304af 9c9ed59 c8c7772 9c9ed59 5d623bb 9c9ed59 28d0e79 9c9ed59 c8c7772 9c9ed59 d2304af 9c9ed59 c8c7772 9c9ed59 c8c7772 9c9ed59 619b3ea 11332ca 619b3ea cac98cc 9c9ed59 4579d7a cac98cc 4579d7a eda969f 619b3ea eda969f e95e8e1 2891dae d95e984 1afe06d 4579d7a 1afe06d e95e8e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from huggingface_hub import InferenceClient
import gradio as gr
client = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(
prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
additional_inputs=[
gr.Textbox(
label="System Prompt",
max_lines=1,
interactive=True,
),
gr.Slider(
label="Temperature",
value=0.5,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Более высокое значение, даёт более разнообразные результаты.",
),
gr.Slider(
label="Max new tokens",
value=20480,
minimum=0,
maximum=32768,
step=64,
interactive=True,
info="Максимальное количество токенов",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.75,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Более высокое значение, даёт большее разнообразие ",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Степень наказания за повторение токенов",
)
]
examples=[
["", "Always answer completely in English", 0.5, 20480, 0.75, 1.2],
["", "Répondez toujours complètement en Français", 0.5, 20480, 0.75, 1.2],
["", "Отвечай всегда полностью на русском языке", 0.5, 20480, 0.75, 1.2],
]
description = r"""
"""
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
additional_inputs=additional_inputs,
title="Mixtral-8x7B-Chat",
examples=examples,
description=description,
concurrency_limit=20,
).launch(show_api=False) |