dh-mc commited on
Commit
3ca5bd8
1 Parent(s): 1e4d37b

added telegram bot

Browse files
Makefile CHANGED
@@ -15,6 +15,9 @@ test:
15
  chat:
16
  python test.py chat
17
 
 
 
 
18
  ingest:
19
  python ingest.py
20
 
 
15
  chat:
16
  python test.py chat
17
 
18
+ tele:
19
+ python telegram_bot.py
20
+
21
  ingest:
22
  python ingest.py
23
 
app_modules/llm_inference.py CHANGED
@@ -38,71 +38,64 @@ class LLMInference(metaclass=abc.ABCMeta):
38
  self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
39
  ):
40
  print(inputs)
 
41
 
42
- if self.llm_loader.streamer is not None and isinstance(
43
- self.llm_loader.streamer, TextIteratorStreamer
44
- ):
45
  self.llm_loader.streamer.reset(q)
46
 
47
- chain = self.get_chain(tracing)
48
- result = (
49
- self._run_chain(
50
- chain,
51
- inputs,
52
- streaming_handler,
 
 
 
 
53
  )
54
- if streaming_handler is not None
55
- else chain(inputs)
56
- )
57
 
58
- if "answer" in result:
59
- result["answer"] = remove_extra_spaces(result["answer"])
60
 
61
- base_url = os.environ.get("PDF_FILE_BASE_URL")
62
- if base_url is not None and len(base_url) > 0:
63
- documents = result["source_documents"]
64
- for doc in documents:
65
- source = doc.metadata["source"]
66
- title = source.split("/")[-1]
67
- doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
68
 
69
- return result
 
 
70
 
71
  def _execute_chain(self, chain, inputs, q, sh):
72
  q.put(chain(inputs, callbacks=[sh]))
73
 
74
  def _run_chain(self, chain, inputs, streaming_handler):
75
- self.llm_loader.lock.acquire()
76
- try:
77
- que = Queue()
78
 
79
- t = Thread(
80
- target=self._execute_chain,
81
- args=(chain, inputs, que, streaming_handler),
82
- )
83
- t.start()
84
-
85
- if self.llm_loader.streamer is not None and isinstance(
86
- self.llm_loader.streamer, TextIteratorStreamer
87
- ):
88
- count = (
89
- 2
90
- if "chat_history" in inputs and len(inputs.get("chat_history")) > 0
91
- else 1
92
- )
93
 
94
- while count > 0:
95
- try:
96
- for token in self.llm_loader.streamer:
97
- streaming_handler.on_llm_new_token(token)
98
 
99
- self.llm_loader.streamer.reset()
100
- count -= 1
101
- except Exception:
102
- print("nothing generated yet - retry in 0.5s")
103
- time.sleep(0.5)
104
 
105
- t.join()
106
- return que.get()
107
- finally:
108
- self.llm_loader.lock.release()
 
 
 
 
 
38
  self, inputs, streaming_handler, q: Queue = None, tracing: bool = False
39
  ):
40
  print(inputs)
41
+ self.llm_loader.lock.acquire()
42
 
43
+ try:
 
 
44
  self.llm_loader.streamer.reset(q)
45
 
46
+ chain = self.get_chain(tracing)
47
+ result = (
48
+ self._run_chain(
49
+ chain,
50
+ inputs,
51
+ streaming_handler,
52
+ )
53
+ if streaming_handler is not None
54
+ and self.llm_loader.streamer.for_huggingface
55
+ else chain(inputs)
56
  )
 
 
 
57
 
58
+ if "answer" in result:
59
+ result["answer"] = remove_extra_spaces(result["answer"])
60
 
61
+ base_url = os.environ.get("PDF_FILE_BASE_URL")
62
+ if base_url is not None and len(base_url) > 0:
63
+ documents = result["source_documents"]
64
+ for doc in documents:
65
+ source = doc.metadata["source"]
66
+ title = source.split("/")[-1]
67
+ doc.metadata["url"] = f"{base_url}{urllib.parse.quote(title)}"
68
 
69
+ return result
70
+ finally:
71
+ self.llm_loader.lock.release()
72
 
73
  def _execute_chain(self, chain, inputs, q, sh):
74
  q.put(chain(inputs, callbacks=[sh]))
75
 
76
  def _run_chain(self, chain, inputs, streaming_handler):
