Tuchuanhuhuhu commited on
Commit
ddd1766
1 Parent(s): 469aa95

LLaMA + LoRA可以用了

Browse files
Files changed (2) hide show
  1. modules/base_model.py +2 -0
  2. modules/models.py +28 -31
modules/base_model.py CHANGED
@@ -70,6 +70,7 @@ class BaseLLMModel:
70
  self.interrupted = False
71
  self.system_prompt = system_prompt
72
  self.api_key = None
 
73
 
74
  self.temperature = temperature
75
  self.top_p = top_p
@@ -263,6 +264,7 @@ class BaseLLMModel:
263
  display_reference = ""
264
 
265
  if (
 
266
  self.api_key is None
267
  and not shared.state.multi_api_key
268
  ):
 
70
  self.interrupted = False
71
  self.system_prompt = system_prompt
72
  self.api_key = None
73
+ self.need_api_key = False
74
 
75
  self.temperature = temperature
76
  self.top_p = top_p
 
264
  display_reference = ""
265
 
266
  if (
267
+ self.need_api_key and
268
  self.api_key is None
269
  and not shared.state.multi_api_key
270
  ):
modules/models.py CHANGED
@@ -42,6 +42,7 @@ class OpenAIClient(BaseLLMModel):
42
  system_prompt=system_prompt,
43
  )
44
  self.api_key = api_key
 
45
  self.headers = {
46
  "Content-Type": "application/json",
47
  "Authorization": f"Bearer {self.api_key}",
@@ -276,7 +277,7 @@ class LLaMA_Client(BaseLLMModel):
276
  from lmflow.datasets.dataset import Dataset
277
  from lmflow.pipeline.auto_pipeline import AutoPipeline
278
  from lmflow.models.auto_model import AutoModel
279
- from lmflow.args import ModelArguments, DatasetArguments, AutoArguments, InferencerArguments
280
  model_path = None
281
  if os.path.exists("models"):
282
  model_dirs = os.listdir("models")
@@ -286,36 +287,12 @@ class LLaMA_Client(BaseLLMModel):
286
  model_source = model_path
287
  else:
288
  raise Exception(f"models目录下没有这个模型: {model_name}")
 
 
289
  self.max_generation_token = 1000
290
  pipeline_name = "inferencer"
291
- PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
292
-
293
- """
294
- if [ $# -ge 2 ]; then
295
- lora_args="--lora_model_path $2"
296
- fi
297
- CUDA_VISIBLE_DEVICES=2 \
298
- deepspeed examples/chatbot.py \
299
- --deepspeed configs/ds_config_chatbot.json \
300
- --model_name_or_path ${model} \
301
- ${lora_args}
302
-
303
- model_args:
304
- ModelArguments(model_name_or_path='/home/guest/llm_models/llama/7B', lora_model_path='/home/guest/llm_models/lora/baize-lora-7B', model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
305
- pipeline_args:
306
- InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
307
- """
308
-
309
- # parser = HfArgumentParser(
310
- # (
311
- # ModelArguments,
312
- # PipelineArguments,
313
- # ChatbotArguments,
314
- # )
315
- # )
316
  model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
317
  pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
318
- # model_args, pipeline_args, chatbot_args = parser.parse_args_into_dataclasses()
319
 
320
  with open(pipeline_args.deepspeed, "r") as f:
321
  ds_config = json.load(f)
@@ -326,7 +303,7 @@ class LLaMA_Client(BaseLLMModel):
326
  ds_config=ds_config,
327
  )
328
 
329
- # We don't need input data, we will read interactively from stdin
330
  data_args = DatasetArguments(dataset_path=None)
331
  self.dataset = Dataset(data_args)
332
 
@@ -379,7 +356,28 @@ class LLaMA_Client(BaseLLMModel):
379
  return response, len(response)
380
 
381
  def get_answer_stream_iter(self):
