Ritesh-hf commited on
Commit
4271625
1 Parent(s): c29148e

change GPU settings

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -6,8 +6,8 @@ os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
6
  os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
7
  os.environ["TOKENIZERS_PARALLELISM"]='true'
8
 
9
- import nltk
10
- nltk.download('punkt')
11
 
12
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
  from langchain.chains.combine_documents import create_stuff_documents_chain
@@ -43,7 +43,7 @@ except:
43
 
44
  bm25 = BM25Encoder().load("./bm25_traveler_website.json")
45
 
46
- embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})
47
 
48
  retriever = PineconeHybridSearchRetriever(
49
  embeddings=embed_model,
@@ -120,11 +120,8 @@ conversational_rag_chain = RunnableWithMessageHistory(
120
  output_messages_key="answer",
121
  )
122
 
123
- @spaces.GPU
124
  def handle_message(question, history={}):
125
- zero = torch.Tensor([0]).cuda()
126
- print("With GPU: ", zero.device)
127
- # question = data.get('question')
128
  response = ''
129
  chain = conversational_rag_chain.pick("answer")
130
  for chunk in chain.stream(
 
6
  os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
7
  os.environ["TOKENIZERS_PARALLELISM"]='true'
8
 
9
+ # import nltk
10
+ # nltk.download('punkt')
11
 
12
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
43
 
44
  bm25 = BM25Encoder().load("./bm25_traveler_website.json")
45
 
46
+ embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True, 'device': 'cuda'})
47
 
48
  retriever = PineconeHybridSearchRetriever(
49
  embeddings=embed_model,
 
120
  output_messages_key="answer",
121
  )
122
 
123
+ @spaces.GPU(duration=10)
124
  def handle_message(question, history={}):
 
 
 
125
  response = ''
126
  chain = conversational_rag_chain.pick("answer")
127
  for chunk in chain.stream(