zetavg commited on
Commit
90c428d
1 Parent(s): 2f0a0ce

change model loading mechanism

Browse files
LLaMA_LoRA.ipynb CHANGED
@@ -281,7 +281,7 @@
281
  "\n",
282
  "# Set Configs\n",
283
  "from llama_lora.llama_lora.globals import Global\n",
284
- "Global.base_model = base_model\n",
285
  "data_dir_realpath = !realpath ./data\n",
286
  "Global.data_dir = data_dir_realpath[0]\n",
287
  "Global.load_8bit = True\n",
@@ -289,12 +289,7 @@
289
  "# Prepare Data Dir\n",
290
  "import os\n",
291
  "from llama_lora.llama_lora.utils.data import init_data_dir\n",
292
- "init_data_dir()\n",
293
- "\n",
294
- "# Load the Base Model\n",
295
- "from llama_lora.llama_lora.models import load_base_model\n",
296
- "load_base_model()\n",
297
- "print(f\"Base model loaded: '{Global.base_model}'.\")"
298
  ],
299
  "metadata": {
300
  "id": "Yf6g248ylteP"
 
281
  "\n",
282
  "# Set Configs\n",
283
  "from llama_lora.llama_lora.globals import Global\n",
284
+ "Global.default_base_model_name = base_model\n",
285
  "data_dir_realpath = !realpath ./data\n",
286
  "Global.data_dir = data_dir_realpath[0]\n",
287
  "Global.load_8bit = True\n",
 
289
  "# Prepare Data Dir\n",
290
  "import os\n",
291
  "from llama_lora.llama_lora.utils.data import init_data_dir\n",
292
+ "init_data_dir()"
 
 
 
 
 
293
  ],