382
- response, _ = self.get_answer_at_once()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  yield response
384
 
385
 
@@ -396,9 +394,7 @@ class ModelManager:
396
  top_p=None,
397
  system_prompt=None,
398
  ) -> BaseLLMModel:
399
- print(lora_model_path)
400
  msg = f"模型设置为了: {model_name}"
401
- logging.info(msg)
402
  model_type = ModelType.get_type(model_name)
403
  lora_selector_visibility = False
404
  lora_choices = []
@@ -435,6 +431,7 @@ class ModelManager:
435
  pass
436
  elif model_type == ModelType.Unknown:
437
  raise ValueError(f"未知模型: {model_name}")
 
438
  except Exception as e:
439
  logging.error(e)
440
  msg = f"{STANDARD_ERROR_MSG}: {e}"
 
42
  system_prompt=system_prompt,
43
  )
44
  self.api_key = api_key
45
+ self.need_api_key = True
46
  self.headers = {
47
  "Content-Type": "application/json",
48
  "Authorization": f"Bearer {self.api_key}",
 
277
  from lmflow.datasets.dataset import Dataset
278
  from lmflow.pipeline.auto_pipeline import AutoPipeline
279
  from lmflow.models.auto_model import AutoModel
280
+ from lmflow.args import ModelArguments, DatasetArguments, InferencerArguments
281
  model_path = None
282
  if os.path.exists("models"):
283
  model_dirs = os.listdir("models")
 
287
  model_source = model_path
288
  else:
289
  raise Exception(f"models目录下没有这个模型: {model_name}")
290
+ if lora_path is not None:
291
+ lora_path = f"lora/{lora_path}"
292
  self.max_generation_token = 1000
293
  pipeline_name = "inferencer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  model_args = ModelArguments(model_name_or_path=model_source, lora_model_path=lora_path, model_type=None, config_overrides=None, config_name=None, tokenizer_name=None, cache_dir=None, use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None, use_lora=False, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_ram_optimized_load=True)
295
  pipeline_args = InferencerArguments(local_rank=0, random_seed=1, deepspeed='configs/ds_config_chatbot.json', mixed_precision='bf16')
 
296
 
297
  with open(pipeline_args.deepspeed, "r") as f:
298
  ds_config = json.load(f)
 
303
  ds_config=ds_config,
304
  )
305
 
306
+ # We don't need input data
307
  data_args = DatasetArguments(dataset_path=None)
308
  self.dataset = Dataset(data_args)
309
 
 
356
  return response, len(response)
357
 
358
  def get_answer_stream_iter(self):
359
+ context = self._get_llama_style_input()
360
+
361
+ input_dataset = self.dataset.from_dict(
362
+ {"type": "text_only", "instances": [{"text": context}]}
363
+ )
364
+
365
+ output_dataset = self.inferencer.inference(
366
+ model=self.model,
367
+ dataset=input_dataset,
368
+ max_new_tokens=self.max_generation_token,
369
+ temperature=self.temperature,
370
+ )
371
+
372
+ response = output_dataset.to_dict()["instances"][0]["text"]
373
+
374
+ try:
375
+ index = response.index(self.end_string)
376
+ except ValueError:
377
+ response += self.end_string
378
+ index = response.index(self.end_string)
379
+
380
+ response = response[: index + 1]
381
  yield response
382
 
383
 
 
394
  top_p=None,
395
  system_prompt=None,
396
  ) -> BaseLLMModel:
 
397
  msg = f"模型设置为了: {model_name}"
 
398
  model_type = ModelType.get_type(model_name)
399
  lora_selector_visibility = False
400
  lora_choices = []
 
431
  pass
432
  elif model_type == ModelType.Unknown:
433
  raise ValueError(f"未知模型: {model_name}")
434
+ logging.info(msg)
435
  except Exception as e:
436
  logging.error(e)
437
  msg = f"{STANDARD_ERROR_MSG}: {e}"