import os from dotenv import load_dotenv load_dotenv(".env") os.environ['USER_AGENT'] = os.getenv("USER_AGENT") os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") os.environ["TOKENIZERS_PARALLELISM"]='true' from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_community.chat_message_histories import ChatMessageHistory from langchain_community.document_loaders import WebBaseLoader from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory from pinecone import Pinecone from pinecone_text.sparse import BM25Encoder from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.retrievers import PineconeHybridSearchRetriever from langchain_groq import ChatGroq # from flask import Flask, request, render_template # from flask_cors import CORS # from flask_socketio import SocketIO, emit import gradio as gr import spaces import torch zero = torch.Tensor([0]).cuda() print(zero.device) # <-- 'cpu' 🤔 @spaces.GPU def greet(n): print(zero.device) # <-- 'cuda:0' 🤗 return f"Hello {zero + n} Tensor" # app = Flask(__name__) # CORS(app) # socketio = SocketIO(app, cors_allowed_origins="*") # app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS # app.config['SESSION_COOKIE_HTTPONLY'] = True # app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' # app.config['SECRET_KEY'] = os.getenv('SECRET_KEY') try: pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) index_name = "traveler-demo-website-vectorstore" # connect to index pinecone_index = pc.Index(index_name) except: pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) index_name = "traveler-demo-website-vectorstore" # connect to index pinecone_index = pc.Index(index_name) bm25 = BM25Encoder().load("./bm25_traveler_website.json") embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True}) retriever = PineconeHybridSearchRetriever( embeddings=embed_model, sparse_encoder=bm25, index=pinecone_index, top_k=20, alpha=0.5, ) llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2) ### Contextualize question ### contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is. """ contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}") ] ) history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \ Provide links to sources provided in the answer. \ If you don't know the answer, just say that you don't know. \ Do not give extra long answers. \ When responding to queries, your responses should be comprehensive and well-organized. For each response: \ 1. Provide Clear Answers \ 2. Include Detailed References: \ - Include links to sources and any links or sites where there is a mentioned in the answer. - Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \ - Downloadable Materials: Include links to any relevant downloadable resources if applicable. \ - Reference Sites: Mention specific websites or platforms that offer additional information. \ 3. Formatting for Readability: \ - Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \ - Emphasize Important Information: Use bold or italics to highlight key details. \ 4. Organize Content Logically \ Do not include anything about context in the answer. \ {context} """ qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}") ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) ### Statefully manage chat history ### store = {} def clean_temporary_data(): store = {} def get_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in store: store[session_id] = ChatMessageHistory() return store[session_id] conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) # Stream response to client @socketio.on('message') def handle_message(data): question = data.get('question') session_id = data.get('session_id', 'abc123') chain = conversational_rag_chain.pick("answer") try: for chunk in chain.stream( {"input": question}, config={ "configurable": {"session_id": "abc123"} }, ): emit('response', chunk, room=request.sid) except: for chunk in chain.stream( {"input": question}, config={ "configurable": {"session_id": "abc123"} }, ): emit('response', chunk, room=request.sid) @app.route("/") def index_view(): return render_template('chat.html') if __name__ == '__main__': socketio.run(app, debug=True) demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text()) demo.launch()