dh-mc commited on
Commit
6011708
1 Parent(s): 7fded8d

use ConversationChain + ConversationSummaryBufferMemory

Browse files
Makefile CHANGED
@@ -12,9 +12,15 @@ endif
12
  test:
13
  python test.py
14
 
 
 
 
15
  chat:
16
  python test.py chat
17
 
 
 
 
18
  unittest:
19
  python unit_test.py $(TEST)
20
 
 
12
  test:
13
  python test.py
14
 
15
+ test2:
16
+ python server.py
17
+
18
  chat:
19
  python test.py chat
20
 
21
+ chat2:
22
+ python unit_test.py chat
23
+
24
  unittest:
25
  python unit_test.py $(TEST)
26
 
app_modules/llm_chat_chain.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
 
2
 
3
- from langchain import LLMChain, PromptTemplate
4
- from langchain.chains import ConversationalRetrievalChain
5
  from langchain.chains.base import Chain
6
- from langchain.memory import ConversationBufferMemory
7
 
8
  from app_modules.llm_inference import LLMInference
9
 
@@ -12,7 +12,7 @@ def get_llama_2_prompt_template():
12
  B_INST, E_INST = "[INST]", "[/INST]"
13
  B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
14
 
15
- instruction = "Chat History:\n\n{chat_history} \n\nUser: {question}"
16
  system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the chat history to get context"
17
  # system_prompt = """\
18
  # You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. \n\nDo not output any emotional expression. Read the chat history to get context.\
@@ -32,20 +32,20 @@ class ChatChain(LLMInference):
32
  get_llama_2_prompt_template()
33
  if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
34
  else """You are a chatbot having a conversation with a human.
35
- {chat_history}
36
- Human: {question}
37
  Chatbot:"""
38
  )
39
 
40
  print(f"template: {template}")
41
 
42
- prompt = PromptTemplate(
43
- input_variables=["chat_history", "question"], template=template
44
- )
45
 
46
- memory = ConversationBufferMemory(memory_key="chat_history")
 
 
47
 
48
- llm_chain = LLMChain(
49
  llm=self.llm_loader.llm,
50
  prompt=prompt,
51
  verbose=True,
@@ -53,3 +53,6 @@ Chatbot:"""
53
  )
54
 
55
  return llm_chain
 
 
 
 
1
  import os
2
+ from typing import List, Optional
3
 
4
+ from langchain import ConversationChain, PromptTemplate
 
5
  from langchain.chains.base import Chain
6
+ from langchain.memory import ConversationSummaryBufferMemory
7
 
8
  from app_modules.llm_inference import LLMInference
9
 
 
12
  B_INST, E_INST = "[INST]", "[/INST]"
13
  B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
14
 
15
+ instruction = "Chat History:\n\n{history} \n\nUser: {input}"
16
  system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the chat history to get context"
17
  # system_prompt = """\
18
  # You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. \n\nDo not output any emotional expression. Read the chat history to get context.\
 
32
  get_llama_2_prompt_template()
33
  if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
34
  else """You are a chatbot having a conversation with a human.
35
+ {history}
36
+ Human: {input}
37
  Chatbot:"""
38
  )
39
 
40
  print(f"template: {template}")
41
 
42
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template)
 
 
43
 
44
+ memory = ConversationSummaryBufferMemory(
45
+ llm=self.llm_loader.llm, max_token_limit=1024, return_messages=True
46
+ )
47
 
48
+ llm_chain = ConversationChain(
49
  llm=self.llm_loader.llm,
50
  prompt=prompt,
51
  verbose=True,
 
53
  )
54
 
55
  return llm_chain
56
+
57
+ def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
58
+ return chain({"input": inputs["question"]}, callbacks)
app_modules/llm_inference.py CHANGED
@@ -4,6 +4,7 @@ import time
4
  import urllib
5
  from queue import Queue
6
  from threading import Thread
 
7
 
8
  from langchain.chains.base import Chain
9
 
@@ -29,6 +30,9 @@ class LLMInference(metaclass=abc.ABCMeta):
29
 
30
  return self.chain
