acres / app.py
ak3ra's picture
minor changes to the chatui
52a64e1
raw
history blame
4.35 kB
import gradio as gr
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt, sample_questions
from config import STUDY_FILES
import json
# Cache for RAG pipelines
rag_cache = {}
def get_rag_pipeline(study_name):
if study_name not in rag_cache:
study_file = STUDY_FILES.get(study_name)
if study_file:
rag_cache[study_name] = RAGPipeline(study_file)
else:
raise ValueError(f"Invalid study name: {study_name}")
return rag_cache[study_name]
def chat_function(message, history, study_name, prompt_type):
if not message.strip():
return "Please enter a valid query."
rag = get_rag_pipeline(study_name)
prompt = (
highlight_prompt
if prompt_type == "Highlight"
else evidence_based_prompt if prompt_type == "Evidence-based" else None
)
response = rag.query(message, prompt_template=prompt)
return response.response
def get_study_info(study_name):
study_file = STUDY_FILES.get(study_name)
if study_file:
with open(study_file, "r") as f:
data = json.load(f)
return f"Number of documents: {len(data)}\nFirst document title: {data[0]['title']}"
else:
return "Invalid study name"
def update_interface(study_name):
study_info = get_study_info(study_name)
questions = sample_questions.get(study_name, [])[:3]
return (
study_info,
*[gr.update(visible=True, value=q) for q in questions],
*[gr.update(visible=False) for _ in range(3 - len(questions))],
)
def set_question(question):
return question
with gr.Blocks() as demo:
gr.Markdown("# ACRES RAG Platform")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(elem_id="chatbot", show_label=False, height=400)
with gr.Row():
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
scale=4,
lines=1,
autofocus=True,
)
send_btn = gr.Button("Send", scale=1)
with gr.Accordion("Sample Questions", open=False):
sample_btn1 = gr.Button("Sample Question 1", visible=False)
sample_btn2 = gr.Button("Sample Question 2", visible=False)
sample_btn3 = gr.Button("Sample Question 3", visible=False)
with gr.Column(scale=1):
gr.Markdown("### Study Information")
study_dropdown = gr.Dropdown(
choices=list(STUDY_FILES.keys()),
label="Select Study",
value=list(STUDY_FILES.keys())[0],
)
study_info = gr.Textbox(label="Study Details", lines=4)
gr.Markdown("### Settings")
prompt_type = gr.Radio(
["Default", "Highlight", "Evidence-based"],
label="Prompt Type",
value="Default",
)
clear = gr.Button("Clear Chat")
def user(user_message, history):
if not user_message.strip():
return "", history # Return unchanged if the message is empty
return "", history + [[user_message, None]]
def bot(history, study_name, prompt_type):
if not history:
return history
user_message = history[-1][0]
bot_message = chat_function(user_message, history, study_name, prompt_type)
history[-1][1] = bot_message
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, study_dropdown, prompt_type], chatbot
)
send_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, study_dropdown, prompt_type], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
study_dropdown.change(
fn=update_interface,
inputs=study_dropdown,
outputs=[study_info, sample_btn1, sample_btn2, sample_btn3],
)
sample_btn1.click(set_question, inputs=[sample_btn1], outputs=[msg])
sample_btn2.click(set_question, inputs=[sample_btn2], outputs=[msg])
sample_btn3.click(set_question, inputs=[sample_btn3], outputs=[msg])
if __name__ == "__main__":
demo.launch(share=True, debug=True)