294
  "metadata": {
295
  "id": "Yf6g248ylteP"
app.py CHANGED
@@ -7,7 +7,6 @@ import gradio as gr
7
  from llama_lora.globals import Global
8
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
9
  from llama_lora.utils.data import init_data_dir
10
- from llama_lora.models import load_base_model
11
 
12
 
13
  def main(
@@ -31,7 +30,7 @@ def main(
31
  data_dir
32
  ), "Please specify a --data_dir, e.g. --data_dir='./data'"
33
 
34
- Global.base_model = base_model
35
  Global.data_dir = os.path.abspath(data_dir)
36
  Global.load_8bit = load_8bit
37
 
@@ -41,9 +40,6 @@ def main(
41
  os.makedirs(data_dir, exist_ok=True)
42
  init_data_dir()
43
 
44
- if not skip_loading_base_model:
45
- load_base_model()
46
-
47
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
48
  main_page()
49
 
 
7
  from llama_lora.globals import Global
8
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
9
  from llama_lora.utils.data import init_data_dir
 
10
 
11
 
12
  def main(
 
30
  data_dir
31
  ), "Please specify a --data_dir, e.g. --data_dir='./data'"
32
 
33
+ Global.default_base_model_name = base_model
34
  Global.data_dir = os.path.abspath(data_dir)
35
  Global.load_8bit = load_8bit
36
 
 
40
  os.makedirs(data_dir, exist_ok=True)
41
  init_data_dir()
42
 
 
 
 
43
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
44
  main_page()
45
 
llama_lora/globals.py CHANGED
@@ -13,12 +13,10 @@ from .lib.finetune import train
13
  class Global:
14
  version = None
15
 
16
- base_model: str = ""
17
  data_dir: str = ""
18
  load_8bit: bool = False
19
 
20
- loaded_tokenizer: Any = None
21
- loaded_base_model: Any = None
22
 
23
  # Functions
24
  train_fn: Any = train
@@ -31,8 +29,8 @@ class Global:
31
  generation_force_stopped_at = None
32
 
33
  # Model related
34
- model_has_been_used = False
35
- cached_lora_models = LRUCache(10)
36
 
37
  # GPU Info
38
  gpu_cc = None # GPU compute capability
 
13
  class Global:
14
  version = None
15
 
 
16
  data_dir: str = ""
17
  load_8bit: bool = False
18
 
19
+ default_base_model_name: str = ""
 
20
 
21
  # Functions
22
  train_fn: Any = train
 
29
  generation_force_stopped_at = None
30
 
31
  # Model related
32
+ loaded_models = LRUCache(1)
33
+ loaded_tokenizers = LRUCache(1)
34
 
35
  # GPU Info
36
  gpu_cc = None # GPU compute capability
llama_lora/models.py CHANGED
@@ -3,9 +3,8 @@ import sys
3
  import gc
4
 
5
  import torch
6
- import transformers
7
  from peft import PeftModel
8
- from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
9
 
10
  from .globals import Global
11
 
@@ -23,96 +22,120 @@ def get_device():
23
  pass
24
 
25
 
26
- device = get_device()
27
-
28
-
29
- def get_base_model():
30
- load_base_model()
31
- return Global.loaded_base_model
32
-
33
-
34
- def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
35
- Global.model_has_been_used = True
36
 
37
- if Global.cached_lora_models:
38
- model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
39
- if model_from_cache:
40
- return model_from_cache
41
 
42
  if device == "cuda":
43
- model = PeftModel.from_pretrained(
44
- get_base_model(),
45
- lora_weights_name_or_path,
46
  torch_dtype=torch.float16,
 
47
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
48
  )
49
  elif device == "mps":
50
- model = PeftModel.from_pretrained(
51
- get_base_model(),
52
- lora_weights_name_or_path,
53
  device_map={"": device},
54
  torch_dtype=torch.float16,
55
  )
56
  else:
57
- model = PeftModel.from_pretrained(
58
- get_base_model(),
59
- lora_weights_name_or_path,
60
- device_map={"": device},
61
  )
62
 
63
- model.config.pad_token_id = get_tokenizer().pad_token_id = 0
64
  model.config.bos_token_id = 1
65
  model.config.eos_token_id = 2
66
 
67
- if not Global.load_8bit:
68
- model.half() # seems to fix bugs for some users.
69
 
70
- model.eval()
71
- if torch.__version__ >= "2" and sys.platform != "win32":
72
- model = torch.compile(model)
73
 
74
- if Global.cached_lora_models:
75
- Global.cached_lora_models.set(lora_weights_name_or_path, model)
 
76
 
77
- return model
 
 
78
 
 
 
79
 
80
- def get_tokenizer():
81
- load_base_model()
82
- return Global.loaded_tokenizer
83
 
84
 
85
- def load_base_model():
 
 
86
  if Global.ui_dev_mode:
87
  return
88
 
89
- if Global.loaded_tokenizer is None:
90
- Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
91
- Global.base_model
92
- )
93
- if Global.loaded_base_model is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if device == "cuda":
95
- Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
96
- Global.base_model,
97
- load_in_8bit=Global.load_8bit,
98
  torch_dtype=torch.float16,
99
- # device_map="auto",
100
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
101
  )
102
  elif device == "mps":
103
- Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
104
- Global.base_model,
 
105
  device_map={"": device},
106
  torch_dtype=torch.float16,
107
  )
108
  else:
109
- Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
110
- Global.base_model, device_map={"": device}, low_cpu_mem_usage=True
 
 
111
  )
112
 
113
- Global.loaded_base_model.config.pad_token_id = get_tokenizer().pad_token_id = 0
114
- Global.loaded_base_model.config.bos_token_id = 1
115
- Global.loaded_base_model.config.eos_token_id = 2
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  def clear_cache():
@@ -124,19 +147,137 @@ def clear_cache():
124
 
125
 
126
  def unload_models():
127
- del Global.loaded_base_model
128
- Global.loaded_base_model = None
 
129
 
130
- del Global.loaded_tokenizer
131
- Global.loaded_tokenizer = None
132
 
133
- Global.cached_lora_models.clear()
134
 
135
- clear_cache()
136
 
137
- Global.model_has_been_used = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
- def unload_models_if_already_used():
141
- if Global.model_has_been_used:
142
- unload_models()
 
3
  import gc
4
 
5
  import torch
6
+ from transformers import LlamaForCausalLM, LlamaTokenizer
7
  from peft import PeftModel
 
8
 
9
  from .globals import Global
10
 
 
22
  pass
23
 
24
 
25
+ def get_new_base_model(base_model_name):
26
+ if Global.ui_dev_mode:
27
+ return
 
 
 
 
 
 
 
28
 
29
+ device = get_device()
 
 
 
30
 
31
  if device == "cuda":
32
+ model = LlamaForCausalLM.from_pretrained(
33
+ base_model_name,
34
+ load_in_8bit=Global.load_8bit,
35
  torch_dtype=torch.float16,
36
+ # device_map="auto",
37
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
38
  )
