Spaces:
Runtime error
Runtime error
File size: 3,230 Bytes
f745223 504b6c8 e8fb838 1ffd977 953413f e8fb838 f745223 cbcb343 f745223 d8a82cd 52c453e f745223 13a089e 269b919 19cbba1 fdc528c 19cbba1 5a9e82b fdc528c 269b919 349875c 19cbba1 349875c 13a089e a9db698 afed27d 13a089e 1ffd977 a9db698 f57923a a9db698 afed27d 1ffd977 19cbba1 cf7aa4d 9b0bdb7 afed27d 9b0bdb7 a9db698 953413f 9b0bdb7 19cbba1 269b919 19cbba1 269b919 19cbba1 1ffd977 2334dc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import time
import torch
import gradio as gr
from strings import TITLE, ABSTRACT, EXAMPLES
from gen import get_pretrained_models, get_output, setup_model_parallel
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "50505"
local_rank, world_size = setup_model_parallel()
generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)
history = []
def chat(
user_input,
include_input,
truncate,
top_p,
temperature,
max_gen_len,
state_chatbot
):
bot_response = get_output(
generator=generator,
prompt=user_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p)[0]
# remove the first phrase identical to user prompt
if not include_input:
bot_response = bot_response[len(user_input):]
bot_response = bot_response.replace("\n", "<br>")
# trip the last phrase
if truncate:
try:
bot_response = bot_response[:bot_response.rfind(".")+1]
except:
pass
history.append({
"role": "user",
"content": user_input
})
history.append({
"role": "system",
"content": bot_response
})
state_chatbot = state_chatbot + [(user_input, None)]
response = ""
for word in bot_response.split(" "):
time.sleep(0.1)
response += word + " "
current_pair = (user_input, response)
state_chatbot[-1] = current_pair
yield state_chatbot, state_chatbot
def reset_textbox():
return gr.update(value='')
with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;}
#chatbot {height: 400px; overflow: auto;}""") as demo:
state_chatbot = gr.State([])
with gr.Column(elem_id='col_container'):
gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
with gr.Accordion("Example prompts", open=False):
example_str = "\n"
for example in EXAMPLES:
example_str += f"- {example}\n"
gr.Markdown(example_str)
chatbot = gr.Chatbot(elem_id='chatbot')
textbox = gr.Textbox(placeholder="Enter a prompt")
with gr.Accordion("Parameters", open=False):
include_input = gr.Checkbox(value=True, label="Do you want to include the input in the generated text?")
truncate = gr.Checkbox(value=True, label="Truncate the unfinished last words?")
max_gen_len = gr.Slider(minimum=20, maximum=512, value=256, step=1, interactive=True, label="Max Genenration Length",)
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
textbox.submit(
chat,
[textbox, include_input, truncate, top_p, temperature, max_gen_len, state_chatbot],
[state_chatbot, chatbot]
)
textbox.submit(reset_textbox, [], [textbox])
demo.queue(api_open=False).launch() |