xu song commited on
Commit
2fa4e4c
1 Parent(s): c619300
Files changed (6) hide show
  1. README.md +4 -1
  2. app.py +30 -22
  3. app_util.py +12 -13
  4. config.py +2 -1
  5. models/cpp_qwen2.py +31 -17
  6. models/hf_qwen2.py +11 -8
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Self Chat
3
- emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
@@ -8,6 +8,9 @@ sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
 
 
11
  ---
12
 
13
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
  ---
2
  title: Self Chat
3
+ emoji: 🤖🤖
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ tags:
12
+ - chatbot
13
+ short_description: Generating synthetic data via self-chat
14
  ---
15
 
16
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
app.py CHANGED
@@ -1,19 +1,29 @@
1
  """
2
  """
 
3
  import gradio
4
-
5
  import config
6
  from app_util import *
7
 
8
-
9
-
10
  user_simulator_doc = """\
 
 
11
  There are maily two types of user simulator:
12
  - prompt-based user-simulator (role-play)
13
  - model-based user-simulator
14
 
15
- In most cases, large language models (LLMs) are used to serve as assistant generator.
16
- Besides, it can also used as user simulator.
 
 
 
 
 
 
 
 
 
 
17
  """
18
 
19
  survey = """\
@@ -28,16 +38,16 @@ Essentially, it is a form of model compression.
28
  ## 有不用概率的知识蒸馏吗?
29
  """
30
 
31
- with gr.Blocks() as demo:
32
  # Knowledge Distillation through Self Chatting
33
  # Distilling the Knowledge from LLM through Self Chatting
34
  # Generating Synthetic Data through Self Chat
35
- gr.HTML("""<h1 align="center">Generating Synthetic Data Through Self-Chat</h1>""")
36
  with gr.Row():
37
  with gr.Column(scale=5):
