File size: 4,120 Bytes
9cb07f9
 
6ff89e0
cfb1a62
daee42b
cfb1a62
042c079
6a076b8
 
 
 
cfb1a62
 
9cb07f9
669d93a
183168e
669d93a
 
 
 
 
 
 
 
 
daee42b
 
cfb1a62
669d93a
daee42b
9f2191f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daee42b
cfb1a62
 
 
6a076b8
 
 
 
6ff89e0
cfb1a62
daee42b
9f2191f
6a076b8
 
 
 
 
d0d7d0e
 
daee42b
 
cfb1a62
 
 
 
 
d0d7d0e
cfb1a62
 
daee42b
 
6a076b8
9cb07f9
6a076b8
 
daee42b
cfb1a62
daee42b
cfb1a62
 
 
 
6a076b8
cfb1a62
 
daee42b
cfb1a62
 
9cb07f9
 
 
6a076b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daee42b
cfb1a62
 
d0d7d0e
daee42b
cfb1a62
 
 
 
 
daee42b
 
9cb07f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# app.py

import gradio as gr
import json
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt
from utils.prompts import (
    study_characteristics_prompt,
    vaccine_coverage_prompt,
    sample_questions,
)
from config import STUDY_FILES

# Cache for RAG pipelines
rag_cache = {}


def get_rag_pipeline(study_name):
    if study_name not in rag_cache:
        study_file = STUDY_FILES.get(study_name)
        if study_file:
            rag_cache[study_name] = RAGPipeline(study_file)
        else:
            raise ValueError(f"Invalid study name: {study_name}")
    return rag_cache[study_name]


def query_rag(study_name, question, prompt_type):
    rag = get_rag_pipeline(study_name)

    # Prepare a dictionary with all possible prompt parameters
    prompt_params = {
        "studyid": "",  # retrieve or generate a study ID?
        "author": "",
        "year": "",
        "title": "",
        "appendix": "",
        "publication_type": "",
        "study_design": "",
        "study_area_region": "",
        "study_population": "",
        "immunisable_disease": "",
        "route_of_administration": "",
        "duration_of_study": "",
        "duration_covid19": "",
        "study_comments": "",
        "coverage_rates": "",
        "proportion_recommended_age": "",
        "immunisation_uptake": "",
        "drop_out_rates": "",
        "intentions_to_vaccinate": "",
        "vaccine_confidence": "",
        "query_str": question,  # Add the question to the prompt parameters
    }

    if prompt_type == "Highlight":
        prompt = highlight_prompt
    elif prompt_type == "Evidence-based":
        prompt = evidence_based_prompt
    elif prompt_type == "Study Characteristics":
        prompt = study_characteristics_prompt
    elif prompt_type == "Vaccine Coverage":
        prompt = vaccine_coverage_prompt
    else:
        prompt = None

    response = rag.query(question, prompt, **prompt_params)

    # Format the response as Markdown
    formatted_response = f"## Question\n\n{response['question']}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
    for source in response["sources"]:
        formatted_response += f"- {source['title']} ({source['year']})\n"

    return formatted_response


def get_study_info(study_name):
    study_file = STUDY_FILES.get(study_name)
    if study_file:
        with open(study_file, "r") as f:
            data = json.load(f)
        return f"**Number of documents:** {len(data)}\n\n**First document title:** {data[0]['title']}"
    else:
        return "Invalid study name"


def update_sample_questions(study_name):
    return gr.Dropdown(choices=sample_questions.get(study_name, []), interactive=True)


with gr.Blocks() as demo:
    gr.Markdown("# RAG Pipeline Demo")

    with gr.Row():
        study_dropdown = gr.Dropdown(
            choices=list(STUDY_FILES.keys()), label="Select Study"
        )
        study_info = gr.Markdown(label="Study Information")

    study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])

    with gr.Row():
        question_input = gr.Textbox(label="Enter your question")
        sample_question_dropdown = gr.Dropdown(
            choices=[], label="Sample Questions", interactive=True
        )

    study_dropdown.change(
        update_sample_questions,
        inputs=[study_dropdown],
        outputs=[sample_question_dropdown],
    )
    sample_question_dropdown.change(
        lambda x: x, inputs=[sample_question_dropdown], outputs=[question_input]
    )

    prompt_type = gr.Radio(
        [
            "Default",
            "Highlight",
            "Evidence-based",
            "Study Characteristics",
            "Vaccine Coverage",
        ],
        label="Prompt Type",
        value="Default",
    )

    submit_button = gr.Button("Submit")

    answer_output = gr.Markdown(label="Answer")

    submit_button.click(
        query_rag,
        inputs=[study_dropdown, question_input, prompt_type],
        outputs=[answer_output],
    )

if __name__ == "__main__":
    demo.launch(share=True)