import json from typing import Dict, Any from llama_index.core import Document, VectorStoreIndex from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter from llama_index.core import PromptTemplate from typing import List class RAGPipeline: def __init__(self, study_json, use_semantic_splitter=False): self.study_json = study_json self.use_semantic_splitter = use_semantic_splitter self.documents = None self.index = None self.load_documents() self.build_index() def load_documents(self): if self.documents is None: with open(self.study_json, "r") as f: self.data = json.load(f) self.documents = [] for index, doc_data in enumerate(self.data): doc_content = ( f"Title: {doc_data['title']}\n" f"Authors: {', '.join(doc_data['authors'])}\n" f"Full Text: {doc_data['full_text']}" ) metadata = { "title": doc_data.get("title"), "abstract": doc_data.get("abstract"), "authors": doc_data.get("authors", []), "year": doc_data.get("year"), "doi": doc_data.get("doi"), } self.documents.append( Document(text=doc_content, id_=f"doc_{index}", metadata=metadata) ) def build_index(self): if self.index is None: sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13) def _split(text: str) -> List[str]: return sentence_splitter.split_text(text) node_parser = SentenceWindowNodeParser.from_defaults( sentence_splitter=_split, window_size=3, window_metadata_key="window", original_text_metadata_key="original_text", ) nodes = node_parser.get_nodes_from_documents(self.documents) self.index = VectorStoreIndex(nodes) def extract_study_info(self) -> Dict[str, Any]: extraction_prompt = PromptTemplate( "Based on the given context, please extract the following information about the study:\n" "1. Study ID\n" "2. Author(s)\n" "3. Year\n" "4. Title\n" "5. Study design\n" "6. Study area/region\n" "7. Study population\n" "8. Disease under study\n" "9. Duration of study\n" "If the information is not available, please respond with 'Not found' for that field.\n" "Context: {context_str}\n" "Extracted information:" ) query_engine = self.index.as_query_engine( text_qa_template=extraction_prompt, similarity_top_k=5 ) response = query_engine.query("Extract study information") # Parse the response to extract key-value pairs lines = response.response.split("\n") extracted_info = {} for line in lines: if ":" in line: key, value = line.split(":", 1) extracted_info[key.strip().lower().replace(" ", "_")] = value.strip() return extracted_info def query( self, context: str, prompt_template: PromptTemplate = None ) -> Dict[str, Any]: if prompt_template is None: prompt_template = PromptTemplate( "Context information is below.\n" "---------------------\n" "{context_str}\n" "---------------------\n" "Given this information, please answer the question provided in the context. " "Include all relevant information from the provided context. " "If information comes from multiple sources, please mention all of them. " "If the information is not available in the context, please state that clearly. " "When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc." ) query_engine = self.index.as_query_engine( text_qa_template=prompt_template, similarity_top_k=5 ) response = query_engine.query(context) return { "answer": response.response, "sources": [node.metadata for node in response.source_nodes], }