File size: 2,784 Bytes
6ff89e0
cfb1a62
daee42b
cfb1a62
042c079
6a076b8
 
d762ede
cfb1a62
 
9cb07f9
669d93a
183168e
669d93a
 
 
 
 
 
 
 
 
daee42b
 
5f52091
669d93a
daee42b
 
cfb1a62
 
 
6ff89e0
cfb1a62
daee42b
9a9bac9
d762ede
6a076b8
d377a8f
daee42b
 
cfb1a62
 
 
 
 
d0d7d0e
cfb1a62
 
daee42b
 
6a076b8
9cb07f9
6a076b8
 
daee42b
cfb1a62
daee42b
cfb1a62
 
 
 
6a076b8
cfb1a62
 
daee42b
cfb1a62
 
9cb07f9
 
 
6a076b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daee42b
cfb1a62
 
d0d7d0e
daee42b
cfb1a62
 
 
 
 
daee42b
 
d762ede
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import json
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt
from utils.prompts import (
    sample_questions,
)

from config import STUDY_FILES

# Cache for RAG pipelines
rag_cache = {}


def get_rag_pipeline(study_name):
    if study_name not in rag_cache:
        study_file = STUDY_FILES.get(study_name)
        if study_file:
            rag_cache[study_name] = RAGPipeline(study_file)
        else:
            raise ValueError(f"Invalid study name: {study_name}")
    return rag_cache[study_name]


def query_rag(study_name: str, question: str, prompt_type: str) -> str:
    rag = get_rag_pipeline(study_name)

    if prompt_type == "Highlight":
        prompt = highlight_prompt
    elif prompt_type == "Evidence-based":
        prompt = evidence_based_prompt
    else:
        prompt = None

    # Use the prepared context in the query
    response = rag.query(question, prompt_template=prompt)

    return response.response


def get_study_info(study_name):
    study_file = STUDY_FILES.get(study_name)
    if study_file:
        with open(study_file, "r") as f:
            data = json.load(f)
        return f"**Number of documents:** {len(data)}\n\n**First document title:** {data[0]['title']}"
    else:
        return "Invalid study name"


def update_sample_questions(study_name):
    return gr.Dropdown(choices=sample_questions.get(study_name, []), interactive=True)


with gr.Blocks() as demo:
    gr.Markdown("# RAG Pipeline Demo")

    with gr.Row():
        study_dropdown = gr.Dropdown(
            choices=list(STUDY_FILES.keys()), label="Select Study"
        )
        study_info = gr.Markdown(label="Study Information")

    study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])

    with gr.Row():
        question_input = gr.Textbox(label="Enter your question")
        sample_question_dropdown = gr.Dropdown(
            choices=[], label="Sample Questions", interactive=True
        )

    study_dropdown.change(
        update_sample_questions,
        inputs=[study_dropdown],
        outputs=[sample_question_dropdown],
    )
    sample_question_dropdown.change(
        lambda x: x, inputs=[sample_question_dropdown], outputs=[question_input]
    )

    prompt_type = gr.Radio(
        [
            "Default",
            "Highlight",
            "Evidence-based",
        ],
        label="Prompt Type",
        value="Default",
    )

    submit_button = gr.Button("Submit")

    answer_output = gr.Markdown(label="Answer")

    submit_button.click(
        query_rag,
        inputs=[study_dropdown, question_input, prompt_type],
        outputs=[answer_output],
    )

if __name__ == "__main__":
    demo.launch(share=True, debug=True)