ak3ra commited on
Commit
b117341
1 Parent(s): 13faf78

changed pipelin

Browse files
.gitattributes CHANGED
@@ -34,4 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *db* filter=lfs diff=lfs merge=lfs -text
37
- vaccine_coverage_study.db filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *db* filter=lfs diff=lfs merge=lfs -text
37
+ vaccine_coverage_study.db filter=lfs diff=lfs merge=lfs -text
38
+ *.db filter=lfs diff=lfs merge=lfs -text
data/ebola_virus_zotero_items.json ADDED
The diff for this file is too large to render. See raw diff
 
data/gene_xpert_zotero_items.json ADDED
The diff for this file is too large to render. See raw diff
 
data/vaccine_coverage_zotero_items.json ADDED
The diff for this file is too large to render. See raw diff
 
database/vaccine_coverage_db.py CHANGED
@@ -1,46 +1,3 @@
1
- import sqlite3
2
- from typing import List, Dict, Any
3
-
4
-
5
- class VaccineCoverageDB:
6
- def __init__(self, db_path: str):
7
- self.conn = sqlite3.connect(db_path)
8
- self.conn.row_factory = sqlite3.Row
9
-
10
- def get_all_items(self) -> List[Dict[str, Any]]:
11
- cursor = self.conn.execute("SELECT * FROM items")
12
- return [dict(row) for row in cursor.fetchall()]
13
-
14
- def get_item_by_key(self, key: str) -> Dict[str, Any]:
15
- cursor = self.conn.execute("SELECT * FROM items WHERE key = ?", (key,))
16
- return dict(cursor.fetchone())
17
-
18
- def get_attachments_for_item(self, item_key: str) -> List[Dict[str, Any]]:
19
- cursor = self.conn.execute(
20
- "SELECT * FROM attachments WHERE parent_key = ?", (item_key,)
21
- )
22
- return [dict(row) for row in cursor.fetchall()]
23
-
24
- def get_pdf_content(self, attachment_key: str) -> bytes:
25
- cursor = self.conn.execute(
26
- "SELECT content FROM attachments WHERE key = ?", (attachment_key,)
27
- )
28
- result = cursor.fetchone()
29
- return result["content"] if result else None
30
-
31
- def save_pdf_to_file(self, attachment_key: str, output_path: str) -> bool:
32
- pdf_content = self.get_pdf_content(attachment_key)
33
- if pdf_content:
34
- try:
35
- with open(output_path, "wb") as f:
36
- f.write(pdf_content)
37
- return True
38
- except Exception as e:
39
- print(f"Error saving PDF: {str(e)}")
40
- return False
41
- else:
42
- print(f"No PDF content found for attachment key: {attachment_key}")
43
- return False
44
-
45
- def close(self):
46
- self.conn.close()
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42a0645cdd38f2d7ede525768eb21a4cbe08b4d86959cb4eb2349887f2bcf70e
3
+ size 1774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
initialize_db.py CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0cb2cf50f14d131b1e999cee44652575fd1029141514dfc2e028af1419b0d46
3
- size 2344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08030c4783a86d9a9afb9437b102dde959405b6b2857725eec02b6d9c2699e97
3
+ size 2346
rag/rag_pipeline.py CHANGED
@@ -2,116 +2,81 @@ import json
2
  import os
3
  from typing import Dict, Any
4
  from llama_index.core import (
5
- SimpleDirectoryReader,
6
  VectorStoreIndex,
7
  Document,
8
- StorageContext,
9
- load_index_from_storage,
 
 
10
  )
11
- from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser
12
- from llama_index.embeddings.openai import OpenAIEmbedding
13
  from llama_index.core import PromptTemplate
14
 
15
 
16
  class RAGPipeline:
17
  def __init__(
18
- self, metadata_file: str, pdf_dir: str, use_semantic_splitter: bool = False
 
 
19
  ):
20
- self.metadata_file = metadata_file
21
- self.pdf_dir = pdf_dir
22
- self.use_semantic_splitter = use_semantic_splitter
23
  self.index = None
 
24
  self.load_documents()
25
  self.build_index()
26
 
27
  def load_documents(self):
28
- if not os.path.exists(self.metadata_file):
29
- print(f"Metadata file not found: {self.metadata_file}")
30
- self.documents = []
31
- return
32
-
33
- with open(self.metadata_file, "r") as f:
34
- self.metadata = json.load(f)
35
 
