ak3ra commited on
Commit
5f52091
1 Parent(s): 9f2191f

prompt params

Browse files
Files changed (2) hide show
  1. app.py +15 -23
  2. rag/rag_pipeline.py +34 -1
app.py CHANGED
@@ -25,31 +25,15 @@ def get_rag_pipeline(study_name):
25
  return rag_cache[study_name]
26
 
27
 
28
- def query_rag(study_name, question, prompt_type):
29
  rag = get_rag_pipeline(study_name)
30
 
 
 
 
31
  # Prepare a dictionary with all possible prompt parameters
32
  prompt_params = {
33
- "studyid": "", # retrieve or generate a study ID?
34
- "author": "",
35
- "year": "",
36
- "title": "",
37
- "appendix": "",
38
- "publication_type": "",
39
- "study_design": "",
40
- "study_area_region": "",
41
- "study_population": "",
42
- "immunisable_disease": "",
43
- "route_of_administration": "",
44
- "duration_of_study": "",
45
- "duration_covid19": "",
46
- "study_comments": "",
47
- "coverage_rates": "",
48
- "proportion_recommended_age": "",
49
- "immunisation_uptake": "",
50
- "drop_out_rates": "",
51
- "intentions_to_vaccinate": "",
52
- "vaccine_confidence": "",
53
  "query_str": question, # Add the question to the prompt parameters
54
  }
55
 
@@ -64,12 +48,20 @@ def query_rag(study_name, question, prompt_type):
64
  else:
65
  prompt = None
66
 
 
67
  response = rag.query(question, prompt, **prompt_params)
68
 
69
  # Format the response as Markdown
70
- formatted_response = f"## Question\n\n{response['question']}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
71
  for source in response["sources"]:
72
- formatted_response += f"- {source['title']} ({source['year']})\n"
 
 
 
 
 
 
 
73
 
74
  return formatted_response
75
 
 
25
  return rag_cache[study_name]
26
 
27
 
28
+ def query_rag(study_name: str, question: str, prompt_type: str) -> str:
29
  rag = get_rag_pipeline(study_name)
30
 
31
+ # Extract study information using RAG
32
+ study_info = rag.extract_study_info()
33
+
34
  # Prepare a dictionary with all possible prompt parameters
35
  prompt_params = {
36
+ **study_info, # Unpack the extracted study info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  "query_str": question, # Add the question to the prompt parameters
38
  }
39
 
 
48
  else:
49
  prompt = None
50
 
51
+ # Use the prompt_params in the query
52
  response = rag.query(question, prompt, **prompt_params)
53
 
54
  # Format the response as Markdown
55
+ formatted_response = f"## Question\n\n{question}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
56
  for source in response["sources"]:
57
+ formatted_response += (
58
+ f"- {source['title']} ({source.get('year', 'Year not specified')})\n"
59
+ )
60
+
61
+ # Add extracted study information to the response
62
+ formatted_response += "\n## Extracted Study Information\n\n"
63
+ for key, value in study_info.items():
64
+ formatted_response += f"- **{key.replace('_', ' ').title()}**: {value}\n"
65
 
66
  return formatted_response
67
 
rag/rag_pipeline.py CHANGED
@@ -41,6 +41,39 @@ class RAGPipeline:
41
  Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
42
  )
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def build_index(self):
45
  if self.index is None:
46
  self.load_documents()
@@ -80,7 +113,7 @@ class RAGPipeline:
80
  query_engine = self.index.as_query_engine(
81
  text_qa_template=prompt_template, similarity_top_k=5
82
  )
83
- # response = query_engine.query(question)
84
  response = query_engine.query(question, **kwargs)
85
 
86
  return {
 
41
  Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
42
  )
43
 
44
+ def extract_study_info(self) -> Dict[str, Any]:
45
+ extraction_prompt = PromptTemplate(
46
+ "Based on the given context, please extract the following information about the study:\n"
47
+ "1. Study ID\n"
48
+ "2. Author(s)\n"
49
+ "3. Year\n"
50
+ "4. Title\n"
51
+ "5. Study design\n"
52
+ "6. Study area/region\n"
53
+ "7. Study population\n"
54
+ "8. Disease under study\n"
55
+ "9. Duration of study\n"
56
+ "If the information is not available, please respond with 'Not found' for that field.\n"
57
+ "Context: {context_str}\n"
58
+ "Extracted information:"
59
+ )
60
+
61
+ query_engine = self.index.as_query_engine(
62
+ text_qa_template=extraction_prompt, similarity_top_k=5
63
+ )
64
+
65
+ response = query_engine.query("Extract study information")
66
+
67
+ # Parse the response to extract key-value pairs
68
+ lines = response.response.split("\n")
69
+ extracted_info = {}
70
+ for line in lines:
71
+ if ":" in line:
72
+ key, value = line.split(":", 1)
73
+ extracted_info[key.strip().lower().replace(" ", "_")] = value.strip()
74
+
75
+ return extracted_info
76
+
77
  def build_index(self):
78
  if self.index is None:
79
  self.load_documents()
 
113
  query_engine = self.index.as_query_engine(
114
  text_qa_template=prompt_template, similarity_top_k=5
115
  )
116
+ # Use kwargs to pass additional parameters to the query
117
  response = query_engine.query(question, **kwargs)
118
 
119
  return {