Spaces:
Running
Running
File size: 3,620 Bytes
8121eff 6a076b8 cfb1a62 bc5a5b2 f4b7267 d762ede 8121eff cfb1a62 b117341 669d93a 122cee1 8121eff 669d93a 8121eff 669d93a 8121eff 669d93a d377a8f 669d93a 52a64e1 669d93a 8121eff 669d93a b117341 669d93a 8121eff 122cee1 d762ede 122cee1 d762ede 122cee1 d762ede 5f52091 6a076b8 9a9bac9 6a076b8 b117341 8121eff 660cea6 8121eff b117341 d762ede b117341 d762ede b117341 122cee1 9a9bac9 b117341 d377a8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import json
from typing import Dict, Any
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
from llama_index.core import PromptTemplate
from typing import List
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
class RAGPipeline:
def __init__(self, study_json, use_semantic_splitter=False):
self.study_json = study_json
self.use_semantic_splitter = use_semantic_splitter
self.documents = None
self.index = None
self.load_documents()
self.build_index()
def load_documents(self):
if self.documents is None:
with open(self.study_json, "r") as f:
self.data = json.load(f)
self.documents = []
for index, doc_data in enumerate(self.data):
doc_content = (
f"Title: {doc_data['title']}\n"
f"Abstract: {doc_data['abstract']}\n"
f"Authors: {', '.join(doc_data['authors'])}\n"
# f"full_text: {doc_data['full_text']}"
)
metadata = {
"title": doc_data.get("title"),
"authors": doc_data.get("authors", []),
"year": doc_data.get("year"),
"doi": doc_data.get("doi"),
}
self.documents.append(
Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
)
def build_index(self):
if self.index is None:
sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)
def _split(text: str) -> List[str]:
return sentence_splitter.split_text(text)
node_parser = SentenceWindowNodeParser.from_defaults(
sentence_splitter=_split,
window_size=5,
window_metadata_key="window",
original_text_metadata_key="original_text",
)
nodes = node_parser.get_nodes_from_documents(self.documents)
self.index = VectorStoreIndex(
nodes, embed_model=OpenAIEmbedding(model_name="text-embedding-3-large")
)
def query(
self, context: str, prompt_template: PromptTemplate = None
) -> Dict[str, Any]:
if prompt_template is None:
prompt_template = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given this information, please answer the question: {query_str}\n"
"Provide an answer to the question using evidence from the context above. "
"Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
"Even if there's only one source, still include the citation. "
"If you're unsure about a source, use [?]. "
"Ensure that EVERY statement from the context is properly cited."
)
# This is a hack to index all the documents in the store :)
n_documents = len(self.index.docstore.docs)
query_engine = self.index.as_query_engine(
text_qa_template=prompt_template,
similarity_top_k=n_documents,
response_mode="tree_summarize",
llm=OpenAI(model="gpt-4o-mini"),
)
response = query_engine.query(context)
return response
|