ak3ra commited on
Commit
9f2191f
1 Parent(s): 9cb07f9
Files changed (2) hide show
  1. app.py +26 -1
  2. rag/rag_pipeline.py +3 -2
app.py CHANGED
@@ -28,6 +28,31 @@ def get_rag_pipeline(study_name):
28
  def query_rag(study_name, question, prompt_type):
29
  rag = get_rag_pipeline(study_name)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if prompt_type == "Highlight":
32
  prompt = highlight_prompt
33
  elif prompt_type == "Evidence-based":
@@ -39,7 +64,7 @@ def query_rag(study_name, question, prompt_type):
39
  else:
40
  prompt = None
41
 
42
- response = rag.query(question, prompt)
43
 
44
  # Format the response as Markdown
45
  formatted_response = f"## Question\n\n{response['question']}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
 
28
  def query_rag(study_name, question, prompt_type):
29
  rag = get_rag_pipeline(study_name)
30
 
31
+ # Prepare a dictionary with all possible prompt parameters
32
+ prompt_params = {
33
+ "studyid": "", # retrieve or generate a study ID?
34
+ "author": "",
35
+ "year": "",
36
+ "title": "",
37
+ "appendix": "",
38
+ "publication_type": "",
39
+ "study_design": "",
40
+ "study_area_region": "",
41
+ "study_population": "",
42
+ "immunisable_disease": "",
43
+ "route_of_administration": "",
44
+ "duration_of_study": "",
45
+ "duration_covid19": "",
46
+ "study_comments": "",
47
+ "coverage_rates": "",
48
+ "proportion_recommended_age": "",
49
+ "immunisation_uptake": "",
50
+ "drop_out_rates": "",
51
+ "intentions_to_vaccinate": "",
52
+ "vaccine_confidence": "",
53
+ "query_str": question, # Add the question to the prompt parameters
54
+ }
55
+
56
  if prompt_type == "Highlight":
57
  prompt = highlight_prompt
58
  elif prompt_type == "Evidence-based":
 
64
  else:
65
  prompt = None
66
 
67
+ response = rag.query(question, prompt, **prompt_params)
68
 
69
  # Format the response as Markdown
70
  formatted_response = f"## Question\n\n{response['question']}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
rag/rag_pipeline.py CHANGED
@@ -60,7 +60,7 @@ class RAGPipeline:
60
  self.index = VectorStoreIndex(nodes)
61
 
62
  def query(
63
- self, question: str, prompt_template: PromptTemplate = None
64
  ) -> Dict[str, Any]:
65
  self.build_index() # This will only build the index if it hasn't been built yet
66
 
@@ -80,7 +80,8 @@ class RAGPipeline:
80
  query_engine = self.index.as_query_engine(
81
  text_qa_template=prompt_template, similarity_top_k=5
82
  )
83
- response = query_engine.query(question)
 
84
 
85
  return {
86
  "question": question,
 
60
  self.index = VectorStoreIndex(nodes)
61
 
62
  def query(
63
+ self, question: str, prompt_template: PromptTemplate = None, **kwargs
64
  ) -> Dict[str, Any]:
65
  self.build_index() # This will only build the index if it hasn't been built yet
66
 
 
80
  query_engine = self.index.as_query_engine(
81
  text_qa_template=prompt_template, similarity_top_k=5
82
  )
83
+ # response = query_engine.query(question)
84
+ response = query_engine.query(question, **kwargs)
85
 
86
  return {
87
  "question": question,