38
  system = gr.Dropdown(
39
  choices=system_list,
40
- value=system_list[0],
41
  allow_custom_value=True,
42
  interactive=True,
43
  label="System message",
@@ -46,7 +56,8 @@ with gr.Blocks() as demo:
46
 
47
  chatbot = gr.Chatbot(show_copy_button=True,
48
  show_share_button=True,
49
- avatar_images=("assets/man.png", "assets/bot.png"))
 
50
 
51
  # gr.Textbox("For faster inference, you can build locally with ")
52
  # ss
@@ -54,30 +65,27 @@ with gr.Blocks() as demo:
54
  input_text_1 = gr.Textbox(show_label=False, placeholder="...", lines=10, visible=False)
55
  generate_btn = gr.Button("🤔️ Self-Chat", variant="primary")
56
  with gr.Row():
57
- retry_btn = gr.Button("🔄 Retry", variant="secondary", size="sm", )
58
  undo_btn = gr.Button("↩️ Undo", variant="secondary", size="sm", )
59
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm", ) # 🧹 Clear History (清除历史)
60
  # stop_btn = gr.Button("停止生成", variant="stop", visible=False)
61
- # gr.Markdown(
62
- # "Self-chat is a demo, which makes the model talk to itself. "
63
- # "It is based on user simulator and response generator.",
64
- # visible=True)
65
 
66
  # 也叫 chat-assistant,
67
- with gradio.Tab("Response Generator", visible=False):
68
  with gr.Row():
69
- input_text_2 = gr.Textbox(show_label=False, placeholder="Please type your input", scale=7)
70
  generate_btn_2 = gr.Button("Send", variant="primary")
71
  with gr.Row():
72
  retry_btn_2 = gr.Button("🔄 Regenerate", variant="secondary", size="sm", )
73
  undo_btn_2 = gr.Button("↩️ Undo", variant="secondary", size="sm", )
74
  clear_btn_2 = gr.Button("🗑️ Clear", variant="secondary", size="sm", ) # 🧹 Clear History (清除历史)
75
- gr.Markdown("Response simulator is the most commonly used chatbot.")
76
 
77
  #
78
- with gradio.Tab("User Simulator", visible=False):
79
  with gr.Row():
80
- input_text_3 = gr.Textbox(show_label=False, placeholder="Please type your response", scale=7)
81
  generate_btn_3 = gr.Button("Send", variant="primary")
82
  with gr.Row():
83
  retry_btn_3 = gr.Button("🔄 Regenerate", variant="secondary", size="sm", )
@@ -85,7 +93,7 @@ with gr.Blocks() as demo:
85
  clear_btn_3 = gr.Button("🗑️ Clear", variant="secondary", size="sm", ) # 🧹 Clear History (清除历史)
86
  gr.Markdown(user_simulator_doc)
87
 
88
- with gr.Column(variant="compact"):
89
  # with gr.Column():
90
  model = gr.Dropdown(
91
  ["Qwen2-0.5B-Instruct", "llama3.1", "gemini"],
@@ -155,8 +163,8 @@ with gr.Blocks() as demo:
155
  slider_top_k.change(set_top_k, inputs=[slider_top_k])
156
 
157
 
 
158
 
159
-
160
- # demo.queue().launch(share=False, server_name="0.0.0.0")
161
  # demo.queue().launch(concurrency_count=1, max_size=5)
162
  demo.queue().launch()
 
1
  """
2
  """
3
+ import random
4
  import gradio
 
5
  import config
6
  from app_util import *
7
 
 
 
8
  user_simulator_doc = """\
9
+ The agent acts as user simulator.
10
+
11
  There are maily two types of user simulator:
12
  - prompt-based user-simulator (role-play)
13
  - model-based user-simulator
14
 
15
+ This demo is a model-based user simulator.
16
+ """
17
+ # In most cases, large language models (LLMs) are used to serve as assistant generator.
18
+ # Besides, it can also used as user simulator.
19
+
20
+ assistant_simulator_doc = """\
21
+ The agent acts as assistant simulator.
22
+ """
23
+
24
+ self_chat_doc = """\
25
+ Self-chat is a demo which make the model talk to itself.
26
+ It is a combination of user simulator and response generator.
27
  """
28
 
29
  survey = """\
 
38
  ## 有不用概率的知识蒸馏吗?
39
  """
40
 
41
+ with gr.Blocks(head=None) as demo:
42
  # Knowledge Distillation through Self Chatting
43
  # Distilling the Knowledge from LLM through Self Chatting
44
  # Generating Synthetic Data through Self Chat
45
+ gr.HTML("""<h1 align="center">Generating Synthetic Data via Self-Chat</h1>""")
46
  with gr.Row():
47
  with gr.Column(scale=5):
48
  system = gr.Dropdown(
49
  choices=system_list,
50
+ # value=system_list[0],
51
  allow_custom_value=True,
52
  interactive=True,
53
  label="System message",
 
56
 
57
  chatbot = gr.Chatbot(show_copy_button=True,
58
  show_share_button=True,
59
+ avatar_images=("assets/man.png", "assets/bot.png"),
60
+ likeable=True)
61
 
62
  # gr.Textbox("For faster inference, you can build locally with ")
63
  # ss
 
65
  input_text_1 = gr.Textbox(show_label=False, placeholder="...", lines=10, visible=False)
66
  generate_btn = gr.Button("🤔️ Self-Chat", variant="primary")
67
  with gr.Row():
68
+ retry_btn = gr.Button("🔄 Regenerate", variant="secondary", size="sm", )
69
  undo_btn = gr.Button("↩️ Undo", variant="secondary", size="sm", )
70
  clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm", ) # 🧹 Clear History (清除历史)
71
  # stop_btn = gr.Button("停止生成", variant="stop", visible=False)
72
+ gr.Markdown(self_chat_doc)
 
 
 
73
 
74
  # 也叫 chat-assistant,
75
+ with gradio.Tab("Response Generator"):
76
  with gr.Row():
77
+ input_text_2 = gr.Textbox(show_label=False, placeholder="Please type user input", scale=7)
78
  generate_btn_2 = gr.Button("Send", variant="primary")
79
  with gr.Row():
80
  retry_btn_2 = gr.Button("🔄 Regenerate", variant="secondary", size="sm", )
81
  undo_btn_2 = gr.Button("↩️ Undo", variant="secondary", size="sm", )
82
  clear_btn_2 = gr.Button("🗑️ Clear", variant="secondary", size="sm", ) # 🧹 Clear History (清除历史)
83
+ gr.Markdown(assistant_simulator_doc)
84
 
85
  #
86
+ with gradio.Tab("User Simulator"):
87
  with gr.Row():
88
+ input_text_3 = gr.Textbox(show_label=False, placeholder="Please type assistant response", scale=7)
89
  generate_btn_3 = gr.Button("Send", variant="primary")
90
  with gr.Row():
91
  retry_btn_3 = gr.Button("🔄 Regenerate", variant="secondary", size="sm", )
 
93
  clear_btn_3 = gr.Button("🗑️ Clear", variant="secondary", size="sm", ) # 🧹 Clear History (清除历史)
94
  gr.Markdown(user_simulator_doc)
95
 
96
+ with gr.Column(variant="compact", scale=1, min_width=300):
97
  # with gr.Column():
98
  model = gr.Dropdown(
99
  ["Qwen2-0.5B-Instruct", "llama3.1", "gemini"],
 
163
  slider_top_k.change(set_top_k, inputs=[slider_top_k])
164
 
165
 
166
+ demo.load(lambda: gr.update(value=random.choice(system_list)), None, system)
167
 
168
+ # demo.queue().launch(share=False, server_name="0.0.0.0", debug=True)
 
169
  # demo.queue().launch(concurrency_count=1, max_size=5)
170
  demo.queue().launch()
app_util.py CHANGED
@@ -1,8 +1,8 @@
1
  import json
2
  import gradio as gr
3
  from utils.logging_util import logger
4
- from models.cpp_qwen2 import bot
5
- # from models.hf_qwen2 import bot
6
 
7
 
8
  #
@@ -22,15 +22,16 @@ system_list = [
22
  "You are a helpful assistant.",
23
  "你是一个导游。",
24
  "你是一名投资经理。",
25
- # "你是一名医生。",
26
- # "你是一个英语老师。",
27
- # "你是一个程序员。",
28
- # "你是一个心理咨询师。",
29
- # "你是一名AI写作助手。"
30
- # "你是一名作家,擅长写小说。"
31
  ]
32
 
33
- bot.pre_cache_system(system_list)
 
34
 
35
  def generate_user_message(chatbot, history):
36
  if history and history[-1]["role"] == "user":
@@ -52,7 +53,6 @@ def generate_assistant_message(chatbot, history):
52
  auto-mode:query is None
53
  manual-mode:query 是用户输入
54
  """
55
- logger.info(f"generating {json.dumps(history, ensure_ascii=False)}")
56
  user_content = history[-1]["content"]
57
  if history[-1]["role"] != "user":
58
  gr.Warning('You should generate or type user-input first.')
@@ -65,13 +65,12 @@ def generate_assistant_message(chatbot, history):
65
 
66
  assistant_tokens = bot.strip_stoptokens(assistant_tokens)
67
  history.append({"role": "assistant", "content": assistant_content, "tokens": assistant_tokens})
68
- print(f"chatbot is {chatbot}")
69
- print(f"history is {history}")
70
  yield chatbot, history
71
 
72
 
73
  def generate(chatbot, history):
74
- logger.info(f"chatbot: {chatbot}; history: {history}")
 
75
  streamer = None
76
  if history[-1]["role"] in ["assistant", "system"]:
77
  streamer = generate_user_message(chatbot, history)
 
1
  import json
2
  import gradio as gr
3
  from utils.logging_util import logger
4
+ from models.cpp_qwen2 import Qwen2Simulator as Bot
5
+ # from models.hf_qwen2 import Qwen2Simulator as Bot
6
 
7
 
8
  #
 
22
  "You are a helpful assistant.",
23
  "你是一个导游。",
24
  "你是一名投资经理。",
25
+ "你是一名医生。",
26
+ "你是一个英语老师。",
27
+ "你是一个程序员。",
28
+ "你是一个心理咨询师。",
29
+ "你是一名AI写作助手。"
30
+ "你是一名作家,擅长写小说。"
31
  ]
32
 
33
+
34
+ bot = Bot(system_list)
35
 
36
  def generate_user_message(chatbot, history):
37
  if history and history[-1]["role"] == "user":
 
53
  auto-mode:query is None
54
  manual-mode:query 是用户输入
55
  """
 
56
  user_content = history[-1]["content"]
57
  if history[-1]["role"] != "user":
58
  gr.Warning('You should generate or type user-input first.')
 
65
 
66
  assistant_tokens = bot.strip_stoptokens(assistant_tokens)
67
  history.append({"role": "assistant", "content": assistant_content, "tokens": assistant_tokens})
 
 
68
  yield chatbot, history
69
 
70
 
71
  def generate(chatbot, history):
72
+ request_param = json.dumps({'chatbot': chatbot, 'history': history}, ensure_ascii=False)
73
+ logger.info(f"request_param: {request_param}")
74
  streamer = None
75
  if history[-1]["role"] in ["assistant", "system"]:
76
  streamer = generate_user_message(chatbot, history)
config.py CHANGED
@@ -1,6 +1,7 @@
1
 
2
 
3
- MAX_SEQUENCE_LENGTH = 32768 # max_seq_len
 
4
 
5
  DEFAULT_MAX_NEW_TOKENS = 128
6
  DEFAULT_TOP_K = 100
 
1
 
2
 
3
+ # MAX_SEQUENCE_LENGTH = 32768 # 消耗内存太多
4
+ MAX_SEQUENCE_LENGTH = 8192 #
5
 
6
  DEFAULT_MAX_NEW_TOKENS = 128
7
  DEFAULT_TOP_K = 100
models/cpp_qwen2.py CHANGED
@@ -77,7 +77,7 @@ import config
77
 
78
  class Qwen2Simulator(Simulator):
79
 
80
- def __init__(self):
81
  local_path = "/workspace/xusong/huggingface/models/Qwen2-0.5B-Instruct-GGUF/qwen2-0_5b-instruct-fp16.gguf"
82
  if os.path.exists(local_path):
83
  self.hf_tokenizer = AutoTokenizer.from_pretrained(
@@ -105,30 +105,37 @@ class Qwen2Simulator(Simulator):
105
  f"n_threads={self.llm.n_threads}, n_ctx={self.llm.n_ctx}, "
106
  f"env[CACHE]={os.environ.get('CACHE', None)}")
107
 
108
- self.stop_words = [
 
 
109
  "<|im_end|>",
110
  "<|im_start|>",
111
  "<|endoftext|>",
112
  ]
113
- self.stop_tokens = self.tokenize("".join(self.stop_words))
 
 
 
 
 
114
  self.generation_kwargs = dict(
115
  temperature=config.DEFAULT_TEMPERATURE,
116
  top_p=config.DEFAULT_TOP_P,
117
  top_k=config.DEFAULT_TOP_K,
118
  max_tokens=config.DEFAULT_MAX_NEW_TOKENS,
119
  repeat_penalty=1.1,
120
- # qwen2-0.5b-chat 有时内容生成结束没有<|im_end|>,直接跟 <|im_start|>
121
- stop=self.stop_words,
122
  )
123
-
124
  self.user_start_tokens = self.tokenize("<|im_start|>user\n")
125
  self.assistant_start_tokens = self.tokenize("<|im_start|>assistant\n")
126
  # self.llm.generate .set_cache .last_n_tokens_size .reset .ctx ._ctx
127
 
128
  # cache = llama_cpp.LlamaDiskCache(capacity_bytes=cache_size)
129
- cache = llama_cpp.LlamaRAMCache(capacity_bytes=2 << 30) # 2G
130
  self.llm.set_cache(cache)
131
 
 
 
 
132
  def tokenize(self, text):
133
  return self.llm.tokenize(text.encode("utf-8"))
134
 
@@ -136,10 +143,10 @@ class Qwen2Simulator(Simulator):
136
  return self.llm.detokenize(tokens).decode("utf-8")
137
 
138
  def strip_stoptokens(self, tokens):
139
- while tokens and tokens[0] in self.stop_tokens:
140
  logger.info(f"head-striping {tokens[0]} {self.detokenize([tokens[0]])}")
141
  tokens.pop(0)
142
- while tokens and tokens[-1] in self.stop_tokens:
143
  logger.info(f"tail-striping {tokens[-1]} {self.detokenize([tokens[-1]])}")
144
  tokens.pop()
145
  return tokens
@@ -154,9 +161,12 @@ class Qwen2Simulator(Simulator):
154
  """
155
  if history[-1]['role'] in ["user"]:
156
  start_tokens = self.assistant_start_tokens
 
157
  suffix_tokens = self.user_start_tokens
 
158
  elif history[-1]['role'] in ["assistant", "system"]:
159
  start_tokens = self.user_start_tokens
 
160
  suffix_tokens = self.assistant_start_tokens
161
 
162
  input_ids = []
@@ -168,15 +178,16 @@ class Qwen2Simulator(Simulator):
168
  + self.tokenize("<|im_end|>\n")
169
  input_ids += start_tokens
170
  if stream:
171
- return self._stream_generate(input_ids, suffix_tokens)
172
  else:
173
  return self._generate(input_ids)
174
 
175
- def _stream_generate(self, input_ids, suffix_tokens=None):
176
  logger.info(f"generation_kwargs {self.generation_kwargs}")
177
  output = self.llm.create_completion(
178
  input_ids,
179
  stream=True,
 
180
  **self.generation_kwargs
181
  )
182
  # TODO: 检测finish reason,如果是length,则shift,并继续生成。
@@ -201,37 +212,40 @@ class Qwen2Simulator(Simulator):
201
  for system_prompt in system_list:
202
  logger.info(f"pre caching '{system_prompt}'")
203
  input_ids = self.tokenize(f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n")
204
- output = self.llm.create_completion(
205
  input_ids,
206
  stream=False,
207
  max_tokens=1,
208
  top_k=1
209
  )
210
- logger.info(f"cache size {self.llm.cache.cache_size}, process_mem: "
211
- f"{psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024:.2f} GB")
 
212
 
213
  self._disable_cache()
214
 
215
-
216
  def post_cache(self, suffix_tokens):
217
  """ warmup for next turn generation
218
  :param suffix_tokens:
219
  :return:
220
  """
 
 
221
  if suffix_tokens:
222
  logger.info(f"before warmup: n_tokens = {self.llm.n_tokens}")
223
  self.llm.eval([151645, 198] + suffix_tokens) # <|im_end|>\n
224
  logger.info(f"after warmup: n_tokens = {self.llm.n_tokens}")
225
-
 
226
 
227
  def _disable_cache(self):
228
  llama_cpp.LlamaRAMCache.__setitem__ = lambda *args: None
229
  llama_cpp.Llama.save_state = lambda *args: None
230
 
231
- bot = Qwen2Simulator()
232
 
233
  if __name__ == "__main__":
234
 
 
235
  messages = [{"role": "system", "content": "你是一个导游。"}]
236
  generated_tokens = None
237
  print("######## requesting", messages)
 
77
 
78
  class Qwen2Simulator(Simulator):
79
 
80
+ def __init__(self, system_list=None):
81
  local_path = "/workspace/xusong/huggingface/models/Qwen2-0.5B-Instruct-GGUF/qwen2-0_5b-instruct-fp16.gguf"
82
  if os.path.exists(local_path):
83
  self.hf_tokenizer = AutoTokenizer.from_pretrained(
 
105
  f"n_threads={self.llm.n_threads}, n_ctx={self.llm.n_ctx}, "
106
  f"env[CACHE]={os.environ.get('CACHE', None)}")
107
 
108
+
109
+ # qwen2-0.5b-chat 有时内容生成结束没有<|im_end|>,直接跟 <|im_start|>
110
+ self.assistant_stop_words = [
111
  "<|im_end|>",
112
  "<|im_start|>",
113
  "<|endoftext|>",
114
  ]
115
+ self.assistant_stop_tokens = self.tokenize("".join(self.assistant_stop_words))
116
+ self.user_stop_words = self.assistant_stop_words + ["?", "?"]
117
+ self.user_stop_tokens = self.tokenize("".join(self.user_stop_words))
118
+ logger.info(f"assistant_stop_tokens: {self.assistant_stop_tokens}")
119
+ logger.info(f"user_stop_tokens: {self.user_stop_tokens}")
120
+
121
  self.generation_kwargs = dict(
122
  temperature=config.DEFAULT_TEMPERATURE,
123
  top_p=config.DEFAULT_TOP_P,
124
  top_k=config.DEFAULT_TOP_K,
125
  max_tokens=config.DEFAULT_MAX_NEW_TOKENS,
126
  repeat_penalty=1.1,
 
 
127
  )
 
128
  self.user_start_tokens = self.tokenize("<|im_start|>user\n")
129
  self.assistant_start_tokens = self.tokenize("<|im_start|>assistant\n")
130
  # self.llm.generate .set_cache .last_n_tokens_size .reset .ctx ._ctx
131
 
132
  # cache = llama_cpp.LlamaDiskCache(capacity_bytes=cache_size)
133
+ cache = llama_cpp.LlamaRAMCache(capacity_bytes=2 << 30) # 2G
134
  self.llm.set_cache(cache)
135
 
136
+ if system_list is not None:
137
+ self.pre_cache_system(system_list)
138
+
139
  def tokenize(self, text):
140
  return self.llm.tokenize(text.encode("utf-8"))
141
 
 
143
  return self.llm.detokenize(tokens).decode("utf-8")
144
 
145
  def strip_stoptokens(self, tokens):
146
+ while tokens and tokens[0] in self.assistant_stop_tokens:
147
  logger.info(f"head-striping {tokens[0]} {self.detokenize([tokens[0]])}")
148
  tokens.pop(0)
149
+ while tokens and tokens[-1] in self.assistant_stop_tokens:
150
  logger.info(f"tail-striping {tokens[-1]} {self.detokenize([tokens[-1]])}")
151
  tokens.pop()
152
  return tokens
 
161
  """
162
  if history[-1]['role'] in ["user"]:
163
  start_tokens = self.assistant_start_tokens
164
+ stop_words = self.assistant_stop_words
165
  suffix_tokens = self.user_start_tokens
166
+
167
  elif history[-1]['role'] in ["assistant", "system"]:
168
  start_tokens = self.user_start_tokens
169
+ stop_words = self.user_stop_words
170
  suffix_tokens = self.assistant_start_tokens
171
 
172
  input_ids = []
 
178
  + self.tokenize("<|im_end|>\n")
179
  input_ids += start_tokens
180
  if stream:
181
+ return self._stream_generate(input_ids, stop_words, suffix_tokens)
182
  else:
183
  return self._generate(input_ids)
184
 
185
+ def _stream_generate(self, input_ids, stop_words, suffix_tokens=None):
186
  logger.info(f"generation_kwargs {self.generation_kwargs}")
187
  output = self.llm.create_completion(
188
  input_ids,
189
  stream=True,
190
+ stop=stop_words,
191
  **self.generation_kwargs
192
  )
193
  # TODO: 检测finish reason,如果是length,则shift,并继续生成。
 
212
  for system_prompt in system_list:
213
  logger.info(f"pre caching '{system_prompt}'")
214
  input_ids = self.tokenize(f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n")
215
+ _output = self.llm.create_completion(
216
  input_ids,
217
  stream=False,
218
  max_tokens=1,
219
  top_k=1
220
  )
221
+ logger.info(
222
+ f"cache size {self.llm.cache.cache_size}={self.llm.cache.cache_size / 1024 / 1024 / 1024:.2f} GB, "
223
+ f"process_mem: {psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024:.2f} GB")
224
 
225
  self._disable_cache()
226
 
 
227
  def post_cache(self, suffix_tokens):
228
  """ warmup for next turn generation
229
  :param suffix_tokens:
230
  :return:
231
  """
232
+ logger.info(f"cache size {self.llm.cache.cache_size}={self.llm.cache.cache_size / 1024 / 1024 / 1024:.2f} GB, "
233
+ f"process_mem: {psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024:.2f} GB")
234
  if suffix_tokens:
235
  logger.info(f"before warmup: n_tokens = {self.llm.n_tokens}")
236
  self.llm.eval([151645, 198] + suffix_tokens) # <|im_end|>\n
237
  logger.info(f"after warmup: n_tokens = {self.llm.n_tokens}")
238
+ logger.info(f"cache size {self.llm.cache.cache_size}={self.llm.cache.cache_size / 1024 / 1024 / 1024:.2f} GB, "
239
+ f"process_mem: {psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024:.2f} GB")
240
 
241
  def _disable_cache(self):
242
  llama_cpp.LlamaRAMCache.__setitem__ = lambda *args: None
243
  llama_cpp.Llama.save_state = lambda *args: None
244
 
 
245
 
246
  if __name__ == "__main__":
247
 
248
+ bot = Qwen2Simulator()
249
  messages = [{"role": "system", "content": "你是一个导游。"}]
250
  generated_tokens = None
251
  print("######## requesting", messages)
models/hf_qwen2.py CHANGED
@@ -14,13 +14,15 @@ class Qwen2Simulator(Simulator):
14
  在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
15
  """
16
 
17
- self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
18
- self.model = AutoModelForCausalLM.from_pretrained(
19
- model_name_or_path,
20
- torch_dtype="auto",
21
- device_map="auto"
22
- )
23
- self.model.eval()
 
 
24
  self.generation_kwargs = dict(
25
  do_sample=True,
26
  temperature=0.7,
@@ -93,11 +95,12 @@ class Qwen2Simulator(Simulator):
93
  return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True)
94
 
95
 
96
- bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
97
  # bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct")
98
 
99
 
100
  if __name__ == "__main__":
 
101
  messages = [
102
  {"role": "system", "content": "you are a helpful assistant"},
103
  {"role": "user", "content": "hi, what your name"}
 
14
  在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
15
  """
16
 
17
+ self.tokenizer = None
18
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
19
+ self.model = None
20
+ # self.model = AutoModelForCausalLM.from_pretrained(
21
+ # model_name_or_path,
22
+ # torch_dtype="auto",
23
+ # device_map="auto"
24
+ # )
25
+ # self.model.eval()
26
  self.generation_kwargs = dict(
27
  do_sample=True,
28
  temperature=0.7,
 
95
  return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True)
96
 
97
 
98
+
99
  # bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct")
100
 
101
 
102
  if __name__ == "__main__":
103
+ bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
104
  messages = [
105
  {"role": "system", "content": "you are a helpful assistant"},
106
  {"role": "user", "content": "hi, what your name"}