39
  elif device == "mps":
40
+ model = LlamaForCausalLM.from_pretrained(
41
+ base_model_name,
 
42
  device_map={"": device},
43
  torch_dtype=torch.float16,
44
  )
45
  else:
46
+ model = LlamaForCausalLM.from_pretrained(
47
+ base_model_name, device_map={"": device}, low_cpu_mem_usage=True
 
 
48
  )
49
 
50
+ model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
51
  model.config.bos_token_id = 1
52
  model.config.eos_token_id = 2
53
 
54
+ return model
 
55
 
 
 
 
56
 
57
+ def get_tokenizer(base_model_name):
58
+ if Global.ui_dev_mode:
59
+ return
60
 
61
+ loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
62
+ if loaded_tokenizer:
63
+ return loaded_tokenizer
64
 
65
+ tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
66
+ Global.loaded_tokenizers.set(base_model_name, tokenizer)
67
 
68
+ return tokenizer
 
 
69
 
70
 
71
+ def get_model(
72
+ base_model_name,
73
+ peft_model_name = None):
74
  if Global.ui_dev_mode:
75
  return
76
 
77
+ if peft_model_name == "None":
78
+ peft_model_name = None
79
+
80
+ model_key = base_model_name
81
+ if peft_model_name:
82
+ model_key = f"{base_model_name}//{peft_model_name}"
83
+
84
+ loaded_model = Global.loaded_models.get(model_key)
85
+ if loaded_model:
86
+ return loaded_model
87
+
88
+ peft_model_name_or_path = peft_model_name
89
+
90
+ lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
91
+ possible_lora_model_path = os.path.join(lora_models_directory_path, peft_model_name)
92
+ if os.path.isdir(possible_lora_model_path):
93
+ peft_model_name_or_path = possible_lora_model_path
94
+
95
+ Global.loaded_models.prepare_to_set()
96
+ clear_cache()
97
+
98
+ model = get_new_base_model(base_model_name)
99
+
100
+ if peft_model_name:
101
+ device = get_device()
102
+
103
  if device == "cuda":
104
+ model = PeftModel.from_pretrained(
105
+ model,
106
+ peft_model_name_or_path,
107
  torch_dtype=torch.float16,
 
108
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
109
  )
110
  elif device == "mps":
111
+ model = PeftModel.from_pretrained(
112
+ model,
113
+ peft_model_name_or_path,
114
  device_map={"": device},
115
  torch_dtype=torch.float16,
116
  )
117
  else:
118
+ model = PeftModel.from_pretrained(
119
+ model,
120
+ peft_model_name_or_path,
121
+ device_map={"": device},
122
  )
123
 
124
+ model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
125
+ model.config.bos_token_id = 1
126
+ model.config.eos_token_id = 2
127
+
128
+ if not Global.load_8bit:
129
+ model.half() # seems to fix bugs for some users.
130
+
131
+ model.eval()
132
+ if torch.__version__ >= "2" and sys.platform != "win32":
133
+ model = torch.compile(model)
134
+
135
+ Global.loaded_models.set(model_key, model)
136
+ clear_cache()
137
+
138
+ return model
139
 
140
 
141
  def clear_cache():
 
147
 
148
 
149
  def unload_models():
150
+ Global.loaded_models.clear()
151
+ Global.loaded_tokenizers.clear()
152
+ clear_cache()
153
 
 
 
154
 
 
155
 
 
156
 
