Spaces:
Running
Running
custom prompts
Browse files- app.py +41 -11
- rag/rag_pipeline.py +9 -2
- requirements.txt +2 -1
- utils/prompts.py +98 -0
app.py
CHANGED
@@ -2,9 +2,13 @@ import gradio as gr
|
|
2 |
import json
|
3 |
from rag.rag_pipeline import RAGPipeline
|
4 |
from utils.prompts import highlight_prompt, evidence_based_prompt
|
|
|
|
|
|
|
|
|
|
|
5 |
from config import STUDY_FILES
|
6 |
|
7 |
-
# Cache for RAG pipelines
|
8 |
rag_cache = {}
|
9 |
|
10 |
|
@@ -25,13 +29,19 @@ def query_rag(study_name, question, prompt_type):
|
|
25 |
prompt = highlight_prompt
|
26 |
elif prompt_type == "Evidence-based":
|
27 |
prompt = evidence_based_prompt
|
|
|
|
|
|
|
|
|
28 |
else:
|
29 |
prompt = None
|
30 |
|
31 |
response = rag.query(question, prompt)
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
35 |
|
36 |
return formatted_response
|
37 |
|
@@ -46,6 +56,10 @@ def get_study_info(study_name):
|
|
46 |
return "Invalid study name"
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
49 |
with gr.Blocks() as demo:
|
50 |
gr.Markdown("# RAG Pipeline Demo")
|
51 |
|
@@ -53,21 +67,37 @@ with gr.Blocks() as demo:
|
|
53 |
study_dropdown = gr.Dropdown(
|
54 |
choices=list(STUDY_FILES.keys()), label="Select Study"
|
55 |
)
|
56 |
-
study_info = gr.
|
57 |
|
58 |
study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])
|
59 |
|
60 |
with gr.Row():
|
61 |
question_input = gr.Textbox(label="Enter your question")
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
submit_button = gr.Button("Submit")
|
69 |
|
70 |
-
# answer_output = gr.Textbox(label="Answer")
|
71 |
answer_output = gr.Markdown(label="Answer")
|
72 |
|
73 |
submit_button.click(
|
|
|
2 |
import json
|
3 |
from rag.rag_pipeline import RAGPipeline
|
4 |
from utils.prompts import highlight_prompt, evidence_based_prompt
|
5 |
+
from utils.custom_prompts import (
|
6 |
+
study_characteristics_prompt,
|
7 |
+
vaccine_coverage_prompt,
|
8 |
+
sample_questions,
|
9 |
+
)
|
10 |
from config import STUDY_FILES
|
11 |
|
|
|
12 |
rag_cache = {}
|
13 |
|
14 |
|
|
|
29 |
prompt = highlight_prompt
|
30 |
elif prompt_type == "Evidence-based":
|
31 |
prompt = evidence_based_prompt
|
32 |
+
elif prompt_type == "Study Characteristics":
|
33 |
+
prompt = study_characteristics_prompt
|
34 |
+
elif prompt_type == "Vaccine Coverage":
|
35 |
+
prompt = vaccine_coverage_prompt
|
36 |
else:
|
37 |
prompt = None
|
38 |
|
39 |
response = rag.query(question, prompt)
|
40 |
+
|
41 |
+
# Format the response as Markdown
|
42 |
+
formatted_response = f"## Question\n\n{response['question']}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
|
43 |
+
for source in response["sources"]:
|
44 |
+
formatted_response += f"- {source['title']} ({source['year']})\n"
|
45 |
|
46 |
return formatted_response
|
47 |
|
|
|
56 |
return "Invalid study name"
|
57 |
|
58 |
|
59 |
+
def update_sample_questions(study_name):
|
60 |
+
return gr.Dropdown.update(choices=sample_questions.get(study_name, []))
|
61 |
+
|
62 |
+
|
63 |
with gr.Blocks() as demo:
|
64 |
gr.Markdown("# RAG Pipeline Demo")
|
65 |
|
|
|
67 |
study_dropdown = gr.Dropdown(
|
68 |
choices=list(STUDY_FILES.keys()), label="Select Study"
|
69 |
)
|
70 |
+
study_info = gr.Markdown(label="Study Information")
|
71 |
|
72 |
study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])
|
73 |
|
74 |
with gr.Row():
|
75 |
question_input = gr.Textbox(label="Enter your question")
|
76 |
+
sample_question_dropdown = gr.Dropdown(choices=[], label="Sample Questions")
|
77 |
+
|
78 |
+
study_dropdown.change(
|
79 |
+
update_sample_questions,
|
80 |
+
inputs=[study_dropdown],
|
81 |
+
outputs=[sample_question_dropdown],
|
82 |
+
)
|
83 |
+
sample_question_dropdown.change(
|
84 |
+
lambda x: x, inputs=[sample_question_dropdown], outputs=[question_input]
|
85 |
+
)
|
86 |
+
|
87 |
+
prompt_type = gr.Radio(
|
88 |
+
[
|
89 |
+
"Default",
|
90 |
+
"Highlight",
|
91 |
+
"Evidence-based",
|
92 |
+
"Study Characteristics",
|
93 |
+
"Vaccine Coverage",
|
94 |
+
],
|
95 |
+
label="Prompt Type",
|
96 |
+
value="Default",
|
97 |
+
)
|
98 |
|
99 |
submit_button = gr.Button("Submit")
|
100 |
|
|
|
101 |
answer_output = gr.Markdown(label="Answer")
|
102 |
|
103 |
submit_button.click(
|
rag/rag_pipeline.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# rag/rag_pipeline.py
|
2 |
|
3 |
import json
|
|
|
4 |
from llama_index.core import Document, VectorStoreIndex
|
5 |
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
|
6 |
from llama_index.core import PromptTemplate
|
@@ -58,7 +59,9 @@ class RAGPipeline:
|
|
58 |
nodes = node_parser.get_nodes_from_documents(self.documents)
|
59 |
self.index = VectorStoreIndex(nodes)
|
60 |
|
61 |
-
def query(
|
|
|
|
|
62 |
self.build_index() # This will only build the index if it hasn't been built yet
|
63 |
|
64 |
if prompt_template is None:
|
@@ -79,4 +82,8 @@ class RAGPipeline:
|
|
79 |
)
|
80 |
response = query_engine.query(question)
|
81 |
|
82 |
-
return
|
|
|
|
|
|
|
|
|
|
1 |
# rag/rag_pipeline.py
|
2 |
|
3 |
import json
|
4 |
+
from typing import Dict, Any
|
5 |
from llama_index.core import Document, VectorStoreIndex
|
6 |
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
|
7 |
from llama_index.core import PromptTemplate
|
|
|
59 |
nodes = node_parser.get_nodes_from_documents(self.documents)
|
60 |
self.index = VectorStoreIndex(nodes)
|
61 |
|
62 |
+
def query(
|
63 |
+
self, question: str, prompt_template: PromptTemplate = None
|
64 |
+
) -> Dict[str, Any]:
|
65 |
self.build_index() # This will only build the index if it hasn't been built yet
|
66 |
|
67 |
if prompt_template is None:
|
|
|
82 |
)
|
83 |
response = query_engine.query(question)
|
84 |
|
85 |
+
return {
|
86 |
+
"question": question,
|
87 |
+
"answer": response.response,
|
88 |
+
"sources": [node.metadata for node in response.source_nodes],
|
89 |
+
}
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
gradio
|
2 |
llama-index
|
3 |
openai
|
4 |
-
pandas
|
|
|
|
1 |
gradio
|
2 |
llama-index
|
3 |
openai
|
4 |
+
pandas
|
5 |
+
pydantic
|
utils/prompts.py
CHANGED
@@ -1,4 +1,102 @@
|
|
1 |
from llama_index.core import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
highlight_prompt = PromptTemplate(
|
4 |
"Context information is below.\n"
|
|
|
1 |
from llama_index.core import PromptTemplate
|
2 |
+
from typing import Optional, List
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from llama_index.core.prompts import PromptTemplate
|
5 |
+
|
6 |
+
|
7 |
+
class StudyCharacteristics(BaseModel):
|
8 |
+
STUDYID: str
|
9 |
+
AUTHOR: str
|
10 |
+
YEAR: int
|
11 |
+
TITLE: str
|
12 |
+
APPENDIX: Optional[str]
|
13 |
+
PUBLICATION_TYPE: str
|
14 |
+
STUDY_DESIGN: str
|
15 |
+
STUDY_AREA_REGION: str
|
16 |
+
STUDY_POPULATION: str
|
17 |
+
IMMUNISABLE_DISEASE_UNDER_STUDY: str
|
18 |
+
ROUTE_OF_VACCINE_ADMINISTRATION: str
|
19 |
+
DURATION_OF_STUDY: str
|
20 |
+
DURATION_IN_RELATION_TO_COVID19: str
|
21 |
+
STUDY_COMMENTS: Optional[str]
|
22 |
+
|
23 |
+
|
24 |
+
class VaccineCoverageVariables(BaseModel):
|
25 |
+
STUDYID: str
|
26 |
+
AUTHOR: str
|
27 |
+
YEAR: int
|
28 |
+
TITLE: str
|
29 |
+
VACCINE_COVERAGE_RATES: float = Field(..., ge=0, le=100)
|
30 |
+
PROPORTION_ADMINISTERED_WITHIN_RECOMMENDED_AGE: float = Field(..., ge=0, le=100)
|
31 |
+
IMMUNISATION_UPTAKE: float = Field(..., ge=0, le=100)
|
32 |
+
VACCINE_DROP_OUT_RATES: float = Field(..., ge=0, le=100)
|
33 |
+
INTENTIONS_TO_VACCINATE: float = Field(..., ge=0, le=100)
|
34 |
+
VACCINE_CONFIDENCE: float = Field(..., ge=0, le=100)
|
35 |
+
STUDY_COMMENTS: Optional[str]
|
36 |
+
|
37 |
+
|
38 |
+
study_characteristics_prompt = PromptTemplate(
|
39 |
+
"Based on the given text, extract the following study characteristics:\n"
|
40 |
+
"STUDYID: {studyid}\n"
|
41 |
+
"AUTHOR: {author}\n"
|
42 |
+
"YEAR: {year}\n"
|
43 |
+
"TITLE: {title}\n"
|
44 |
+
"APPENDIX: {appendix}\n"
|
45 |
+
"PUBLICATION_TYPE: {publication_type}\n"
|
46 |
+
"STUDY_DESIGN: {study_design}\n"
|
47 |
+
"STUDY_AREA_REGION: {study_area_region}\n"
|
48 |
+
"STUDY_POPULATION: {study_population}\n"
|
49 |
+
"IMMUNISABLE_DISEASE_UNDER_STUDY: {immunisable_disease}\n"
|
50 |
+
"ROUTE_OF_VACCINE_ADMINISTRATION: {route_of_administration}\n"
|
51 |
+
"DURATION_OF_STUDY: {duration_of_study}\n"
|
52 |
+
"DURATION_IN_RELATION_TO_COVID19: {duration_covid19}\n"
|
53 |
+
"STUDY_COMMENTS: {study_comments}\n"
|
54 |
+
"Provide the information in a JSON format. If a field is not found, leave it as null."
|
55 |
+
)
|
56 |
+
|
57 |
+
vaccine_coverage_prompt = PromptTemplate(
|
58 |
+
"Based on the given text, extract the following vaccine coverage variables:\n"
|
59 |
+
"STUDYID: {studyid}\n"
|
60 |
+
"AUTHOR: {author}\n"
|
61 |
+
"YEAR: {year}\n"
|
62 |
+
"TITLE: {title}\n"
|
63 |
+
"VACCINE_COVERAGE_RATES: {coverage_rates}\n"
|
64 |
+
"PROPORTION_ADMINISTERED_WITHIN_RECOMMENDED_AGE: {proportion_recommended_age}\n"
|
65 |
+
"IMMUNISATION_UPTAKE: {immunisation_uptake}\n"
|
66 |
+
"VACCINE_DROP_OUT_RATES: {drop_out_rates}\n"
|
67 |
+
"INTENTIONS_TO_VACCINATE: {intentions_to_vaccinate}\n"
|
68 |
+
"VACCINE_CONFIDENCE: {vaccine_confidence}\n"
|
69 |
+
"STUDY_COMMENTS: {study_comments}\n"
|
70 |
+
"Provide the information in a JSON format. For numerical values, provide percentages as floats between 0 and 100. If a field is not found, leave it as null."
|
71 |
+
)
|
72 |
+
|
73 |
+
sample_questions = {
|
74 |
+
"Vaccine Coverage": [
|
75 |
+
"What are the vaccine coverage rates reported in the study?",
|
76 |
+
"What proportion of vaccines were administered within the recommended age range?",
|
77 |
+
"What is the immunisation uptake reported in the study?",
|
78 |
+
"What are the vaccine drop-out rates mentioned in the document?",
|
79 |
+
"What are the intentions to vaccinate reported in the study?",
|
80 |
+
"How is vaccine confidence described in the document?",
|
81 |
+
],
|
82 |
+
"Ebola Virus": [
|
83 |
+
"What is the sample size of the study?",
|
84 |
+
"What is the type of plasma used in the study?",
|
85 |
+
"What is the dosage and frequency of administration of the plasma?",
|
86 |
+
"Are there any reported side effects?",
|
87 |
+
"What is the change in viral load after treatment?",
|
88 |
+
"How many survivors were there in the intervention group compared to the control group?",
|
89 |
+
],
|
90 |
+
"Gene Xpert": [
|
91 |
+
"What is the main objective of the study?",
|
92 |
+
"What is the study design?",
|
93 |
+
"What disease condition is being studied?",
|
94 |
+
"What are the main outcome measures in the study?",
|
95 |
+
"What is the sensitivity and specificity of the Gene Xpert test?",
|
96 |
+
"How does the cost of the Gene Xpert testing strategy compare to other methods?",
|
97 |
+
],
|
98 |
+
}
|
99 |
+
|
100 |
|
101 |
highlight_prompt = PromptTemplate(
|
102 |
"Context information is below.\n"
|