|
from typing import List |
|
from pathlib import Path |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
|
|
|
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.schema import StrOutputParser |
|
from langchain_community.document_loaders import ( |
|
PyMuPDFLoader, |
|
) |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores.chroma import Chroma |
|
from langchain.indexes import SQLRecordManager, index |
|
from langchain.schema import Document |
|
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
|
|
import chainlit as cl |
|
|
|
|
|
chunk_size = 1024 |
|
chunk_overlap = 50 |
|
|
|
embeddings_model = HuggingFaceEmbeddings() |
|
|
|
PDF_STORAGE_PATH = "./public/pdfs" |
|
|
|
|
|
def process_pdfs(pdf_storage_path: str): |
|
pdf_directory = Path(pdf_storage_path) |
|
docs = [] |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
|
|
for pdf_path in pdf_directory.glob("*.pdf"): |
|
loader = PyMuPDFLoader(str(pdf_path)) |
|
documents = loader.load() |
|
docs += text_splitter.split_documents(documents) |
|
|
|
doc_search = Chroma.from_documents(docs, embeddings_model) |
|
|
|
namespace = "chromadb/my_documents" |
|
record_manager = SQLRecordManager( |
|
namespace, db_url="sqlite:///record_manager_cache.sql" |
|
) |
|
record_manager.create_schema() |
|
|
|
index_result = index( |
|
docs, |
|
record_manager, |
|
doc_search, |
|
cleanup="incremental", |
|
source_id_key="source", |
|
) |
|
|
|
print(f"Indexing stats: {index_result}") |
|
|
|
return doc_search |
|
|
|
|
|
doc_search = process_pdfs(PDF_STORAGE_PATH) |
|
|
|
os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN'] |
|
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" |
|
|
|
model = HuggingFaceEndpoint( |
|
repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True |
|
) |
|
|
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
template = """Answer the question based only on the following context: |
|
|
|
{context} |
|
|
|
Question: {question} |
|
""" |
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
def format_docs(docs): |
|
return "\n\n".join([d.page_content for d in docs]) |
|
|
|
retriever = doc_search.as_retriever() |
|
|
|
runnable = ( |
|
{"context": retriever | format_docs, "question": RunnablePassthrough()} |
|
| prompt |
|
| model |
|
| StrOutputParser() |
|
) |
|
|
|
cl.user_session.set("runnable", runnable) |
|
|
|
|
|
@cl.on_message |
|
async def on_message(message: cl.Message): |
|
runnable = cl.user_session.get("runnable") |
|
msg = cl.Message(content="") |
|
|
|
class PostMessageHandler(BaseCallbackHandler): |
|
""" |
|
Callback handler for handling the retriever and LLM processes. |
|
Used to post the sources of the retrieved documents as a Chainlit element. |
|
""" |
|
|
|
def __init__(self, msg: cl.Message): |
|
BaseCallbackHandler.__init__(self) |
|
self.msg = msg |
|
self.sources = set() |
|
|
|
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): |
|
for d in documents: |
|
source_page_pair = (d.metadata['source'], d.metadata['page']) |
|
self.sources.add(source_page_pair) |
|
|
|
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): |
|
if len(self.sources): |
|
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources]) |
|
self.msg.elements.append( |
|
cl.Text(name="Sources", content=sources_text, display="inline") |
|
) |
|
|
|
async with cl.Step(type="run", name="QA Assistant"): |
|
async for chunk in runnable.astream( |
|
message.content, |
|
config=RunnableConfig(callbacks=[ |
|
cl.LangchainCallbackHandler(), |
|
PostMessageHandler(msg) |
|
]), |
|
): |
|
await msg.stream_token(chunk) |
|
|
|
await msg.send() |