157
+
158
+ ########
159
+
160
+ # def get_base_model():
161
+ # load_base_model()
162
+ # return Global.loaded_base_model
163
+
164
+
165
+ # def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
166
+ # # Global.model_has_been_used = True
167
+ # #
168
+ # #
169
+ # if Global.loaded_tokenizer is None:
170
+ # Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
171
+ # Global.base_model
172
+ # )
173
+
174
+ # if Global.cached_lora_models:
175
+ # model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
176
+ # if model_from_cache:
177
+ # return model_from_cache
178
+
179
+ # Global.cached_lora_models.prepare_to_set()
180
+
181
+ # if device == "cuda":
182
+ # model = PeftModel.from_pretrained(
183
+ # get_new_base_model(),
184
+ # lora_weights_name_or_path,
185
+ # torch_dtype=torch.float16,
186
+ # device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
187
+ # )
188
+ # elif device == "mps":
189
+ # model = PeftModel.from_pretrained(
190
+ # get_new_base_model(),
191
+ # lora_weights_name_or_path,
192
+ # device_map={"": device},
193
+ # torch_dtype=torch.float16,
194
+ # )
195
+ # else:
196
+ # model = PeftModel.from_pretrained(
197
+ # get_new_base_model(),
198
+ # lora_weights_name_or_path,
199
+ # device_map={"": device},
200
+ # )
201
+
202
+ # model.config.pad_token_id = get_tokenizer().pad_token_id = 0
203
+ # model.config.bos_token_id = 1
204
+ # model.config.eos_token_id = 2
205
+
206
+ # if not Global.load_8bit:
207
+ # model.half() # seems to fix bugs for some users.
208
+
209
+ # model.eval()
210
+ # if torch.__version__ >= "2" and sys.platform != "win32":
211
+ # model = torch.compile(model)
212
+
213
+ # if Global.cached_lora_models:
214
+ # Global.cached_lora_models.set(lora_weights_name_or_path, model)
215
+
216
+ # clear_cache()
217
+
218
+ # return model
219
+
220
+
221
+
222
+
223
+
224
+ # def load_base_model():
225
+ # return;
226
+
227
+ # if Global.ui_dev_mode:
228
+ # return
229
+
230
+ # if Global.loaded_tokenizer is None:
231
+ # Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
232
+ # Global.base_model
233
+ # )
234
+ # if Global.loaded_base_model is None:
235
+ # if device == "cuda":
236
+ # Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
237
+ # Global.base_model,
238
+ # load_in_8bit=Global.load_8bit,
239
+ # torch_dtype=torch.float16,
240
+ # # device_map="auto",
241
+ # device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
242
+ # )
243
+ # elif device == "mps":
244
+ # Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
245
+ # Global.base_model,
246
+ # device_map={"": device},
247
+ # torch_dtype=torch.float16,
248
+ # )
249
+ # else:
250
+ # Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
251
+ # Global.base_model, device_map={"": device}, low_cpu_mem_usage=True
252
+ # )
253
+
254
+ # Global.loaded_base_model.config.pad_token_id = get_tokenizer().pad_token_id = 0
255
+ # Global.loaded_base_model.config.bos_token_id = 1
256
+ # Global.loaded_base_model.config.eos_token_id = 2
257
+
258
+
259
+ # def clear_cache():
260
+ # gc.collect()
261
+
262
+ # # if not shared.args.cpu: # will not be running on CPUs anyway
263
+ # with torch.no_grad():
264
+ # torch.cuda.empty_cache()
265
+
266
+
267
+ # def unload_models():
268
+ # del Global.loaded_base_model
269
+ # Global.loaded_base_model = None
270
+
271
+ # del Global.loaded_tokenizer
272
+ # Global.loaded_tokenizer = None
273
+
274
+ # Global.cached_lora_models.clear()
275
+
276
+ # clear_cache()
277
+
278
+ # Global.model_has_been_used = False
279
 
280
 
281
+ # def unload_models_if_already_used():
282
+ # if Global.model_has_been_used:
283
+ # unload_models()
llama_lora/ui/finetune_ui.py CHANGED
@@ -10,8 +10,8 @@ from transformers import TrainerCallback
10
 
11
  from ..globals import Global
