ak3ra commited on
Commit
9a9bac9
1 Parent(s): 122cee1
Files changed (2) hide show
  1. app.py +11 -10
  2. rag/rag_pipeline.py +3 -7
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # app.py
2
-
3
  import gradio as gr
4
  import json
5
  from rag.rag_pipeline import RAGPipeline
@@ -31,12 +29,6 @@ def query_rag(study_name: str, question: str, prompt_type: str) -> str:
31
  # Extract study information using RAG
32
  study_info = rag.extract_study_info()
33
 
34
- # Prepare a dictionary with all possible prompt parameters
35
- prompt_params = {
36
- **study_info, # Unpack the extracted study info
37
- "query_str": question, # Add the question to the prompt parameters
38
- }
39
-
40
  if prompt_type == "Highlight":
41
  prompt = highlight_prompt
42
  elif prompt_type == "Evidence-based":
@@ -48,8 +40,17 @@ def query_rag(study_name: str, question: str, prompt_type: str) -> str:
48
  else:
49
  prompt = None
50
 
51
- # Use the prompt_params in the query
52
- response = rag.query(question, prompt, **prompt_params)
 
 
 
 
 
 
 
 
 
53
 
54
  # Format the response as Markdown
55
  formatted_response = f"## Question\n\n{question}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
 
 
 
1
  import gradio as gr
2
  import json
3
  from rag.rag_pipeline import RAGPipeline
 
29
  # Extract study information using RAG
30
  study_info = rag.extract_study_info()
31
 
 
 
 
 
 
 
32
  if prompt_type == "Highlight":
33
  prompt = highlight_prompt
34
  elif prompt_type == "Evidence-based":
 
40
  else:
41
  prompt = None
42
 
43
+ # Prepare the context with study info
44
+ context = "Study Information:\n"
45
+ for key, value in study_info.items():
46
+ context += f"{key}: {value}\n"
47
+ context += "\n"
48
+
49
+ # Add the question to the context
50
+ context += f"Question: {question}\n"
51
+
52
+ # Use the prepared context in the query
53
+ response = rag.query(context, prompt_template=prompt)
54
 
55
  # Format the response as Markdown
56
  formatted_response = f"## Question\n\n{question}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
rag/rag_pipeline.py CHANGED
@@ -1,5 +1,3 @@
1
- # rag/rag_pipeline.py
2
-
3
  import json
4
  from typing import Dict, Any
5
  from llama_index.core import Document, VectorStoreIndex
@@ -94,7 +92,7 @@ class RAGPipeline:
94
  return extracted_info
95
 
96
  def query(
97
- self, question: str, prompt_template: PromptTemplate = None, **kwargs
98
  ) -> Dict[str, Any]:
99
  if prompt_template is None:
100
  prompt_template = PromptTemplate(
@@ -102,7 +100,7 @@ class RAGPipeline:
102
  "---------------------\n"
103
  "{context_str}\n"
104
  "---------------------\n"
105
- "Given this information, please answer the question: {query_str}\n"
106
  "Include all relevant information from the provided context. "
107
  "If information comes from multiple sources, please mention all of them. "
108
  "If the information is not available in the context, please state that clearly. "
@@ -113,11 +111,9 @@ class RAGPipeline:
113
  text_qa_template=prompt_template, similarity_top_k=5
114
  )
115
 
116
- # Use kwargs to pass additional parameters to the query
117
- response = query_engine.query(question, **kwargs)
118
 
119
  return {
120
- "question": question,
121
  "answer": response.response,
122
  "sources": [node.metadata for node in response.source_nodes],
123
  }
 
 
 
1
  import json
2
  from typing import Dict, Any
3
  from llama_index.core import Document, VectorStoreIndex
 
92
  return extracted_info
93
 
94
  def query(
95
+ self, context: str, prompt_template: PromptTemplate = None
96
  ) -> Dict[str, Any]:
97
  if prompt_template is None:
98
  prompt_template = PromptTemplate(
 
100
  "---------------------\n"
101
  "{context_str}\n"
102
  "---------------------\n"
103
+ "Given this information, please answer the question provided in the context. "
104
  "Include all relevant information from the provided context. "
105
  "If information comes from multiple sources, please mention all of them. "
106
  "If the information is not available in the context, please state that clearly. "
 
111
  text_qa_template=prompt_template, similarity_top_k=5
112
  )
113
 
114
+ response = query_engine.query(context)
 
115
 
116
  return {
 
117
  "answer": response.response,
118
  "sources": [node.metadata for node in response.source_nodes],
119
  }