# app.py import gradio as gr import json from rag.rag_pipeline import RAGPipeline from utils.prompts import highlight_prompt, evidence_based_prompt from utils.prompts import ( study_characteristics_prompt, vaccine_coverage_prompt, sample_questions, ) from config import STUDY_FILES # Cache for RAG pipelines rag_cache = {} def get_rag_pipeline(study_name): if study_name not in rag_cache: study_file = STUDY_FILES.get(study_name) if study_file: rag_cache[study_name] = RAGPipeline(study_file) else: raise ValueError(f"Invalid study name: {study_name}") return rag_cache[study_name] def query_rag(study_name: str, question: str, prompt_type: str) -> str: rag = get_rag_pipeline(study_name) # Extract study information using RAG study_info = rag.extract_study_info() # Prepare a dictionary with all possible prompt parameters prompt_params = { **study_info, # Unpack the extracted study info "query_str": question, # Add the question to the prompt parameters } if prompt_type == "Highlight": prompt = highlight_prompt elif prompt_type == "Evidence-based": prompt = evidence_based_prompt elif prompt_type == "Study Characteristics": prompt = study_characteristics_prompt elif prompt_type == "Vaccine Coverage": prompt = vaccine_coverage_prompt else: prompt = None # Use the prompt_params in the query response = rag.query(question, prompt, **prompt_params) # Format the response as Markdown formatted_response = f"## Question\n\n{question}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n" for source in response["sources"]: formatted_response += ( f"- {source['title']} ({source.get('year', 'Year not specified')})\n" ) # Add extracted study information to the response formatted_response += "\n## Extracted Study Information\n\n" for key, value in study_info.items(): formatted_response += f"- **{key.replace('_', ' ').title()}**: {value}\n" return formatted_response def get_study_info(study_name): study_file = STUDY_FILES.get(study_name) if study_file: with open(study_file, "r") as f: data = json.load(f) return f"**Number of documents:** {len(data)}\n\n**First document title:** {data[0]['title']}" else: return "Invalid study name" def update_sample_questions(study_name): return gr.Dropdown(choices=sample_questions.get(study_name, []), interactive=True) with gr.Blocks() as demo: gr.Markdown("# RAG Pipeline Demo") with gr.Row(): study_dropdown = gr.Dropdown( choices=list(STUDY_FILES.keys()), label="Select Study" ) study_info = gr.Markdown(label="Study Information") study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info]) with gr.Row(): question_input = gr.Textbox(label="Enter your question") sample_question_dropdown = gr.Dropdown( choices=[], label="Sample Questions", interactive=True ) study_dropdown.change( update_sample_questions, inputs=[study_dropdown], outputs=[sample_question_dropdown], ) sample_question_dropdown.change( lambda x: x, inputs=[sample_question_dropdown], outputs=[question_input] ) prompt_type = gr.Radio( [ "Default", "Highlight", "Evidence-based", "Study Characteristics", "Vaccine Coverage", ], label="Prompt Type", value="Default", ) submit_button = gr.Button("Submit") answer_output = gr.Markdown(label="Answer") submit_button.click( query_rag, inputs=[study_dropdown, question_input, prompt_type], outputs=[answer_output], ) if __name__ == "__main__": demo.launch(share=True)