johnsmith253325 commited on
Commit
52cd289
1 Parent(s): 7691698

feat: 初步加入LLaMA.cpp支持

Browse files
.gitignore CHANGED
@@ -141,6 +141,7 @@ api_key.txt
141
  config.json
142
  auth.json
143
  .models/
 
144
  lora/
145
  .idea
146
  templates/*
 
141
  config.json
142
  auth.json
143
  .models/
144
+ models/*
145
  lora/
146
  .idea
147
  templates/*
modules/models/{azure.py → Azure.py} RENAMED
File without changes
modules/models/LLaMA.py CHANGED
@@ -3,11 +3,40 @@ from __future__ import annotations
3
  import json
4
  import os
5
 
 
 
 
6
  from ..index_func import *
7
  from ..presets import *
8
  from ..utils import *
9
  from .base_model import BaseLLMModel
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class LLaMA_Client(BaseLLMModel):
13
  def __init__(
@@ -17,51 +46,28 @@ class LLaMA_Client(BaseLLMModel):
17
  user_name=""
18
  ) -> None:
19
  super().__init__(model_name=model_name, user=user_name)
20
- from lmflow.args import (DatasetArguments, InferencerArguments,
21
- ModelArguments)
22
- from lmflow.datasets.dataset import Dataset
23
- from lmflow.models.auto_model import AutoModel
24
- from lmflow.pipeline.auto_pipeline import AutoPipeline
25
 
26
  self.max_generation_token = 1000
27
  self.end_string = "\n\n"
28
  # We don't need input data
29
- data_args = DatasetArguments(dataset_path=None)
30
- self.dataset = Dataset(data_args)
31
  self.system_prompt = ""
32
 
33
- global LLAMA_MODEL, LLAMA_INFERENCER
34
- if LLAMA_MODEL is None or LLAMA_INFERENCER is None:
35
- model_path = None
36
- if os.path.exists("models"):
37
- model_dirs = os.listdir("models")
38
- if model_name in model_dirs:
39
- model_path = f"models/{model_name}"
40
- if model_path is not None:
41
- model_source = model_path
42
- else:
43
- model_source = f"decapoda-research/{model_name}"
 
44
  # raise Exception(f"models目录下没有这个模型: {model_name}")
45
- if lora_path is not None:
46
- lora_path = f"lora/{lora_path}"
47
- model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None,
48
- use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
49
- pipeline_args = InferencerArguments(
50
- local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
51
-
52
- with open(pipeline_args.deepspeed, "r", encoding="utf-8") as f:
53
- ds_config = json.load(f)
54
- LLAMA_MODEL = AutoModel.get_model(
55
- model_args,
56
- tune_strategy="none",
57
- ds_config=ds_config,
58
- )
59
- LLAMA_INFERENCER = AutoPipeline.get_pipeline(
60
- pipeline_name="inferencer",
61
- model_args=model_args,
62
- data_args=data_args,
63
- pipeline_args=pipeline_args,
64
- )
65
 
66
  def _get_llama_style_input(self):
67
  history = []
@@ -79,38 +85,14 @@ class LLaMA_Client(BaseLLMModel):
79
 
80
  def get_answer_at_once(self):
81
  context = self._get_llama_style_input()
82
-
83
- input_dataset = self.dataset.from_dict(
84
- {"type": "text_only", "instances": [{"text": context}]}
85
- )
86
-
87
- output_dataset = LLAMA_INFERENCER.inference(
88
- model=LLAMA_MODEL,
89
- dataset=input_dataset,
90
- max_new_tokens=self.max_generation_token,
91
- temperature=self.temperature,
92
- )
93
-
94
- response = output_dataset.to_dict()["instances"][0]["text"]
95
  return response, len(response)
96
 
97
  def get_answer_stream_iter(self):
98
  context = self._get_llama_style_input()
 
99
  partial_text = ""
100
- step = 1
101
- for _ in range(0, self.max_generation_token, step):
102
- input_dataset = self.dataset.from_dict(
103
- {"type": "text_only", "instances": [
104
- {"text": context + partial_text}]}
105
- )
106
- output_dataset = LLAMA_INFERENCER.inference(
107
- model=LLAMA_MODEL,
108
- dataset=input_dataset,
109
- max_new_tokens=step,
110
- temperature=self.temperature,
111
- )
112
- response = output_dataset.to_dict()["instances"][0]["text"]
113
- if response == "" or response == self.end_string:
114
- break
115
  partial_text += response
116
  yield partial_text
 
3
  import json
4
  import os
5
 
6
+ from huggingface_hub import hf_hub_download
7
+ from llama_cpp import Llama
8
+
9
  from ..index_func import *
10
  from ..presets import *
11
  from ..utils import *
12
  from .base_model import BaseLLMModel
13
 
14
+ import json
15
+ from llama_cpp import Llama
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ def download(repo_id, filename, retry=10):
19
+ if os.path.exists("./models/downloaded_models.json"):
20
+ with open("./models/downloaded_models.json", "r") as f:
21
+ downloaded_models = json.load(f)
22
+ if repo_id in downloaded_models:
23
+ return downloaded_models[repo_id]["path"]
24
+ else:
25
+ downloaded_models = {}
26
+ while retry > 0:
27
+ try:
28
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir="models", resume_download=True)
29
+ downloaded_models[repo_id] = {"path": model_path}
30
+ with open("./models/downloaded_models.json", "w") as f:
31
+ json.dump(downloaded_models, f)
32
+ break
33
+ except:
34
+ print("Error downloading model, retrying...")
35
+ retry -= 1
36
+ if retry == 0:
37
+ raise Exception("Error downloading model, please try again later.")
38
+ return model_path
39
+
40
 
41
  class LLaMA_Client(BaseLLMModel):
42
  def __init__(
 
46
  user_name=""
47
  ) -> None:
48
  super().__init__(model_name=model_name, user=user_name)
 
 
 
 
 
49
 
50
  self.max_generation_token = 1000
51
  self.end_string = "\n\n"
52
  # We don't need input data
53
+ path_to_model = download(MODEL_METADATA[model_name]["repo_id"], MODEL_METADATA[model_name]["filelist"][0])
 
54
  self.system_prompt = ""
55
 
56
+ global LLAMA_MODEL
57
+ if LLAMA_MODEL is None:
58
+ LLAMA_MODEL = Llama(model_path=path_to_model)
59
+ # model_path = None
60
+ # if os.path.exists("models"):
61
+ # model_dirs = os.listdir("models")
62
+ # if model_name in model_dirs:
63
+ # model_path = f"models/{model_name}"
64
+ # if model_path is not None:
65
+ # model_source = model_path
66
+ # else:
67
+ # model_source = f"decapoda-research/{model_name}"
68
  # raise Exception(f"models目录下没有这个模型: {model_name}")
69
+ # if lora_path is not None:
70
+ # lora_path = f"lora/{lora_path}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def _get_llama_style_input(self):
73
  history = []
 
85
 
86
  def get_answer_at_once(self):
87
  context = self._get_llama_style_input()
88
+ response = LLAMA_MODEL(context, max_tokens=self.max_generation_token, stop=[], echo=False, stream=False)
 
 
 
 
 
 
 
 
 
 
 
 
89
  return response, len(response)
90
 
91
  def get_answer_stream_iter(self):
92
  context = self._get_llama_style_input()
93
+ iter = LLAMA_MODEL(context, max_tokens=self.max_generation_token, stop=[], echo=False, stream=True)
94
  partial_text = ""
95
+ for i in iter:
96
+ response = i["choices"][0]["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  partial_text += response
98
  yield partial_text
modules/presets.py CHANGED
@@ -83,10 +83,7 @@ LOCAL_MODELS = [
83
  "chatglm2-6b-int4",
84
  "StableLM",
85
  "MOSS",
86
- "llama-7b-hf",
87
- "llama-13b-hf",
88
- "llama-30b-hf",
89
- "llama-65b-hf",
90
  ]
91
 
92
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
@@ -134,8 +131,8 @@ REPLY_LANGUAGES = [
134
  ]
135
 
136
  HISTORY_NAME_METHODS = [
137
- i18n("根据日期时间"),
138
- i18n("第一条提问"),
139
  i18n("模型自动总结(消耗tokens)"),
140
  ]
141
 
@@ -266,3 +263,10 @@ small_and_beautiful_theme = gr.themes.Soft(
266
  chatbot_code_background_color_dark="*neutral_950",
267
  )
268
 
 
 
 
 
 
 
 
 
83
  "chatglm2-6b-int4",
84
  "StableLM",
85
  "MOSS",
86
+ "Llama-2-7B",
 
 
 
87
  ]
88
 
89
  if os.environ.get('HIDE_LOCAL_MODELS', 'false') == 'true':
 
131
  ]
132
 
133
  HISTORY_NAME_METHODS = [
134
+ i18n("根据日期时间"),
135
+ i18n("第一条提问"),
136
  i18n("模型自动总结(消耗tokens)"),
137
  ]
138
 
 
263
  chatbot_code_background_color_dark="*neutral_950",
264
  )
265
 
266
+ # Additional metadate for local models
267
+ MODEL_METADATA = {
268
+ "Llama-2-7B":{
269
+ "repo_id": "TheBloke/Llama-2-7B-GGUF",
270
+ "filelist": ["llama-2-7b.Q6_K.gguf"],
271
+ }
272
+ }