import os from typing import List from pathlib import Path from langchain_huggingface import HuggingFaceEmbeddings #from langchain_community.llms import HuggingFaceEndpoint from langchain_huggingface import HuggingFaceEndpoint #from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate from langchain.schema import StrOutputParser from langchain_community.document_loaders import ( PyMuPDFLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.indexes import SQLRecordManager, index from langchain.schema import Document from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig from langchain.callbacks.base import BaseCallbackHandler import chainlit as cl from literalai import LiteralClient literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY")) chunk_size = 1024 chunk_overlap = 50 embeddings_model = HuggingFaceEmbeddings() PDF_STORAGE_PATH = "./public/pdfs" def process_pdfs(pdf_storage_path: str): pdf_directory = Path(pdf_storage_path) docs = [] # type: List[Document] text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) for pdf_path in pdf_directory.glob("*.pdf"): loader = PyMuPDFLoader(str(pdf_path)) documents = loader.load() docs += text_splitter.split_documents(documents) doc_search = Chroma.from_documents(docs, embeddings_model) namespace = "chromadb/my_documents" record_manager = SQLRecordManager( namespace, db_url="sqlite:///record_manager_cache.sql" ) record_manager.create_schema() index_result = index( docs, record_manager, doc_search, cleanup="incremental", source_id_key="source", ) print(f"Indexing stats: {index_result}") return doc_search doc_search = process_pdfs(PDF_STORAGE_PATH) #model = ChatOpenAI(model_name="gpt-4", streaming=True) os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN'] repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" model = HuggingFaceEndpoint( repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True ) @cl.on_chat_start async def on_chat_start(): await cl.Message(f"> REVIEWSTREAM").send() template = """Answer the question based only on the following context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template) def format_docs(docs): return "\n\n".join([d.page_content for d in docs]) retriever = doc_search.as_retriever() runnable = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | model | StrOutputParser() ) cl.user_session.set("runnable", runnable) @cl.on_message async def on_message(message: cl.Message): runnable = cl.user_session.get("runnable") # type: Runnable msg = cl.Message(content="") class PostMessageHandler(BaseCallbackHandler): """ Callback handler for handling the retriever and LLM processes. Used to post the sources of the retrieved documents as a Chainlit element. """ def __init__(self, msg: cl.Message): BaseCallbackHandler.__init__(self) self.msg = msg self.sources = set() # To store unique pairs def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): for d in documents: source_page_pair = (d.metadata['source'], d.metadata['page']) self.sources.add(source_page_pair) # Add unique pairs to the set def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): if len(self.sources): sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources]) self.msg.elements.append( cl.Text(name="Sources", content=sources_text, display="inline") ) async with cl.Step(type="run", name="QA Assistant"): async for chunk in runnable.astream( message.content, config=RunnableConfig(callbacks=[ cl.LangchainCallbackHandler(), PostMessageHandler(msg) ]), ): await msg.stream_token(chunk) await msg.send()