import gradio as gr from rag.rag_pipeline import RAGPipeline from utils.prompts import highlight_prompt, evidence_based_prompt, sample_questions from config import STUDY_FILES import json # Cache for RAG pipelines rag_cache = {} def get_rag_pipeline(study_name): if study_name not in rag_cache: study_file = STUDY_FILES.get(study_name) if study_file: rag_cache[study_name] = RAGPipeline(study_file) else: raise ValueError(f"Invalid study name: {study_name}") return rag_cache[study_name] def chat_function(message, history, study_name, prompt_type): if not message.strip(): return "Please enter a valid query." rag = get_rag_pipeline(study_name) prompt = ( highlight_prompt if prompt_type == "Highlight" else evidence_based_prompt if prompt_type == "Evidence-based" else None ) response = rag.query(message, prompt_template=prompt) return response.response def get_study_info(study_name): study_file = STUDY_FILES.get(study_name) if study_file: with open(study_file, "r") as f: data = json.load(f) return f"Number of documents: {len(data)}\nFirst document title: {data[0]['title']}" else: return "Invalid study name" def update_interface(study_name): study_info = get_study_info(study_name) questions = sample_questions.get(study_name, [])[:3] return ( study_info, *[gr.update(visible=True, value=q) for q in questions], *[gr.update(visible=False) for _ in range(3 - len(questions))], ) def set_question(question): return question with gr.Blocks() as demo: gr.Markdown("# ACRES RAG Platform") with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot(elem_id="chatbot", show_label=False, height=400) with gr.Row(): msg = gr.Textbox( show_label=False, placeholder="Type your message here...", scale=4, lines=1, autofocus=True, ) send_btn = gr.Button("Send", scale=1) with gr.Accordion("Sample Questions", open=False): sample_btn1 = gr.Button("Sample Question 1", visible=False) sample_btn2 = gr.Button("Sample Question 2", visible=False) sample_btn3 = gr.Button("Sample Question 3", visible=False) with gr.Column(scale=1): gr.Markdown("### Study Information") study_dropdown = gr.Dropdown( choices=list(STUDY_FILES.keys()), label="Select Study", value=list(STUDY_FILES.keys())[0], ) study_info = gr.Textbox(label="Study Details", lines=4) gr.Markdown("### Settings") prompt_type = gr.Radio( ["Default", "Highlight", "Evidence-based"], label="Prompt Type", value="Default", ) clear = gr.Button("Clear Chat") def user(user_message, history): if not user_message.strip(): return "", history # Return unchanged if the message is empty return "", history + [[user_message, None]] def bot(history, study_name, prompt_type): if not history: return history user_message = history[-1][0] bot_message = chat_function(user_message, history, study_name, prompt_type) history[-1][1] = bot_message return history msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, study_dropdown, prompt_type], chatbot ) send_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, [chatbot, study_dropdown, prompt_type], chatbot ) clear.click(lambda: None, None, chatbot, queue=False) study_dropdown.change( fn=update_interface, inputs=study_dropdown, outputs=[study_info, sample_btn1, sample_btn2, sample_btn3], ) sample_btn1.click(set_question, inputs=[sample_btn1], outputs=[msg]) sample_btn2.click(set_question, inputs=[sample_btn2], outputs=[msg]) sample_btn3.click(set_question, inputs=[sample_btn3], outputs=[msg]) if __name__ == "__main__": demo.launch(share=True, debug=True)