lindsay-qu's picture
Update core/retriever/chroma_retriever.py
6ec8a6d verified
from .base_retriever import BaseRetriever
from models import BaseModel
from langchain.document_loaders import PyMuPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from chromadb import PersistentClient
import os
class ChromaRetriever(BaseRetriever):
def __init__(self,
pdf_dir: str,
collection_name: str,
split_args: dict,
embed_model: BaseModel = None,
refine_model: BaseModel = None,):
'''
pdf_dir: directory containing pdfs for vector database
collection_name: name of collection to be used (if collection exists, it will be loaded, otherwise it will be created)
split_args: dictionary of arguments for text splitter ("size": size of chunks, "overlap": overlap between chunks)
embed_model: model to embed text chunks (if not provided, will use chroma's default embeddings)
example:
from models import GPT4Model
dir = "papers"
retriever = ChromaRetriever(dir, "pdfs", {"size": 2048, "overlap": 10}, embed_model=GPT4Model()
'''
self.embed_model = embed_model
if not os.path.exists("persist"):
os.mkdir("persist")
client = PersistentClient(path="persist")
print(client.list_collections())
try:
collection = client.get_collection(name=collection_name)
except:
print("Creating new collection...")
print("Loading pdf papers into the vector database... ")
pdf_loader = DirectoryLoader(pdf_dir, loader_cls=PyMuPDFLoader)
docs = pdf_loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=split_args["size"], chunk_overlap=split_args["overlap"])
split_docs = text_splitter.split_documents(docs)
texts = [doc.page_content for doc in split_docs]
# TODO
titles = [doc.metadata["title"] for doc in split_docs]
collection = client.create_collection(name=collection_name)
if embed_model is not None:
embeddings = embed_model.embedding(texts)
collection.add(
embeddings=embeddings,
documents=texts,
ids=[str(i+1) for i in range(len(texts))],
metadatas=[{"title": title} for title in titles]
)
else:
collection.add(
documents=texts,
ids=[str(i+1) for i in range(len(texts))],
metadatas=[{"title": title} for title in titles]
)
self.collection = collection
print("Papers Loaded.")
def retrieve(self, query: str, k: int = 5) -> list:
'''
query: text string used to query the vector database
k: number of text chunks to return
returns: list of retrieved text chunks
example:
retriever.retrieve("how do sex chromosomes in rhesus monkeys influence proteome?", k=10)
'''
if self.embed_model is not None:
results = self.collection.query(
query_embeddings=self.embed_model.embedding([query]),
n_results=k,
)
else:
results = self.collection.query(
query_texts=[query],
n_results=k,
)
return results['documents'][0], [result["title"] for result in results['metadatas'][0]]