Spaces:
Running
Running
import json | |
from typing import Dict, Any | |
from llama_index.core import Document, VectorStoreIndex | |
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter | |
from llama_index.core import PromptTemplate | |
from typing import List | |
class RAGPipeline: | |
def __init__(self, study_json, use_semantic_splitter=False): | |
self.study_json = study_json | |
self.use_semantic_splitter = use_semantic_splitter | |
self.documents = None | |
self.index = None | |
self.load_documents() | |
self.build_index() | |
def load_documents(self): | |
if self.documents is None: | |
with open(self.study_json, "r") as f: | |
self.data = json.load(f) | |
self.documents = [] | |
for index, doc_data in enumerate(self.data): | |
doc_content = ( | |
f"Title: {doc_data['title']}\n" | |
f"Authors: {', '.join(doc_data['authors'])}\n" | |
f"Full Text: {doc_data['full_text']}" | |
) | |
metadata = { | |
"title": doc_data.get("title"), | |
"abstract": doc_data.get("abstract"), | |
"authors": doc_data.get("authors", []), | |
"year": doc_data.get("year"), | |
"doi": doc_data.get("doi"), | |
} | |
self.documents.append( | |
Document(text=doc_content, id_=f"doc_{index}", metadata=metadata) | |
) | |
def build_index(self): | |
if self.index is None: | |
sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13) | |
def _split(text: str) -> List[str]: | |
return sentence_splitter.split_text(text) | |
node_parser = SentenceWindowNodeParser.from_defaults( | |
sentence_splitter=_split, | |
window_size=3, | |
window_metadata_key="window", | |
original_text_metadata_key="original_text", | |
) | |
nodes = node_parser.get_nodes_from_documents(self.documents) | |
self.index = VectorStoreIndex(nodes) | |
def extract_study_info(self) -> Dict[str, Any]: | |
extraction_prompt = PromptTemplate( | |
"Based on the given context, please extract the following information about the study:\n" | |
"1. Study ID\n" | |
"2. Author(s)\n" | |
"3. Year\n" | |
"4. Title\n" | |
"5. Study design\n" | |
"6. Study area/region\n" | |
"7. Study population\n" | |
"8. Disease under study\n" | |
"9. Duration of study\n" | |
"If the information is not available, please respond with 'Not found' for that field.\n" | |
"Context: {context_str}\n" | |
"Extracted information:" | |
) | |
query_engine = self.index.as_query_engine( | |
text_qa_template=extraction_prompt, similarity_top_k=5 | |
) | |
response = query_engine.query("Extract study information") | |
# Parse the response to extract key-value pairs | |
lines = response.response.split("\n") | |
extracted_info = {} | |
for line in lines: | |
if ":" in line: | |
key, value = line.split(":", 1) | |
extracted_info[key.strip().lower().replace(" ", "_")] = value.strip() | |
return extracted_info | |
def query( | |
self, context: str, prompt_template: PromptTemplate = None | |
) -> Dict[str, Any]: | |
if prompt_template is None: | |
prompt_template = PromptTemplate( | |
"Context information is below.\n" | |
"---------------------\n" | |
"{context_str}\n" | |
"---------------------\n" | |
"Given this information, please answer the question provided in the context. " | |
"Include all relevant information from the provided context. " | |
"If information comes from multiple sources, please mention all of them. " | |
"If the information is not available in the context, please state that clearly. " | |
"When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc." | |
) | |
query_engine = self.index.as_query_engine( | |
text_qa_template=prompt_template, similarity_top_k=5 | |
) | |
response = query_engine.query(context) | |
return { | |
"answer": response.response, | |
"sources": [node.metadata for node in response.source_nodes], | |
} | |