ak3ra commited on
Commit
6a076b8
1 Parent(s): d0d7d0e

custom prompts

Browse files
Files changed (4) hide show
  1. app.py +41 -11
  2. rag/rag_pipeline.py +9 -2
  3. requirements.txt +2 -1
  4. utils/prompts.py +98 -0
app.py CHANGED
@@ -2,9 +2,13 @@ import gradio as gr
2
  import json
3
  from rag.rag_pipeline import RAGPipeline
4
  from utils.prompts import highlight_prompt, evidence_based_prompt
 
 
 
 
 
5
  from config import STUDY_FILES
6
 
7
- # Cache for RAG pipelines
8
  rag_cache = {}
9
 
10
 
@@ -25,13 +29,19 @@ def query_rag(study_name, question, prompt_type):
25
  prompt = highlight_prompt
26
  elif prompt_type == "Evidence-based":
27
  prompt = evidence_based_prompt
 
 
 
 
28
  else:
29
  prompt = None
30
 
31
  response = rag.query(question, prompt)
32
- formatted_response = (
33
- f"## Question\n\n{question}\n\n## Answer\n\n{response.response}"
34
- )
 
 
35
 
36
  return formatted_response
37
 
@@ -46,6 +56,10 @@ def get_study_info(study_name):
46
  return "Invalid study name"
47
 
48
 
 
 
 
 
49
  with gr.Blocks() as demo:
50
  gr.Markdown("# RAG Pipeline Demo")
51
 
@@ -53,21 +67,37 @@ with gr.Blocks() as demo:
53
  study_dropdown = gr.Dropdown(
54
  choices=list(STUDY_FILES.keys()), label="Select Study"
55
  )
56
- study_info = gr.Textbox(label="Study Information", interactive=False)
57
 
58
  study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])
59
 
60
  with gr.Row():
61
  question_input = gr.Textbox(label="Enter your question")
