Tuchuanhuhuhu commited on
Commit
9a2b13d
1 Parent(s): 6a2dc28

修复了ChatGLM MPS加速的问题

Browse files
Files changed (2) hide show
  1. modules/base_model.py +2 -2
  2. modules/models.py +75 -57
modules/base_model.py CHANGED
@@ -126,6 +126,7 @@ class BaseLLMModel:
126
 
127
  stream_iter = self.get_answer_stream_iter()
128
 
 
129
  for partial_text in stream_iter:
130
  self.history[-1] = construct_assistant(partial_text)
131
  chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
@@ -144,9 +145,9 @@ class BaseLLMModel:
144
  user_token_count = self.count_token(inputs)
145
  self.all_token_counts.append(user_token_count)
146
  ai_reply, total_token_count = self.get_answer_at_once()
 
147
  if fake_input is not None:
148
  self.history[-2] = construct_user(fake_input)
149
- self.history[-1] = construct_assistant(ai_reply)
150
  chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
151
  if fake_input is not None:
152
  self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
@@ -265,7 +266,6 @@ class BaseLLMModel:
265
  return
266
 
267
  self.history.append(construct_user(inputs))
268
- self.history.append(construct_assistant(""))
269
 
270
  if stream:
271
  logging.debug("使用流式传输")
 
126
 
127
  stream_iter = self.get_answer_stream_iter()
128
 
129
+ self.history.append(construct_assistant(""))
130
  for partial_text in stream_iter:
131
  self.history[-1] = construct_assistant(partial_text)
132
  chatbot[-1] = (chatbot[-1][0], partial_text + display_append)
 
145
  user_token_count = self.count_token(inputs)
146
  self.all_token_counts.append(user_token_count)
147
  ai_reply, total_token_count = self.get_answer_at_once()
148
+ self.history.append(construct_assistant(ai_reply))
149
  if fake_input is not None:
150
  self.history[-2] = construct_user(fake_input)
 
151
  chatbot[-1] = (chatbot[-1][0], ai_reply + display_append)
152
  if fake_input is not None:
153
  self.all_token_counts[-1] += count_token(construct_assistant(ai_reply))
 
266
  return
267
 
268
  self.history.append(construct_user(inputs))
 
269
 
270
  if stream:
271
  logging.debug("使用流式传输")
modules/models.py CHANGED
@@ -200,17 +200,13 @@ class OpenAIClient(BaseLLMModel):
200
  # logging.error(f"Error: {e}")
201
  continue
202
 
 
203
  class ChatGLM_Client(BaseLLMModel):
204
- def __init__(
205
- self,
206
- model_name,
207
- model_path = None
208
- ) -> None:
209
- super().__init__(
210
- model_name=model_name
211
- )
212
  from transformers import AutoTokenizer, AutoModel
213
  import torch
 
214
  system_name = platform.system()
215
  if os.path.exists("models"):
216
  model_dirs = os.listdir("models")
@@ -220,23 +216,29 @@ class ChatGLM_Client(BaseLLMModel):
220
  model_source = model_path
221
  else:
222
  model_source = f"THUDM/{model_name}"
223
- self.tokenizer = AutoTokenizer.from_pretrained(model_source, trust_remote_code=True)
 
 
224
  quantified = False
225
  if "int4" in model_name:
226
  quantified = True
227
  if quantified:
228
- model = AutoModel.from_pretrained(model_source, trust_remote_code=True).float()
 
 
229
  else:
230
- model = AutoModel.from_pretrained(model_source, trust_remote_code=True).half()
 
 
231
  if torch.cuda.is_available():
232
  # run on CUDA
233
  logging.info("CUDA is available, using CUDA")
234
  model = model.cuda()
235
  # mps加速还存在一些问题,暂时不使用
236
- # elif system_name == "Darwin" and model_path is not None:
237
- # logging.info("Running on macOS, using MPS")
238
- # # running on macOS and model already downloaded
239
- # model = model.to('mps')
240
  else:
241
  logging.info("GPU is not available, using CPU")
242
  model = model.eval()
@@ -246,8 +248,10 @@ class ChatGLM_Client(BaseLLMModel):
246
  history = [x["content"] for x in self.history]
247
  query = history.pop()
