Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
98dba8d
1
Parent(s):
f3676ed
support hf_access_token
Browse files- README.md +2 -1
- app.py +3 -0
- llama_lora/config.py +8 -2
- llama_lora/lib/finetune.py +18 -7
- llama_lora/models.py +21 -7
README.md
CHANGED
@@ -78,11 +78,12 @@ setup: |
|
|
78 |
echo "Pre-downloading base models so that you won't have to wait for long once the app is ready..."
|
79 |
python llm_tuner/download_base_model.py --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'
|
80 |
|
81 |
-
# Start the app. `wandb_api_key` and `wandb_project_name` are optional.
|
82 |
run: |
|
83 |
conda activate llm-tuner
|
84 |
python llm_tuner/app.py \
|
85 |
--data_dir='/data' \
|
|
|
86 |
--wandb_api_key="$([ -f /data/secrets/wandb_api_key.txt ] && cat /data/secrets/wandb_api_key.txt | tr -d '\n')" \
|
87 |
--wandb_project_name='llm-tuner' \
|
88 |
--timezone='Atlantic/Reykjavik' \
|
|
|
78 |
echo "Pre-downloading base models so that you won't have to wait for long once the app is ready..."
|
79 |
python llm_tuner/download_base_model.py --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'
|
80 |
|
81 |
+
# Start the app. `hf_access_token`, `wandb_api_key` and `wandb_project_name` are optional.
|
82 |
run: |
|
83 |
conda activate llm-tuner
|
84 |
python llm_tuner/app.py \
|
85 |
--data_dir='/data' \
|
86 |
+
--hf_access_token="$([ -f /data/secrets/hf_access_token.txt ] && cat /data/secrets/hf_access_token.txt | tr -d '\n')" \
|
87 |
--wandb_api_key="$([ -f /data/secrets/wandb_api_key.txt ] && cat /data/secrets/wandb_api_key.txt | tr -d '\n')" \
|
88 |
--wandb_project_name='llm-tuner' \
|
89 |
--timezone='Atlantic/Reykjavik' \
|
app.py
CHANGED
@@ -29,6 +29,7 @@ def main(
|
|
29 |
ui_dev_mode: Union[bool, None] = None,
|
30 |
wandb_api_key: Union[str, None] = None,
|
31 |
wandb_project: Union[str, None] = None,
|
|
|
32 |
timezone: Union[str, None] = None,
|
33 |
config: Union[str, None] = None,
|
34 |
):
|
@@ -45,6 +46,8 @@ def main(
|
|
45 |
|
46 |
:param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
|
47 |
:param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
|
|
|
|
|
48 |
'''
|
49 |
|
50 |
config_from_file = read_yaml_config(config_path=config)
|
|
|
29 |
ui_dev_mode: Union[bool, None] = None,
|
30 |
wandb_api_key: Union[str, None] = None,
|
31 |
wandb_project: Union[str, None] = None,
|
32 |
+
hf_access_token: Union[str, None] = None,
|
33 |
timezone: Union[str, None] = None,
|
34 |
config: Union[str, None] = None,
|
35 |
):
|
|
|
46 |
|
47 |
:param wandb_api_key: The API key for Weights & Biases. Setting either this or `wandb_project` will enable Weights & Biases.
|
48 |
:param wandb_project: The default project name for Weights & Biases. Setting either this or `wandb_api_key` will enable Weights & Biases.
|
49 |
+
|
50 |
+
:param hf_access_token: Provide an access token to load private models form Hugging Face Hub. An access token can be created at https://huggingface.co/settings/tokens.
|
51 |
'''
|
52 |
|
53 |
config_from_file = read_yaml_config(config_path=config)
|
llama_lora/config.py
CHANGED
@@ -8,19 +8,25 @@ class Config:
|
|
8 |
Stores the application configuration. This is a singleton class.
|
9 |
"""
|
10 |
|
|
|
11 |
data_dir: str = ""
|
12 |
-
load_8bit: bool = False
|
13 |
|
|
|
14 |
default_base_model_name: str = ""
|
15 |
base_model_choices: Union[List[str], str] = []
|
16 |
-
|
17 |
trust_remote_code: bool = False
|
18 |
|
|
|
19 |
timezone: Any = pytz.UTC
|
20 |
|
|
|
21 |
auth_username: Union[str, None] = None
|
22 |
auth_password: Union[str, None] = None
|
23 |
|
|
|
|
|
|
|
24 |
# WandB
|
25 |
enable_wandb: Union[bool, None] = None
|
26 |
wandb_api_key: Union[str, None] = None
|
|
|
8 |
Stores the application configuration. This is a singleton class.
|
9 |
"""
|
10 |
|
11 |
+
# Where data is stored
|
12 |
data_dir: str = ""
|
|
|
13 |
|
14 |
+
# Model Related
|
15 |
default_base_model_name: str = ""
|
16 |
base_model_choices: Union[List[str], str] = []
|
17 |
+
load_8bit: bool = False
|
18 |
trust_remote_code: bool = False
|
19 |
|
20 |
+
# Application Settings
|
21 |
timezone: Any = pytz.UTC
|
22 |
|
23 |
+
# Authentication
|
24 |
auth_username: Union[str, None] = None
|
25 |
auth_password: Union[str, None] = None
|
26 |
|
27 |
+
# Hugging Face
|
28 |
+
hf_access_token: Union[str, None] = None
|
29 |
+
|
30 |
# WandB
|
31 |
enable_wandb: Union[bool, None] = None
|
32 |
wandb_api_key: Union[str, None] = None
|
llama_lora/lib/finetune.py
CHANGED
@@ -71,6 +71,7 @@ def train(
|
|
71 |
wandb_watch: str = "false", # options: false | gradients | all
|
72 |
wandb_log_model: str = "true", # options: false | true
|
73 |
additional_wandb_config: Union[dict, None] = None,
|
|
|
74 |
status_message_callback: Any = None,
|
75 |
params_info_callback: Any = None,
|
76 |
):
|
@@ -88,9 +89,11 @@ def train(
|
|
88 |
additional_training_arguments = None
|
89 |
if isinstance(additional_training_arguments, str):
|
90 |
try:
|
91 |
-
additional_training_arguments = json.loads(
|
|
|
92 |
except Exception as e:
|
93 |
-
raise ValueError(
|
|
|
94 |
|
95 |
if isinstance(additional_lora_config, str):
|
96 |
additional_lora_config = additional_lora_config.strip()
|
@@ -183,11 +186,13 @@ def train(
|
|
183 |
|
184 |
if status_message_callback:
|
185 |
if isinstance(base_model, str):
|
186 |
-
cb_result = status_message_callback(
|
|
|
187 |
if cb_result:
|
188 |
return
|
189 |
else:
|
190 |
-
cb_result = status_message_callback(
|
|
|
191 |
if cb_result:
|
192 |
return
|
193 |
|
@@ -201,6 +206,7 @@ def train(
|
|
201 |
torch_dtype=torch.float16,
|
202 |
llm_int8_skip_modules=lora_modules_to_save,
|
203 |
device_map=device_map,
|
|
|
204 |
)
|
205 |
if re.match("[^/]+/llama", model_name):
|
206 |
print(f"Setting special tokens for LLaMA model {model_name}...")
|
@@ -213,11 +219,14 @@ def train(
|
|
213 |
if isinstance(tokenizer, str):
|
214 |
tokenizer_name = tokenizer
|
215 |
try:
|
216 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
217 |
except Exception as e:
|
218 |
if 'LLaMATokenizer' in str(e):
|
219 |
tokenizer = LlamaTokenizer.from_pretrained(
|
220 |
tokenizer_name,
|
|
|
221 |
)
|
222 |
else:
|
223 |
raise e
|
@@ -243,7 +252,8 @@ def train(
|
|
243 |
f"Got error while running prepare_model_for_int8_training(model), maybe the model has already be prepared. Original error: {e}.")
|
244 |
|
245 |
if status_message_callback:
|
246 |
-
cb_result = status_message_callback(
|
|
|
247 |
if cb_result:
|
248 |
return
|
249 |
|
@@ -299,7 +309,8 @@ def train(
|
|
299 |
wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
|
300 |
"trainable%": 100 * trainable_params / all_params}})
|
301 |
if params_info_callback:
|
302 |
-
cb_result = params_info_callback(
|
|
|
303 |
if cb_result:
|
304 |
return
|
305 |
|
|
|
71 |
wandb_watch: str = "false", # options: false | gradients | all
|
72 |
wandb_log_model: str = "true", # options: false | true
|
73 |
additional_wandb_config: Union[dict, None] = None,
|
74 |
+
hf_access_token: Union[str, None] = None,
|
75 |
status_message_callback: Any = None,
|
76 |
params_info_callback: Any = None,
|
77 |
):
|
|
|
89 |
additional_training_arguments = None
|
90 |
if isinstance(additional_training_arguments, str):
|
91 |
try:
|
92 |
+
additional_training_arguments = json.loads(
|
93 |
+
additional_training_arguments)
|
94 |
except Exception as e:
|
95 |
+
raise ValueError(
|
96 |
+
f"Could not parse additional_training_arguments: {e}")
|
97 |
|
98 |
if isinstance(additional_lora_config, str):
|
99 |
additional_lora_config = additional_lora_config.strip()
|
|
|
186 |
|
187 |
if status_message_callback:
|
188 |
if isinstance(base_model, str):
|
189 |
+
cb_result = status_message_callback(
|
190 |
+
f"Preparing model '{base_model}' for training...")
|
191 |
if cb_result:
|
192 |
return
|
193 |
else:
|
194 |
+
cb_result = status_message_callback(
|
195 |
+
"Preparing model for training...")
|
196 |
if cb_result:
|
197 |
return
|
198 |
|
|
|
206 |
torch_dtype=torch.float16,
|
207 |
llm_int8_skip_modules=lora_modules_to_save,
|
208 |
device_map=device_map,
|
209 |
+
use_auth_token=hf_access_token
|
210 |
)
|
211 |
if re.match("[^/]+/llama", model_name):
|
212 |
print(f"Setting special tokens for LLaMA model {model_name}...")
|
|
|
219 |
if isinstance(tokenizer, str):
|
220 |
tokenizer_name = tokenizer
|
221 |
try:
|
222 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
223 |
+
tokenizer, use_auth_token=hf_access_token
|
224 |
+
)
|
225 |
except Exception as e:
|
226 |
if 'LLaMATokenizer' in str(e):
|
227 |
tokenizer = LlamaTokenizer.from_pretrained(
|
228 |
tokenizer_name,
|
229 |
+
use_auth_token=hf_access_token
|
230 |
)
|
231 |
else:
|
232 |
raise e
|
|
|
252 |
f"Got error while running prepare_model_for_int8_training(model), maybe the model has already be prepared. Original error: {e}.")
|
253 |
|
254 |
if status_message_callback:
|
255 |
+
cb_result = status_message_callback(
|
256 |
+
"Preparing PEFT model for training...")
|
257 |
if cb_result:
|
258 |
return
|
259 |
|
|
|
309 |
wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
|
310 |
"trainable%": 100 * trainable_params / all_params}})
|
311 |
if params_info_callback:
|
312 |
+
cb_result = params_info_callback(
|
313 |
+
all_params=all_params, trainable_params=trainable_params)
|
314 |
if cb_result:
|
315 |
return
|
316 |
|
llama_lora/models.py
CHANGED
@@ -47,7 +47,11 @@ def get_new_base_model(base_model_name):
|
|
47 |
while True:
|
48 |
try:
|
49 |
model = _get_model_from_pretrained(
|
50 |
-
model_class,
|
|
|
|
|
|
|
|
|
51 |
break
|
52 |
except Exception as e:
|
53 |
if 'from_tf' in str(e):
|
@@ -83,7 +87,9 @@ def get_new_base_model(base_model_name):
|
|
83 |
return model
|
84 |
|
85 |
|
86 |
-
def _get_model_from_pretrained(
|
|
|
|
|
87 |
torch = get_torch()
|
88 |
device = get_device()
|
89 |
|
@@ -97,7 +103,8 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
|
|
97 |
device_map={'': 0},
|
98 |
from_tf=from_tf,
|
99 |
force_download=force_download,
|
100 |
-
trust_remote_code=Config.trust_remote_code
|
|
|
101 |
)
|
102 |
elif device == "mps":
|
103 |
return model_class.from_pretrained(
|
@@ -106,7 +113,8 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
|
|
106 |
torch_dtype=torch.float16,
|
107 |
from_tf=from_tf,
|
108 |
force_download=force_download,
|
109 |
-
trust_remote_code=Config.trust_remote_code
|
|
|
110 |
)
|
111 |
else:
|
112 |
return model_class.from_pretrained(
|
@@ -115,7 +123,8 @@ def _get_model_from_pretrained(model_class, model_name, from_tf=False, force_dow
|
|
115 |
low_cpu_mem_usage=True,
|
116 |
from_tf=from_tf,
|
117 |
force_download=force_download,
|
118 |
-
trust_remote_code=Config.trust_remote_code
|
|
|
119 |
)
|
120 |
|
121 |
|
@@ -133,13 +142,15 @@ def get_tokenizer(base_model_name):
|
|
133 |
try:
|
134 |
tokenizer = AutoTokenizer.from_pretrained(
|
135 |
base_model_name,
|
136 |
-
trust_remote_code=Config.trust_remote_code
|
|
|
137 |
)
|
138 |
except Exception as e:
|
139 |
if 'LLaMATokenizer' in str(e):
|
140 |
tokenizer = LlamaTokenizer.from_pretrained(
|
141 |
base_model_name,
|
142 |
-
trust_remote_code=Config.trust_remote_code
|
|
|
143 |
)
|
144 |
else:
|
145 |
raise e
|
@@ -210,6 +221,7 @@ def get_model(
|
|
210 |
torch_dtype=torch.float16,
|
211 |
# ? https://github.com/tloen/alpaca-lora/issues/21
|
212 |
device_map={'': 0},
|
|
|
213 |
)
|
214 |
elif device == "mps":
|
215 |
model = PeftModel.from_pretrained(
|
@@ -217,12 +229,14 @@ def get_model(
|
|
217 |
peft_model_name_or_path,
|
218 |
device_map={"": device},
|
219 |
torch_dtype=torch.float16,
|
|
|
220 |
)
|
221 |
else:
|
222 |
model = PeftModel.from_pretrained(
|
223 |
model,
|
224 |
peft_model_name_or_path,
|
225 |
device_map={"": device},
|
|
|
226 |
)
|
227 |
|
228 |
if re.match("[^/]+/llama", base_model_name):
|
|
|
47 |
while True:
|
48 |
try:
|
49 |
model = _get_model_from_pretrained(
|
50 |
+
model_class,
|
51 |
+
base_model_name,
|
52 |
+
from_tf=from_tf,
|
53 |
+
force_download=force_download
|
54 |
+
)
|
55 |
break
|
56 |
except Exception as e:
|
57 |
if 'from_tf' in str(e):
|
|
|
87 |
return model
|
88 |
|
89 |
|
90 |
+
def _get_model_from_pretrained(
|
91 |
+
model_class, model_name,
|
92 |
+
from_tf=False, force_download=False):
|
93 |
torch = get_torch()
|
94 |
device = get_device()
|
95 |
|
|
|
103 |
device_map={'': 0},
|
104 |
from_tf=from_tf,
|
105 |
force_download=force_download,
|
106 |
+
trust_remote_code=Config.trust_remote_code,
|
107 |
+
use_auth_token=Config.hf_access_token
|
108 |
)
|
109 |
elif device == "mps":
|
110 |
return model_class.from_pretrained(
|
|
|
113 |
torch_dtype=torch.float16,
|
114 |
from_tf=from_tf,
|
115 |
force_download=force_download,
|
116 |
+
trust_remote_code=Config.trust_remote_code,
|
117 |
+
use_auth_token=Config.hf_access_token
|
118 |
)
|
119 |
else:
|
120 |
return model_class.from_pretrained(
|
|
|
123 |
low_cpu_mem_usage=True,
|
124 |
from_tf=from_tf,
|
125 |
force_download=force_download,
|
126 |
+
trust_remote_code=Config.trust_remote_code,
|
127 |
+
use_auth_token=Config.hf_access_token
|
128 |
)
|
129 |
|
130 |
|
|
|
142 |
try:
|
143 |
tokenizer = AutoTokenizer.from_pretrained(
|
144 |
base_model_name,
|
145 |
+
trust_remote_code=Config.trust_remote_code,
|
146 |
+
use_auth_token=Config.hf_access_token
|
147 |
)
|
148 |
except Exception as e:
|
149 |
if 'LLaMATokenizer' in str(e):
|
150 |
tokenizer = LlamaTokenizer.from_pretrained(
|
151 |
base_model_name,
|
152 |
+
trust_remote_code=Config.trust_remote_code,
|
153 |
+
use_auth_token=Config.hf_access_token
|
154 |
)
|
155 |
else:
|
156 |
raise e
|
|
|
221 |
torch_dtype=torch.float16,
|
222 |
# ? https://github.com/tloen/alpaca-lora/issues/21
|
223 |
device_map={'': 0},
|
224 |
+
use_auth_token=Config.hf_access_token
|
225 |
)
|
226 |
elif device == "mps":
|
227 |
model = PeftModel.from_pretrained(
|
|
|
229 |
peft_model_name_or_path,
|
230 |
device_map={"": device},
|
231 |
torch_dtype=torch.float16,
|
232 |
+
use_auth_token=Config.hf_access_token
|
233 |
)
|
234 |
else:
|
235 |
model = PeftModel.from_pretrained(
|
236 |
model,
|
237 |
peft_model_name_or_path,
|
238 |
device_map={"": device},
|
239 |
+
use_auth_token=Config.hf_access_token
|
240 |
)
|
241 |
|
242 |
if re.match("[^/]+/llama", base_model_name):
|