acres / rag /rag_pipeline.py
ak3ra's picture
minor
9a9bac9
raw
history blame
4.51 kB
import json
from typing import Dict, Any
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
from llama_index.core import PromptTemplate
from typing import List
class RAGPipeline:
def __init__(self, study_json, use_semantic_splitter=False):
self.study_json = study_json
self.use_semantic_splitter = use_semantic_splitter
self.documents = None
self.index = None
self.load_documents()
self.build_index()
def load_documents(self):
if self.documents is None:
with open(self.study_json, "r") as f:
self.data = json.load(f)
self.documents = []
for index, doc_data in enumerate(self.data):
doc_content = (
f"Title: {doc_data['title']}\n"
f"Authors: {', '.join(doc_data['authors'])}\n"
f"Full Text: {doc_data['full_text']}"
)
metadata = {
"title": doc_data.get("title"),
"abstract": doc_data.get("abstract"),
"authors": doc_data.get("authors", []),
"year": doc_data.get("year"),
"doi": doc_data.get("doi"),
}
self.documents.append(
Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
)
def build_index(self):
if self.index is None:
sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)
def _split(text: str) -> List[str]:
return sentence_splitter.split_text(text)
node_parser = SentenceWindowNodeParser.from_defaults(
sentence_splitter=_split,
window_size=3,
window_metadata_key="window",
original_text_metadata_key="original_text",
)
nodes = node_parser.get_nodes_from_documents(self.documents)
self.index = VectorStoreIndex(nodes)
def extract_study_info(self) -> Dict[str, Any]:
extraction_prompt = PromptTemplate(
"Based on the given context, please extract the following information about the study:\n"
"1. Study ID\n"
"2. Author(s)\n"
"3. Year\n"
"4. Title\n"
"5. Study design\n"
"6. Study area/region\n"
"7. Study population\n"
"8. Disease under study\n"
"9. Duration of study\n"
"If the information is not available, please respond with 'Not found' for that field.\n"
"Context: {context_str}\n"
"Extracted information:"
)
query_engine = self.index.as_query_engine(
text_qa_template=extraction_prompt, similarity_top_k=5
)
response = query_engine.query("Extract study information")
# Parse the response to extract key-value pairs
lines = response.response.split("\n")
extracted_info = {}
for line in lines:
if ":" in line:
key, value = line.split(":", 1)
extracted_info[key.strip().lower().replace(" ", "_")] = value.strip()
return extracted_info
def query(
self, context: str, prompt_template: PromptTemplate = None
) -> Dict[str, Any]:
if prompt_template is None:
prompt_template = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given this information, please answer the question provided in the context. "
"Include all relevant information from the provided context. "
"If information comes from multiple sources, please mention all of them. "
"If the information is not available in the context, please state that clearly. "
"When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
)
query_engine = self.index.as_query_engine(
text_qa_template=prompt_template, similarity_top_k=5
)
response = query_engine.query(context)
return {
"answer": response.response,
"sources": [node.metadata for node in response.source_nodes],
}