change GPU settings
Browse files
app.py
CHANGED
@@ -6,8 +6,8 @@ os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
|
|
6 |
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
|
7 |
os.environ["TOKENIZERS_PARALLELISM"]='true'
|
8 |
|
9 |
-
import nltk
|
10 |
-
nltk.download('punkt')
|
11 |
|
12 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
13 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
@@ -43,7 +43,7 @@ except:
|
|
43 |
|
44 |
bm25 = BM25Encoder().load("./bm25_traveler_website.json")
|
45 |
|
46 |
-
embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})
|
47 |
|
48 |
retriever = PineconeHybridSearchRetriever(
|
49 |
embeddings=embed_model,
|
@@ -120,11 +120,8 @@ conversational_rag_chain = RunnableWithMessageHistory(
|
|
120 |
output_messages_key="answer",
|
121 |
)
|
122 |
|
123 |
-
@spaces.GPU
|
124 |
def handle_message(question, history={}):
|
125 |
-
zero = torch.Tensor([0]).cuda()
|
126 |
-
print("With GPU: ", zero.device)
|
127 |
-
# question = data.get('question')
|
128 |
response = ''
|
129 |
chain = conversational_rag_chain.pick("answer")
|
130 |
for chunk in chain.stream(
|
|
|
6 |
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
|
7 |
os.environ["TOKENIZERS_PARALLELISM"]='true'
|
8 |
|
9 |
+
# import nltk
|
10 |
+
# nltk.download('punkt')
|
11 |
|
12 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
13 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
43 |
|
44 |
bm25 = BM25Encoder().load("./bm25_traveler_website.json")
|
45 |
|
46 |
+
embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True, 'device': 'cuda'})
|
47 |
|
48 |
retriever = PineconeHybridSearchRetriever(
|
49 |
embeddings=embed_model,
|
|
|
120 |
output_messages_key="answer",
|
121 |
)
|
122 |
|
123 |
+
@spaces.GPU(duration=10)
|
124 |
def handle_message(question, history={}):
|
|
|
|
|
|
|
125 |
response = ''
|
126 |
chain = conversational_rag_chain.pick("answer")
|
127 |
for chunk in chain.stream(
|