Farid Karimli commited on
Commit
4de6b1a
1 Parent(s): 4f1706c

Initial streaming implementation

Browse files
code/main.py CHANGED
@@ -12,7 +12,6 @@ from typing import Optional
12
  from dotenv import load_dotenv
13
 
14
  load_dotenv()
15
- print(os.environ.get("OAUTH_GOOGLE_CLIENT_ID"))
16
 
17
  USER_TIMEOUT = 60_000
18
  SYSTEM = "System 🖥️"
@@ -248,25 +247,38 @@ class Chatbot:
248
  processor = cl.user_session.get("chat_processor")
249
  res = await processor.rag(message.content, chain)
250
 
251
- answer = res.get("answer", res.get("result"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  answer_with_sources, source_elements, sources_dict = get_sources(
253
- res, answer, view_sources=view_sources
254
  )
255
  processor._process(message.content, answer, sources_dict)
256
 
257
  await cl.Message(content=answer_with_sources, elements=source_elements).send()
258
 
259
- def oauth_callback(
260
- provider_id: str,
261
- token: str,
262
- raw_user_data: Dict[str, str],
263
- default_user: cl.User,
264
- ) -> Optional[cl.User]:
265
- return default_user
266
-
267
 
268
  chatbot = Chatbot()
269
- cl.oauth_callback(chatbot.oauth_callback)
270
  cl.set_starters(chatbot.set_starters)
271
  cl.author_rename(chatbot.rename)
272
  cl.on_chat_start(chatbot.start)
 
12
  from dotenv import load_dotenv
13
 
14
  load_dotenv()
 
15
 
16
  USER_TIMEOUT = 60_000
17
  SYSTEM = "System 🖥️"
 
247
  processor = cl.user_session.get("chat_processor")
248
  res = await processor.rag(message.content, chain)
249
 
250
+ # TODO: STREAM MESSAGE
251
+ msg = cl.Message(content="")
252
+ await msg.send()
253
+
254
+ output = {}
255
+ for chunk in res:
256
+ if 'answer' in chunk:
257
+ await msg.stream_token(chunk['answer'])
258
+
259
+ for key in chunk:
260
+ if key not in output:
261
+ output[key] = chunk[key]
262
+ else:
263
+ output[key] += chunk[key]
264
+
265
+ answer = output.get("answer", output.get("result"))
266
+
267
  answer_with_sources, source_elements, sources_dict = get_sources(
268
+ output, answer, view_sources=view_sources
269
  )
270
  processor._process(message.content, answer, sources_dict)
271
 
272
  await cl.Message(content=answer_with_sources, elements=source_elements).send()
273
 
274
+ def auth_callback(self, username: str, password: str) -> Optional[cl.User]:
275
+ return cl.User(
276
+ identifier=username,
277
+ metadata={"role": "admin", "provider": "credentials"},
278
+ )
 
 
 
279
 
280
  chatbot = Chatbot()
281
+ cl.password_auth_callback(chatbot.auth_callback)
282
  cl.set_starters(chatbot.set_starters)
283
  cl.author_rename(chatbot.rename)
284
  cl.on_chat_start(chatbot.start)
code/modules/chat/helpers.py CHANGED
@@ -36,8 +36,9 @@ def get_sources(res, answer, view_sources=False):
36
  source_dict[url_name]["text"] += f"\n\n{source.page_content}"
37
 
38
  # First, display the answer
39
- full_answer = "**Answer:**\n"
40
- full_answer += answer
 
41
 
42
  if view_sources:
43
 
 
36
  source_dict[url_name]["text"] += f"\n\n{source.page_content}"
37
 
38
  # First, display the answer
39
+ #full_answer = "**Answer:**\n"
40
+ #full_answer += answer
41
+ full_answer = "" # Not to include the answer again
42
 
43
  if view_sources:
44
 
code/modules/chat/langchain/langchain_rag.py CHANGED
@@ -133,6 +133,10 @@ class Langchain_RAG(BaseRAG):
133
  res["qa_prompt"] = self.qa_prompt
134
  return res
135
 
 
 
 
 
136
  def add_history_from_list(self, history_list):
137
  """
138
  Add messages from a list to the chat history.
 
133
  res["qa_prompt"] = self.qa_prompt
134
  return res
135
 
136
+ def stream(self, user_query, config):
137
+ res = self.rag_chain.stream(user_query, config)
138
+ return res
139
+
140
  def add_history_from_list(self, history_list):
141
  """
142
  Add messages from a list to the chat history.
code/modules/chat_processor/chat_processor.py CHANGED
@@ -50,4 +50,4 @@ class ChatProcessor:
50
  user_query=user_query_dict, config=config, chain=chain
51
  )
52
  else:
53
- return chain.invoke(user_query=user_query_dict, config=config)
 
50
  user_query=user_query_dict, config=config, chain=chain
51
  )
52
  else:
53
+ return chain.stream(user_query=user_query_dict, config=config)