36
  self.documents = []
37
- for item_key, item_data in self.metadata.items():
38
- metadata = item_data["metadata"]
39
- pdf_path = item_data.get("pdf_path")
40
-
41
- if pdf_path:
42
- full_pdf_path = os.path.join(self.pdf_dir, os.path.basename(pdf_path))
43
- if os.path.exists(full_pdf_path):
44
- pdf_content = (
45
- SimpleDirectoryReader(input_files=[full_pdf_path])
46
- .load_data()[0]
47
- .text
48
- )
49
- else:
50
- pdf_content = "PDF file not found"
51
- else:
52
- pdf_content = "PDF path not available in metadata"
53
 
 
54
  doc_content = (
55
- f"Title: {metadata['title']}\n"
56
- f"Abstract: {metadata['abstract']}\n"
57
- f"Authors: {metadata['authors']}\n"
58
- f"Year: {metadata['year']}\n"
59
- f"DOI: {metadata['doi']}\n"
60
- f"Full Text: {pdf_content}"
61
  )
62
 
 
 
 
 
 
 
 
 
63
  self.documents.append(
64
- Document(text=doc_content, id_=item_key, metadata=metadata)
 
 
 
 
65
  )
66
 
67
  def build_index(self):
68
- if self.use_semantic_splitter:
69
- embed_model = OpenAIEmbedding()
70
- splitter = SemanticSplitterNodeParser(
71
- buffer_size=1,
72
- breakpoint_percentile_threshold=95,
73
- embed_model=embed_model,
74
- )
75
- else:
76
- splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
77
-
78
- nodes = splitter.get_nodes_from_documents(self.documents)
79
- self.index = VectorStoreIndex(nodes)
80
 
81
- def query(self, question: str, prompt_type: str = "default") -> Dict[str, Any]:
82
- prompt_template = self._get_prompt_template(prompt_type)
83
 
84
- query_engine = self.index.as_query_engine(
85
- text_qa_template=prompt_template, similarity_top_k=5
 
 
 
86
  )
87
- response = query_engine.query(question)
88
 
89
- return response
90
 
91
- def _get_prompt_template(self, prompt_type: str) -> PromptTemplate:
92
- if prompt_type == "highlight":
93
- return PromptTemplate(
94
- "Context information is below.\n"
95
- "---------------------\n"
96
- "{context_str}\n"
97
- "---------------------\n"
98
- "Given this information, please answer the question: {query_str}\n"
99
- "Include all relevant information from the provided context. "
100
- "Highlight key information by enclosing it in **asterisks**. "
101
- "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
102
- )
103
- elif prompt_type == "evidence_based":
104
- return PromptTemplate(
105
- "Context information is below.\n"
106
- "---------------------\n"
107
- "{context_str}\n"
108
- "---------------------\n"
109
- "Given this information, please answer the question: {query_str}\n"
110
- "Provide an answer to the question using evidence from the context above. "
111
- "Cite sources using square brackets."
112
- )
113
- else:
114
- return PromptTemplate(
115
  "Context information is below.\n"
116
  "---------------------\n"
117
  "{context_str}\n"
@@ -122,3 +87,11 @@ class RAGPipeline:
122
  "If the information is not available in the context, please state that clearly. "
123
  "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
124
  )
 
 
 
 
 
 
 
 
 
2
  import os
3
  from typing import Dict, Any
4
  from llama_index.core import (
 
5
  VectorStoreIndex,
6
  Document,
7
+ SentenceWindowNodeParser,
8
+ )
9
+ from llama_index.core.node_parser import (
10
+ SentenceSplitter,
11
  )
 
 
12
  from llama_index.core import PromptTemplate
13
 
14
 
15
  class RAGPipeline:
16
  def __init__(
17
+ self,
18
+ study_json,
19
+ use_semantic_splitter=False,
20
  ):
21
+
22
+ self.study_json = study_json
 
23
  self.index = None
24
+ self.use_semantic_splitter = use_semantic_splitter
25
  self.load_documents()
26
  self.build_index()
27
 
28
  def load_documents(self):
29
+ with open(self.study_json, "r") as f:
30
+ self.data = json.load(f)
 
 
 
 
 
31
 
32
  self.documents = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ for index, doc_data in enumerate(self.data):
