File size: 3,620 Bytes
8121eff
6a076b8
cfb1a62
 
bc5a5b2
f4b7267
d762ede
 
8121eff
 
 
cfb1a62
b117341
 
669d93a
 
122cee1
 
8121eff
 
669d93a
 
 
8121eff
669d93a
8121eff
669d93a
 
 
d377a8f
669d93a
52a64e1
669d93a
8121eff
669d93a
 
 
 
 
 
b117341
669d93a
 
 
8121eff
122cee1
 
d762ede
122cee1
 
 
 
 
 
d762ede
122cee1
 
 
 
 
d762ede
 
 
5f52091
6a076b8
9a9bac9
6a076b8
b117341
 
8121eff
 
 
 
660cea6
 
 
 
 
 
8121eff
b117341
d762ede
 
b117341
d762ede
 
 
 
b117341
122cee1
9a9bac9
b117341
d377a8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
from typing import Dict, Any
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
from llama_index.core import PromptTemplate
from typing import List
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI


class RAGPipeline:
    def __init__(self, study_json, use_semantic_splitter=False):
        self.study_json = study_json
        self.use_semantic_splitter = use_semantic_splitter
        self.documents = None
        self.index = None
        self.load_documents()
        self.build_index()

    def load_documents(self):
        if self.documents is None:
            with open(self.study_json, "r") as f:
                self.data = json.load(f)

            self.documents = []

            for index, doc_data in enumerate(self.data):
                doc_content = (
                    f"Title: {doc_data['title']}\n"
                    f"Abstract: {doc_data['abstract']}\n"
                    f"Authors: {', '.join(doc_data['authors'])}\n"
                    # f"full_text: {doc_data['full_text']}"
                )

                metadata = {
                    "title": doc_data.get("title"),
                    "authors": doc_data.get("authors", []),
                    "year": doc_data.get("year"),
                    "doi": doc_data.get("doi"),
                }

                self.documents.append(
                    Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
                )

    def build_index(self):
        if self.index is None:
            sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20)

            def _split(text: str) -> List[str]:
                return sentence_splitter.split_text(text)

            node_parser = SentenceWindowNodeParser.from_defaults(
                sentence_splitter=_split,
                window_size=5,
                window_metadata_key="window",
                original_text_metadata_key="original_text",
            )

            nodes = node_parser.get_nodes_from_documents(self.documents)
            self.index = VectorStoreIndex(
                nodes, embed_model=OpenAIEmbedding(model_name="text-embedding-3-large")
            )

    def query(
        self, context: str, prompt_template: PromptTemplate = None
    ) -> Dict[str, Any]:
        if prompt_template is None:
            prompt_template = PromptTemplate(
                "Context information is below.\n"
                "---------------------\n"
                "{context_str}\n"
                "---------------------\n"
                "Given this information, please answer the question: {query_str}\n"
                "Provide an answer to the question using evidence from the context above. "
                "Cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc. "
                "Even if there's only one source, still include the citation. "
                "If you're unsure about a source, use [?]. "
                "Ensure that EVERY statement from the context is properly cited."
            )

        # This is a hack to index all the documents in the store :)
        n_documents = len(self.index.docstore.docs)
        query_engine = self.index.as_query_engine(
            text_qa_template=prompt_template,
            similarity_top_k=n_documents,
            response_mode="tree_summarize",
            llm=OpenAI(model="gpt-4o-mini"),
        )

        response = query_engine.query(context)

        return response