12
  from ..models import (
13
- get_base_model, get_tokenizer,
14
- clear_cache, unload_models_if_already_used)
15
  from ..utils.data import (
16
  get_available_template_names,
17
  get_available_dataset_names,
@@ -269,14 +269,16 @@ def do_train(
269
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
270
  ):
271
  try:
 
 
 
 
 
 
272
  if not should_training_progress_track_tqdm:
273
  progress(0, desc="Preparing train data...")
274
 
275
- clear_cache()
276
- # If model has been used in inference, we need to unload it first.
277
- # Otherwise, we'll get a 'Function MmBackward0 returned an invalid
278
- # gradient at index 1 - expected device meta but got cuda:0' error.
279
- unload_models_if_already_used()
280
 
281
  prompter = Prompter(template)
282
  variable_names = prompter.get_variable_names()
@@ -415,17 +417,12 @@ Train data (first 10):
415
 
416
  Global.should_stop_training = False
417
 
418
- # Do this again right before training to make sure the model is not used in inference.
419
- unload_models_if_already_used()
420
- clear_cache()
421
-
422
- base_model = get_base_model()
423
- tokenizer = get_tokenizer()
424
 
425
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
426
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
427
 
428
- output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
429
  if not os.path.exists(output_dir):
430
  os.makedirs(output_dir)
431
 
@@ -435,10 +432,11 @@ Train data (first 10):
435
  dataset_name = dataset_from_data_dir
436
 
437
  info = {
438
- 'base_model': Global.base_model,
439
  'prompt_template': template,
440
  'dataset_name': dataset_name,
441
  'dataset_rows': len(train_data),
 
442
  }
443
  json.dump(info, info_json_file, indent=2)
444
 
@@ -472,7 +470,11 @@ Train data (first 10):
472
 
473
  result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
474
  print(result_message)
 
 
 
475
  clear_cache()
 
476
  return result_message
477
 
478
  except Exception as e:
@@ -837,6 +839,12 @@ def finetune_ui():
837
  document.getElementById('finetune_confirm_stop_btn').style.display =
838
  'none';
839
  }, 5000);
 
 
 
 
 
 
840
  document.getElementById('finetune_stop_btn').style.display = 'none';
841
  document.getElementById('finetune_confirm_stop_btn').style.display =
842
  'block';
 
10
 
11
  from ..globals import Global
12
  from ..models import (
13
+ get_new_base_model, get_tokenizer,
14
+ clear_cache, unload_models)
15
  from ..utils.data import (
16
  get_available_template_names,
17
  get_available_dataset_names,
 
269
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
270
  ):
271
  try:
272
+ base_model_name = Global.default_base_model_name
273
+ output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
274
+ if os.path.exists(output_dir):
275
+ if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
276
+ raise ValueError(f"The output directory already exists and is not empty. ({output_dir})")
277
+
278
  if not should_training_progress_track_tqdm:
279
  progress(0, desc="Preparing train data...")
280
 
281
+ unload_models() # Need RAM for training
 
 
 
 
282
 
283
  prompter = Prompter(template)
284
  variable_names = prompter.get_variable_names()
 
417
 
418
  Global.should_stop_training = False
419
 
420
+ base_model = get_new_base_model(base_model_name)
421
+ tokenizer = get_tokenizer(base_model_name)
 
 
 
 
422
 
423
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
424
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
425
 
 
426
  if not os.path.exists(output_dir):
427
  os.makedirs(output_dir)
428
 
 
432
  dataset_name = dataset_from_data_dir
433
 
434
  info = {
435
+ 'base_model': base_model_name,
436
  'prompt_template': template,
437
  'dataset_name': dataset_name,
438
  'dataset_rows': len(train_data),
439
+ 'timestamp': time.time()
440
  }
441
  json.dump(info, info_json_file, indent=2)
442
 
 
470
 
471
  result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
472
  print(result_message)
473
+
474
+ del base_model
475
+ del tokenizer
476
  clear_cache()
477
+
478
  return result_message
479
 
480
  except Exception as e:
 
839
  document.getElementById('finetune_confirm_stop_btn').style.display =
840
  'none';
841
  }, 5000);
842
+ document.getElementById('finetune_confirm_stop_btn').style['pointer-events'] =
843
+ 'none';
844
+ setTimeout(function () {
845
+ document.getElementById('finetune_confirm_stop_btn').style['pointer-events'] =
846
+ 'inherit';
847
+ }, 300);
848
  document.getElementById('finetune_stop_btn').style.display = 'none';
849
  document.getElementById('finetune_confirm_stop_btn').style.display =
850
  'block';
llama_lora/ui/inference_ui.py CHANGED
@@ -7,11 +7,10 @@ import transformers
7
  from transformers import GenerationConfig
8
 
9
  from ..globals import Global
10
- from ..models import get_base_model, get_model_with_lora, get_tokenizer, get_device
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
- get_path_of_available_lora_model,
15
  get_info_of_available_lora_model)
16
  from ..utils.prompter import Prompter
17
  from ..utils.callbacks import Iteratorize, Stream
