masanorihirano commited on
Commit
d91928f
·
1 Parent(s): dc15b84
Files changed (2) hide show
  1. app.py +17 -15
  2. pyproject.toml +3 -2
app.py CHANGED
@@ -7,8 +7,13 @@ from typing import Tuple
7
 
8
  import gradio as gr
9
  import torch
 
 
10
  from huggingface_hub import Repository
 
 
11
  from peft import PeftModel
 
12
  from transformers import AutoModelForCausalLM
13
  from transformers import GenerationConfig
14
  from transformers import LlamaTokenizer
@@ -58,35 +63,32 @@ try:
58
  except Exception:
59
  pass
60
 
 
61
  if device == "cuda":
62
  model = AutoModelForCausalLM.from_pretrained(
63
- BASE_MODEL,
64
- load_in_8bit=True,
65
- device_map="auto",
66
  )
67
- model = PeftModel.from_pretrained(model, LORA_WEIGHTS, load_in_8bit=True,)
68
  elif device == "mps":
69
  model = AutoModelForCausalLM.from_pretrained(
70
  BASE_MODEL,
71
  device_map={"": device},
72
  load_in_8bit=True,
73
- )
74
- model = PeftModel.from_pretrained(
75
- model,
76
- LORA_WEIGHTS,
77
- device_map={"": device},
78
- load_in_8bit=True,
79
  )
80
  else:
81
  model = AutoModelForCausalLM.from_pretrained(
82
- BASE_MODEL, device_map={"": device},load_in_8bit=True, low_cpu_mem_usage=True
83
- )
84
- model = PeftModel.from_pretrained(
85
- model,
86
- LORA_WEIGHTS,
87
  device_map={"": device},
88
  load_in_8bit=True,
 
 
89
  )
 
 
 
 
 
 
90
 
91
 
92
  def generate_prompt(instruction: str, input: Optional[str] = None):
 
7
 
8
  import gradio as gr
9
  import torch
10
+ from fastchat.serve.inference import compress_module
11
+ from fastchat.serve.inference import raise_warning_for_old_weights
12
  from huggingface_hub import Repository
13
+ from huggingface_hub import hf_hub_download
14
+ from peft import LoraConfig
15
  from peft import PeftModel
16
+ from peft import set_peft_model_state_dict
17
  from transformers import AutoModelForCausalLM
18
  from transformers import GenerationConfig
19
  from transformers import LlamaTokenizer
 
63
  except Exception:
64
  pass
65
 
66
+ checkpoint_name = hf_hub_download(repo_id=LORA_WEIGHTS, filename="adapter_model.bin", use_auth_token=HF_TOKEN)
67
  if device == "cuda":
68
  model = AutoModelForCausalLM.from_pretrained(
69
+ BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
 
 
70
  )
 
71
  elif device == "mps":
72
  model = AutoModelForCausalLM.from_pretrained(
73
  BASE_MODEL,
74
  device_map={"": device},
75
  load_in_8bit=True,
76
+ torch_dtype=torch.float16,
 
 
 
 
 
77
  )
78
  else:
79
  model = AutoModelForCausalLM.from_pretrained(
80
+ BASE_MODEL,
 
 
 
 
81
  device_map={"": device},
82
  load_in_8bit=True,
83
+ low_cpu_mem_usage=True,
84
+ torch_dtype=torch.float16,
85
  )
86
+ adapters_weights = torch.load(checkpoint_name)
87
+ set_peft_model_state_dict(model, adapters_weights)
88
+ raise_warning_for_old_weights(BASE_MODEL, model)
89
+ compress_module(model, device)
90
+ if device == "cuda" or device == "mps":
91
+ model = model.to(device)
92
 
93
 
94
  def generate_prompt(instruction: str, input: Optional[str] = None):
pyproject.toml CHANGED
@@ -9,13 +9,14 @@ readme = "README.md"
9
  [tool.poetry.dependencies]
10
  python = "^3.9"
11
  peft = "^0.3.0"
12
- transformers = {git = "https://github.com/huggingface/transformers.git", branch = "main"}
13
- gradio = "^3.32.0"
14
  torch = "^2.0.1"
15
  huggingface-hub = "^0.14.1"
16
  sentencepiece = "^0.1.99"
17
  bitsandbytes = "^0.38.1"
18
  accelerate = "^0.19.0"
 
 
19
 
20
 
21
  [tool.poetry.group.dev.dependencies]
 
9
  [tool.poetry.dependencies]
10
  python = "^3.9"
11
  peft = "^0.3.0"
12
+ gradio = "^3.23.0"
 
13
  torch = "^2.0.1"
14
  huggingface-hub = "^0.14.1"
15
  sentencepiece = "^0.1.99"
16
  bitsandbytes = "^0.38.1"
17
  accelerate = "^0.19.0"
18
+ fschat = "^0.2.3"
19
+ transformers = "^4.29.2"
20
 
21
 
22
  [tool.poetry.group.dev.dependencies]