Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
•
7620bdc
1
Parent(s):
461db8d
加入了llama模型支持
Browse files- ChuanhuChatbot.py +5 -1
- configs/ds_config_chatbot.json +17 -0
- modules/base_model.py +7 -3
- modules/models.py +84 -26
- modules/presets.py +9 -1
- modules/utils.py +4 -0
ChuanhuChatbot.py
CHANGED
@@ -80,6 +80,9 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
80 |
model_select_dropdown = gr.Dropdown(
|
81 |
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[DEFAULT_MODEL], interactive=True
|
82 |
)
|
|
|
|
|
|
|
83 |
use_streaming_checkbox = gr.Checkbox(
|
84 |
label="实时传输回答", value=True, visible=ENABLE_STREAMING_OPTION
|
85 |
)
|
@@ -350,7 +353,8 @@ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
|
|
350 |
# LLM Models
|
351 |
keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
|
352 |
keyTxt.submit(**get_usage_args)
|
353 |
-
model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
|
|
|
354 |
|
355 |
# Template
|
356 |
systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
|
|
|
80 |
model_select_dropdown = gr.Dropdown(
|
81 |
label="选择模型", choices=MODELS, multiselect=False, value=MODELS[DEFAULT_MODEL], interactive=True
|
82 |
)
|
83 |
+
lora_select_dropdown = gr.Dropdown(
|
84 |
+
label="选择LoRA模型", choices=[], multiselect=False, interactive=True, visible=False
|
85 |
+
)
|
86 |
use_streaming_checkbox = gr.Checkbox(
|
87 |
label="实时传输回答", value=True, visible=ENABLE_STREAMING_OPTION
|
88 |
)
|
|
|
353 |
# LLM Models
|
354 |
keyTxt.change(current_model.value.set_key, keyTxt, [status_display]).then(**get_usage_args)
|
355 |
keyTxt.submit(**get_usage_args)
|
356 |
+
model_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display, lora_select_dropdown], show_progress=True)
|
357 |
+
lora_select_dropdown.change(current_model.value.get_model, [model_select_dropdown, lora_select_dropdown, keyTxt, temperature_slider, top_p_slider, systemPromptTxt], [status_display], show_progress=True)
|
358 |
|
359 |
# Template
|
360 |
systemPromptTxt.change(current_model.value.set_system_prompt, [systemPromptTxt], None)
|
configs/ds_config_chatbot.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": false
|
4 |
+
},
|
5 |
+
"bf16": {
|
6 |
+
"enabled": true
|
7 |
+
},
|
8 |
+
"comms_logger": {
|
9 |
+
"enabled": false,
|
10 |
+
"verbose": false,
|
11 |
+
"prof_all": false,
|
12 |
+
"debug": false
|
13 |
+
},
|
14 |
+
"steps_per_print": 20000000000000000,
|
15 |
+
"train_micro_batch_size_per_gpu": 1,
|
16 |
+
"wall_clock_breakdown": false
|
17 |
+
}
|
modules/base_model.py
CHANGED
@@ -24,6 +24,7 @@ from .config import retrieve_proxy
|
|
24 |
|
25 |
|
26 |
class ModelType(Enum):
|
|
|
27 |
OpenAI = 0
|
28 |
ChatGLM = 1
|
29 |
LLaMA = 2
|
@@ -31,12 +32,15 @@ class ModelType(Enum):
|
|
31 |
@classmethod
|
32 |
def get_type(cls, model_name: str):
|
33 |
model_type = None
|
34 |
-
|
|
|
35 |
model_type = ModelType.OpenAI
|
36 |
-
elif "chatglm" in
|
37 |
model_type = ModelType.ChatGLM
|
38 |
-
|
39 |
model_type = ModelType.LLaMA
|
|
|
|
|
40 |
return model_type
|
41 |
|
42 |
|
|
|
24 |
|
25 |
|
26 |
class ModelType(Enum):
|
27 |
+
Unknown = -1
|
28 |
OpenAI = 0
|
29 |
ChatGLM = 1
|
30 |
LLaMA = 2
|
|
|
32 |
@classmethod
|
33 |
def get_type(cls, model_name: str):
|
34 |
model_type = None
|
35 |
+
model_name_lower = model_name.lower()
|
36 |
+
if "gpt" in model_name_lower:
|
37 |
model_type = ModelType.OpenAI
|
38 |
+
elif "chatglm" in model_name_lower:
|
39 |
model_type = ModelType.ChatGLM
|
40 |
+
elif "llama" in model_name_lower:
|
41 |
model_type = ModelType.LLaMA
|
42 |
+
else:
|
43 |
+
model_type = ModelType.Unknown
|
44 |
return model_type
|
45 |
|
46 |
|
modules/models.py
CHANGED
@@ -13,11 +13,6 @@ import platform
|
|
13 |
from dataclasses import dataclass, field
|
14 |
from transformers import HfArgumentParser
|
15 |
|
16 |
-
from lmflow.datasets.dataset import Dataset
|
17 |
-
from lmflow.pipeline.auto_pipeline import AutoPipeline
|
18 |
-
from lmflow.models.auto_model import AutoModel
|
19 |
-
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
|
20 |
-
|
21 |
from tqdm import tqdm
|
22 |
import colorama
|
23 |
from duckduckgo_search import ddg
|
@@ -203,12 +198,13 @@ class OpenAIClient(BaseLLMModel):
|
|
203 |
|
204 |
|
205 |
class ChatGLM_Client(BaseLLMModel):
|
206 |
-
def __init__(self, model_name
|
207 |
super().__init__(model_name=model_name)
|
208 |
from transformers import AutoTokenizer, AutoModel
|
209 |
import torch
|
210 |
|
211 |
system_name = platform.system()
|
|
|
212 |
if os.path.exists("models"):
|
213 |
model_dirs = os.listdir("models")
|
214 |
if model_name in model_dirs:
|
@@ -285,18 +281,49 @@ class LLaMA_Client(BaseLLMModel):
|
|
285 |
lora_path=None,
|
286 |
) -> None:
|
287 |
super().__init__(model_name=model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
self.max_generation_token = 1000
|
289 |
pipeline_name = "inferencer"
|
290 |
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
|
291 |
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
with open(pipeline_args.deepspeed, "r") as f:
|
302 |
ds_config = json.load(f)
|
@@ -377,23 +404,54 @@ class ModelManager:
|
|
377 |
top_p=None,
|
378 |
system_prompt=None,
|
379 |
) -> BaseLLMModel:
|
|
|
380 |
msg = f"模型设置为了: {model_name}"
|
381 |
logging.info(msg)
|
382 |
model_type = ModelType.get_type(model_name)
|
|
|
|
|
|
|
383 |
if model_type != ModelType.OpenAI:
|
384 |
config.local_embedding = True
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
def predict(self, *args):
|
399 |
iter = self.model.predict(*args)
|
|
|
13 |
from dataclasses import dataclass, field
|
14 |
from transformers import HfArgumentParser
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
from tqdm import tqdm
|
17 |
import colorama
|
18 |
from duckduckgo_search import ddg
|
|
|
198 |
|
199 |
|
200 |
class ChatGLM_Client(BaseLLMModel):
|
201 |
+
def __init__(self, model_name) -> None:
|
202 |
super().__init__(model_name=model_name)
|
203 |
from transformers import AutoTokenizer, AutoModel
|
204 |
import torch
|
205 |
|
206 |
system_name = platform.system()
|
207 |
+
model_path=None
|
208 |
if os.path.exists("models"):
|
209 |
model_dirs = os.listdir("models")
|
210 |
if model_name in model_dirs:
|
|
|
281 |
lora_path=None,
|
282 |
) -> None:
|
283 |
super().__init__(model_name=model_name)
|
284 |
+
from lmflow.datasets.dataset import Dataset
|
285 |
+
from lmflow.pipeline.auto_pipeline import AutoPipeline
|
286 |
+
from lmflow.models.auto_model import AutoModel
|
287 |
+
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments, InferencerArguments
|
288 |
+
model_path = None
|
289 |
+
if os.path.exists("models"):
|
290 |
+
model_dirs = os.listdir("models")
|
291 |
+
if model_name in model_dirs:
|
292 |
+
model_path = f"models/{model_name}"
|
293 |
+
if model_path is not None:
|
294 |
+
model_source = model_path
|
295 |
+
else:
|
296 |
+
raise Exception(f"models目录下没有这个模型: {model_name}")
|
297 |
self.max_generation_token = 1000
|
298 |
pipeline_name = "inferencer"
|
299 |
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
|
300 |
|
301 |
+
"""
|
302 |
+
if [ $# -ge 2 ]; then
|
303 |
+
lora_args="--lora_model_path $2"
|
304 |
+
fi
|
305 |
+
CUDA_VISIBLE_DEVICES=2 \
|
306 |
+
deepspeed examples/chatbot.py \
|
307 |
+
--deepspeed configs/ds_config_chatbot.json \
|
308 |
+
--model_name_or_path ${model} \
|
309 |
+
${lora_args}
|
310 |
+
|
311 |
+
model_args:
|
312 |
+
ModelArguments(model_name_or_path='/home/guest/llm_models/llama/7B', lora_model_path='/home/guest/llm_models/lora/baize-lora-7B', model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
|
313 |
+
pipeline_args:
|
314 |
+
InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
|
315 |
+
"""
|
316 |
+
|
317 |
+
# parser = HfArgumentParser(
|
318 |
+
# (
|
319 |
+
# ModelArguments,
|
320 |
+
# PipelineArguments,
|
321 |
+
# ChatbotArguments,
|
322 |
+
# )
|
323 |
+
# )
|
324 |
+
model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
|
325 |
+
pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
|
326 |
+
# model_args, pipeline_args, chatbot_args = parser.parse_args_into_dataclasses()
|
327 |
|
328 |
with open(pipeline_args.deepspeed, "r") as f:
|
329 |
ds_config = json.load(f)
|
|
|
404 |
top_p=None,
|
405 |
system_prompt=None,
|
406 |
) -> BaseLLMModel:
|
407 |
+
print(lora_model_path)
|
408 |
msg = f"模型设置为了: {model_name}"
|
409 |
logging.info(msg)
|
410 |
model_type = ModelType.get_type(model_name)
|
411 |
+
lora_selector_visibility = False
|
412 |
+
lora_choices = []
|
413 |
+
dont_change_lora_selector = False
|
414 |
if model_type != ModelType.OpenAI:
|
415 |
config.local_embedding = True
|
416 |
+
model = None
|
417 |
+
try:
|
418 |
+
if model_type == ModelType.OpenAI:
|
419 |
+
model = OpenAIClient(
|
420 |
+
model_name=model_name,
|
421 |
+
api_key=access_key,
|
422 |
+
system_prompt=system_prompt,
|
423 |
+
temperature=temperature,
|
424 |
+
top_p=top_p,
|
425 |
+
)
|
426 |
+
elif model_type == ModelType.ChatGLM:
|
427 |
+
model = ChatGLM_Client(model_name)
|
428 |
+
elif model_type == ModelType.LLaMA and lora_model_path == "":
|
429 |
+
msg = "现在请选择LoRA模型"
|
430 |
+
logging.info(msg)
|
431 |
+
lora_selector_visibility = True
|
432 |
+
if os.path.isdir("lora"):
|
433 |
+
lora_choices = get_file_names("lora", plain=True, filetypes=[""])
|
434 |
+
lora_choices = ["No LoRA"] + lora_choices
|
435 |
+
elif model_type == ModelType.LLaMA and lora_model_path != "":
|
436 |
+
dont_change_lora_selector = True
|
437 |
+
if lora_model_path == "No LoRA":
|
438 |
+
lora_model_path = None
|
439 |
+
msg += " + No LoRA"
|
440 |
+
else:
|
441 |
+
msg += f" + {lora_model_path}"
|
442 |
+
model = LLaMA_Client(model_name, lora_model_path)
|
443 |
+
pass
|
444 |
+
elif model_type == ModelType.Unknown:
|
445 |
+
raise ValueError(f"未知模型: {model_name}")
|
446 |
+
except Exception as e:
|
447 |
+
logging.error(e)
|
448 |
+
msg = f"{STANDARD_ERROR_MSG}: {e}"
|
449 |
+
if model is not None:
|
450 |
+
self.model = model
|
451 |
+
if dont_change_lora_selector:
|
452 |
+
return msg
|
453 |
+
else:
|
454 |
+
return msg, gr.Dropdown.update(choices=lora_choices, visible=lora_selector_visibility)
|
455 |
|
456 |
def predict(self, *args):
|
457 |
iter = self.model.predict(*args)
|
modules/presets.py
CHANGED
@@ -59,7 +59,15 @@ MODELS = [
|
|
59 |
"gpt-4-32k-0314",
|
60 |
"chatglm-6b",
|
61 |
"chatglm-6b-int4",
|
62 |
-
"chatglm-6b-int4-qe"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
] # 可选的模型
|
64 |
|
65 |
DEFAULT_MODEL = 0 # 默认的模型在MODELS中的序号,从0开始数
|
|
|
59 |
"gpt-4-32k-0314",
|
60 |
"chatglm-6b",
|
61 |
"chatglm-6b-int4",
|
62 |
+
"chatglm-6b-int4-qe",
|
63 |
+
"llama-7b-hf",
|
64 |
+
"llama-7b-hf-int4",
|
65 |
+
"llama-7b-hf-int8",
|
66 |
+
"llama-13b-hf",
|
67 |
+
"llama-13b-hf-int4",
|
68 |
+
"llama-30b-hf",
|
69 |
+
"llama-30b-hf-int4",
|
70 |
+
"llama-65b-hf",
|
71 |
] # 可选的模型
|
72 |
|
73 |
DEFAULT_MODEL = 0 # 默认的模型在MODELS中的序号,从0开始数
|
modules/utils.py
CHANGED
@@ -445,3 +445,7 @@ def get_last_day_of_month(any_day):
|
|
445 |
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
446 |
# subtracting the number of the current day brings us back one month
|
447 |
return next_month - datetime.timedelta(days=next_month.day)
|
|
|
|
|
|
|
|
|
|
445 |
next_month = any_day.replace(day=28) + datetime.timedelta(days=4)
|
446 |
# subtracting the number of the current day brings us back one month
|
447 |
return next_month - datetime.timedelta(days=next_month.day)
|
448 |
+
|
449 |
+
def get_model_source(model_name, alternative_source):
|
450 |
+
if model_name == "gpt2-medium":
|
451 |
+
return "https://huggingface.co/gpt2-medium"
|