learn-ai / app.py
dh-mc's picture
completed gradio app for qa
e182c41
raw
history blame
No virus
6.2 kB
"""Main entrypoint for the app."""
import os
import time
from queue import Queue
from timeit import default_timer as timer
import gradio as gr
from anyio.from_thread import start_blocking_portal
from app_modules.init import app_init
from app_modules.utils import print_llm_response
qa_chain = app_init()
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
show_param_settings = os.environ.get("SHOW_PARAM_SETTINGS") == "true"
share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
using_openai = os.environ.get("LLM_MODEL_TYPE") == "openai"
model = (
"OpenAI GPT-4" if using_openai else os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH")
)
href = "https://openai.com/gpt-4" if using_openai else f"https://huggingface.co/{model}"
title = """<h1 align="left" style="min-width:200px; margin-top:0;"> Chat with AI Books </h1>"""
description_top = f"""\
<div align="left">
<p> Currently Running: <a href="{href}">{model}</a></p>
</div>
"""
description = """\
<div align="center" style="margin:16px 0">
The demo is built on <a href="https://github.com/hwchase17/langchain">LangChain</a>.
</div>
"""
CONCURRENT_COUNT = 100
def qa(chatbot):
user_msg = chatbot[-1][0]
q = Queue()
result = Queue()
job_done = object()
def task(question, chat_history):
start = timer()
ret = qa_chain.call_chain(
{"question": question, "chat_history": chat_history}, None, q
)
end = timer()
print(f"Completed in {end - start:.3f}s")
print_llm_response(ret)
q.put(job_done)
result.put(ret)
with start_blocking_portal() as portal:
chat_history = []
if chat_history_enabled:
for i in range(len(chatbot) - 1):
element = chatbot[i]
item = (element[0] or "", element[1] or "")
chat_history.append(item)
portal.start_task_soon(task, user_msg, chat_history)
content = ""
count = 2 if len(chat_history) > 0 else 1
while count > 0:
while q.empty():
print("nothing generated yet - retry in 0.5s")
time.sleep(0.5)
for next_token in qa_chain.streamer:
if next_token is job_done:
break
content += next_token or ""
chatbot[-1][1] = remove_extra_spaces(content)
if count == 1:
yield chatbot
count -= 1
chatbot[-1][1] += "\n\nSources:\n"
ret = result.get()
titles = []
for doc in ret["source_documents"]:
page = doc.metadata["page"] + 1
url = f"{doc.metadata['url']}#page={page}"
file_name = doc.metadata["source"].split("/")[-1]
title = f"{file_name} Page: {page}"
if title not in titles:
titles.append(title)
chatbot[-1][1] += f"1. [{title}]({url})\n"
yield chatbot
with open("assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with gr.Blocks(css=customCSS) as demo:
user_question = gr.State("")
with gr.Row():
gr.HTML(title)
gr.Markdown(description_top)
with gr.Row().style(equal_height=True):
with gr.Column(scale=5):
with gr.Row():
chatbot = gr.Chatbot(elem_id="inflaton_chatbot").style(height="100%")
with gr.Row():
with gr.Column(scale=2):
user_input = gr.Textbox(
show_label=False, placeholder="Enter your question here"
).style(container=False)
with gr.Column(
min_width=70,
):
submitBtn = gr.Button("Send")
with gr.Column(
min_width=70,
):
clearBtn = gr.Button("Clear")
if show_param_settings:
with gr.Column():
with gr.Column(
min_width=50,
):
with gr.Tab(label="Parameter Setting"):
gr.Markdown("# Parameters")
top_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.95,
step=0.05,
# interactive=True,
label="Top-p",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0,
step=0.1,
# interactive=True,
label="Temperature",
)
max_new_tokens = gr.Slider(
minimum=0,
maximum=2048,
value=2048,
step=8,
# interactive=True,
label="Max Generation Tokens",
)
max_context_length_tokens = gr.Slider(
minimum=0,
maximum=4096,
value=4096,
step=128,
# interactive=True,
label="Max Context Tokens",
)
gr.Markdown(description)
def chat(user_message, history):
return "", history + [[user_message, None]]
user_input.submit(
chat, [user_input, chatbot], [user_input, chatbot], queue=True
).then(qa, chatbot, chatbot)
submitBtn.click(
chat, [user_input, chatbot], [user_input, chatbot], queue=True, api_name="chat"
).then(qa, chatbot, chatbot)
def reset():
return "", []
clearBtn.click(
reset,
outputs=[user_input, chatbot],
show_progress=True,
api_name="reset",
)
demo.title = "Chat with AI Books"
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(share=share_gradio_app)