@@ -22,6 +21,18 @@ default_show_raw = True
22
  inference_output_lines = 12
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def do_inference(
26
  lora_model_name,
27
  prompt_template,
@@ -37,6 +48,8 @@ def do_inference(
37
  show_raw=False,
38
  progress=gr.Progress(track_tqdm=True),
39
  ):
 
 
40
  try:
41
  if Global.generation_force_stopped_at is not None:
42
  required_elapsed_time_after_forced_stop = 1
@@ -52,16 +65,8 @@ def do_inference(
52
  prompter = Prompter(prompt_template)
53
  prompt = prompter.generate_prompt(variables)
54
 
55
- if not lora_model_name:
56
- lora_model_name = "None"
57
- if "/" not in lora_model_name and lora_model_name != "None":
58
- path_of_available_lora_model = get_path_of_available_lora_model(
59
- lora_model_name)
60
- if path_of_available_lora_model:
61
- lora_model_name = path_of_available_lora_model
62
-
63
  if Global.ui_dev_mode:
64
- message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {Global.base_model}\nLoRA model: {lora_model_name}\n\nThe following text is your prompt:\n\n{prompt}"
65
  print(message)
66
 
67
  if stream_output:
@@ -90,18 +95,13 @@ def do_inference(
90
  return
91
  time.sleep(1)
92
  yield (
93
- gr.Textbox.update(value=message, lines=1), # TODO
94
  json.dumps(list(range(len(message.split()))), indent=2)
95
  )
96
  return
97
 
98
- # model = get_base_model()
99
- if lora_model_name != "None":
100
- model = get_model_with_lora(lora_model_name)
101
- else:
102
- raise ValueError("No LoRA model selected.")
103
-
104
- tokenizer = get_tokenizer()
105
 
106
  inputs = tokenizer(prompt, return_tensors="pt")
107
  input_ids = inputs["input_ids"].to(device)
@@ -210,7 +210,6 @@ def do_inference(
210
  gr.Textbox.update(value=response, lines=inference_output_lines),
211
  raw_output)
212
 
213
-
214
  except Exception as e:
215
  raise gr.Error(e)
216
 
@@ -232,7 +231,7 @@ def reload_selections(current_lora_model, current_prompt_template):
232
 
233
  default_lora_models = ["tloen/alpaca-lora-7b"]
234
  available_lora_models = default_lora_models + get_available_lora_model_names()
235
- available_lora_models = available_lora_models
236
 
237
  current_lora_model = current_lora_model or next(
238
  iter(available_lora_models), None)
@@ -462,6 +461,10 @@ def inference_ui():
462
  things_that_might_timeout.append(lora_model_change_event)
463
 
464
  generate_event = generate_btn.click(
 
 
 
 
465
  fn=do_inference,
466
  inputs=[
467
  lora_model,
 
7
  from transformers import GenerationConfig
8
 
9
  from ..globals import Global
10
+ from ..models import get_model, get_tokenizer, get_device
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
 
14
  get_info_of_available_lora_model)
15
  from ..utils.prompter import Prompter
16
  from ..utils.callbacks import Iteratorize, Stream
 
21
  inference_output_lines = 12
22
 
23
 
24
+ def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
25
+ base_model_name = Global.default_base_model_name
26
+
27
+ try:
28
+ get_tokenizer(base_model_name)
29
+ get_model(base_model_name, lora_model_name)
30
+ return ("", "")
31
+
32
+ except Exception as e:
33
+ raise gr.Error(e)
34
+
35
+
36
  def do_inference(
37
  lora_model_name,
38
  prompt_template,
 
48
  show_raw=False,
49
  progress=gr.Progress(track_tqdm=True),
50
  ):
51
+ base_model_name = Global.default_base_model_name
52
+
53
  try:
54
  if Global.generation_force_stopped_at is not None:
55
  required_elapsed_time_after_forced_stop = 1
 
65
  prompter = Prompter(prompt_template)
66
  prompt = prompter.generate_prompt(variables)
67
 
 
 
 
 
 
 
 
 
68
  if Global.ui_dev_mode:
69
+ message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
70
  print(message)
71
 
72
  if stream_output:
 
95
  return
96
  time.sleep(1)
97
  yield (
98
+ gr.Textbox.update(value=message, lines=1), # TODO
99
  json.dumps(list(range(len(message.split()))), indent=2)
100
  )
101
  return
102
 
103
+ tokenizer = get_tokenizer(base_model_name)
104
+ model = get_model(base_model_name, lora_model_name)
 
 
 
 
 
105
 
106
  inputs = tokenizer(prompt, return_tensors="pt")
107
  input_ids = inputs["input_ids"].to(device)
 
210
  gr.Textbox.update(value=response, lines=inference_output_lines),
211
  raw_output)
212
 
 
213
  except Exception as e:
214
  raise gr.Error(e)
215
 
 
231
 
232
  default_lora_models = ["tloen/alpaca-lora-7b"]
233
  available_lora_models = default_lora_models + get_available_lora_model_names()
234
+ available_lora_models = available_lora_models + ["None"]
235
 
236
  current_lora_model = current_lora_model or next(
237
  iter(available_lora_models), None)
 
461
  things_that_might_timeout.append(lora_model_change_event)
462
 
463
  generate_event = generate_btn.click(
464
+ fn=prepare_inference,
465
+ inputs=[lora_model],
466
+ outputs=[inference_output, inference_raw_output],
467
+ ).then(
468
  fn=do_inference,
469
  inputs=[
470
  lora_model,
llama_lora/ui/main_page.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
 
3
  from ..globals import Global
4
- from ..models import get_model_with_lora
5
 
6
  from .inference_ui import inference_ui
7
  from .finetune_ui import finetune_ui
@@ -31,7 +30,7 @@ def main_page():
31
  info = []
32
  if Global.version:
33
  info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
34
- info.append(f"Base model: `{Global.base_model}`")
35
  if Global.ui_show_sys_info:
36
  info.append(f"Data dir: `{Global.data_dir}`")
37
  gr.Markdown(f"""
 
1
  import gradio as gr
2
 
3
  from ..globals import Global
 
4
 
5
  from .inference_ui import inference_ui
6
  from .finetune_ui import finetune_ui
 
30
  info = []
31
  if Global.version:
32
  info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
33
+ info.append(f"Base model: `{Global.default_base_model_name}`")
34
  if Global.ui_show_sys_info:
35
  info.append(f"Data dir: `{Global.data_dir}`")
36
  gr.Markdown(f"""
llama_lora/ui/tokenizer_ui.py CHANGED
@@ -7,11 +7,12 @@ from ..models import get_tokenizer
7
 
8
 
9
  def handle_decode(encoded_tokens_json):
 
10
  try:
11
  encoded_tokens = json.loads(encoded_tokens_json)
12
  if Global.ui_dev_mode:
13
  return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
14
- tokenizer = get_tokenizer()
15
  decoded_tokens = tokenizer.decode(encoded_tokens)
16
  return decoded_tokens, gr.Markdown.update("", visible=False)
17
  except Exception as e:
@@ -19,10 +20,11 @@ def handle_decode(encoded_tokens_json):
19
 
20
 
21
  def handle_encode(decoded_tokens):
 
22
  try:
23
  if Global.ui_dev_mode:
24
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
25
- tokenizer = get_tokenizer()
26
  result = tokenizer(decoded_tokens)
27
  encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
28
  return encoded_tokens_json, gr.Markdown.update("", visible=False)
 
7
 
8
 
9
  def handle_decode(encoded_tokens_json):
10
+ base_model_name = Global.default_base_model_name
11
  try:
12
  encoded_tokens = json.loads(encoded_tokens_json)
13
  if Global.ui_dev_mode:
14
  return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
15
+ tokenizer = get_tokenizer(base_model_name)
16
  decoded_tokens = tokenizer.decode(encoded_tokens)
17
  return decoded_tokens, gr.Markdown.update("", visible=False)
18
  except Exception as e:
 
20
 
21
 
22
  def handle_encode(decoded_tokens):
23
+ base_model_name = Global.default_base_model_name
24
  try:
25
  if Global.ui_dev_mode:
26
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
27
+ tokenizer = get_tokenizer(base_model_name)
28
  result = tokenizer(decoded_tokens)
29
  encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
30
  return encoded_tokens_json, gr.Markdown.update("", visible=False)
llama_lora/utils/lru_cache.py CHANGED
@@ -25,3 +25,7 @@ class LRUCache:
25
 
26
  def clear(self):
27
  self.cache.clear()
 
 
 
 
 
25
 
26
  def clear(self):
27
  self.cache.clear()
28
+
29
+ def prepare_to_set(self):
30
+ if len(self.cache) >= self.capacity:
31
+ self.cache.popitem(last=False)