ak3ra commited on
Commit
d762ede
1 Parent(s): 70691a9

bugfix with fastapi because bleeding edge gradio https://github.com/gradio-app/gradio/issues/9275

Browse files
__pycache__/config.cpython-311.pyc ADDED
Binary file (455 Bytes). View file
 
app.py CHANGED
@@ -5,6 +5,7 @@ from utils.prompts import highlight_prompt, evidence_based_prompt
5
  from utils.prompts import (
6
  sample_questions,
7
  )
 
8
  from config import STUDY_FILES
9
 
10
  # Cache for RAG pipelines
@@ -24,9 +25,6 @@ def get_rag_pipeline(study_name):
24
  def query_rag(study_name: str, question: str, prompt_type: str) -> str:
25
  rag = get_rag_pipeline(study_name)
26
 
27
- # Extract study information using RAG
28
- study_info = rag.extract_study_info()
29
-
30
  if prompt_type == "Highlight":
31
  prompt = highlight_prompt
32
  elif prompt_type == "Evidence-based":
@@ -34,17 +32,8 @@ def query_rag(study_name: str, question: str, prompt_type: str) -> str:
34
  else:
35
  prompt = None
36
 
37
- # Prepare the context with study info
38
- context = "Study Information:\n"
39
- for key, value in study_info.items():
40
- context += f"{key}: {value}\n"
41
- context += "\n"
42
-
43
- # Add the question to the context
44
- context += f"Question: {question}\n"
45
-
46
  # Use the prepared context in the query
47
- response = rag.query(context, prompt_template=prompt)
48
 
49
  # Format the response as Markdown
50
  formatted_response = f"## Question\n\n{question}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
@@ -53,11 +42,6 @@ def query_rag(study_name: str, question: str, prompt_type: str) -> str:
53
  f"- {source['title']} ({source.get('year', 'Year not specified')})\n"
54
  )
55
 
56
- # Add extracted study information to the response
57
- formatted_response += "\n## Extracted Study Information\n\n"
58
- for key, value in study_info.items():
59
- formatted_response += f"- **{key.replace('_', ' ').title()}**: {value}\n"
60
-
61
  return formatted_response
62
 
63
 
@@ -122,4 +106,4 @@ with gr.Blocks() as demo:
122
  )
123
 
124
  if __name__ == "__main__":
125
- demo.launch(share=True)
 
5
  from utils.prompts import (
6
  sample_questions,
7
  )
8
+
9
  from config import STUDY_FILES
10
 
11
  # Cache for RAG pipelines
 
25
  def query_rag(study_name: str, question: str, prompt_type: str) -> str:
26
  rag = get_rag_pipeline(study_name)
27
 
 
 
 
28
  if prompt_type == "Highlight":
29
  prompt = highlight_prompt
30
  elif prompt_type == "Evidence-based":
 
32
  else:
33
  prompt = None
34
 
 
 
 
 
 
 
 
 
 
35
  # Use the prepared context in the query
36
+ response = rag.query(question, prompt_template=prompt)
37
 
38
  # Format the response as Markdown
39
  formatted_response = f"## Question\n\n{question}\n\n## Answer\n\n{response['answer']}\n\n## Sources\n\n"
 
42
  f"- {source['title']} ({source.get('year', 'Year not specified')})\n"
43
  )
44
 
 
 
 
 
 
45
  return formatted_response
46
 
47
 
 
106
  )
107
 
108
  if __name__ == "__main__":
109
+ demo.launch(share=True, debug=True)
rag/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (159 Bytes). View file
 
rag/__pycache__/rag_pipeline.cpython-311.pyc ADDED
Binary file (5.58 kB). View file
 
rag/rag_pipeline.py CHANGED
@@ -4,6 +4,8 @@ from llama_index.core import Document, VectorStoreIndex
4
  from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
5
  from llama_index.core import PromptTemplate
6
  from typing import List
 
 
7
 
8
 
9
  class RAGPipeline:
@@ -43,53 +45,22 @@ class RAGPipeline:
43
 
44
  def build_index(self):
45
  if self.index is None:
46
- sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)
47
 
48
  def _split(text: str) -> List[str]:
49
  return sentence_splitter.split_text(text)
50
 
51
  node_parser = SentenceWindowNodeParser.from_defaults(
52
  sentence_splitter=_split,
53
- window_size=3,
54
  window_metadata_key="window",
55
  original_text_metadata_key="original_text",
56
  )
57
 
58
  nodes = node_parser.get_nodes_from_documents(self.documents)
59
- self.index = VectorStoreIndex(nodes)
60
-
61
- def extract_study_info(self) -> Dict[str, Any]:
62
- extraction_prompt = PromptTemplate(
63
- "Based on the given context, please extract the following information about the study:\n"
64
- "1. Study ID\n"
65
- "2. Author(s)\n"
66
- "3. Year\n"
67
- "4. Title\n"
68
- "5. Study design\n"
69
- "6. Study area/region\n"
70
- "7. Study population\n"
71
- "8. Disease under study\n"
72
- "9. Duration of study\n"
73
- "If the information is not available, please respond with 'Not found' for that field.\n"
74
- "Context: {context_str}\n"
75
- "Extracted information:"
76
- )
77
-
78
- query_engine = self.index.as_query_engine(
79
- text_qa_template=extraction_prompt, similarity_top_k=5
80
- )
81
-
82
- response = query_engine.query("Extract study information")
83
-
84
- # Parse the response to extract key-value pairs
85
- lines = response.response.split("\n")
86
- extracted_info = {}
87
- for line in lines:
88
- if ":" in line:
89
- key, value = line.split(":", 1)
90
- extracted_info[key.strip().lower().replace(" ", "_")] = value.strip()
91
-
92
- return extracted_info
93
 
94
  def query(
95
  self, context: str, prompt_template: PromptTemplate = None
@@ -107,8 +78,13 @@ class RAGPipeline:
107
  "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
108
  )
109
 
 
 
110
  query_engine = self.index.as_query_engine(
111
- text_qa_template=prompt_template, similarity_top_k=5
 
 
 
112
  )
113
 
114
  response = query_engine.query(context)
 
4
  from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
5
  from llama_index.core import PromptTemplate
6
  from typing import List
7
+ from llama_index.embeddings.openai import OpenAIEmbedding
8
+ from llama_index.llms.openai import OpenAI
9
 
10
 
11
  class RAGPipeline:
 
45
 
46
  def build_index(self):
47
  if self.index is None:
48
+ sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)
49
 
50
  def _split(text: str) -> List[str]:
51
  return sentence_splitter.split_text(text)
52
 
53
  node_parser = SentenceWindowNodeParser.from_defaults(
54
  sentence_splitter=_split,
55
+ window_size=5,
56
  window_metadata_key="window",
57
  original_text_metadata_key="original_text",
58
  )
59
 
60
  nodes = node_parser.get_nodes_from_documents(self.documents)
61
+ self.index = VectorStoreIndex(
62
+ nodes, embed_model=OpenAIEmbedding(model_name="text-embedding-3-large")
63
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def query(
66
  self, context: str, prompt_template: PromptTemplate = None
 
78
  "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
79
  )
80
 
81
+ # This is a hack to index all the documents in the store :)
82
+ n_documents = len(self.index.docstore.docs)
83
  query_engine = self.index.as_query_engine(
84
+ text_qa_template=prompt_template,
85
+ similarity_top_k=n_documents,
86
+ response_mode="tree_summarize",
87
+ llm=OpenAI(model="gpt-4o-mini"),
88
  )
89
 
90
  response = query_engine.query(context)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  gradio
2
  llama-index
3
  openai
 
1
+ fastapi==0.112.2
2
  gradio
3
  llama-index
4
  openai
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (161 Bytes). View file
 
utils/__pycache__/prompts.cpython-311.pyc ADDED
Binary file (5.68 kB). View file