248
  logging.info(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
249
- assert len(history) % 2 == 0
250
- history = [[history[i], history[i+1]] for i in range(0, len(history), 2)]
 
 
251
  return history, query
252
 
253
  def get_answer_at_once(self):
@@ -257,42 +261,48 @@ class ChatGLM_Client(BaseLLMModel):
257
 
258
  def get_answer_stream_iter(self):
259
  history, query = self._get_glm_style_input()
260
- for response, history in self.model.stream_chat(self.tokenizer, query, history, max_length=self.token_upper_limit, top_p=self.top_p,
261
- temperature=self.temperature):
 
 
 
 
 
 
262
  yield response
263
 
 
264
  @dataclass
265
  class ChatbotArguments:
266
  pass
267
 
 
268
  class LLaMA_Client(BaseLLMModel):
269
  def __init__(
270
  self,
271
  model_name,
272
- lora_path = None,
273
  ) -> None:
274
- super().__init__(
275
- model_name=model_name
276
- )
277
  self.max_generation_token = 1000
278
  pipeline_name = "inferencer"
279
  PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
280
 
281
- parser = HfArgumentParser((
282
- ModelArguments,
283
- PipelineArguments,
284
- ChatbotArguments,
285
- ))
286
- model_args, pipeline_args, chatbot_args = (
287
- parser.parse_args_into_dataclasses()
288
  )
 
289
 
290
- with open (pipeline_args.deepspeed, "r") as f:
291
  ds_config = json.load(f)
292
 
293
  self.model = AutoModel.get_model(
294
  model_args,
295
- tune_strategy='none',
296
  ds_config=ds_config,
297
  )
298
 
@@ -323,14 +333,12 @@ class LLaMA_Client(BaseLLMModel):
323
  context = "\n".join(history)
324
  return context
325
 
326
-
327
  def get_answer_at_once(self):
328
  context = self._get_llama_style_input()
329
 
330
- input_dataset = self.dataset.from_dict({
331
- "type": "text_only",
332
- "instances": [ { "text": context } ]
333
- })
334
 
335
  output_dataset = self.inferencer.inference(
336
  model=self.model,
@@ -347,7 +355,7 @@ class LLaMA_Client(BaseLLMModel):
347
  response += self.end_string
348
  index = response.index(self.end_string)
349
 
350
- response = response[:index + 1]
351
  return response, len(response)
352
 
353
  def get_answer_stream_iter(self):
@@ -355,34 +363,44 @@ class LLaMA_Client(BaseLLMModel):
355
  yield response
356
 
357
 
 
 
 
358
 
359
- def get_model(
360
- model_name, access_key=None, temperature=None, top_p=None, system_prompt=None
361
- ) -> BaseLLMModel:
362
- msg = f"模型设置为了: {model_name}"
363
- logging.info(msg)
364
- model_type = ModelType.get_type(model_name)
365
- print(model_type.name)
366
- if model_type == ModelType.OpenAI:
367
- model = OpenAIClient(
368
- model_name=model_name,
369
- api_key=access_key,
370
- system_prompt=system_prompt,
371
- temperature=temperature,
372
- top_p=top_p,
373
- )
374
- elif model_type == ModelType.ChatGLM:
375
- model = ChatGLM_Client(model_name)
376
- return model, msg
 
 
 
 
 
377
 
378
 
379
  if __name__ == "__main__":
380
  with open("config.json", "r") as f:
381
  openai_api_key = cjson.load(f)["openai_api_key"]
 
 
382
  # client, _ = get_model("gpt-3.5-turbo", openai_api_key)
383
  client, _ = get_model("chatglm-6b-int4")
384
  chatbot = []
385
- stream = True
386
  # 测试账单功能
387
  logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
388
  logging.info(client.billing_info())
 
200
  # logging.error(f"Error: {e}")
201
  continue
202
 
203
+
204
  class ChatGLM_Client(BaseLLMModel):
205
+ def __init__(self, model_name, model_path=None) -> None:
206
+ super().__init__(model_name=model_name)
 
 
 
 
 
 
207
  from transformers import AutoTokenizer, AutoModel
208
  import torch
209
+
210
  system_name = platform.system()
211
  if os.path.exists("models"):
212
  model_dirs = os.listdir("models")
 
216
  model_source = model_path
217
  else:
218
  model_source = f"THUDM/{model_name}"
219
+ self.tokenizer = AutoTokenizer.from_pretrained(
220
+ model_source, trust_remote_code=True
221
+ )
222
  quantified = False
223
  if "int4" in model_name:
224
  quantified = True
225
  if quantified:
226
+ model = AutoModel.from_pretrained(
227
+ model_source, trust_remote_code=True
228
+ ).float()
229
  else:
230
+ model = AutoModel.from_pretrained(
231
+ model_source, trust_remote_code=True
232
+ ).half()
233
  if torch.cuda.is_available():
234
  # run on CUDA
235
  logging.info("CUDA is available, using CUDA")
236
  model = model.cuda()
237
  # mps加速还存在一些问题,暂时不使用
238
+ elif system_name == "Darwin" and model_path is not None and not quantified:
239
+ logging.info("Running on macOS, using MPS")
240
+ # running on macOS and model already downloaded
241
+ model = model.to("mps")
242
  else:
243
  logging.info("GPU is not available, using CPU")
244
  model = model.eval()
 
248
  history = [x["content"] for x in self.history]
249
  query = history.pop()
250
  logging.info(colorama.Fore.YELLOW + f"{history}" + colorama.Fore.RESET)
251
+ assert (
252
+ len(history) % 2 == 0
253
+ ), f"History should be even length. current history is: {history}"
254
+ history = [[history[i], history[i + 1]] for i in range(0, len(history), 2)]
255
  return history, query
256
 
257
  def get_answer_at_once(self):
 
261
 
262
  def get_answer_stream_iter(self):
263
  history, query = self._get_glm_style_input()
264
+ for response, history in self.model.stream_chat(
265
+ self.tokenizer,
266
+ query,
267
+ history,
268
+ max_length=self.token_upper_limit,
269
+ top_p=self.top_p,
270
+ temperature=self.temperature,
271
+ ):
272
  yield response
273
 
274
+
275
  @dataclass
276
  class ChatbotArguments:
277
  pass
278
 
279
+
280
  class LLaMA_Client(BaseLLMModel):
281
  def __init__(
282
  self,
283
  model_name,
284
+ lora_path=None,
285
  ) -> None:
286
+ super().__init__(model_name=model_name)
 
 
287
  self.max_generation_token = 1000
288
  pipeline_name = "inferencer"
289
  PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
290
 
291
+ parser = HfArgumentParser(
292
+ (
293
+ ModelArguments,
294
+ PipelineArguments,
295
+ ChatbotArguments,
296
+ )
 
297
  )
298
+ model_args, pipeline_args, chatbot_args = parser.parse_args_into_dataclasses()
299
 
300
+ with open(pipeline_args.deepspeed, "r") as f:
301
  ds_config = json.load(f)
302
 
303
  self.model = AutoModel.get_model(
304
  model_args,
305
+ tune_strategy="none",
306
  ds_config=ds_config,
307
  )
308
 
 
333
  context = "\n".join(history)
334
  return context
335
 
 
336
  def get_answer_at_once(self):
337
  context = self._get_llama_style_input()
338
 
339
+ input_dataset = self.dataset.from_dict(
340
+ {"type": "text_only", "instances": [{"text": context}]}
341
+ )
 
342
 
343
  output_dataset = self.inferencer.inference(
344
  model=self.model,
 
355
  response += self.end_string
356
  index = response.index(self.end_string)
357
 
358
+ response = response[: index + 1]
359
  return response, len(response)
360
 
361
  def get_answer_stream_iter(self):
 
363
  yield response
364
 
365
 
366
+ class ModelManager:
367
+ def __init__(self, **kwargs) -> None:
368
+ self.model, self.msg = self.get_model(**kwargs)
369
 
370
+ def get_model(
371
+ self,
372
+ model_name,
373
+ access_key=None,
374
+ temperature=None,
375
+ top_p=None,
376
+ system_prompt=None,
377
+ ) -> BaseLLMModel:
378
+ msg = f"模型设置为了: {model_name}"
379
+ logging.info(msg)
380
+ model_type = ModelType.get_type(model_name)
381
+ print(model_type.name)
382
+ if model_type == ModelType.OpenAI:
383
+ model = OpenAIClient(
384
+ model_name=model_name,
385
+ api_key=access_key,
386
+ system_prompt=system_prompt,
387
+ temperature=temperature,
388
+ top_p=top_p,
389
+ )
390
+ elif model_type == ModelType.ChatGLM:
391
+ model = ChatGLM_Client(model_name)
392
+ return model, msg
393
 
394
 
395
  if __name__ == "__main__":
396
  with open("config.json", "r") as f:
397
  openai_api_key = cjson.load(f)["openai_api_key"]
398
+ # set logging level to debug
399
+ logging.basicConfig(level=logging.DEBUG)
400
  # client, _ = get_model("gpt-3.5-turbo", openai_api_key)
401
  client, _ = get_model("chatglm-6b-int4")
402
  chatbot = []
403
+ stream = False
404
  # 测试账单功能
405
  logging.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
406
  logging.info(client.billing_info())