31
 
 
 
 
32
  def call_chain(
33
  self,
34
  inputs,
@@ -45,9 +49,11 @@ class LLMInference(metaclass=abc.ABCMeta):
45
 
46
  chain = self.get_chain()
47
  result = (
48
- self._run_chain(chain, inputs, streaming_handler, testing)
 
 
49
  if streaming_handler is not None
50
- else chain(inputs)
51
  )
52
 
53
  if "answer" in result:
@@ -67,9 +73,11 @@ class LLMInference(metaclass=abc.ABCMeta):
67
  self.llm_loader.lock.release()
68
 
69
  def _execute_chain(self, chain, inputs, q, sh):
70
- q.put(chain(inputs, callbacks=[sh]))
71
 
72
- def _run_chain(self, chain, inputs, streaming_handler, testing):
 
 
73
  que = Queue()
74
 
75
  t = Thread(
 
4
  import urllib
5
  from queue import Queue
6
  from threading import Thread
7
+ from typing import List, Optional
8
 
9
  from langchain.chains.base import Chain
10
 
 
30
 
31
  return self.chain
32
 
33
+ def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
34
+ return chain(inputs, callbacks)
35
+
36
  def call_chain(
37
  self,
38
  inputs,
 
49
 
50
  chain = self.get_chain()
51
  result = (
52
+ self._run_chain_with_streaming_handler(
53
+ chain, inputs, streaming_handler, testing
54
+ )
55
  if streaming_handler is not None
56
+ else self.run_chain(chain, inputs)
57
  )
58
 
59
  if "answer" in result:
 
73
  self.llm_loader.lock.release()
74
 
75
  def _execute_chain(self, chain, inputs, q, sh):
76
+ q.put(self.run_chain(chain, inputs, callbacks=[sh]))
77
 
78
+ def _run_chain_with_streaming_handler(
79
+ self, chain, inputs, streaming_handler, testing
80
+ ):
81
  que = Queue()
82
 
83
  t = Thread(
app_modules/llm_loader.py CHANGED
@@ -188,6 +188,7 @@ class LLMLoader:
188
  )
189
  elif self.llm_model_type == "hftgi":
190
  HFTGI_SERVER_URL = os.environ.get("HFTGI_SERVER_URL")
 
191
  self.llm = HuggingFaceTextGenInference(
192
  inference_server_url=HFTGI_SERVER_URL,
193
  max_new_tokens=self.max_tokens_limit / 2,
 
188
  )
189
  elif self.llm_model_type == "hftgi":
190
  HFTGI_SERVER_URL = os.environ.get("HFTGI_SERVER_URL")
191
+ self.max_tokens_limit = 4096
192
  self.llm = HuggingFaceTextGenInference(
193
  inference_server_url=HFTGI_SERVER_URL,
194
  max_new_tokens=self.max_tokens_limit / 2,
server.py CHANGED
@@ -78,17 +78,18 @@ def chat_sync(
78
  ) -> str:
79
  print("question@chat_sync:", question)
80
  result = do_chat(question, history, chat_id, None)
81
- return result["text"]
82
 
83
 
84
  if __name__ == "__main__":
85
  # print_llm_response(json.loads(chat("What's deep learning?", [])))
86
  chat_start = timer()
87
- chat_sync("What's generative AI?", chat_id="test_user")
88
  chat_sync("more on finance", chat_id="test_user")
89
- # chat_sync("给我讲一个年轻人奋斗创业最终取得成功的故事。", chat_id="test_user")
90
- # chat_sync("给这个故事起一个标题", chat_id="test_user")
91
- # chat_sync("Write the game 'snake' in python", chat_id="test_user")
 
92
  chat_end = timer()
93
  total_time = chat_end - chat_start
94
  print(f"Total time used: {total_time:.3f} s")
 
78
  ) -> str:
79
  print("question@chat_sync:", question)
80
  result = do_chat(question, history, chat_id, None)
81
+ return result["response"]
82
 
83
 
84
  if __name__ == "__main__":
85
  # print_llm_response(json.loads(chat("What's deep learning?", [])))
86
  chat_start = timer()
87
+ chat_sync("what's deep learning?", chat_id="test_user")
88
  chat_sync("more on finance", chat_id="test_user")
89
+ chat_sync("more on Sentiment analysis", chat_id="test_user")
90
+ chat_sync("Write the game 'snake' in python", chat_id="test_user")
91
+ chat_sync("给我讲一个年轻人奋斗创业最终取得成功的故事。", chat_id="test_user")
92
+ chat_sync("给这个故事起一个标题", chat_id="test_user")
93
  chat_end = timer()
94
  total_time = chat_end - chat_start
95
  print(f"Total time used: {total_time:.3f} s")
unit_test.py CHANGED
@@ -170,7 +170,7 @@ def chat():
170
  end = timer()
171
  print(f"Completed in {end - start:.3f}s")
172
 
173
- chat_history.append((query, result["text"]))
174
 
175
  chat_end = timer()
176
  print(f"Total time used: {chat_end - chat_start:.3f}s")
 
170
  end = timer()
171
  print(f"Completed in {end - start:.3f}s")
172
 
173
+ chat_history.append((query, result["response"]))
174
 
175
  chat_end = timer()
176
  print(f"Total time used: {chat_end - chat_start:.3f}s")