Spaces:
Running
Running
File size: 4,353 Bytes
6ff89e0 daee42b c6040d0 cfb1a62 660cea6 cfb1a62 9cb07f9 669d93a 183168e 669d93a daee42b c6040d0 52a64e1 669d93a 660cea6 c6040d0 d377a8f daee42b cfb1a62 52a64e1 cfb1a62 daee42b 52a64e1 daee42b cfb1a62 52a64e1 660cea6 52a64e1 660cea6 52a64e1 660cea6 52a64e1 660cea6 52a64e1 660cea6 daee42b c6040d0 52a64e1 c6040d0 daee42b c6040d0 52a64e1 c6040d0 cfb1a62 660cea6 c6040d0 52a64e1 c6040d0 52a64e1 660cea6 daee42b 52a64e1 daee42b d762ede |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt, sample_questions
from config import STUDY_FILES
import json
# Cache for RAG pipelines
rag_cache = {}
def get_rag_pipeline(study_name):
if study_name not in rag_cache:
study_file = STUDY_FILES.get(study_name)
if study_file:
rag_cache[study_name] = RAGPipeline(study_file)
else:
raise ValueError(f"Invalid study name: {study_name}")
return rag_cache[study_name]
def chat_function(message, history, study_name, prompt_type):
if not message.strip():
return "Please enter a valid query."
rag = get_rag_pipeline(study_name)
prompt = (
highlight_prompt
if prompt_type == "Highlight"
else evidence_based_prompt if prompt_type == "Evidence-based" else None
)
response = rag.query(message, prompt_template=prompt)
return response.response
def get_study_info(study_name):
study_file = STUDY_FILES.get(study_name)
if study_file:
with open(study_file, "r") as f:
data = json.load(f)
return f"Number of documents: {len(data)}\nFirst document title: {data[0]['title']}"
else:
return "Invalid study name"
def update_interface(study_name):
study_info = get_study_info(study_name)
questions = sample_questions.get(study_name, [])[:3]
return (
study_info,
*[gr.update(visible=True, value=q) for q in questions],
*[gr.update(visible=False) for _ in range(3 - len(questions))],
)
def set_question(question):
return question
with gr.Blocks() as demo:
gr.Markdown("# ACRES RAG Platform")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(elem_id="chatbot", show_label=False, height=400)
with gr.Row():
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
scale=4,
lines=1,
autofocus=True,
)
send_btn = gr.Button("Send", scale=1)
with gr.Accordion("Sample Questions", open=False):
sample_btn1 = gr.Button("Sample Question 1", visible=False)
sample_btn2 = gr.Button("Sample Question 2", visible=False)
sample_btn3 = gr.Button("Sample Question 3", visible=False)
with gr.Column(scale=1):
gr.Markdown("### Study Information")
study_dropdown = gr.Dropdown(
choices=list(STUDY_FILES.keys()),
label="Select Study",
value=list(STUDY_FILES.keys())[0],
)
study_info = gr.Textbox(label="Study Details", lines=4)
gr.Markdown("### Settings")
prompt_type = gr.Radio(
["Default", "Highlight", "Evidence-based"],
label="Prompt Type",
value="Default",
)
clear = gr.Button("Clear Chat")
def user(user_message, history):
if not user_message.strip():
return "", history # Return unchanged if the message is empty
return "", history + [[user_message, None]]
def bot(history, study_name, prompt_type):
if not history:
return history
user_message = history[-1][0]
bot_message = chat_function(user_message, history, study_name, prompt_type)
history[-1][1] = bot_message
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, study_dropdown, prompt_type], chatbot
)
send_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, study_dropdown, prompt_type], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
study_dropdown.change(
fn=update_interface,
inputs=study_dropdown,
outputs=[study_info, sample_btn1, sample_btn2, sample_btn3],
)
sample_btn1.click(set_question, inputs=[sample_btn1], outputs=[msg])
sample_btn2.click(set_question, inputs=[sample_btn2], outputs=[msg])
sample_btn3.click(set_question, inputs=[sample_btn3], outputs=[msg])
if __name__ == "__main__":
demo.launch(share=True, debug=True)
|