File size: 4,353 Bytes
6ff89e0
daee42b
c6040d0
cfb1a62
660cea6
cfb1a62
9cb07f9
669d93a
183168e
669d93a
 
 
 
 
 
 
 
 
daee42b
 
c6040d0
52a64e1
 
 
669d93a
660cea6
 
 
 
 
c6040d0
d377a8f
daee42b
 
cfb1a62
 
 
 
 
52a64e1
cfb1a62
 
daee42b
 
52a64e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daee42b
cfb1a62
52a64e1
 
660cea6
 
52a64e1
 
 
 
 
660cea6
 
52a64e1
 
 
 
660cea6
 
52a64e1
660cea6
 
 
 
 
52a64e1
 
660cea6
 
 
 
 
 
daee42b
c6040d0
52a64e1
 
c6040d0
daee42b
c6040d0
52a64e1
 
c6040d0
 
 
 
 
 
 
cfb1a62
660cea6
 
 
c6040d0
 
 
52a64e1
c6040d0
52a64e1
660cea6
daee42b
52a64e1
 
 
 
 
daee42b
d762ede
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt, sample_questions
from config import STUDY_FILES
import json

# Cache for RAG pipelines
rag_cache = {}


def get_rag_pipeline(study_name):
    if study_name not in rag_cache:
        study_file = STUDY_FILES.get(study_name)
        if study_file:
            rag_cache[study_name] = RAGPipeline(study_file)
        else:
            raise ValueError(f"Invalid study name: {study_name}")
    return rag_cache[study_name]


def chat_function(message, history, study_name, prompt_type):
    if not message.strip():
        return "Please enter a valid query."

    rag = get_rag_pipeline(study_name)
    prompt = (
        highlight_prompt
        if prompt_type == "Highlight"
        else evidence_based_prompt if prompt_type == "Evidence-based" else None
    )
    response = rag.query(message, prompt_template=prompt)
    return response.response


def get_study_info(study_name):
    study_file = STUDY_FILES.get(study_name)
    if study_file:
        with open(study_file, "r") as f:
            data = json.load(f)
        return f"Number of documents: {len(data)}\nFirst document title: {data[0]['title']}"
    else:
        return "Invalid study name"


def update_interface(study_name):
    study_info = get_study_info(study_name)
    questions = sample_questions.get(study_name, [])[:3]
    return (
        study_info,
        *[gr.update(visible=True, value=q) for q in questions],
        *[gr.update(visible=False) for _ in range(3 - len(questions))],
    )


def set_question(question):
    return question


with gr.Blocks() as demo:
    gr.Markdown("# ACRES RAG Platform")

    with gr.Row():
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(elem_id="chatbot", show_label=False, height=400)
            with gr.Row():
                msg = gr.Textbox(
                    show_label=False,
                    placeholder="Type your message here...",
                    scale=4,
                    lines=1,
                    autofocus=True,
                )
                send_btn = gr.Button("Send", scale=1)
            with gr.Accordion("Sample Questions", open=False):
                sample_btn1 = gr.Button("Sample Question 1", visible=False)
                sample_btn2 = gr.Button("Sample Question 2", visible=False)
                sample_btn3 = gr.Button("Sample Question 3", visible=False)

        with gr.Column(scale=1):
            gr.Markdown("### Study Information")
            study_dropdown = gr.Dropdown(
                choices=list(STUDY_FILES.keys()),
                label="Select Study",
                value=list(STUDY_FILES.keys())[0],
            )
            study_info = gr.Textbox(label="Study Details", lines=4)
            gr.Markdown("### Settings")
            prompt_type = gr.Radio(
                ["Default", "Highlight", "Evidence-based"],
                label="Prompt Type",
                value="Default",
            )
            clear = gr.Button("Clear Chat")

    def user(user_message, history):
        if not user_message.strip():
            return "", history  # Return unchanged if the message is empty
        return "", history + [[user_message, None]]

    def bot(history, study_name, prompt_type):
        if not history:
            return history
        user_message = history[-1][0]
        bot_message = chat_function(user_message, history, study_name, prompt_type)
        history[-1][1] = bot_message
        return history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, study_dropdown, prompt_type], chatbot
    )
    send_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, study_dropdown, prompt_type], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

    study_dropdown.change(
        fn=update_interface,
        inputs=study_dropdown,
        outputs=[study_info, sample_btn1, sample_btn2, sample_btn3],
    )

    sample_btn1.click(set_question, inputs=[sample_btn1], outputs=[msg])
    sample_btn2.click(set_question, inputs=[sample_btn2], outputs=[msg])
    sample_btn3.click(set_question, inputs=[sample_btn3], outputs=[msg])


if __name__ == "__main__":
    demo.launch(share=True, debug=True)