77
+ que = Queue()
 
 
78
 
79
+ t = Thread(
80
+ target=self._execute_chain,
81
+ args=(chain, inputs, que, streaming_handler),
82
+ )
83
+ t.start()
 
 
 
 
 
 
 
 
 
84
 
85
+ count = (
86
+ 2 if "chat_history" in inputs and len(inputs.get("chat_history")) > 0 else 1
87
+ )
 
88
 
89
+ while count > 0:
90
+ try:
91
+ for token in self.llm_loader.streamer:
92
+ streaming_handler.on_llm_new_token(token)
 
93
 
94
+ self.llm_loader.streamer.reset()
95
+ count -= 1
96
+ except Exception:
97
+ print("nothing generated yet - retry in 0.5s")
98
+ time.sleep(0.5)
99
+
100
+ t.join()
101
+ return que.get()
app_modules/llm_loader.py CHANGED
@@ -33,18 +33,22 @@ class TextIteratorStreamer(TextStreamer, StreamingStdOutCallbackHandler):
33
  tokenizer: "AutoTokenizer",
34
  skip_prompt: bool = False,
35
  timeout: Optional[float] = None,
 
36
  **decode_kwargs,
37
  ):
38
  super().__init__(tokenizer, skip_prompt, **decode_kwargs)
39
  self.text_queue = Queue()
40
  self.stop_signal = None
41
  self.timeout = timeout
 
 
42
 
43
  def on_finalized_text(self, text: str, stream_end: bool = False):
44
  super().on_finalized_text(text, stream_end=stream_end)
45
 
46
  """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
47
  self.text_queue.put(text, timeout=self.timeout)
 
48
  if stream_end:
49
  print("\n")
50
  self.text_queue.put("\n", timeout=self.timeout)
@@ -54,12 +58,16 @@ class TextIteratorStreamer(TextStreamer, StreamingStdOutCallbackHandler):
54
  sys.stdout.write(token)
55
  sys.stdout.flush()
56
  self.text_queue.put(token, timeout=self.timeout)
 
57
 
58
  def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
59
  print("\n")
60
  self.text_queue.put("\n", timeout=self.timeout)
61
  self.text_queue.put(self.stop_signal, timeout=self.timeout)
62
 
 
 
 
63
  def __iter__(self):
64
  return self
65
 
@@ -88,21 +96,18 @@ class LLMLoader:
88
  def __init__(self, llm_model_type, lc_serve: bool = False):
89
  self.llm_model_type = llm_model_type
90
  self.llm = None
91
- self.streamer = None if lc_serve else TextIteratorStreamer("")
92
  self.max_tokens_limit = 2048
93
  self.search_kwargs = {"k": 4}
94
  self.lock = threading.Lock()
95
 
96
- def _init_streamer(self, tokenizer, custom_handler):
97
- self.streamer = (
98
- TextIteratorStreamer(
99
- tokenizer,
100
- timeout=10.0,
101
- skip_prompt=True,
102
- skip_special_tokens=True,
103
- )
104
- if custom_handler is None
105
- else TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
106
  )
107
 
108
  def init(
@@ -179,7 +184,11 @@ class LLMLoader:
179
  MODEL_NAME_OR_PATH = os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH")
180
  print(f" loading model: {MODEL_NAME_OR_PATH}")
181
 
182
- hf_auth_token = os.environ.get("HUGGINGFACE_AUTH_TOKEN")
 
 
 
 
183
  transformers_offline = os.environ.get("TRANSFORMERS_OFFLINE") == "1"
184
  token = (
185
  hf_auth_token
@@ -231,7 +240,7 @@ class LLMLoader:
231
  )
232
  )
233
 
234
- self._init_streamer(tokenizer, custom_handler)
235
 
236
  task = "text2text-generation" if is_t5 else "text-generation"
237
 
@@ -343,14 +352,21 @@ class LLMLoader:
343
  MODEL_NAME_OR_PATH,
344
  config=config,
345
  trust_remote_code=True,
346
- token=token,
347
  )
348
  if is_t5
349
- else AutoModelForCausalLM.from_pretrained(
350
- MODEL_NAME_OR_PATH,
351
- config=config,
352
- trust_remote_code=True,
353
- token=token,
 
 
 
 
 
 
 
 
354
  )
355
  )
356
  print(f"Model memory footprint: {model.get_memory_footprint()}")
@@ -405,7 +421,7 @@ class LLMLoader:
405
  print(f"Model memory footprint: {model.get_memory_footprint()}")
406
 
407
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
408
- self._init_streamer(tokenizer, custom_handler)
409
 
410
  # mtp-7b is trained to add "<|endoftext|>" at the end of generations
411
  stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
@@ -497,7 +513,7 @@ class LLMLoader:
497
  print(f"Model memory footprint: {model.get_memory_footprint()}")
498
 
499
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
500
- self._init_streamer(tokenizer, custom_handler)
501
 
502
  class StopOnTokens(StoppingCriteria):
503
  def __call__(
 
33
  tokenizer: "AutoTokenizer",
34
  skip_prompt: bool = False,
35
  timeout: Optional[float] = None,
36
+ for_huggingface: bool = False,
37
  **decode_kwargs,
38
  ):
39
  super().__init__(tokenizer, skip_prompt, **decode_kwargs)
40
  self.text_queue = Queue()
41
  self.stop_signal = None
42
  self.timeout = timeout
43
+ self.total_tokens = 0
44
+ self.for_huggingface = for_huggingface
45
 
46
  def on_finalized_text(self, text: str, stream_end: bool = False):
47
  super().on_finalized_text(text, stream_end=stream_end)
48
 
49
  """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
