File size: 3,984 Bytes
6ff89e0
cfb1a62
daee42b
cfb1a62
042c079
6a076b8
 
 
 
cfb1a62
 
9cb07f9
669d93a
183168e
669d93a
 
 
 
 
 
 
 
 
daee42b
 
5f52091
669d93a
daee42b
5f52091
 
 
daee42b
cfb1a62
 
 
6a076b8
 
 
 
6ff89e0
cfb1a62
daee42b
9a9bac9
 
 
 
 
 
 
 
 
 
 
6a076b8
 
5f52091
6a076b8
5f52091
 
 
 
 
 
 
 
d0d7d0e
 
daee42b
 
cfb1a62
 
 
 
 
d0d7d0e
cfb1a62
 
daee42b
 
6a076b8
9cb07f9
6a076b8
 
daee42b
cfb1a62
daee42b
cfb1a62
 
 
 
6a076b8
cfb1a62
 
daee42b
cfb1a62
 
9cb07f9
 
 
6a076b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daee42b
cfb1a62
 
d0d7d0e
daee42b
cfb1a62
 
 
 
 
daee42b
 
9cb07f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import json
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt
from utils.prompts import (
    study_characteristics_prompt,
    vaccine_coverage_prompt,
    sample_questions,
)
from config import STUDY_FILES

# Cache for RAG pipelines
rag_cache = {}


def get_rag_pipeline(study_name):
    if study_name not in rag_cache:
        study_file = STUDY_FILES.get(study_name)
        if study_file:
            rag_cache[study_name] = RAGPipeline(study_file)
        else:
            raise ValueError(f"Invalid study name: {study_name}")
    return rag_cache[study_name]


def query_rag(study_name: str, question: str, prompt_type: str) -> str:
    rag = get_rag_pipeline(study_name)

    # Extract study information using RAG
    study_info = rag.extract_study_info()

    if prompt_type == "Highlight":
        prompt = highlight_prompt
    elif prompt_type == "Evidence-based":
        prompt = evidence_based_prompt
    elif prompt_type == "Study Characteristics":
        prompt = study_characteristics_prompt
    elif prompt_type == "Vaccine Coverage":
        prompt = vaccine_coverage_prompt
    else:
        prompt = None

    # Prepare the context with study info
    context = "Study Information:\n"
    for key, value in study_info.items():
        context += f"{key}: {value}\n"
    context += "\n"

    # Add the question to the context
    context += f"Question: {question}\n"

    # Use the prepared context in the query
    response = rag.query(context, prompt_template=prompt)

    # Format the response as Markdown
    formatted_response = f"## Question\n\n{question}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
    for source in response["sources"]:
        formatted_response += (
            f"- {source['title']} ({source.get('year', 'Year not specified')})\n"
        )

    # Add extracted study information to the response
    formatted_response += "\n## Extracted Study Information\n\n"
    for key, value in study_info.items():
        formatted_response += f"- **{key.replace('_', ' ').title()}**: {value}\n"

    return formatted_response


def get_study_info(study_name):
    study_file = STUDY_FILES.get(study_name)
    if study_file:
        with open(study_file, "r") as f:
            data = json.load(f)
        return f"**Number of documents:** {len(data)}\n\n**First document title:** {data[0]['title']}"
    else:
        return "Invalid study name"


def update_sample_questions(study_name):
    return gr.Dropdown(choices=sample_questions.get(study_name, []), interactive=True)


with gr.Blocks() as demo:
    gr.Markdown("# RAG Pipeline Demo")

    with gr.Row():
        study_dropdown = gr.Dropdown(
            choices=list(STUDY_FILES.keys()), label="Select Study"
        )
        study_info = gr.Markdown(label="Study Information")

    study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])

    with gr.Row():
        question_input = gr.Textbox(label="Enter your question")
        sample_question_dropdown = gr.Dropdown(
            choices=[], label="Sample Questions", interactive=True
        )

    study_dropdown.change(
        update_sample_questions,
        inputs=[study_dropdown],
        outputs=[sample_question_dropdown],
    )
    sample_question_dropdown.change(
        lambda x: x, inputs=[sample_question_dropdown], outputs=[question_input]
    )

    prompt_type = gr.Radio(
        [
            "Default",
            "Highlight",
            "Evidence-based",
            "Study Characteristics",
            "Vaccine Coverage",
        ],
        label="Prompt Type",
        value="Default",
    )

    submit_button = gr.Button("Submit")

    answer_output = gr.Markdown(label="Answer")

    submit_button.click(
        query_rag,
        inputs=[study_dropdown, question_input, prompt_type],
        outputs=[answer_output],
    )

if __name__ == "__main__":
    demo.launch(share=True)