Spaces:
Running
Running
File size: 2,859 Bytes
cfb1a62 8121eff cfb1a62 bc5a5b2 8121eff cfb1a62 b117341 bc5a5b2 b117341 8121eff b117341 8121eff b117341 8121eff b117341 8121eff b117341 bc5a5b2 cfb1a62 bc5a5b2 8121eff b117341 8121eff b117341 bc5a5b2 b117341 bc5a5b2 b117341 8121eff b117341 cfb1a62 b117341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# rag/rag_pipeline.py
import json
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
from llama_index.core import PromptTemplate
class RAGPipeline:
def __init__(self, study_json, use_semantic_splitter=False):
self.study_json = study_json
self.index = None
self.use_semantic_splitter = use_semantic_splitter
self.load_documents()
self.build_index()
def load_documents(self):
with open(self.study_json, "r") as f:
self.data = json.load(f)
self.documents = []
for index, doc_data in enumerate(self.data):
doc_content = (
f"Title: {doc_data['title']}\n"
f"Authors: {', '.join(doc_data['authors'])}\n"
f"Full Text: {doc_data['full_text']}"
)
metadata = {
"title": doc_data.get("title"),
"abstract": doc_data.get("abstract"),
"authors": doc_data.get("authors", []),
"year": doc_data.get("year"),
"doi": doc_data.get("doi"),
}
self.documents.append(
Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
)
def build_index(self):
sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)
def _split(text: str) -> List[str]:
return sentence_splitter.split_text(text)
node_parser = SentenceWindowNodeParser.from_defaults(
sentence_splitter=_split,
window_size=3,
window_metadata_key="window",
original_text_metadata_key="original_text",
)
nodes = node_parser.get_nodes_from_documents(self.documents)
self.index = VectorStoreIndex(nodes)
def query(self, question, prompt_template=None):
if prompt_template is None:
prompt_template = PromptTemplate(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given this information, please answer the question: {query_str}\n"
"Include all relevant information from the provided context. "
"If information comes from multiple sources, please mention all of them. "
"If the information is not available in the context, please state that clearly. "
"When quoting specific information, please use square brackets to indicate the source, e.g. [1], [2], etc."
)
query_engine = self.index.as_query_engine(
text_qa_template=prompt_template, similarity_top_k=5
)
response = query_engine.query(question)
return response
|