Tuchuanhuhuhu commited on
Commit
7620bdc
1 Parent(s): 461db8d

加入了llama模型支持

Browse files
ChuanhuChatbot.py CHANGED
@@ -80,6 +80,9 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
80
  model_select_dropdown = gr.Dropdown(
81
  label="选择模型", choices=MODELS, multiselect=False, value=MODELS[DEFAULT_MODEL], interactive=True
82
  )
 
 
 
83
  use_streaming_checkbox = gr.Checkbox(
84
  label="实时传输回答", value=True, visible=ENABLE_STREAMING_OPTION
85
  )
@@ -350,7 +353,8 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
350
  # LLM Models
351
  keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
352
  keyTxt.submit(**get_usage_args)
353
- model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
 
354
 
355
  # Template
356
  systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
 
80
  model_select_dropdown = gr.Dropdown(
81
  label="选择模型", choices=MODELS, multiselect=False, value=MODELS[DEFAULT_MODEL], interactive=True
82
  )
83
+ lora_select_dropdown = gr.Dropdown(
84
+ label="选择LoRA模型", choices=[], multiselect=False, interactive=True, visible=False
85
+ )
86
  use_streaming_checkbox = gr.Checkbox(
87
  label="实时传输回答", value=True, visible=ENABLE_STREAMING_OPTION
88
  )
 
353
  # LLM Models
354
  keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
355
  keyTxt.submit(**get_usage_args)
356
+ model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
357
+ lora_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
358
 
359
  # Template
360
  systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
configs/ds_config_chatbot.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "bf16": {
6
+ "enabled": true
7
+ },
8
+ "comms_logger": {
9
+ "enabled": false,
10
+ "verbose": false,
11
+ "prof_all": false,
12
+ "debug": false
13
+ },
14
+ "steps_per_print": 20000000000000000,
15
+ "train_micro_batch_size_per_gpu": 1,
16
+ "wall_clock_breakdown": false
17
+ }
modules/base_model.py CHANGED
@@ -24,6 +24,7 @@ from .config import retrieve_proxy
24
 
25
 
26
  class ModelType(Enum):
 
27
  OpenAI = 0
28
  ChatGLM = 1
29
  LLaMA = 2
@@ -31,12 +32,15 @@ class ModelType(Enum):
31
  @classmethod
32
  def get_type(cls, model_name: str):
33
  model_type = None
34
- if "gpt" in model_name.lower():
 
35
  model_type = ModelType.OpenAI
36
- elif "chatglm" in model_name.lower():
37
  model_type = ModelType.ChatGLM
38
- else:
39
  model_type = ModelType.LLaMA
 
 
40
  return model_type
41
 
42
 
 
24
 
25
 
26
  class ModelType(Enum):
27
+ Unknown = -1
28
  OpenAI = 0
29
  ChatGLM = 1
30
  LLaMA = 2
 
32
  @classmethod
33
  def get_type(cls, model_name: str):
34
  model_type = None
35
+ model_name_lower = model_name.lower()
36
+ if "gpt" in model_name_lower:
37
  model_type = ModelType.OpenAI
38
+ elif "chatglm" in model_name_lower:
39
  model_type = ModelType.ChatGLM
40
+ elif "llama" in model_name_lower:
41
  model_type = ModelType.LLaMA
42
+ else:
43
+ model_type = ModelType.Unknown
44
  return model_type
45
 
46
 
modules/models.py CHANGED
@@ -13,11 +13,6 @@ import platform
13
  from dataclasses import dataclass, field
14
  from transformers import HfArgumentParser
15
 
16
- from lmflow.datasets.dataset import Dataset
17
- from lmflow.pipeline.auto_pipeline import AutoPipeline
18
- from lmflow.models.auto_model import AutoModel
19
- from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
20
-
21
  from tqdm import tqdm
22
  import colorama
23
  from duckduckgo_search import ddg
@@ -203,12 +198,13 @@ class OpenAIClient(BaseLLMModel):
203
 
204
 
205
  class ChatGLM_Client(BaseLLMModel):
206
- def __init__(self, model_name, model_path=None) -> None:
207
  super().__init__(model_name=model_name)
