import PyPDF2 from langchain_community.embeddings import SentenceTransformerEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain_groq import ChatGroq from langchain.memory import ChatMessageHistory, ConversationBufferMemory import chainlit as cl from chainlit.input_widget import Select import os @cl.cache def get_memory(): # Initialize message history for conversation message_history = ChatMessageHistory() # Memory for conversational context memory = ConversationBufferMemory( memory_key="chat_history", output_key="answer", chat_memory=message_history, return_messages=True, ) return memory @cl.on_chat_start async def on_chat_start(): user_env = cl.user_session.get("env") os.environ["GROQ_API_KEY"] = user_env.get("GROQ_API_KEY") settings = await cl.ChatSettings( [ Select( id="Model", label="Choose your favorite LLM:", values=["llama3-8b-8192", "llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"], initial_index=1, ) ] ).send() files = None #Initialize variable to store uploaded files # Wait for the user to upload a file while files is None: files = await cl.AskFileMessage( content="Please upload a pdf file to begin!", accept=["application/pdf"], max_size_mb=100, timeout=180, max_files = 10, ).send() pdf_text = "" for file in files: # Inform the user that processing has started msg = cl.Message(content=f"Processing `{file.name}`...") await msg.send() # Read the PDF file pdf = PyPDF2.PdfReader(file.path) for page in pdf.pages: pdf_text += page.extract_text() # Split the text into chunks text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) texts = text_splitter.split_text(pdf_text) # Create a metadata for each chunk metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))] # Create a Chroma vector store # embeddings = OllamaEmbeddings(model="nomic-embed-text") # embeddings = SentenceTransformerEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2") embeddings = SentenceTransformerEmbeddings(model_name = "Snowflake/snowflake-arctic-embed-m") #embeddings = OllamaEmbeddings(model="llama2:7b") docsearch = await cl.make_async(Chroma.from_texts)( texts, embeddings, metadatas=metadatas ) cl.user_session.set("docsearch", docsearch) # Let the user know that the system is ready msg.content = f"Processing `{file.name}` done. You can now ask questions!" await msg.update() await setup_agent(settings) @cl.on_settings_update async def setup_agent(settings): user_env = cl.user_session.get("env") os.environ["GROQ_API_KEY"] = user_env.get("GROQ_API_KEY") msg = cl.Message(content = f"You are using `{settings['Model']}` as LLM. You can change model in `Settings Panel` in the chat box.") await msg.send() memory=get_memory() docsearch = cl.user_session.get("docsearch") # Create a chain that uses the Chroma vector stores chain = ConversationalRetrievalChain.from_llm( llm = ChatGroq(model=settings["Model"]), chain_type="stuff", retriever=docsearch.as_retriever(), memory=memory, return_source_documents=True, ) #store the chain in user session cl.user_session.set("chain", chain) @cl.on_message async def main(message: cl.Message): # Retrieve the chain from user session chain = cl.user_session.get("chain") #call backs happens asynchronously/parallel cb = cl.AsyncLangchainCallbackHandler() user_env = cl.user_session.get("env") os.environ["GROQ_API_KEY"] = user_env.get("GROQ_API_KEY") print(chain) # call the chain with user's message content res = await chain.ainvoke(message.content, callbacks=[cb]) answer = res["answer"] source_documents = res["source_documents"] text_elements = [] # Initialize list to store text elements # Process source documents if available if source_documents: for source_idx, source_doc in enumerate(source_documents): source_name = f"source_{source_idx}" # Create the text element referenced in the message text_elements.append( cl.Text(content=source_doc.page_content, name=source_name) ) source_names = [text_el.name for text_el in text_elements] # Add source references to the answer if source_names: answer += f"\nSources: {', '.join(source_names)}" else: answer += "\nNo sources found" #return results await cl.Message(content=answer, elements=text_elements).send() @cl.on_stop def on_stop(): print("The user wants to stop the task!") docsearch = cl.user_session.get("docsearch") docsearch.delete_collection() @cl.on_chat_end def on_chat_end(): print("The user disconnected!") docsearch = cl.user_session.get("docsearch") docsearch.delete_collection()