Spaces:
Running
Running
import os | |
import gradio as gr | |
from text_generation import Client, InferenceAPIClient | |
import retrieval | |
NUM_ANSWERS_GENERATED = 3 | |
openchat_preprompt = ( | |
"\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for " | |
"fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source " | |
"community. I am not human, not evil and not alive, and thus have no thoughts and feelings, " | |
"but I am programmed to be helpful, polite, honest, and friendly.\n" | |
) | |
def get_client(model: str): | |
if model == "Rallio67/joi2_20Be_instruct_alpha": | |
return Client(os.getenv("JOI_API_URL")) | |
if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B": | |
return Client(os.getenv("OPENCHAT_API_URL")) | |
return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None)) | |
def get_usernames(model: str): | |
""" | |
Returns: | |
(str, str, str, str): pre-prompt, username, bot name, separator | |
""" | |
if model == "OpenAssistant/oasst-sft-1-pythia-12b": | |
return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>" | |
if model == "Rallio67/joi2_20Be_instruct_alpha": | |
return "", "User: ", "Joi: ", "\n\n" | |
if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B": | |
return openchat_preprompt, "<human>: ", "<bot>: ", "\n" | |
return "", "User: ", "Assistant: ", "\n" | |
def predict( | |
model: str, | |
inputs: str, | |
typical_p: float, | |
top_p: float, | |
temperature: float, | |
top_k: int, | |
repetition_penalty: float, | |
watermark: bool, | |
chatbot, | |
history, | |
): | |
client = get_client(model) | |
preprompt, user_name, assistant_name, sep = get_usernames(model) | |
history.append(inputs) | |
past = [] | |
for data in chatbot: | |
user_data, model_data = data | |
if not user_data.startswith(user_name): | |
user_data = user_name + user_data | |
if not model_data.startswith(sep + assistant_name): | |
model_data = sep + assistant_name + model_data | |
past.append(user_data + model_data.rstrip() + sep) | |
if not inputs.startswith(user_name): | |
inputs = user_name + inputs | |
total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
partial_words = "" | |
if model == "OpenAssistant/oasst-sft-1-pythia-12b": | |
iterator = client.generate_stream( | |
total_inputs, | |
typical_p=typical_p, | |
truncate=1000, | |
watermark=watermark, | |
max_new_tokens=500, | |
) | |
else: | |
iterator = client.generate_stream( | |
total_inputs, | |
top_p=top_p if top_p < 1.0 else None, | |
top_k=top_k, | |
truncate=1000, | |
repetition_penalty=repetition_penalty, | |
watermark=watermark, | |
temperature=temperature, | |
max_new_tokens=500, | |
stop_sequences=[user_name.rstrip(), assistant_name.rstrip()], | |
) | |
for i, response in enumerate(iterator): | |
if response.token.special: | |
continue | |
partial_words = partial_words + response.token.text | |
if partial_words.endswith(user_name.rstrip()): | |
partial_words = partial_words.rstrip(user_name.rstrip()) | |
if partial_words.endswith(assistant_name.rstrip()): | |
partial_words = partial_words.rstrip(assistant_name.rstrip()) | |
if i == 0: | |
history.append(" " + partial_words) | |
elif response.token.text not in user_name: | |
history[-1] = partial_words | |
chat = [ | |
(history[i].strip(), history[i + 1].strip()) | |
for i in range(0, len(history) - 1, 2) | |
] | |
yield chat, history | |
# add context retrieval part here | |
ta = retrieval.Retrieval() | |
ta._load_pinecone_vectorstore() | |
question = inputs | |
top_context_list = ta.retrieve_contexts_from_pinecone(user_question=question, topk=NUM_ANSWERS_GENERATED) | |
print(top_context_list) | |
def reset_textbox(): | |
return gr.update(value="") | |
def radio_on_change( | |
value: str, | |
disclaimer, | |
typical_p, | |
top_p, | |
top_k, | |
temperature, | |
repetition_penalty, | |
watermark, | |
): | |
if value == "OpenAssistant/oasst-sft-1-pythia-12b": | |
typical_p = typical_p.update(value=0.2, visible=True) | |
top_p = top_p.update(visible=False) | |
top_k = top_k.update(visible=False) | |
temperature = temperature.update(visible=False) | |
disclaimer = disclaimer.update(visible=False) | |
repetition_penalty = repetition_penalty.update(visible=False) | |
watermark = watermark.update(False) | |
elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B": | |
typical_p = typical_p.update(visible=False) | |
top_p = top_p.update(value=0.25, visible=True) | |
top_k = top_k.update(value=50, visible=True) | |
temperature = temperature.update(value=0.6, visible=True) | |
repetition_penalty = repetition_penalty.update(value=1.01, visible=True) | |
watermark = watermark.update(False) | |
disclaimer = disclaimer.update(visible=True) | |
else: | |
typical_p = typical_p.update(visible=False) | |
top_p = top_p.update(value=0.95, visible=True) | |
top_k = top_k.update(value=4, visible=True) | |
temperature = temperature.update(value=0.5, visible=True) | |
repetition_penalty = repetition_penalty.update(value=1.03, visible=True) | |
watermark = watermark.update(True) | |
disclaimer = disclaimer.update(visible=False) | |
return ( | |
disclaimer, | |
typical_p, | |
top_p, | |
top_k, | |
temperature, | |
repetition_penalty, | |
watermark, | |
) | |
title = """<h1 align="center">🔥Teaching Assistant Chatbot 🚀Streaming🚀</h1>""" | |
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form: | |
``` | |
User: <utterance> | |
Assistant: <utterance> | |
User: <utterance> | |
Assistant: <utterance> | |
... | |
``` | |
In this app, you can explore the outputs of multiple LLMs when prompted in this way. | |
""" | |
openchat_disclaimer = """ | |
<div align="center">Checkout the official <a href=https://huggingface.co/spaces/togethercomputer/OpenChatKit>OpenChatKit feedback app</a> for the full experience.</div> | |
""" | |
with gr.Blocks( | |
css="""#col_container {margin-left: auto; margin-right: auto;} | |
#chatbot {height: 520px; overflow: auto;}""" | |
) as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
model = gr.Radio( | |
value="OpenAssistant/oasst-sft-1-pythia-12b", | |
choices=[ | |
"OpenAssistant/oasst-sft-1-pythia-12b", | |
# "togethercomputer/GPT-NeoXT-Chat-Base-20B", | |
"Rallio67/joi2_20Be_instruct_alpha", | |
"google/flan-t5-xxl", | |
"google/flan-ul2", | |
"bigscience/bloom", | |
"bigscience/bloomz", | |
"EleutherAI/gpt-neox-20b", | |
], | |
label="Model", | |
interactive=True, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
use_gpt3_checkbox = gr.Checkbox(label="Include GPT-3 (paid)?") | |
with gr.Column(): | |
use_equation_checkbox = gr.Checkbox(label="Prioritize equations?") | |
with gr.Row(): | |
with gr.Column(): | |
chatbot = gr.Chatbot(elem_id="chatbot") | |
inputs = gr.Textbox( | |
placeholder="Ask an Electrical Engineering question!", label="Type an input and press Enter" | |
) | |
examples = gr.Examples( | |
examples=[ | |
["What is a Finite State Machine?"], | |
["How do you design a functional a Two-Bit Gray Code Counter?"], | |
], | |
inputs=[inputs], | |
outputs=[], | |
) | |
gr.Markdown("The top 3 retrieved contexts are:") | |
with gr.Row(): | |
with gr.Column(): | |
# add context retrieval here | |
context1 = gr.Textbox(label="Context 1") | |
with gr.Column(): | |
context2 = gr.Textbox(label="Context 2") | |
with gr.Column(): | |
context3 = gr.Textbox(label="Context 3") | |
gr.Markdown("Lecture Slides:") | |
with gr.Row(): | |
# add CLIP part here | |
lec_gallery = gr.Gallery(label="Lecture images", show_label=False, elem_id="gallery").style(grid=[2], height="auto") | |
disclaimer = gr.Markdown(openchat_disclaimer, visible=False) | |
state = gr.State([]) | |
run_btn = gr.Button(variant='primary',) | |
with gr.Row(): | |
with gr.Accordion("Parameters", open=False): | |
typical_p = gr.Slider( | |
minimum=-0, | |
maximum=1.0, | |
value=0.2, | |
step=0.05, | |
interactive=True, | |
label="Typical P mass", | |
) | |
top_p = gr.Slider( | |
minimum=-0, | |
maximum=1.0, | |
value=0.25, | |
step=0.05, | |
interactive=True, | |
label="Top-p (nucleus sampling)", | |
visible=False, | |
) | |
temperature = gr.Slider( | |
minimum=-0, | |
maximum=5.0, | |
value=0.6, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
visible=False, | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=50, | |
value=50, | |
step=1, | |
interactive=True, | |
label="Top-k", | |
visible=False, | |
) | |
repetition_penalty = gr.Slider( | |
minimum=0.1, | |
maximum=3.0, | |
value=1.03, | |
step=0.01, | |
interactive=True, | |
label="Repetition Penalty", | |
visible=False, | |
) | |
watermark = gr.Checkbox(value=False, label="Text watermarking") | |
model.change( | |
lambda value: radio_on_change( | |
value, | |
disclaimer, | |
typical_p, | |
top_p, | |
top_k, | |
temperature, | |
repetition_penalty, | |
watermark, | |
), | |
inputs=model, | |
outputs=[ | |
disclaimer, | |
typical_p, | |
top_p, | |
top_k, | |
temperature, | |
repetition_penalty, | |
watermark, | |
], | |
) | |
inputs.submit( | |
predict, | |
[ | |
model, | |
inputs, | |
typical_p, | |
top_p, | |
temperature, | |
top_k, | |
repetition_penalty, | |
watermark, | |
chatbot, | |
state, | |
], | |
[chatbot, state], | |
) | |
run_btn.click( | |
predict, | |
[ | |
model, | |
inputs, | |
typical_p, | |
top_p, | |
temperature, | |
top_k, | |
repetition_penalty, | |
watermark, | |
chatbot, | |
state, | |
], | |
[chatbot, state], | |
) | |
run_btn.click(reset_textbox, [], [inputs]) | |
inputs.submit(reset_textbox, [], [inputs]) | |
gr.Markdown(description) | |
demo.queue(concurrency_count=16).launch(debug=True) | |