208
  from transformers import AutoTokenizer, AutoModel
209
  import torch
210
 
211
  system_name = platform.system()
 
212
  if os.path.exists("models"):
213
  model_dirs = os.listdir("models")
214
  if model_name in model_dirs:
@@ -285,18 +281,49 @@ class LLaMA_Client(BaseLLMModel):
285
  lora_path=None,
286
  ) -> None:
287
  super().__init__(model_name=model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  self.max_generation_token = 1000
289
  pipeline_name = "inferencer"
290
  PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
291
 
292
- parser = HfArgumentParser(
293
- (
294
- ModelArguments,
295
- PipelineArguments,
296
- ChatbotArguments,
297
- )
298
- )
299
- model_args, pipeline_args, chatbot_args = parser.parse_args_into_dataclasses()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  with open(pipeline_args.deepspeed, "r") as f:
302
  ds_config = json.load(f)
@@ -377,23 +404,54 @@ class ModelManager:
377
  top_p=None,
378
  system_prompt=None,
379
  ) -> BaseLLMModel:
 
380
  msg = f"模型设置为了: {model_name}"
381
  logging.info(msg)
382
  model_type = ModelType.get_type(model_name)
 
 
 
383
  if model_type != ModelType.OpenAI:
384
  config.local_embedding = True
385
- if model_type == ModelType.OpenAI:
386
- model = OpenAIClient(
387
- model_name=model_name,
388
- api_key=access_key,
389
- system_prompt=system_prompt,
390
- temperature=temperature,
391
- top_p=top_p,
392
- )
393
- elif model_type == ModelType.ChatGLM:
394
- model = ChatGLM_Client(model_name)
395
- self.model = model
396
- return msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  def predict(self, *args):
399
  iter = self.model.predict(*args)
 
13
  from dataclasses import dataclass, field
14
  from transformers import HfArgumentParser
15
 
 
 
 
 
 
16
  from tqdm import tqdm
17
  import colorama
18
  from duckduckgo_search import ddg
 
198
 
199
 
200
  class ChatGLM_Client(BaseLLMModel):
201
+ def __init__(self, model_name) -> None:
202
  super().__init__(model_name=model_name)
203
  from transformers import AutoTokenizer, AutoModel
204
  import torch
205
 
206
  system_name = platform.system()
207
+ model_path=None
208
  if os.path.exists("models"):
209
  model_dirs = os.listdir("models")
210
  if model_name in model_dirs:
 
281
  lora_path=None,
282
  ) -> None:
283
  super().__init__(model_name=model_name)
284
+ from lmflow.datasets.dataset import Dataset
285
+ from lmflow.pipeline.auto_pipeline import AutoPipeline
286
+ from lmflow.models.auto_model import AutoModel
287
+ from lmflow.args import ModelArguments, DatasetArguments, AutoArguments, InferencerArguments
288
+ model_path = None
289
+ if os.path.exists("models"):
290
+ model_dirs = os.listdir("models")
291
+ if model_name in model_dirs:
292
+ model_path = f"models/{model_name}"
293
+ if model_path is not None:
294
+ model_source = model_path
295
+ else:
296
+ raise Exception(f"models目录下没有这个模型: {model_name}")
297
  self.max_generation_token = 1000
298
  pipeline_name = "inferencer"
299
  PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
300
 
301
+ """
302
+ if [ $# -ge 2 ]; then
303
+ lora_args="--lora_model_path $2"
304
+ fi
305
+ CUDA_VISIBLE_DEVICES=2 \
306
+ deepspeed examples/chatbot.py \
307
+ --deepspeed configs/ds_config_chatbot.json \
308
+ --model_name_or_path ${model} \
309
+ ${lora_args}
310
+
311
+ model_args:
312
+ ModelArguments(model_name_or_path='/home/guest/llm_models/llama/7B', lora_model_path='/home/guest/llm_models/lora/baize-lora-7B', model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
313
+ pipeline_args:
314
+ InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
315
+ """
316
+
317
+ # parser = HfArgumentParser(
318
+ # (
319
+ # ModelArguments,
320
+ # PipelineArguments,
321
+ # ChatbotArguments,
322
+ # )
323
+ # )
324
+ model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
325
+ pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
326
+ # model_args, pipeline_args, chatbot_args = parser.parse_args_into_dataclasses()
327
 