62
- prompt_type = gr.Radio(
63
- ["Default", "Highlight", "Evidence-based"],
64
- label="Prompt Type",
65
- value="Default",
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  submit_button = gr.Button("Submit")
69
 
70
- # answer_output = gr.Textbox(label="Answer")
71
  answer_output = gr.Markdown(label="Answer")
72
 
73
  submit_button.click(
 
2
  import json
3
  from rag.rag_pipeline import RAGPipeline
4
  from utils.prompts import highlight_prompt, evidence_based_prompt
5
+ from utils.custom_prompts import (
6
+ study_characteristics_prompt,
7
+ vaccine_coverage_prompt,
8
+ sample_questions,
9
+ )
10
  from config import STUDY_FILES
11
 
 
12
  rag_cache = {}
13
 
14
 
 
29
  prompt = highlight_prompt
30
  elif prompt_type == "Evidence-based":
31
  prompt = evidence_based_prompt
32
+ elif prompt_type == "Study Characteristics":
33
+ prompt = study_characteristics_prompt
34
+ elif prompt_type == "Vaccine Coverage":
35
+ prompt = vaccine_coverage_prompt
36
  else:
37
  prompt = None
38
 
39
  response = rag.query(question, prompt)
40
+
41
+ # Format the response as Markdown
42
+ formatted_response = f"## Question\n\n{response['question']}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
43
+ for source in response["sources"]:
44
+ formatted_response += f"- {source['title']} ({source['year']})\n"
45
 
46
  return formatted_response
47
 
 
56
  return "Invalid study name"
57
 
58
 
59
+ def update_sample_questions(study_name):
60
+ return gr.Dropdown.update(choices=sample_questions.get(study_name, []))
61
+
62
+
63
  with gr.Blocks() as demo:
64
  gr.Markdown("# RAG Pipeline Demo")
65
 
 
67
  study_dropdown = gr.Dropdown(
68
  choices=list(STUDY_FILES.keys()), label="Select Study"
69
  )
70
+ study_info = gr.Markdown(label="Study Information")
71
 
72
  study_dropdown.change(get_study_info, inputs=[study_dropdown], outputs=[study_info])
73
 
74
  with gr.Row():
75
  question_input = gr.Textbox(label="Enter your question")
76
+ sample_question_dropdown = gr.Dropdown(choices=[], label="Sample Questions")
77
+
78
+ study_dropdown.change(
79
+ update_sample_questions,
80
+ inputs=[study_dropdown],
81
+ outputs=[sample_question_dropdown],
82
+ )
83
+ sample_question_dropdown.change(
84
+ lambda x: x, inputs=[sample_question_dropdown], outputs=[question_input]
85
+ )
86
+
87
+ prompt_type = gr.Radio(
88
+ [
89
+ "Default",
90
+ "Highlight",
91
+ "Evidence-based",
92
+ "Study Characteristics",
93
+ "Vaccine Coverage",
94
+ ],
95
+ label="Prompt Type",
96
+ value="Default",
97
+ )
98
 
99
  submit_button = gr.Button("Submit")
100
 
 
101
  answer_output = gr.Markdown(label="Answer")
102
 
103
  submit_button.click(
rag/rag_pipeline.py CHANGED
@@ -1,6 +1,7 @@
1
  # rag/rag_pipeline.py
2
 
3
  import json
 
4
  from llama_index.core import Document, VectorStoreIndex
5
  from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
6
  from llama_index.core import PromptTemplate
@@ -58,7 +59,9 @@ class RAGPipeline:
58
  nodes = node_parser.get_nodes_from_documents(self.documents)
59
  self.index = VectorStoreIndex(nodes)
60
 
61
- def query(self, question, prompt_template=None):
 
 
62
  self.build_index() # This will only build the index if it hasn't been built yet
63
 
64
  if prompt_template is None:
@@ -79,4 +82,8 @@ class RAGPipeline:
79
  )
80
  response = query_engine.query(question)
81
 
82
- return response
 
 
 
 
 
1
  # rag/rag_pipeline.py
2
 
3
  import json
4
+ from typing import Dict, Any
5
  from llama_index.core import Document, VectorStoreIndex
6
  from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
7
  from llama_index.core import PromptTemplate
 
59
  nodes = node_parser.get_nodes_from_documents(self.documents)
60
  self.index = VectorStoreIndex(nodes)
61
 
62
+ def query(
63
+ self, question: str, prompt_template: PromptTemplate = None
64
+ ) -> Dict[str, Any]:
65
  self.build_index() # This will only build the index if it hasn't been built yet
66
 
67
  if prompt_template is None:
 
82
  )
83
  response = query_engine.query(question)
84
 
85
+ return {
86
+ "question": question,
87
+ "answer": response.response,
88
+ "sources": [node.metadata for node in response.source_nodes],
89
+ }
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  gradio
2
  llama-index
3
  openai
4
- pandas
 
 
1
  gradio
2
  llama-index
3
  openai
4
+ pandas
5
+ pydantic
utils/prompts.py CHANGED
@@ -1,4 +1,102 @@
1
  from llama_index.core import PromptTemplate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  highlight_prompt = PromptTemplate(
4
  "Context information is below.\n"
 
1
  from llama_index.core import PromptTemplate
2
+ from typing import Optional, List
3
+ from pydantic import BaseModel, Field
4
+ from llama_index.core.prompts import PromptTemplate
5
+
6
+
7
+ class StudyCharacteristics(BaseModel):
8
+ STUDYID: str
9
+ AUTHOR: str
10
+ YEAR: int
11
+ TITLE: str
12
+ APPENDIX: Optional[str]
13
+ PUBLICATION_TYPE: str
14
+ STUDY_DESIGN: str
15
+ STUDY_AREA_REGION: str
16
+ STUDY_POPULATION: str
17
+ IMMUNISABLE_DISEASE_UNDER_STUDY: str
18
+ ROUTE_OF_VACCINE_ADMINISTRATION: str
19
+ DURATION_OF_STUDY: str
20
+ DURATION_IN_RELATION_TO_COVID19: str
21
+ STUDY_COMMENTS: Optional[str]
22
+
23
+
24
+ class VaccineCoverageVariables(BaseModel):
25
+ STUDYID: str
26
+ AUTHOR: str
27
+ YEAR: int
28
+ TITLE: str
29
+ VACCINE_COVERAGE_RATES: float = Field(..., ge=0, le=100)
30
+ PROPORTION_ADMINISTERED_WITHIN_RECOMMENDED_AGE: float = Field(..., ge=0, le=100)
31
+ IMMUNISATION_UPTAKE: float = Field(..., ge=0, le=100)
32
+ VACCINE_DROP_OUT_RATES: float = Field(..., ge=0, le=100)
33
+ INTENTIONS_TO_VACCINATE: float = Field(..., ge=0, le=100)
34
+ VACCINE_CONFIDENCE: float = Field(..., ge=0, le=100)
35
+ STUDY_COMMENTS: Optional[str]
36
+
37
+
38
+ study_characteristics_prompt = PromptTemplate(
39
+ "Based on the given text, extract the following study characteristics:\n"
40
+ "STUDYID: {studyid}\n"
41
+ "AUTHOR: {author}\n"
42
+ "YEAR: {year}\n"
43
+ "TITLE: {title}\n"
44
+ "APPENDIX: {appendix}\n"
45
+ "PUBLICATION_TYPE: {publication_type}\n"
46
+ "STUDY_DESIGN: {study_design}\n"
47
+ "STUDY_AREA_REGION: {study_area_region}\n"
48
+ "STUDY_POPULATION: {study_population}\n"
49
+ "IMMUNISABLE_DISEASE_UNDER_STUDY: {immunisable_disease}\n"
50
+ "ROUTE_OF_VACCINE_ADMINISTRATION: {route_of_administration}\n"
51
+ "DURATION_OF_STUDY: {duration_of_study}\n"
52
+ "DURATION_IN_RELATION_TO_COVID19: {duration_covid19}\n"
53
+ "STUDY_COMMENTS: {study_comments}\n"
54
+ "Provide the information in a JSON format. If a field is not found, leave it as null."
55
+ )
56
+
57
+ vaccine_coverage_prompt = PromptTemplate(
58
+ "Based on the given text, extract the following vaccine coverage variables:\n"
59
+ "STUDYID: {studyid}\n"
60
+ "AUTHOR: {author}\n"
61
+ "YEAR: {year}\n"
62
+ "TITLE: {title}\n"
63
+ "VACCINE_COVERAGE_RATES: {coverage_rates}\n"
64
+ "PROPORTION_ADMINISTERED_WITHIN_RECOMMENDED_AGE: {proportion_recommended_age}\n"
65
+ "IMMUNISATION_UPTAKE: {immunisation_uptake}\n"
66
+ "VACCINE_DROP_OUT_RATES: {drop_out_rates}\n"
67
+ "INTENTIONS_TO_VACCINATE: {intentions_to_vaccinate}\n"
68
+ "VACCINE_CONFIDENCE: {vaccine_confidence}\n"
69
+ "STUDY_COMMENTS: {study_comments}\n"
70
+ "Provide the information in a JSON format. For numerical values, provide percentages as floats between 0 and 100. If a field is not found, leave it as null."
71
+ )
72
+
73
+ sample_questions = {
74
+ "Vaccine Coverage": [
75
+ "What are the vaccine coverage rates reported in the study?",
76
+ "What proportion of vaccines were administered within the recommended age range?",
77
+ "What is the immunisation uptake reported in the study?",
78
+ "What are the vaccine drop-out rates mentioned in the document?",
79
+ "What are the intentions to vaccinate reported in the study?",
80
+ "How is vaccine confidence described in the document?",
81
+ ],
82
+ "Ebola Virus": [
83
+ "What is the sample size of the study?",
84
+ "What is the type of plasma used in the study?",
85
+ "What is the dosage and frequency of administration of the plasma?",
86
+ "Are there any reported side effects?",
87
+ "What is the change in viral load after treatment?",
88
+ "How many survivors were there in the intervention group compared to the control group?",
89
+ ],
90
+ "Gene Xpert": [
91
+ "What is the main objective of the study?",
92
+ "What is the study design?",
93
+ "What disease condition is being studied?",
94
+ "What are the main outcome measures in the study?",
95
+ "What is the sensitivity and specificity of the Gene Xpert test?",
96
+ "How does the cost of the Gene Xpert testing strategy compare to other methods?",
97
+ ],
98
+ }
99
+
100
 
101
  highlight_prompt = PromptTemplate(
102
  "Context information is below.\n"