Spaces:
Running
Running
changed pipelin
Browse files- .gitattributes +2 -1
- data/ebola_virus_zotero_items.json +0 -0
- data/gene_xpert_zotero_items.json +0 -0
- data/vaccine_coverage_zotero_items.json +0 -0
- database/vaccine_coverage_db.py +3 -46
- initialize_db.py +2 -2
- rag/rag_pipeline.py +54 -81
- utils/helpers.py +8 -6
- vaccine_coverage_study.db +0 -3
.gitattributes
CHANGED
@@ -34,4 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*db* filter=lfs diff=lfs merge=lfs -text
|
37 |
-
vaccine_coverage_study.db filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*db* filter=lfs diff=lfs merge=lfs -text
|
37 |
+
vaccine_coverage_study.db filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
data/ebola_virus_zotero_items.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/gene_xpert_zotero_items.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/vaccine_coverage_zotero_items.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
database/vaccine_coverage_db.py
CHANGED
@@ -1,46 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
class VaccineCoverageDB:
|
6 |
-
def __init__(self, db_path: str):
|
7 |
-
self.conn = sqlite3.connect(db_path)
|
8 |
-
self.conn.row_factory = sqlite3.Row
|
9 |
-
|
10 |
-
def get_all_items(self) -> List[Dict[str, Any]]:
|
11 |
-
cursor = self.conn.execute("SELECT * FROM items")
|
12 |
-
return [dict(row) for row in cursor.fetchall()]
|
13 |
-
|
14 |
-
def get_item_by_key(self, key: str) -> Dict[str, Any]:
|
15 |
-
cursor = self.conn.execute("SELECT * FROM items WHERE key = ?", (key,))
|
16 |
-
return dict(cursor.fetchone())
|
17 |
-
|
18 |
-
def get_attachments_for_item(self, item_key: str) -> List[Dict[str, Any]]:
|
19 |
-
cursor = self.conn.execute(
|
20 |
-
"SELECT * FROM attachments WHERE parent_key = ?", (item_key,)
|
21 |
-
)
|
22 |
-
return [dict(row) for row in cursor.fetchall()]
|
23 |
-
|
24 |
-
def get_pdf_content(self, attachment_key: str) -> bytes:
|
25 |
-
cursor = self.conn.execute(
|
26 |
-
"SELECT content FROM attachments WHERE key = ?", (attachment_key,)
|
27 |
-
)
|
28 |
-
result = cursor.fetchone()
|
29 |
-
return result["content"] if result else None
|
30 |
-
|
31 |
-
def save_pdf_to_file(self, attachment_key: str, output_path: str) -> bool:
|
32 |
-
pdf_content = self.get_pdf_content(attachment_key)
|
33 |
-
if pdf_content:
|
34 |
-
try:
|
35 |
-
with open(output_path, "wb") as f:
|
36 |
-
f.write(pdf_content)
|
37 |
-
return True
|
38 |
-
except Exception as e:
|
39 |
-
print(f"Error saving PDF: {str(e)}")
|
40 |
-
return False
|
41 |
-
else:
|
42 |
-
print(f"No PDF content found for attachment key: {attachment_key}")
|
43 |
-
return False
|
44 |
-
|
45 |
-
def close(self):
|
46 |
-
self.conn.close()
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:42a0645cdd38f2d7ede525768eb21a4cbe08b4d86959cb4eb2349887f2bcf70e
|
3 |
+
size 1774
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
initialize_db.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:08030c4783a86d9a9afb9437b102dde959405b6b2857725eec02b6d9c2699e97
|
3 |
+
size 2346
|
rag/rag_pipeline.py
CHANGED
@@ -2,116 +2,81 @@ import json
|
|
2 |
import os
|
3 |
from typing import Dict, Any
|
4 |
from llama_index.core import (
|
5 |
-
SimpleDirectoryReader,
|
6 |
VectorStoreIndex,
|
7 |
Document,
|
8 |
-
|
9 |
-
|
|
|
|
|
10 |
)
|
11 |
-
from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser
|
12 |
-
from llama_index.embeddings.openai import OpenAIEmbedding
|
13 |
from llama_index.core import PromptTemplate
|
14 |
|
15 |
|
16 |
class RAGPipeline:
|
17 |
def __init__(
|
18 |
-
self,
|
|
|
|
|
19 |
):
|
20 |
-
|
21 |
-
self.
|
22 |
-
self.use_semantic_splitter = use_semantic_splitter
|
23 |
self.index = None
|
|
|
24 |
self.load_documents()
|
25 |
self.build_index()
|
26 |
|
27 |
def load_documents(self):
|
28 |
-
|
29 |
-
|
30 |
-
self.documents = []
|
31 |
-
return
|
32 |
-
|
33 |
-
with open(self.metadata_file, "r") as f:
|
34 |
-
self.metadata = json.load(f)
|
35 |
|
36 |
self.documents = []
|
37 |
-
for item_key, item_data in self.metadata.items():
|
38 |
-
metadata = item_data["metadata"]
|
39 |
-
pdf_path = item_data.get("pdf_path")
|
40 |
-
|
41 |
-
if pdf_path:
|
42 |
-
full_pdf_path = os.path.join(self.pdf_dir, os.path.basename(pdf_path))
|
43 |
-
if os.path.exists(full_pdf_path):
|
44 |
-
pdf_content = (
|
45 |
-
SimpleDirectoryReader(input_files=[full_pdf_path])
|
46 |
-
.load_data()[0]
|
47 |
-
.text
|
48 |
-
)
|
49 |
-
else:
|
50 |
-
pdf_content = "PDF file not found"
|
51 |
-
else:
|
52 |
-
pdf_content = "PDF path not available in metadata"
|
53 |
|
|
|
54 |
doc_content = (
|
55 |
-
f"Title: {
|
56 |
-
f"Abstract: {
|
57 |
-
f"Authors: {
|
58 |
-
f"Year: {
|
59 |
-
f"DOI: {
|
60 |
-
f"Full Text: {
|
61 |
)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
self.documents.append(
|
64 |
-
Document(
|
|
|
|
|
|
|
|
|
65 |
)
|
66 |
|
67 |
def build_index(self):
|
68 |
-
|
69 |
-
embed_model = OpenAIEmbedding()
|
70 |
-
splitter = SemanticSplitterNodeParser(
|
71 |
-
buffer_size=1,
|
72 |
-
breakpoint_percentile_threshold=95,
|
73 |
-
embed_model=embed_model,
|
74 |
-
)
|
75 |
-
else:
|
76 |
-
splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
|
77 |
-
|
78 |
-
nodes = splitter.get_nodes_from_documents(self.documents)
|
79 |
-
self.index = VectorStoreIndex(nodes)
|
80 |
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
86 |
)
|
87 |
-
response = query_engine.query(question)
|
88 |
|
89 |
-
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
"{context_str}\n"
|
97 |
-
"---------------------\n"
|
98 |
-
"Given this information, please answer the question: {query_str}\n"
|
99 |
-
"Include all relevant information from the provided context. "
|
100 |
-
"Highlight key information by enclosing it in **asterisks**. "
|
101 |
-
"When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
|
102 |
-
)
|
103 |
-
elif prompt_type == "evidence_based":
|
104 |
-
return PromptTemplate(
|
105 |
-
"Context information is below.\n"
|
106 |
-
"---------------------\n"
|
107 |
-
"{context_str}\n"
|
108 |
-
"---------------------\n"
|
109 |
-
"Given this information, please answer the question: {query_str}\n"
|
110 |
-
"Provide an answer to the question using evidence from the context above. "
|
111 |
-
"Cite sources using square brackets."
|
112 |
-
)
|
113 |
-
else:
|
114 |
-
return PromptTemplate(
|
115 |
"Context information is below.\n"
|
116 |
"---------------------\n"
|
117 |
"{context_str}\n"
|
@@ -122,3 +87,11 @@ class RAGPipeline:
|
|
122 |
"If the information is not available in the context, please state that clearly. "
|
123 |
"When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
|
124 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
from typing import Dict, Any
|
4 |
from llama_index.core import (
|
|
|
5 |
VectorStoreIndex,
|
6 |
Document,
|
7 |
+
SentenceWindowNodeParser,
|
8 |
+
)
|
9 |
+
from llama_index.core.node_parser import (
|
10 |
+
SentenceSplitter,
|
11 |
)
|
|
|
|
|
12 |
from llama_index.core import PromptTemplate
|
13 |
|
14 |
|
15 |
class RAGPipeline:
|
16 |
def __init__(
|
17 |
+
self,
|
18 |
+
study_json,
|
19 |
+
use_semantic_splitter=False,
|
20 |
):
|
21 |
+
|
22 |
+
self.study_json = study_json
|
|
|
23 |
self.index = None
|
24 |
+
self.use_semantic_splitter = use_semantic_splitter
|
25 |
self.load_documents()
|
26 |
self.build_index()
|
27 |
|
28 |
def load_documents(self):
|
29 |
+
with open(self.study_json, "r") as f:
|
30 |
+
self.data = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
self.documents = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
for index, doc_data in enumerate(self.data):
|
35 |
doc_content = (
|
36 |
+
f"Title: {doc_data['title']}\n"
|
37 |
+
f"Abstract: {doc_data['abstract']}\n"
|
38 |
+
f"Authors: {', '.join(doc_data['authors'])}\n"
|
39 |
+
f"Year: {doc_data['year']}\n"
|
40 |
+
f"DOI: {doc_data['doi']}\n"
|
41 |
+
f"Full Text: {doc_data['full_text']}"
|
42 |
)
|
43 |
|
44 |
+
metadata = {
|
45 |
+
"title": doc_data.get("title"),
|
46 |
+
"abstract": doc_data.get("abstract"),
|
47 |
+
"authors": doc_data.get("authors", []),
|
48 |
+
"year": doc_data.get("year"),
|
49 |
+
"doi": doc_data.get("doi"),
|
50 |
+
}
|
51 |
+
|
52 |
self.documents.append(
|
53 |
+
Document(
|
54 |
+
text=doc_content,
|
55 |
+
id_=f"doc_{index}",
|
56 |
+
metadata=metadata,
|
57 |
+
)
|
58 |
)
|
59 |
|
60 |
def build_index(self):
|
61 |
+
sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
def _split(text: str) -> List[str]:
|
64 |
+
return sentence_splitter.split_text(text)
|
65 |
|
66 |
+
node_parser = SentenceWindowNodeParser.from_defaults(
|
67 |
+
sentence_splitter=_split,
|
68 |
+
window_size=3,
|
69 |
+
window_metadata_key="window",
|
70 |
+
original_text_metadata_key="original_text",
|
71 |
)
|
|
|
72 |
|
73 |
+
nodes = node_parser.get_nodes_from_documents(self.documents)
|
74 |
|
75 |
+
self.index = VectorStoreIndex(nodes)
|
76 |
+
|
77 |
+
def query(self, question, prompt_template=None):
|
78 |
+
if prompt_template is None:
|
79 |
+
prompt_template = PromptTemplate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
"Context information is below.\n"
|
81 |
"---------------------\n"
|
82 |
"{context_str}\n"
|
|
|
87 |
"If the information is not available in the context, please state that clearly. "
|
88 |
"When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
|
89 |
)
|
90 |
+
|
91 |
+
query_engine = self.index.as_query_engine(
|
92 |
+
text_qa_template=prompt_template,
|
93 |
+
similarity_top_k=5,
|
94 |
+
)
|
95 |
+
response = query_engine.query(question)
|
96 |
+
|
97 |
+
return response
|
utils/helpers.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from typing import Dict, Any
|
2 |
from llama_index.core import Response
|
3 |
|
|
|
4 |
def process_response(response: Response) -> Dict[str, Any]:
|
5 |
source_nodes = response.source_nodes
|
6 |
sources = {}
|
@@ -18,12 +19,13 @@ def process_response(response: Response) -> Dict[str, Any]:
|
|
18 |
|
19 |
return {"markdown": markdown_text, "raw": raw_text, "sources": sources}
|
20 |
|
|
|
21 |
def format_source(metadata: Dict[str, Any]) -> str:
|
22 |
-
authors = metadata.get(
|
23 |
-
year = metadata.get(
|
24 |
-
title = metadata.get(
|
25 |
|
26 |
-
author_list = authors.split(
|
27 |
if len(author_list) > 2:
|
28 |
formatted_authors = f"{author_list[0].strip()} et al."
|
29 |
elif len(author_list) == 2:
|
@@ -31,10 +33,10 @@ def format_source(metadata: Dict[str, Any]) -> str:
|
|
31 |
else:
|
32 |
formatted_authors = author_list[0].strip()
|
33 |
|
34 |
-
year =
|
35 |
|
36 |
max_title_length = 250
|
37 |
if len(title) > max_title_length:
|
38 |
-
title = title[:max_title_length] +
|
39 |
|
40 |
return f"{formatted_authors} ({year}). {title}"
|
|
|
1 |
from typing import Dict, Any
|
2 |
from llama_index.core import Response
|
3 |
|
4 |
+
|
5 |
def process_response(response: Response) -> Dict[str, Any]:
|
6 |
source_nodes = response.source_nodes
|
7 |
sources = {}
|
|
|
19 |
|
20 |
return {"markdown": markdown_text, "raw": raw_text, "sources": sources}
|
21 |
|
22 |
+
|
23 |
def format_source(metadata: Dict[str, Any]) -> str:
|
24 |
+
authors = metadata.get("authors", "Unknown Author")
|
25 |
+
year = metadata.get("year", "n.d.")
|
26 |
+
title = metadata.get("title", "Untitled")
|
27 |
|
28 |
+
author_list = authors.split(",")
|
29 |
if len(author_list) > 2:
|
30 |
formatted_authors = f"{author_list[0].strip()} et al."
|
31 |
elif len(author_list) == 2:
|
|
|
33 |
else:
|
34 |
formatted_authors = author_list[0].strip()
|
35 |
|
36 |
+
year = "n.d." if year is None or year == "None" else str(year)
|
37 |
|
38 |
max_title_length = 250
|
39 |
if len(title) > max_title_length:
|
40 |
+
title = title[:max_title_length] + "..."
|
41 |
|
42 |
return f"{formatted_authors} ({year}). {title}"
|
vaccine_coverage_study.db
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:121fd525453b27b5008a3714840c929402ec01b74aea4d21bdd87be1a60bc008
|
3 |
-
size 41222144
|
|
|
|
|
|
|
|