35
  doc_content = (
36
+ f"Title: {doc_data['title']}\n"
37
+ f"Abstract: {doc_data['abstract']}\n"
38
+ f"Authors: {', '.join(doc_data['authors'])}\n"
39
+ f"Year: {doc_data['year']}\n"
40
+ f"DOI: {doc_data['doi']}\n"
41
+ f"Full Text: {doc_data['full_text']}"
42
  )
43
 
44
+ metadata = {
45
+ "title": doc_data.get("title"),
46
+ "abstract": doc_data.get("abstract"),
47
+ "authors": doc_data.get("authors", []),
48
+ "year": doc_data.get("year"),
49
+ "doi": doc_data.get("doi"),
50
+ }
51
+
52
  self.documents.append(
53
+ Document(
54
+ text=doc_content,
55
+ id_=f"doc_{index}",
56
+ metadata=metadata,
57
+ )
58
  )
59
 
60
  def build_index(self):
61
+ sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ def _split(text: str) -> List[str]:
64
+ return sentence_splitter.split_text(text)
65
 
66
+ node_parser = SentenceWindowNodeParser.from_defaults(
67
+ sentence_splitter=_split,
68
+ window_size=3,
69
+ window_metadata_key="window",
70
+ original_text_metadata_key="original_text",
71
  )
 
72
 
73
+ nodes = node_parser.get_nodes_from_documents(self.documents)
74
 
75
+ self.index = VectorStoreIndex(nodes)
76
+
77
+ def query(self, question, prompt_template=None):
78
+ if prompt_template is None:
79
+ prompt_template = PromptTemplate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  "Context information is below.\n"
81
  "---------------------\n"
82
  "{context_str}\n"
 
87
  "If the information is not available in the context, please state that clearly. "
88
  "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
89
  )
90
+
91
+ query_engine = self.index.as_query_engine(
92
+ text_qa_template=prompt_template,
93
+ similarity_top_k=5,
94
+ )
95
+ response = query_engine.query(question)
96
+
97
+ return response
utils/helpers.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Dict, Any
2
  from llama_index.core import Response
3
 
 
4
  def process_response(response: Response) -> Dict[str, Any]:
5
  source_nodes = response.source_nodes
6
  sources = {}
@@ -18,12 +19,13 @@ def process_response(response: Response) -> Dict[str, Any]:
18
 
19
  return {"markdown": markdown_text, "raw": raw_text, "sources": sources}
20
 
 
21
  def format_source(metadata: Dict[str, Any]) -> str:
22
- authors = metadata.get('authors', 'Unknown Author')
23
- year = metadata.get('year', 'n.d.')
24
- title = metadata.get('title', 'Untitled')
25
 
26
- author_list = authors.split(',')
27
  if len(author_list) > 2:
28
  formatted_authors = f"{author_list[0].strip()} et al."
29
  elif len(author_list) == 2:
@@ -31,10 +33,10 @@ def format_source(metadata: Dict[str, Any]) -> str:
31
  else:
32
  formatted_authors = author_list[0].strip()
33
 
34
- year = 'n.d.' if year is None or year == 'None' else str(year)
35
 
36
  max_title_length = 250
37
  if len(title) > max_title_length:
38
- title = title[:max_title_length] + '...'
39
 
40
  return f"{formatted_authors} ({year}). {title}"
 
1
  from typing import Dict, Any
2
  from llama_index.core import Response
3
 
4
+
5
  def process_response(response: Response) -> Dict[str, Any]:
6
  source_nodes = response.source_nodes
7
  sources = {}
 
19
 
20
  return {"markdown": markdown_text, "raw": raw_text, "sources": sources}
21
 
22
+
23
  def format_source(metadata: Dict[str, Any]) -> str:
24
+ authors = metadata.get("authors", "Unknown Author")
25
+ year = metadata.get("year", "n.d.")
26
+ title = metadata.get("title", "Untitled")
27
 
28
+ author_list = authors.split(",")
29
  if len(author_list) > 2:
30
  formatted_authors = f"{author_list[0].strip()} et al."
31
  elif len(author_list) == 2:
 
33
  else:
34
  formatted_authors = author_list[0].strip()
35
 
36
+ year = "n.d." if year is None or year == "None" else str(year)
37
 
38
  max_title_length = 250
39
  if len(title) > max_title_length:
40
+ title = title[:max_title_length] + "..."
41
 
42
  return f"{formatted_authors} ({year}). {title}"
vaccine_coverage_study.db DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:121fd525453b27b5008a3714840c929402ec01b74aea4d21bdd87be1a60bc008
3
- size 41222144