50
  self.text_queue.put(text, timeout=self.timeout)
51
+ self.total_tokens = self.total_tokens + 1
52
  if stream_end:
53
  print("\n")
54
  self.text_queue.put("\n", timeout=self.timeout)
 
58
  sys.stdout.write(token)
59
  sys.stdout.flush()
60
  self.text_queue.put(token, timeout=self.timeout)
61
+ self.total_tokens = self.total_tokens + 1
62
 
63
  def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
64
  print("\n")
65
  self.text_queue.put("\n", timeout=self.timeout)
66
  self.text_queue.put(self.stop_signal, timeout=self.timeout)
67
 
68
+ def for_huggingface(self) -> bool:
69
+ return self.tokenizer != ""
70
+
71
  def __iter__(self):
72
  return self
73
 
 
96
  def __init__(self, llm_model_type, lc_serve: bool = False):
97
  self.llm_model_type = llm_model_type
98
  self.llm = None
99
+ self.streamer = TextIteratorStreamer("")
100
  self.max_tokens_limit = 2048
101
  self.search_kwargs = {"k": 4}
102
  self.lock = threading.Lock()
103
 
104
+ def _init_hf_streamer(self, tokenizer):
105
+ self.streamer = TextIteratorStreamer(
106
+ tokenizer,
107
+ timeout=10.0,
108
+ skip_prompt=True,
109
+ skip_special_tokens=True,
110
+ for_huggingface=True,
 
 
 
111
  )
112
 
113
  def init(
 
184
  MODEL_NAME_OR_PATH = os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH")
185
  print(f" loading model: {MODEL_NAME_OR_PATH}")
186
 
187
+ hf_auth_token = (
188
+ os.environ.get("HUGGINGFACE_AUTH_TOKEN")
189
+ if "Llama-2" in MODEL_NAME_OR_PATH
190
+ else None
191
+ )
192
  transformers_offline = os.environ.get("TRANSFORMERS_OFFLINE") == "1"
193
  token = (
194
  hf_auth_token
 
240
  )
241
  )
242
 
243
+ self._init_hf_streamer(tokenizer)
244
 
245
  task = "text2text-generation" if is_t5 else "text-generation"
246
 
 
352
  MODEL_NAME_OR_PATH,
353
  config=config,
354
  trust_remote_code=True,
 
355
  )
356
  if is_t5
357
+ else (
358
+ AutoModelForCausalLM.from_pretrained(
359
+ MODEL_NAME_OR_PATH,
360
+ config=config,
361
+ trust_remote_code=True,
362
+ )
363
+ if token is None
364
+ else AutoModelForCausalLM.from_pretrained(
365
+ MODEL_NAME_OR_PATH,
366
+ config=config,
367
+ trust_remote_code=True,
368
+ token=token,
369
+ )
370
  )
371
  )
372
  print(f"Model memory footprint: {model.get_memory_footprint()}")
 
421
  print(f"Model memory footprint: {model.get_memory_footprint()}")