328
  with open(pipeline_args.deepspeed, "r") as f:
329
  ds_config = json.load(f)
 
404
  top_p=None,
405
  system_prompt=None,
406
  ) -> BaseLLMModel:
407
+ print(lora_model_path)
408
  msg = f"模型设置为了: {model_name}"
409
  logging.info(msg)
410
  model_type = ModelType.get_type(model_name)
411
+ lora_selector_visibility = False
412
+ lora_choices = []
413
+ dont_change_lora_selector = False
414
  if model_type != ModelType.OpenAI:
415
  config.local_embedding = True
416
+ model = None
417
+ try:
418
+ if model_type == ModelType.OpenAI:
419
+ model = OpenAIClient(
420
+ model_name=model_name,
421
+ api_key=access_key,
422
+ system_prompt=system_prompt,
423
+ temperature=temperature,
424
+ top_p=top_p,
425
+ )
426
+ elif model_type == ModelType.ChatGLM:
427
+ model = ChatGLM_Client(model_name)
428
+ elif model_type == ModelType.LLaMA and lora_model_path == "":
429
+ msg = "现在请选择LoRA模型"
430
+ logging.info(msg)
431
+ lora_selector_visibility = True
432
+ if os.path.isdir("lora"):
433
+ lora_choices = get_file_names("lora", plain=True, filetypes=[""])
434
+ lora_choices = ["No LoRA"] + lora_choices
435
+ elif model_type == ModelType.LLaMA and lora_model_path != "":
436
+ dont_change_lora_selector = True
437
+ if lora_model_path == "No LoRA":
438
+ lora_model_path = None
439
+ msg += " + No LoRA"
440
+ else:
441
+ msg += f" + {lora_model_path}"
442
+ model = LLaMA_Client(model_name, lora_model_path)
443
+ pass
444
+ elif model_type == ModelType.Unknown:
445
+ raise ValueError(f"未知模型: {model_name}")
446
+ except Exception as e:
447
+ logging.error(e)
448
+ msg = f"{STANDARD_ERROR_MSG}: {e}"
449
+ if model is not None:
450
+ self.model = model
451
+ if dont_change_lora_selector:
452
+ return msg
453
+ else:
454
+ return msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
455
 
456
  def predict(self, *args):
457
  iter = self.model.predict(*args)
modules/presets.py CHANGED
@@ -59,7 +59,15 @@ MODELS = [
59
  "gpt-4-32k-0314",
60
  "chatglm-6b",
61
  "chatglm-6b-int4",
62
- "chatglm-6b-int4-qe"
 
 
 
 
 
 
 
 
63
  ] # 可选的模型
64
 
65
  DEFAULT_MODEL = 0 # 默认的模型在MODELS中的序号,从0开始数
 
59
  "gpt-4-32k-0314",
60
  "chatglm-6b",
61
  "chatglm-6b-int4",
62
+ "chatglm-6b-int4-qe",
63
+ "llama-7b-hf",
64
+ "llama-7b-hf-int4",
65
+ "llama-7b-hf-int8",
66
+ "llama-13b-hf",
67
+ "llama-13b-hf-int4",
68
+ "llama-30b-hf",
69
+ "llama-30b-hf-int4",
70
+ "llama-65b-hf",
71
  ] # 可选的模型
72
 
73
  DEFAULT_MODEL = 0 # 默认的模型在MODELS中的序号,从0开始数
modules/utils.py CHANGED
@@ -445,3 +445,7 @@ def get_last_day_of_month(any_day):
445
  next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
446
  # subtracting the number of the current day brings us back one month
447
  return next_month - datetime.timedelta(days=next_month.day)
 
 
 
 
 
445
  next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
446
  # subtracting the number of the current day brings us back one month
447
  return next_month - datetime.timedelta(days=next_month.day)
448
+
449
+ def get_model_source(model_name, alternative_source):
450
+ if model_name == "gpt2-medium":
451
+ return "https://huggingface.co/gpt2-medium"