acres / app.py
ak3ra's picture
prompt params
5f52091
raw
history blame
3.96 kB
# app.py
import gradio as gr
import json
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt
from utils.prompts import (
study_characteristics_prompt,
vaccine_coverage_prompt,
sample_questions,
)
from config import STUDY_FILES
# Cache for RAG pipelines
rag_cache = {}
def get_rag_pipeline(study_name):
if study_name not in rag_cache:
study_file = STUDY_FILES.get(study_name)
if study_file:
rag_cache[study_name] = RAGPipeline(study_file)
else:
raise ValueError(f"Invalid study name: {study_name}")
return rag_cache[study_name]
def query_rag(study_name: str, question: str, prompt_type: str) -> str:
rag = get_rag_pipeline(study_name)
# Extract study information using RAG
study_info = rag.extract_study_info()
# Prepare a dictionary with all possible prompt parameters
prompt_params = {
**study_info, # Unpack the extracted study info
"query_str": question, # Add the question to the prompt parameters
}
if prompt_type == "Highlight":
prompt = highlight_prompt
elif prompt_type == "Evidence-based":
prompt = evidence_based_prompt
elif prompt_type == "Study Characteristics":
prompt = study_characteristics_prompt
elif prompt_type == "Vaccine Coverage":
prompt = vaccine_coverage_prompt
else:
prompt = None
# Use the prompt_params in the query
response = rag.query(question, prompt, **prompt_params)
# Format the response as Markdown
formatted_response = f"## Question\n\n{question}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
for source in response["sources"]:
formatted_response += (
f"- {source['title']} ({source.get('year', 'Year not specified')})\n"
)
# Add extracted study information to the response
formatted_response += "\n## Extracted Study Information\n\n"
for key, value in study_info.items():
formatted_response += f"- **{key.replace('_', ' ').title()}**: {value}\n"
return formatted_response
def get_study_info(study_name):
study_file = STUDY_FILES.get(study_name)
if study_file:
with open(study_file, "r") as f:
data = json.load(f)
return f"**Number of documents:** {len(data)}\n\n**First document title:** {data[0]['title']}"
else:
return "Invalid study name"
def update_sample_questions(study_name):
return gr.Dropdown(choices=sample_questions.get(study_name, []), interactive=True)
with gr.Blocks() as demo:
gr.Markdown("# RAG Pipeline Demo")
with gr.Row():
study_dropdown = gr.Dropdown(
choices=list(STUDY_FILES.keys()), label="Select Study"
)
study_info = gr.Markdown(label="Study Information")
study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])
with gr.Row():
question_input = gr.Textbox(label="Enter your question")
sample_question_dropdown = gr.Dropdown(
choices=[], label="Sample Questions", interactive=True
)
study_dropdown.change(
update_sample_questions,
inputs=[study_dropdown],
outputs=[sample_question_dropdown],
)
sample_question_dropdown.change(
lambda x: x, inputs=[sample_question_dropdown], outputs=[question_input]
)
prompt_type = gr.Radio(
[
"Default",
"Highlight",
"Evidence-based",
"Study Characteristics",
"Vaccine Coverage",
],
label="Prompt Type",
value="Default",
)
submit_button = gr.Button("Submit")
answer_output = gr.Markdown(label="Answer")
submit_button.click(
query_rag,
inputs=[study_dropdown, question_input, prompt_type],
outputs=[answer_output],
)
if __name__ == "__main__":
demo.launch(share=True)