Ritesh-hf commited on
Commit
4c6d98a
1 Parent(s): 32ba13a

initial commit

Browse files
Files changed (6) hide show
  1. .env +4 -0
  2. __pycache__/app.cpython-38.pyc +0 -0
  3. app.py +138 -0
  4. bm25_traveler_website.json +0 -0
  5. requirements.txt +98 -0
  6. temp.py +176 -0
.env ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ USER_AGENT='myagent'
2
+ GROQ_API_KEY="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"
3
+ PINECONE_API_KEY="ca8e6a33-7355-453f-ad4b-80c8a1c6a9c7"
4
+ SECRET_KEY="b0*1x^y@9$)w%v+k=p!8xp@4bkt37s&b8+uf%1=mh+v1=@ybsh"
__pycache__/app.cpython-38.pyc ADDED
Binary file (4.55 kB). View file
 
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ load_dotenv(".env")
4
+
5
+ os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
6
+ os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
7
+ os.environ["TOKENIZERS_PARALLELISM"]='true'
8
+
9
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
10
+ from langchain.chains.combine_documents import create_stuff_documents_chain
11
+ from langchain_community.chat_message_histories import ChatMessageHistory
12
+ from langchain_community.document_loaders import WebBaseLoader
13
+ from langchain_core.chat_history import BaseChatMessageHistory
14
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
+ from langchain_core.runnables.history import RunnableWithMessageHistory
16
+
17
+ from pinecone import Pinecone
18
+ from pinecone_text.sparse import BM25Encoder
19
+
20
+ from langchain_huggingface import HuggingFaceEmbeddings
21
+ from langchain_community.retrievers import PineconeHybridSearchRetriever
22
+
23
+ from langchain_groq import ChatGroq
24
+
25
+ import gradio as gr
26
+ import spaces
27
+ import torch
28
+
29
+
30
+ try:
31
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
32
+ index_name = "traveler-demo-website-vectorstore"
33
+ # connect to index
34
+ pinecone_index = pc.Index(index_name)
35
+ except:
36
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
37
+ index_name = "traveler-demo-website-vectorstore"
38
+ # connect to index
39
+ pinecone_index = pc.Index(index_name)
40
+
41
+ bm25 = BM25Encoder().load("./bm25_traveler_website.json")
42
+
43
+ embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})
44
+
45
+ retriever = PineconeHybridSearchRetriever(
46
+ embeddings=embed_model,
47
+ sparse_encoder=bm25,
48
+ index=pinecone_index,
49
+ top_k=20,
50
+ alpha=0.5,
51
+ )
52
+
53
+ llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2)
54
+
55
+ ### Contextualize question ###
56
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
57
+ which might reference context in the chat history, formulate a standalone question \
58
+ which can be understood without the chat history. Do NOT answer the question, \
59
+ just reformulate it if needed and otherwise return it as is.
60
+ """
61
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
62
+ [
63
+ ("system", contextualize_q_system_prompt),
64
+ MessagesPlaceholder("chat_history"),
65
+ ("human", "{input}")
66
+ ]
67
+ )
68
+
69
+ history_aware_retriever = create_history_aware_retriever(
70
+ llm, retriever, contextualize_q_prompt
71
+ )
72
+
73
+
74
+ qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \
75
+ Provide links to sources provided in the answer. \
76
+ If you don't know the answer, just say that you don't know. \
77
+ Do not give extra long answers. \
78
+ When responding to queries, your responses should be comprehensive and well-organized. For each response: \
79
+ 1. Provide Clear Answers \
80
+ 2. Include Detailed References: \
81
+ - Include links to sources and any links or sites where there is a mentioned in the answer.
82
+ - Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \
83
+ - Downloadable Materials: Include links to any relevant downloadable resources if applicable. \
84
+ - Reference Sites: Mention specific websites or platforms that offer additional information. \
85
+ 3. Formatting for Readability: \
86
+ - Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \
87
+ - Emphasize Important Information: Use bold or italics to highlight key details. \
88
+ 4. Organize Content Logically \
89
+ Do not include anything about context in the answer. \
90
+ {context}
91
+ """
92
+ qa_prompt = ChatPromptTemplate.from_messages(
93
+ [
94
+ ("system", qa_system_prompt),
95
+ MessagesPlaceholder("chat_history"),
96
+ ("human", "{input}")
97
+ ]
98
+ )
99
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
100
+
101
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
102
+
103
+ ### Statefully manage chat history ###
104
+ store = {}
105
+
106
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
107
+ if session_id not in store:
108
+ store[session_id] = ChatMessageHistory()
109
+ return store[session_id]
110
+
111
+
112
+ conversational_rag_chain = RunnableWithMessageHistory(
113
+ rag_chain,
114
+ get_session_history,
115
+ input_messages_key="input",
116
+ history_messages_key="chat_history",
117
+ output_messages_key="answer",
118
+ )
119
+
120
+ @spaces.GPU
121
+ def handle_message(question, history={}):
122
+ zero = torch.Tensor([0]).cuda()
123
+ print("With GPU: ", zero.device)
124
+ # question = data.get('question')
125
+ response = ''
126
+ chain = conversational_rag_chain.pick("answer")
127
+ for chunk in chain.stream(
128
+ {"input": question},
129
+ config={
130
+ "configurable": {"session_id": "abc123"}
131
+ },
132
+ ):
133
+ response += chunk
134
+ yield response
135
+
136
+ if __name__ == '__main__':
137
+ demo = gr.ChatInterface(fn=handle_message)
138
+ demo.launch()
bm25_traveler_website.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.5
2
+ aiosignal==1.3.1
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ async-timeout==4.0.3
6
+ attrs==23.2.0
7
+ bidict==0.23.1
8
+ blinker==1.8.2
9
+ certifi==2024.7.4
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ dataclasses-json==0.6.7
13
+ distro==1.9.0
14
+ exceptiongroup==1.2.2
15
+ filelock==3.15.4
16
+ flask==3.0.3
17
+ Flask-Cors==4.0.1
18
+ Flask-SocketIO==5.3.6
19
+ frozenlist==1.4.1
20
+ fsspec==2024.6.1
21
+ greenlet==3.0.3
22
+ groq==0.9.0
23
+ h11==0.14.0
24
+ httpcore==1.0.5
25
+ httpx==0.27.0
26
+ huggingface-hub==0.24.2
27
+ idna==3.7
28
+ importlib-metadata==8.2.0
29
+ itsdangerous==2.2.0
30
+ jinja2==3.1.4
31
+ joblib==1.4.2
32
+ jsonpatch==1.33
33
+ jsonpointer==3.0.0
34
+ langchain==0.2.11
35
+ langchain-community==0.2.10
36
+ langchain-core==0.2.24
37
+ langchain-groq==0.1.6
38
+ langchain-huggingface==0.0.3
39
+ langchain-text-splitters==0.2.2
40
+ langsmith==0.1.93
41
+ MarkupSafe==2.1.5
42
+ marshmallow==3.21.3
43
+ mmh3==4.1.0
44
+ mpmath==1.3.0
45
+ multidict==6.0.5
46
+ mypy-extensions==1.0.0
47
+ networkx==3.1
48
+ nltk==3.8.1
49
+ numpy==1.24.4
50
+ nvidia-cublas-cu12==12.1.3.1
51
+ nvidia-cuda-cupti-cu12==12.1.105
52
+ nvidia-cuda-nvrtc-cu12==12.1.105
53
+ nvidia-cuda-runtime-cu12==12.1.105
54
+ nvidia-cudnn-cu12==9.1.0.70
55
+ nvidia-cufft-cu12==11.0.2.54
56
+ nvidia-curand-cu12==10.3.2.106
57
+ nvidia-cusolver-cu12==11.4.5.107
58
+ nvidia-cusparse-cu12==12.1.0.106
59
+ nvidia-nccl-cu12==2.20.5
60
+ nvidia-nvjitlink-cu12==12.5.82
61
+ nvidia-nvtx-cu12==12.1.105
62
+ orjson==3.10.6
63
+ packaging==24.1
64
+ pillow==10.4.0
65
+ pinecone==4.0.0
66
+ pinecone-text==0.9.0
67
+ pydantic==2.8.2
68
+ pydantic-core==2.20.1
69
+ python-dotenv==1.0.1
70
+ python-engineio==4.9.1
71
+ python-socketio==5.11.3
72
+ PyYAML==6.0.1
73
+ regex==2024.7.24
74
+ requests==2.32.3
75
+ safetensors==0.4.3
76
+ scikit-learn==1.3.2
77
+ scipy==1.10.1
78
+ sentence-transformers==3.0.1
79
+ simple-websocket==1.0.0
80
+ sniffio==1.3.1
81
+ SQLAlchemy==2.0.31
82
+ sympy==1.13.1
83
+ tenacity==8.5.0
84
+ threadpoolctl==3.5.0
85
+ tokenizers==0.19.1
86
+ torch==2.4.0
87
+ tqdm==4.66.4
88
+ transformers==4.43.3
89
+ triton==3.0.0
90
+ types-requests==2.32.0.20240712
91
+ typing-extensions==4.12.2
92
+ typing-inspect==0.9.0
93
+ urllib3==2.2.2
94
+ werkzeug==3.0.3
95
+ wget==3.2
96
+ wsproto==1.2.0
97
+ yarl==1.9.4
98
+ zipp==3.19.2
temp.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ load_dotenv(".env")
4
+
5
+ os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
6
+ os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
7
+ os.environ["TOKENIZERS_PARALLELISM"]='true'
8
+
9
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
10
+ from langchain.chains.combine_documents import create_stuff_documents_chain
11
+ from langchain_community.chat_message_histories import ChatMessageHistory
12
+ from langchain_community.document_loaders import WebBaseLoader
13
+ from langchain_core.chat_history import BaseChatMessageHistory
14
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
+ from langchain_core.runnables.history import RunnableWithMessageHistory
16
+
17
+ from pinecone import Pinecone
18
+ from pinecone_text.sparse import BM25Encoder
19
+
20
+ from langchain_huggingface import HuggingFaceEmbeddings
21
+ from langchain_community.retrievers import PineconeHybridSearchRetriever
22
+
23
+ from langchain_groq import ChatGroq
24
+
25
+ # from flask import Flask, request, render_template
26
+ # from flask_cors import CORS
27
+ # from flask_socketio import SocketIO, emit
28
+
29
+ import gradio as gr
30
+ import spaces
31
+ import torch
32
+
33
+ zero = torch.Tensor([0]).cuda()
34
+ print(zero.device) # <-- 'cpu' 🤔
35
+
36
+ @spaces.GPU
37
+ def greet(n):
38
+ print(zero.device) # <-- 'cuda:0' 🤗
39
+ return f"Hello {zero + n} Tensor"
40
+
41
+
42
+ # app = Flask(__name__)
43
+ # CORS(app)
44
+ # socketio = SocketIO(app, cors_allowed_origins="*")
45
+ # app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
46
+ # app.config['SESSION_COOKIE_HTTPONLY'] = True
47
+ # app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
48
+ # app.config['SECRET_KEY'] = os.getenv('SECRET_KEY')
49
+
50
+ try:
51
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
52
+ index_name = "traveler-demo-website-vectorstore"
53
+ # connect to index
54
+ pinecone_index = pc.Index(index_name)
55
+ except:
56
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
57
+ index_name = "traveler-demo-website-vectorstore"
58
+ # connect to index
59
+ pinecone_index = pc.Index(index_name)
60
+
61
+ bm25 = BM25Encoder().load("./bm25_traveler_website.json")
62
+
63
+ embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})
64
+
65
+ retriever = PineconeHybridSearchRetriever(
66
+ embeddings=embed_model,
67
+ sparse_encoder=bm25,
68
+ index=pinecone_index,
69
+ top_k=20,
70
+ alpha=0.5,
71
+ )
72
+
73
+ llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2)
74
+
75
+ ### Contextualize question ###
76
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
77
+ which might reference context in the chat history, formulate a standalone question \
78
+ which can be understood without the chat history. Do NOT answer the question, \
79
+ just reformulate it if needed and otherwise return it as is.
80
+ """
81
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
82
+ [
83
+ ("system", contextualize_q_system_prompt),
84
+ MessagesPlaceholder("chat_history"),
85
+ ("human", "{input}")
86
+ ]
87
+ )
88
+
89
+ history_aware_retriever = create_history_aware_retriever(
90
+ llm, retriever, contextualize_q_prompt
91
+ )
92
+
93
+
94
+ qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \
95
+ Provide links to sources provided in the answer. \
96
+ If you don't know the answer, just say that you don't know. \
97
+ Do not give extra long answers. \
98
+ When responding to queries, your responses should be comprehensive and well-organized. For each response: \
99
+ 1. Provide Clear Answers \
100
+ 2. Include Detailed References: \
101
+ - Include links to sources and any links or sites where there is a mentioned in the answer.
102
+ - Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \
103
+ - Downloadable Materials: Include links to any relevant downloadable resources if applicable. \
104
+ - Reference Sites: Mention specific websites or platforms that offer additional information. \
105
+ 3. Formatting for Readability: \
106
+ - Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \
107
+ - Emphasize Important Information: Use bold or italics to highlight key details. \
108
+ 4. Organize Content Logically \
109
+ Do not include anything about context in the answer. \
110
+ {context}
111
+ """
112
+ qa_prompt = ChatPromptTemplate.from_messages(
113
+ [
114
+ ("system", qa_system_prompt),
115
+ MessagesPlaceholder("chat_history"),
116
+ ("human", "{input}")
117
+ ]
118
+ )
119
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
120
+
121
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
122
+
123
+ ### Statefully manage chat history ###
124
+ store = {}
125
+
126
+ def clean_temporary_data():
127
+ store = {}
128
+
129
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
130
+ if session_id not in store:
131
+ store[session_id] = ChatMessageHistory()
132
+ return store[session_id]
133
+
134
+
135
+ conversational_rag_chain = RunnableWithMessageHistory(
136
+ rag_chain,
137
+ get_session_history,
138
+ input_messages_key="input",
139
+ history_messages_key="chat_history",
140
+ output_messages_key="answer",
141
+ )
142
+
143
+ # Stream response to client
144
+ @socketio.on('message')
145
+ def handle_message(data):
146
+ question = data.get('question')
147
+ session_id = data.get('session_id', 'abc123')
148
+ chain = conversational_rag_chain.pick("answer")
149
+
150
+ try:
151
+ for chunk in chain.stream(
152
+ {"input": question},
153
+ config={
154
+ "configurable": {"session_id": "abc123"}
155
+ },
156
+ ):
157
+ emit('response', chunk, room=request.sid)
158
+ except:
159
+ for chunk in chain.stream(
160
+ {"input": question},
161
+ config={
162
+ "configurable": {"session_id": "abc123"}
163
+ },
164
+ ):
165
+ emit('response', chunk, room=request.sid)
166
+
167
+ @app.route("/")
168
+ def index_view():
169
+ return render_template('chat.html')
170
+
171
+ if __name__ == '__main__':
172
+ socketio.run(app, debug=True)
173
+
174
+
175
+ demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
176
+ demo.launch()