Ritesh-hf commited on
Commit
1639e80
1 Parent(s): 8870bd5

change GPU duration

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +1 -1
  3. temp.py +0 -176
  4. test.py +16 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test.py
app.py CHANGED
@@ -120,7 +120,7 @@ conversational_rag_chain = RunnableWithMessageHistory(
120
  output_messages_key="answer",
121
  )
122
 
123
- @spaces.GPU(duration=5)
124
  def handle_message(question, history={}):
125
  response = ''
126
  chain = conversational_rag_chain.pick("answer")
 
120
  output_messages_key="answer",
121
  )
122
 
123
+ @spaces.GPU(duration=10)
124
  def handle_message(question, history={}):
125
  response = ''
126
  chain = conversational_rag_chain.pick("answer")
temp.py DELETED
@@ -1,176 +0,0 @@
1
- import os
2
- from dotenv import load_dotenv
3
- load_dotenv(".env")
4
-
5
- os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
6
- os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
7
- os.environ["TOKENIZERS_PARALLELISM"]='true'
8
-
9
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
10
- from langchain.chains.combine_documents import create_stuff_documents_chain
11
- from langchain_community.chat_message_histories import ChatMessageHistory
12
- from langchain_community.document_loaders import WebBaseLoader
13
- from langchain_core.chat_history import BaseChatMessageHistory
14
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
- from langchain_core.runnables.history import RunnableWithMessageHistory
16
-
17
- from pinecone import Pinecone
18
- from pinecone_text.sparse import BM25Encoder
19
-
20
- from langchain_huggingface import HuggingFaceEmbeddings
21
- from langchain_community.retrievers import PineconeHybridSearchRetriever
22
-
23
- from langchain_groq import ChatGroq
24
-
25
- # from flask import Flask, request, render_template
26
- # from flask_cors import CORS
27
- # from flask_socketio import SocketIO, emit
28
-
29
- import gradio as gr
30
- import spaces
31
- import torch
32
-
33
- zero = torch.Tensor([0]).cuda()
34
- print(zero.device) # <-- 'cpu' 🤔
35
-
36
- @spaces.GPU
37
- def greet(n):
38
- print(zero.device) # <-- 'cuda:0' 🤗
39
- return f"Hello {zero + n} Tensor"
40
-
41
-
42
- # app = Flask(__name__)
43
- # CORS(app)
44
- # socketio = SocketIO(app, cors_allowed_origins="*")
45
- # app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
46
- # app.config['SESSION_COOKIE_HTTPONLY'] = True
47
- # app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
48
- # app.config['SECRET_KEY'] = os.getenv('SECRET_KEY')
49
-
50
- try:
51
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
52
- index_name = "traveler-demo-website-vectorstore"
53
- # connect to index
54
- pinecone_index = pc.Index(index_name)
55
- except:
56
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
57
- index_name = "traveler-demo-website-vectorstore"
58
- # connect to index
59
- pinecone_index = pc.Index(index_name)
60
-
61
- bm25 = BM25Encoder().load("./bm25_traveler_website.json")
62
-
63
- embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})
64
-
65
- retriever = PineconeHybridSearchRetriever(
66
- embeddings=embed_model,
67
- sparse_encoder=bm25,
68
- index=pinecone_index,
69
- top_k=20,
70
- alpha=0.5,
71
- )
72
-
73
- llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2)
74
-
75
- ### Contextualize question ###
76
- contextualize_q_system_prompt = """Given a chat history and the latest user question \
77
- which might reference context in the chat history, formulate a standalone question \
78
- which can be understood without the chat history. Do NOT answer the question, \
79
- just reformulate it if needed and otherwise return it as is.
80
- """
81
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
82
- [
83
- ("system", contextualize_q_system_prompt),
84
- MessagesPlaceholder("chat_history"),
85
- ("human", "{input}")
86
- ]
87
- )
88
-
89
- history_aware_retriever = create_history_aware_retriever(
90
- llm, retriever, contextualize_q_prompt
91
- )
92
-
93
-
94
- qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \
95
- Provide links to sources provided in the answer. \
96
- If you don't know the answer, just say that you don't know. \
97
- Do not give extra long answers. \
98
- When responding to queries, your responses should be comprehensive and well-organized. For each response: \
99
- 1. Provide Clear Answers \
100
- 2. Include Detailed References: \
101
- - Include links to sources and any links or sites where there is a mentioned in the answer.
102
- - Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \
103
- - Downloadable Materials: Include links to any relevant downloadable resources if applicable. \
104
- - Reference Sites: Mention specific websites or platforms that offer additional information. \
105
- 3. Formatting for Readability: \
106
- - Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \
107
- - Emphasize Important Information: Use bold or italics to highlight key details. \
108
- 4. Organize Content Logically \
109
- Do not include anything about context in the answer. \
110
- {context}
111
- """
112
- qa_prompt = ChatPromptTemplate.from_messages(
113
- [
114
- ("system", qa_system_prompt),
115
- MessagesPlaceholder("chat_history"),
116
- ("human", "{input}")
117
- ]
118
- )
119
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
120
-
121
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
122
-
123
- ### Statefully manage chat history ###
124
- store = {}
125
-
126
- def clean_temporary_data():
127
- store = {}
128
-
129
- def get_session_history(session_id: str) -> BaseChatMessageHistory:
130
- if session_id not in store:
131
- store[session_id] = ChatMessageHistory()
132
- return store[session_id]
133
-
134
-
135
- conversational_rag_chain = RunnableWithMessageHistory(
136
- rag_chain,
137
- get_session_history,
138
- input_messages_key="input",
139
- history_messages_key="chat_history",
140
- output_messages_key="answer",
141
- )
142
-
143
- # Stream response to client
144
- @socketio.on('message')
145
- def handle_message(data):
146
- question = data.get('question')
147
- session_id = data.get('session_id', 'abc123')
148
- chain = conversational_rag_chain.pick("answer")
149
-
150
- try:
151
- for chunk in chain.stream(
152
- {"input": question},
153
- config={
154
- "configurable": {"session_id": "abc123"}
155
- },
156
- ):
157
- emit('response', chunk, room=request.sid)
158
- except:
159
- for chunk in chain.stream(
160
- {"input": question},
161
- config={
162
- "configurable": {"session_id": "abc123"}
163
- },
164
- ):
165
- emit('response', chunk, room=request.sid)
166
-
167
- @app.route("/")
168
- def index_view():
169
- return render_template('chat.html')
170
-
171
- if __name__ == '__main__':
172
- socketio.run(app, debug=True)
173
-
174
-
175
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
176
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+ import timeit
3
+
4
+ client = Client("Ritesh-hf/rag-api")
5
+
6
+
7
+ while True:
8
+ question = input("Question: ")
9
+ start_time = timeit.default_timer()
10
+ result = client.predict(
11
+ question=question,
12
+ api_name="/chat"
13
+ )
14
+ end_time = timeit.default_timer()
15
+ print(result)
16
+ print("Time Taken: ", end_time-start_time)