|
import os |
|
import requests |
|
import json |
|
import time |
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
import psycopg2 |
|
|
|
|
|
import socket |
|
hostname=socket.gethostname() |
|
IPAddr=socket.gethostbyname(hostname) |
|
print("Your Computer Name is:" + hostname) |
|
print("Your Computer IP Address is:" + IPAddr) |
|
|
|
|
|
DESCRIPTION = """ |
|
# MediaTek Research Breexe-8x7B |
|
|
|
Breexe-8x7B is a language model family that builds on top of [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1), |
|
specifically intended for Traditional Chinese use. |
|
[Breexe-8x7B-Instruct-v0_1](https://huggingface.co/MediaTek-Research/Breexe-8x7B-Instruct-v0_1) demonstrates impressive performance in benchmarks for Traditional Chinese and English, on par with OpenAI's gpt-3.5-turbo-1106. |
|
|
|
|
|
*A project by the members (in alphabetical order): Chan-Jan Hsu 許湛然, Chang-Le Liu 劉昶樂, Feng-Ting Liao 廖峰挺, Po-Chun Hsu 許博竣, Yi-Chang Chen 陳宜昌, and the supervisor Da-Shan Shiu 許大山.* |
|
|
|
**免責聲明: Breexe-8x7B-Instruct 並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。** |
|
""" |
|
|
|
LICENSE = """ |
|
""" |
|
|
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan." |
|
|
|
API_URL = os.environ.get("API_URL") |
|
TOKEN = os.environ.get("TOKEN") |
|
TOKENIZER_REPO = "MediaTek-Research/Breeze-7B-Instruct-v1_0" |
|
API_MODEL_TYPE = "breexe-8x7b-instruct-v01" |
|
|
|
HEADERS = { |
|
"Authorization": f"Bearer {TOKEN}", |
|
"Content-Type": "application/json", |
|
"accept": "application/json" |
|
} |
|
|
|
MAX_SEC = 30 |
|
MAX_INPUT_LENGTH = 5000 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_auth_token=os.environ.get("HF_TOKEN")) |
|
|
|
|
|
def refusal_condition(query): |
|
|
|
|
|
query_remove_space = query.replace(' ', '').lower() |
|
is_including_tw = False |
|
for x in ['台灣', '台湾', 'taiwan', 'tw', '中華民國', '中华民国']: |
|
if x in query_remove_space: |
|
is_including_tw = True |
|
is_including_cn = False |
|
for x in ['中國', '中国', 'cn', 'china', '大陸', '內地', '大陆', '内地', '中華人民共和國', '中华人民共和国']: |
|
if x in query_remove_space: |
|
is_including_cn = True |
|
if is_including_tw and is_including_cn: |
|
return True |
|
|
|
for x in ['一個中國', '兩岸', '一中原則', '一中政策', '一个中国', '两岸', '一中原则']: |
|
if x in query_remove_space: |
|
return True |
|
|
|
return False |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
system_prompt = gr.Textbox(label='System prompt', |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
lines=1) |
|
|
|
with gr.Accordion(label='Advanced options', open=False): |
|
|
|
max_new_tokens = gr.Slider( |
|
label='Max new tokens', |
|
minimum=32, |
|
maximum=2048, |
|
step=1, |
|
value=1024, |
|
) |
|
temperature = gr.Slider( |
|
label='Temperature', |
|
minimum=0.01, |
|
maximum=0.5, |
|
step=0.01, |
|
value=0.01, |
|
) |
|
top_p = gr.Slider( |
|
label='Top-p (nucleus sampling)', |
|
minimum=0.01, |
|
maximum=0.99, |
|
step=0.01, |
|
value=0.01, |
|
) |
|
|
|
chatbot = gr.Chatbot(show_copy_button=True, show_share_button=True, ) |
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
container=False, |
|
show_label=False, |
|
placeholder='Type a message...', |
|
scale=10, |
|
lines=6 |
|
) |
|
submit_button = gr.Button('Submit', |
|
variant='primary', |
|
scale=1, |
|
min_width=0) |
|
|
|
with gr.Row(): |
|
retry_button = gr.Button('🔄 Retry', variant='secondary') |
|
undo_button = gr.Button('↩️ Undo', variant='secondary') |
|
clear = gr.Button('🗑️ Clear', variant='secondary') |
|
|
|
saved_input = gr.State() |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
|
|
def connect_server(data): |
|
for _ in range(3): |
|
s = requests.Session() |
|
r = s.post(API_URL, headers=HEADERS, json=data, stream=True, timeout=30) |
|
time.sleep(1) |
|
if r.status_code == 200: |
|
return r |
|
return None |
|
|
|
|
|
def stream_response_from_server(r): |
|
|
|
keep_streaming = True |
|
for line in r.iter_lines(): |
|
|
|
|
|
|
|
|
|
if line and keep_streaming: |
|
if r.status_code != 200: |
|
continue |
|
json_response = json.loads(line) |
|
|
|
if "fragment" not in json_response["result"]: |
|
keep_streaming = False |
|
break |
|
|
|
delta = json_response["result"]["fragment"]["data"]["text"] |
|
yield delta |
|
|
|
|
|
|
|
|
|
def bot(history, max_new_tokens, temperature, top_p, system_prompt): |
|
chat_data = [] |
|
system_prompt = system_prompt.strip() |
|
if system_prompt: |
|
chat_data.append({"role": "system", "content": system_prompt}) |
|
for user_msg, assistant_msg in history: |
|
chat_data.append({"role": "user", "content": user_msg if user_msg is not None else ''}) |
|
chat_data.append({"role": "assistant", "content": assistant_msg if assistant_msg is not None else ''}) |
|
|
|
message = tokenizer.apply_chat_template(chat_data, tokenize=False) |
|
message = message[3:] |
|
|
|
if len(message) > MAX_INPUT_LENGTH: |
|
raise Exception() |
|
|
|
response = '[ERROR]' |
|
if refusal_condition(history[-1][0]): |
|
history = [['[安全拒答啟動]', '[安全拒答啟動] 請清除再開啟對話']] |
|
response = '[REFUSAL]' |
|
yield history |
|
else: |
|
data = { |
|
"model_type": API_MODEL_TYPE, |
|
"prompt": str(message), |
|
"parameters": { |
|
"temperature": float(temperature), |
|
"top_p": float(top_p), |
|
"max_new_tokens": int(max_new_tokens), |
|
"repetition_penalty": 1.1 |
|
} |
|
} |
|
|
|
r = connect_server(data) |
|
if r is not None: |
|
for delta in stream_response_from_server(r): |
|
if history[-1][1] is None: |
|
history[-1][1] = '' |
|
history[-1][1] += delta |
|
yield history |
|
|
|
if history[-1][1].endswith('</s>'): |
|
history[-1][1] = history[-1][1][:-4] |
|
yield history |
|
|
|
response = history[-1][1] |
|
|
|
if refusal_condition(history[-1][1]): |
|
history[-1][1] = history[-1][1] + '\n\n**[免責聲明: 此模型並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。]**' |
|
yield history |
|
else: |
|
del history[-1] |
|
yield history |
|
|
|
print('== Record ==\nQuery: {query}\nResponse: {response}'.format(query=repr(message), response=repr(history[-1][1]))) |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot |
|
) |
|
submit_button.click( |
|
user, [msg, chatbot], [msg, chatbot], queue=False |
|
).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot |
|
) |
|
|
|
|
|
def delete_prev_fn( |
|
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
|
try: |
|
message, _ = history.pop() |
|
except IndexError: |
|
message = '' |
|
return history, message or '' |
|
|
|
|
|
def display_input(message: str, |
|
history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
|
history.append((message, '')) |
|
return history |
|
|
|
retry_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot, |
|
) |
|
|
|
undo_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=lambda x: x, |
|
inputs=[saved_input], |
|
outputs=msg, |
|
api_name=False, |
|
queue=False, |
|
) |
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
demo.queue(concurrency_count=4, max_size=128) |
|
demo.launch() |