422
 
423
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
424
+ self._init_hf_streamer(tokenizer)
425
 
426
  # mtp-7b is trained to add "<|endoftext|>" at the end of generations
427
  stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
 
513
  print(f"Model memory footprint: {model.get_memory_footprint()}")
514
 
515
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
516
+ self._init_hf_streamer(tokenizer)
517
 
518
  class StopOnTokens(StoppingCriteria):
519
  def __call__(
requirements.txt CHANGED
@@ -31,3 +31,4 @@ einops
31
  gevent
32
  pydantic >= 1.10.11
33
  pypdf
 
 
31
  gevent
32
  pydantic >= 1.10.11
33
  pypdf
34
+ python-telegram-bot
server.py CHANGED
@@ -11,7 +11,7 @@ from app_modules.init import app_init
11
  from app_modules.llm_chat_chain import ChatChain
12
  from app_modules.utils import print_llm_response
13
 
14
- llm_loader, qa_chain = app_init(True)
15
 
16
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
17
 
@@ -26,14 +26,13 @@ class ChatResponse(BaseModel):
26
  sourceDocs: Optional[List] = None
27
 
28
 
29
- @serving(websocket=True)
30
- def chat(
31
- question: str, history: Optional[List] = [], uuid: Optional[str] = None, **kwargs
32
- ) -> str:
33
- print(f"uuid: {uuid}")
34
- # Get the `streaming_handler` from `kwargs`. This is used to stream data to the client.
35
- streaming_handler = kwargs.get("streaming_handler")
36
- if uuid is None:
37
  chat_history = []
38
  if chat_history_enabled:
39
  for element in history:
@@ -48,21 +47,49 @@ def chat(
48
  print(f"Completed in {end - start:.3f}s")
49
 
50
  print(f"qa_chain result: {result}")
51
- resp = ChatResponse(sourceDocs=result["source_documents"])
52
-
53
- return json.dumps(resp.dict())
54
  else:
55
- if uuid in uuid_to_chat_chain_mapping:
56
- chat = uuid_to_chat_chain_mapping[uuid]
57
  else:
58
  chat = ChatChain(llm_loader)
59
- uuid_to_chat_chain_mapping[uuid] = chat
60
  result = chat.call_chain({"question": question}, streaming_handler)
61
  print(f"chat result: {result}")
 
 
62
 
63
- resp = ChatResponse(sourceDocs=[])
64
- return json.dumps(resp.dict())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  if __name__ == "__main__":
68
- print_llm_response(json.loads(chat("What's deep learning?", [])))
 
 
 
 
 
 
 
 
 
 
 
11
  from app_modules.llm_chat_chain import ChatChain
12
  from app_modules.utils import print_llm_response
13
 
14
+ llm_loader, qa_chain = app_init(__name__ != "__main__")
15
 
16
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
17
 
 
26
  sourceDocs: Optional[List] = None
27
 
28
 
29
+ def do_chat(
30
+ question: str,
31
+ history: Optional[List] = [],
32
+ chat_id: Optional[str] = None,
33
+ streaming_handler: any = None,
34
+ ):
35
+ if chat_id is None:
 
36
  chat_history = []
37
  if chat_history_enabled:
38
  for element in history:
 
47
  print(f"Completed in {end - start:.3f}s")
48
 
49
  print(f"qa_chain result: {result}")
50
+ return result
 
 
51
  else:
52
+ if chat_id in uuid_to_chat_chain_mapping:
53
+ chat = uuid_to_chat_chain_mapping[chat_id]
54
  else:
55
  chat = ChatChain(llm_loader)
56
+ uuid_to_chat_chain_mapping[chat_id] = chat
57
  result = chat.call_chain({"question": question}, streaming_handler)
58
  print(f"chat result: {result}")
59
+ return result
60
+
61
 
62
+ @serving(websocket=True)
63
+ def chat(
64
+ question: str, history: Optional[List] = [], chat_id: Optional[str] = None, **kwargs
65
+ ) -> str:
66
+ print("question@chat:", question)
67
+ streaming_handler = kwargs.get("streaming_handler")
68
+ result = do_chat(question, history, chat_id, streaming_handler)
69
+ resp = ChatResponse(
70
+ sourceDocs=result["source_documents"] if chat_id is None else []
71
+ )
72
+ return json.dumps(resp.dict())
73
+
74
+
75
+ @serving
76
+ def chat_sync(
77
+ question: str, history: Optional[List] = [], chat_id: Optional[str] = None, **kwargs
78
+ ) -> str:
79
+ print("question@chat_sync:", question)
80
+ result = do_chat(question, history, chat_id, None)
81
+ return result["text"]
82
 
83
 
84
  if __name__ == "__main__":
85
+ # print_llm_response(json.loads(chat("What's deep learning?", [])))
86
+ chat_start = timer()
87
+ chat_sync("What's generative AI?", chat_id="test_user")
88
+ chat_sync("more on finance", chat_id="test_user")
89
+ chat_end = timer()
90
+ total_time = chat_end - chat_start
91
+ print(f"Total time used: {total_time:.3f} s")
92
+ print(f"Number of tokens generated: {llm_loader.streamer.total_tokens}")
93
+ print(
94
+ f"Average generation speed: {llm_loader.streamer.total_tokens / total_time:.3f} tokens/s"
95
+ )
telegram_bot.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ssl
3
+ import time
4
+ from threading import Thread
5
+
6
+ import requests
7
+ from telegram import Update
8
+ from telegram import __version__ as TG_VER
9
+ from telegram.ext import (
10
+ Application,
11
+ CommandHandler,
12
+ ContextTypes,
13
+ MessageHandler,
14
+ filters,
15
+ )
16
+
17
+ from app_modules.init import *
18
+
19
+ ctx = ssl.create_default_context()
20
+ ctx.set_ciphers("DEFAULT")
21
+
22
+ try:
23
+ from telegram import __version_info__
24
+ except ImportError:
25
+ __version_info__ = (0, 0, 0, 0, 0) # type: ignore[assignment]
26
+
27
+ if __version_info__ < (20, 0, 0, "alpha", 1):
28
+ raise RuntimeError(
29
+ f"This example is not compatible with your current PTB version {TG_VER}. To view the "
30
+ f"{TG_VER} version of this example, "
31
+ f"visit https://docs.python-telegram-bot.org/en/v{TG_VER}/examples.html"
32
+ )
33
+
34
+ TOKEN = os.getenv("TELEGRAM_API_TOKEN")
35
+ ENDPOINT = os.getenv("CHAT_API_URL")
36
+
37
+
38
+ # Define a few command handlers. These usually take the two arguments update and
39
+ # context.
40
+ async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
41
+ """Send a message when the command /start is issued."""
42
+ user = update.effective_user
43
+ await update.message.reply_html(
44
+ rf"Hi {user.mention_html()}! You are welcome to ask questions on anything!",
45
+ )
46
+
47
+
48
+ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
49
+ """Send a message when the command /help is issued."""
50
+ await update.message.reply_text("Help!")
51
+
52
+
53
+ async def chat_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
54
+ """Echo the user message."""
55
+ tic = time.perf_counter()
56
+ try:
57
+ message = {
58
+ "question": update.message.text,
59
+ "chat_id": update.message.chat.username,
60
+ }
61
+ print(message)
62
+ x = requests.post(ENDPOINT, json=message).json()
63
+ temp = time.perf_counter()
64
+ print(f"Received response in {temp - tic:0.4f} seconds")
65
+ result = x["result"]
66
+ print(result)
67
+ await update.message.reply_text(result)
68
+ toc = time.perf_counter()
69
+ print(f"Response time in {toc - tic:0.4f} seconds")
70
+ except Exception as e:
71
+ print("error", e)
72
+
73
+
74
+ def start_telegram_bot() -> None:
75
+ """Start the bot."""
76
+ print("starting telegram bot ...")
77
+ # Create the Application and pass it your bot's token.
78
+ application = Application.builder().token(TOKEN).build()
79
+
80
+ # on different commands - answer in Telegram
81
+ application.add_handler(CommandHandler("start_command", start_command))
82
+ application.add_handler(CommandHandler("help", help_command))
83
+
84
+ # on non command i.e message - chat_command the message on Telegram
85
+ application.add_handler(
86
+ MessageHandler(filters.TEXT & ~filters.COMMAND, chat_command)
87
+ )
88
+
89
+ application.run_polling()
90
+
91
+
92
+ if __name__ == "__main__":
93
+ start_telegram_bot()