File size: 3,285 Bytes
9cb07f9
 
6ff89e0
cfb1a62
daee42b
cfb1a62
042c079
6a076b8
 
 
 
cfb1a62
 
9cb07f9
669d93a
183168e
669d93a
 
 
 
 
 
 
 
 
daee42b
 
cfb1a62
669d93a
daee42b
 
cfb1a62
 
 
6a076b8
 
 
 
6ff89e0
cfb1a62
daee42b
cfb1a62
6a076b8
 
 
 
 
d0d7d0e
 
daee42b
 
cfb1a62
 
 
 
 
d0d7d0e
cfb1a62
 
daee42b
 
6a076b8
9cb07f9
6a076b8
 
daee42b
cfb1a62
daee42b
cfb1a62
 
 
 
6a076b8
cfb1a62
 
daee42b
cfb1a62
 
9cb07f9
 
 
6a076b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daee42b
cfb1a62
 
d0d7d0e
daee42b
cfb1a62
 
 
 
 
daee42b
 
9cb07f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# app.py

import gradio as gr
import json
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt
from utils.prompts import (
    study_characteristics_prompt,
    vaccine_coverage_prompt,
    sample_questions,
)
from config import STUDY_FILES

# Cache for RAG pipelines
rag_cache = {}


def get_rag_pipeline(study_name):
    if study_name not in rag_cache:
        study_file = STUDY_FILES.get(study_name)
        if study_file:
            rag_cache[study_name] = RAGPipeline(study_file)
        else:
            raise ValueError(f"Invalid study name: {study_name}")
    return rag_cache[study_name]


def query_rag(study_name, question, prompt_type):
    rag = get_rag_pipeline(study_name)

    if prompt_type == "Highlight":
        prompt = highlight_prompt
    elif prompt_type == "Evidence-based":
        prompt = evidence_based_prompt
    elif prompt_type == "Study Characteristics":
        prompt = study_characteristics_prompt
    elif prompt_type == "Vaccine Coverage":
        prompt = vaccine_coverage_prompt
    else:
        prompt = None

    response = rag.query(question, prompt)

    # Format the response as Markdown
    formatted_response = f"## Question\n\n{response['question']}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
    for source in response["sources"]:
        formatted_response += f"- {source['title']} ({source['year']})\n"

    return formatted_response


def get_study_info(study_name):
    study_file = STUDY_FILES.get(study_name)
    if study_file:
        with open(study_file, "r") as f:
            data = json.load(f)
        return f"**Number of documents:** {len(data)}\n\n**First document title:** {data[0]['title']}"
    else:
        return "Invalid study name"


def update_sample_questions(study_name):
    return gr.Dropdown(choices=sample_questions.get(study_name, []), interactive=True)


with gr.Blocks() as demo:
    gr.Markdown("# RAG Pipeline Demo")

    with gr.Row():
        study_dropdown = gr.Dropdown(
            choices=list(STUDY_FILES.keys()), label="Select Study"
        )
        study_info = gr.Markdown(label="Study Information")

    study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])

    with gr.Row():
        question_input = gr.Textbox(label="Enter your question")
        sample_question_dropdown = gr.Dropdown(
            choices=[], label="Sample Questions", interactive=True
        )

    study_dropdown.change(
        update_sample_questions,
        inputs=[study_dropdown],
        outputs=[sample_question_dropdown],
    )
    sample_question_dropdown.change(
        lambda x: x, inputs=[sample_question_dropdown], outputs=[question_input]
    )

    prompt_type = gr.Radio(
        [
            "Default",
            "Highlight",
            "Evidence-based",
            "Study Characteristics",
            "Vaccine Coverage",
        ],
        label="Prompt Type",
        value="Default",
    )

    submit_button = gr.Button("Submit")

    answer_output = gr.Markdown(label="Answer")

    submit_button.click(
        query_rag,
        inputs=[study_dropdown, question_input, prompt_type],
        outputs=[answer_output],
    )

if __name__ == "__main__":
    demo.launch(share=True)