Spaces:
Running
Running
import os | |
import json | |
from typing import List | |
from pathlib import Path | |
from langchain_huggingface import HuggingFaceEmbeddings | |
#from langchain_community.llms import HuggingFaceEndpoint | |
from langchain_huggingface import HuggingFaceEndpoint | |
#from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema import StrOutputParser | |
from langchain_community.document_loaders import ( | |
PyMuPDFLoader, | |
) | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain.indexes import SQLRecordManager, index | |
from langchain.schema import Document | |
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig | |
from langchain.callbacks.base import BaseCallbackHandler | |
import chainlit as cl | |
from literalai import LiteralClient | |
def auth_callback(username: str, password: str): | |
auth = json.loads(os.environ['CHAINLIT_AUTH_LOGIN']) | |
ident = next(d['ident'] for d in auth if d['ident'] == username) | |
pwd = next(d['pwd'] for d in auth if d['ident'] == username) | |
resultLogAdmin = bcrypt.checkpw(username.encode('utf-8'), bcrypt.hashpw(ident.encode('utf-8'), bcrypt.gensalt())) | |
resultPwdAdmin = bcrypt.checkpw(password.encode('utf-8'), bcrypt.hashpw(pwd.encode('utf-8'), bcrypt.gensalt())) | |
resultRole = next(d['role'] for d in auth if d['ident'] == username) | |
if resultLogAdmin and resultPwdAdmin and resultRole == "admindatapcc": | |
return cl.User( | |
identifier=ident + " : ๐งโ๐ผ Admin Datapcc", metadata={"role": "admin", "provider": "credentials"} | |
) | |
elif resultLogAdmin and resultPwdAdmin and resultRole == "userdatapcc": | |
return cl.User( | |
identifier=ident + " : ๐งโ๐ User Datapcc", metadata={"role": "user", "provider": "credentials"} | |
) | |
literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY")) | |
chunk_size = 1024 | |
chunk_overlap = 50 | |
embeddings_model = HuggingFaceEmbeddings() | |
PDF_STORAGE_PATH = "./public/pdfs" | |
def process_pdfs(pdf_storage_path: str): | |
pdf_directory = Path(pdf_storage_path) | |
docs = [] # type: List[Document] | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
for pdf_path in pdf_directory.glob("*.pdf"): | |
loader = PyMuPDFLoader(str(pdf_path)) | |
documents = loader.load() | |
docs += text_splitter.split_documents(documents) | |
doc_search = Chroma.from_documents(docs, embeddings_model) | |
namespace = "chromadb/my_documents" | |
record_manager = SQLRecordManager( | |
namespace, db_url="sqlite:///record_manager_cache.sql" | |
) | |
record_manager.create_schema() | |
index_result = index( | |
docs, | |
record_manager, | |
doc_search, | |
cleanup="incremental", | |
source_id_key="source", | |
) | |
print(f"Indexing stats: {index_result}") | |
return doc_search | |
doc_search = process_pdfs(PDF_STORAGE_PATH) | |
#model = ChatOpenAI(model_name="gpt-4", streaming=True) | |
os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.environ['HUGGINGFACEHUB_API_TOKEN'] | |
repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
model = HuggingFaceEndpoint( | |
repo_id=repo_id, max_new_tokens=8000, temperature=1.0, task="text2text-generation", streaming=True | |
) | |
async def on_chat_start(): | |
await cl.Message(f"> REVIEWSTREAM").send() | |
template = """Answer the question based only on the following context: | |
{context} | |
Question: {question} | |
""" | |
prompt = ChatPromptTemplate.from_template(template) | |
def format_docs(docs): | |
return "\n\n".join([d.page_content for d in docs]) | |
retriever = doc_search.as_retriever() | |
runnable = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| model | |
| StrOutputParser() | |
) | |
cl.user_session.set("runnable", runnable) | |
async def on_message(message: cl.Message): | |
runnable = cl.user_session.get("runnable") # type: Runnable | |
msg = cl.Message(content="") | |
class PostMessageHandler(BaseCallbackHandler): | |
""" | |
Callback handler for handling the retriever and LLM processes. | |
Used to post the sources of the retrieved documents as a Chainlit element. | |
""" | |
def __init__(self, msg: cl.Message): | |
BaseCallbackHandler.__init__(self) | |
self.msg = msg | |
self.sources = set() # To store unique pairs | |
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): | |
for d in documents: | |
source_page_pair = (d.metadata['source'], d.metadata['page']) | |
self.sources.add(source_page_pair) # Add unique pairs to the set | |
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): | |
if len(self.sources): | |
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources]) | |
self.msg.elements.append( | |
cl.Text(name="Sources", content=sources_text, display="inline") | |
) | |
async with cl.Step(type="run", name="QA Assistant"): | |
async for chunk in runnable.astream( | |
message.content, | |
config=RunnableConfig(callbacks=[ | |
cl.LangchainCallbackHandler(), | |
PostMessageHandler(msg) | |
]), | |
): | |
await msg.stream_token(chunk) | |
await msg.send() |