zetavg commited on
Commit
6c08b63
β€’
2 Parent(s): 4c02e18 bbdf699

Merge branch 'dev-2'

Browse files
LLaMA_LoRA.ipynb CHANGED
@@ -27,13 +27,13 @@
27
  "colab_type": "text"
28
  },
29
  "source": [
30
- "<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
31
  ]
32
  },
33
  {
34
  "cell_type": "markdown",
35
  "source": [
36
- "# πŸ¦™πŸŽ›οΈ LLaMA-LoRA\n",
37
  "\n",
38
  "TL;DR: **Runtime > Run All** (`⌘/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
39
  ],
@@ -72,9 +72,9 @@
72
  "# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
73
  "# @markdown Project settings.\n",
74
  "\n",
75
- "# @markdown The URL of the LLaMA-LoRA project<br>&nbsp;&nbsp;(default: `https://github.com/zetavg/llama-lora.git`):\n",
76
- "llama_lora_project_url = \"https://github.com/zetavg/llama-lora.git\" # @param {type:\"string\"}\n",
77
- "# @markdown The branch to use for LLaMA-LoRA project:\n",
78
  "llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
79
  "\n",
80
  "# # @markdown Forces the local directory to be updated by the remote branch:\n",
 
27
  "colab_type": "text"
28
  },
29
  "source": [
30
+ "<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
31
  ]
32
  },
33
  {
34
  "cell_type": "markdown",
35
  "source": [
36
+ "# πŸ¦™πŸŽ›οΈ LLaMA-LoRA Tuner\n",
37
  "\n",
38
  "TL;DR: **Runtime > Run All** (`⌘/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
39
  ],
 
72
  "# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
73
  "# @markdown Project settings.\n",
74
  "\n",
75
+ "# @markdown The URL of the LLaMA-LoRA-Tuner project<br>&nbsp;&nbsp;(default: `https://github.com/zetavg/LLaMA-LoRA-Tuner.git`):\n",
76
+ "llama_lora_project_url = \"https://github.com/zetavg/LLaMA-LoRA-Tuner.git\" # @param {type:\"string\"}\n",
77
+ "# @markdown The branch to use for LLaMA-LoRA-Tuner project:\n",
78
  "llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
79
  "\n",
80
  "# # @markdown Forces the local directory to be updated by the remote branch:\n",
README.md CHANGED
@@ -1,6 +1,6 @@
1
- # πŸ¦™πŸŽ›οΈ LLaMA-LoRA
2
 
3
- <a href="https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
4
 
5
  Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) easy.
6
 
@@ -27,7 +27,7 @@ There are various ways to run this app:
27
 
28
  ### Run On Google Colab
29
 
30
- Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
31
 
32
  You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
33
 
@@ -38,7 +38,7 @@ After approximately 5 minutes of running, you will see the public URL in the out
38
  After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
39
 
40
  ```yaml
41
- # llama-lora-multitool.yaml
42
 
43
  resources:
44
  accelerators: A10:1 # 1x NVIDIA A10 GPU, about US$ 0.6 / hr on Lambda Cloud.
@@ -49,13 +49,13 @@ file_mounts:
49
  # (to store train datasets trained models)
50
  # See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
51
  /data:
52
- name: llama-lora-multitool-data # Make sure this name is unique or you own this bucket. If it does not exists, SkyPilot will try to create a bucket with this name.
53
  store: s3 # Could be either of [s3, gcs]
54
  mode: MOUNT
55
 
56
- # Clone the LLaMA-LoRA repo and install its dependencies.
57
  setup: |
58
- git clone https://github.com/zetavg/LLaMA-LoRA.git llama_lora
59
  cd llama_lora && pip install -r requirements.lock.txt
60
  cd ..
61
  echo 'Dependencies installed.'
@@ -69,7 +69,7 @@ run: |
69
  Then launch a cluster to run the task:
70
 
71
  ```
72
- sky launch -c llama-lora-multitool llama-lora-multitool.yaml
73
  ```
74
 
75
  `-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
@@ -86,8 +86,8 @@ When you are done, run `sky stop <cluster_name>` to stop the cluster. To termina
86
  <summary>Prepare environment with conda</summary>
87
 
88
  ```bash
89
- conda create -y python=3.8 -n llama-lora-multitool
90
- conda activate llama-lora-multitool
91
  ```
92
  </details>
93
 
 
1
+ # πŸ¦™πŸŽ›οΈ LLaMA-LoRA Tuner
2
 
3
+ <a href="https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
4
 
5
  Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) easy.
6
 
 
27
 
28
  ### Run On Google Colab
29
 
30
+ Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
31
 
32
  You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
33
 
 
38
  After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
39
 
40
  ```yaml
41
+ # llama-lora-tuner.yaml
42
 
43
  resources:
44
  accelerators: A10:1 # 1x NVIDIA A10 GPU, about US$ 0.6 / hr on Lambda Cloud.
 
49
  # (to store train datasets trained models)
50
  # See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
51
  /data:
52
+ name: llama-lora-tuner-data # Make sure this name is unique or you own this bucket. If it does not exists, SkyPilot will try to create a bucket with this name.
53
  store: s3 # Could be either of [s3, gcs]
54
  mode: MOUNT
55
 
56
+ # Clone the LLaMA-LoRA Tuner repo and install its dependencies.
57
  setup: |
58
+ git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora
59
  cd llama_lora && pip install -r requirements.lock.txt
60
  cd ..
61
  echo 'Dependencies installed.'
 
69
  Then launch a cluster to run the task:
70
 
71
  ```
72
+ sky launch -c llama-lora-tuner llama-lora-tuner.yaml
73
  ```
74
 
75
  `-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
 
86
  <summary>Prepare environment with conda</summary>
87
 
88
  ```bash
89
+ conda create -y python=3.8 -n llama-lora-tuner
90
+ conda activate llama-lora-tuner
91
  ```
92
  </details>
93
 
app.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  from llama_lora.globals import Global
8
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
9
  from llama_lora.utils.data import init_data_dir
 
10
 
11
 
12
  def main(
@@ -16,6 +17,7 @@ def main(
16
  # Allows to listen on all interfaces by providing '0.0.0.0'.
17
  server_name: str = "127.0.0.1",
18
  share: bool = False,
 
19
  ui_show_sys_info: bool = True,
20
  ui_dev_mode: bool = False,
21
  ):
@@ -39,6 +41,9 @@ def main(
39
  os.makedirs(data_dir, exist_ok=True)
40
  init_data_dir()
41
 
 
 
 
42
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
43
  main_page()
44
 
 
7
  from llama_lora.globals import Global
8
  from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
9
  from llama_lora.utils.data import init_data_dir
10
+ from llama_lora.models import load_base_model
11
 
12
 
13
  def main(
 
17
  # Allows to listen on all interfaces by providing '0.0.0.0'.
18
  server_name: str = "127.0.0.1",
19
  share: bool = False,
20
+ skip_loading_base_model: bool = False,
21
  ui_show_sys_info: bool = True,
22
  ui_dev_mode: bool = False,
23
  ):
 
41
  os.makedirs(data_dir, exist_ok=True)
42
  init_data_dir()
43
 
44
+ if not skip_loading_base_model:
45
+ load_base_model()
46
+
47
  with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
48
  main_page()
49
 
llama_lora/globals.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
  from numba import cuda
7
  import nvidia_smi
8
 
 
9
  from .lib.finetune import train
10
 
11
 
@@ -25,8 +26,13 @@ class Global:
25
  # Training Control
26
  should_stop_training = False
27
 
 
 
 
 
28
  # Model related
29
  model_has_been_used = False
 
30
 
31
  # GPU Info
32
  gpu_cc = None # GPU compute capability
@@ -35,7 +41,7 @@ class Global:
35
  gpu_total_memory = None
36
 
37
  # UI related
38
- ui_title: str = "LLaMA-LoRA"
39
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
40
  ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
41
  ui_show_sys_info: bool = True
 
6
  from numba import cuda
7
  import nvidia_smi
8
 
9
+ from .utils.lru_cache import LRUCache
10
  from .lib.finetune import train
11
 
12
 
 
26
  # Training Control
27
  should_stop_training = False
28
 
29
+ # Generation Control
30
+ should_stop_generating = False
31
+ generation_force_stopped_at = None
32
+
33
  # Model related
34
  model_has_been_used = False
35
+ cached_lora_models = LRUCache(10)
36
 
37
  # GPU Info
38
  gpu_cc = None # GPU compute capability
 
41
  gpu_total_memory = None
42
 
43
  # UI related
44
+ ui_title: str = "LLaMA-LoRA Tuner"
45
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
46
  ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
47
  ui_show_sys_info: bool = True
llama_lora/lib/finetune.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import sys
3
  from typing import Any, List
4
 
 
 
5
  import fire
6
  import torch
7
  import transformers
@@ -47,6 +49,10 @@ def train(
47
  # logging
48
  callbacks: List[Any] = []
49
  ):
 
 
 
 
50
  device_map = "auto"
51
  world_size = int(os.environ.get("WORLD_SIZE", 1))
52
  ddp = world_size != 1
@@ -202,6 +208,12 @@ def train(
202
  ),
203
  callbacks=callbacks,
204
  )
 
 
 
 
 
 
205
  model.config.use_cache = False
206
 
207
  old_state_dict = model.state_dict
@@ -214,9 +226,16 @@ def train(
214
  if torch.__version__ >= "2" and sys.platform != "win32":
215
  model = torch.compile(model)
216
 
217
- result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
218
 
219
  model.save_pretrained(output_dir)
220
  print(f"Model saved to {output_dir}.")
221
 
222
- return result
 
 
 
 
 
 
 
 
2
  import sys
3
  from typing import Any, List
4
 
5
+ import json
6
+
7
  import fire
8
  import torch
9
  import transformers
 
49
  # logging
50
  callbacks: List[Any] = []
51
  ):
52
+ if os.path.exists(output_dir):
53
+ if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
54
+ raise ValueError(f"The output directory already exists and is not empty. ({output_dir})")
55
+
56
  device_map = "auto"
57
  world_size = int(os.environ.get("WORLD_SIZE", 1))
58
  ddp = world_size != 1
 
208
  ),
209
  callbacks=callbacks,
210
  )
211
+
212
+ if not os.path.exists(output_dir):
213
+ os.makedirs(output_dir)
214
+ with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
215
+ json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
216
+
217
  model.config.use_cache = False
218
 
219
  old_state_dict = model.state_dict
 
226
  if torch.__version__ >= "2" and sys.platform != "win32":
227
  model = torch.compile(model)
228
 
229
+ train_output = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
230
 
231
  model.save_pretrained(output_dir)
232
  print(f"Model saved to {output_dir}.")
233
 
234
+ with open(os.path.join(output_dir, "trainer_log_history.jsonl"), 'w') as trainer_log_history_jsonl_file:
235
+ trainer_log_history = "\n".join([json.dumps(line) for line in trainer.state.log_history])
236
+ trainer_log_history_jsonl_file.write(trainer_log_history)
237
+
238
+ with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
239
+ json.dump(train_output, train_output_json_file, indent=2)
240
+
241
+ return train_output
llama_lora/models.py CHANGED
@@ -31,27 +31,32 @@ def get_base_model():
31
  return Global.loaded_base_model
32
 
33
 
34
- def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
35
  Global.model_has_been_used = True
36
 
 
 
 
 
 
37
  if device == "cuda":
38
  model = PeftModel.from_pretrained(
39
  get_base_model(),
40
- lora_weights,
41
  torch_dtype=torch.float16,
42
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
43
  )
44
  elif device == "mps":
45
  model = PeftModel.from_pretrained(
46
  get_base_model(),
47
- lora_weights,
48
  device_map={"": device},
49
  torch_dtype=torch.float16,
50
  )
51
  else:
52
  model = PeftModel.from_pretrained(
53
  get_base_model(),
54
- lora_weights,
55
  device_map={"": device},
56
  )
57
 
@@ -65,6 +70,10 @@ def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
65
  model.eval()
66
  if torch.__version__ >= "2" and sys.platform != "win32":
67
  model = torch.compile(model)
 
 
 
 
68
  return model
69
 
70
 
@@ -121,6 +130,8 @@ def unload_models():
121
  del Global.loaded_tokenizer
122
  Global.loaded_tokenizer = None
123
 
 
 
124
  clear_cache()
125
 
126
  Global.model_has_been_used = False
 
31
  return Global.loaded_base_model
32
 
33
 
34
+ def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
35
  Global.model_has_been_used = True
36
 
37
+ if Global.cached_lora_models:
38
+ model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
39
+ if model_from_cache:
40
+ return model_from_cache
41
+
42
  if device == "cuda":
43
  model = PeftModel.from_pretrained(
44
  get_base_model(),
45
+ lora_weights_name_or_path,
46
  torch_dtype=torch.float16,
47
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
48
  )
49
  elif device == "mps":
50
  model = PeftModel.from_pretrained(
51
  get_base_model(),
52
+ lora_weights_name_or_path,
53
  device_map={"": device},
54
  torch_dtype=torch.float16,
55
  )
56
  else:
57
  model = PeftModel.from_pretrained(
58
  get_base_model(),
59
+ lora_weights_name_or_path,
60
  device_map={"": device},
61
  )
62
 
 
70
  model.eval()
71
  if torch.__version__ >= "2" and sys.platform != "win32":
72
  model = torch.compile(model)
73
+
74
+ if Global.cached_lora_models:
75
+ Global.cached_lora_models.set(lora_weights_name_or_path, model)
76
+
77
  return model
78
 
79
 
 
130
  del Global.loaded_tokenizer
131
  Global.loaded_tokenizer = None
132
 
133
+ Global.cached_lora_models.clear()
134
+
135
  clear_cache()
136
 
137
  Global.model_has_been_used = False
llama_lora/ui/finetune_ui.py CHANGED
@@ -269,6 +269,9 @@ def do_train(
269
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
270
  ):
271
  try:
 
 
 
272
  clear_cache()
273
  # If model has been used in inference, we need to unload it first.
274
  # Otherwise, we'll get a 'Function MmBackward0 returned an invalid
@@ -373,6 +376,9 @@ Train data (first 10):
373
  time.sleep(2)
374
  return message
375
 
 
 
 
376
  log_history = []
377
 
378
  class UiTrainerCallback(TrainerCallback):
@@ -419,11 +425,30 @@ Train data (first 10):
419
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
420
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
421
 
422
- results = Global.train_fn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  base_model, # base_model
424
  tokenizer, # tokenizer
425
- os.path.join(Global.data_dir, "lora_models",
426
- model_name), # output_dir
427
  train_data,
428
  # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
429
  micro_batch_size, # micro_batch_size
@@ -445,12 +470,13 @@ Train data (first 10):
445
  logs_str = "\n".join([json.dumps(log)
446
  for log in log_history]) or "None"
447
 
448
- result_message = f"Training ended:\n{str(results)}\n\nLogs:\n{logs_str}"
449
  print(result_message)
 
450
  return result_message
451
 
452
  except Exception as e:
453
- raise gr.Error(e)
454
 
455
 
456
  def do_abort_training():
@@ -675,9 +701,9 @@ def finetune_ui():
675
  elem_id="finetune_confirm_stop_btn"
676
  )
677
 
678
- training_status = gr.Text(
679
- "Training status will be shown here.",
680
- label="Training Status/Results",
681
  elem_id="finetune_training_status")
682
 
683
  train_progress = train_btn.click(
@@ -693,7 +719,7 @@ def finetune_ui():
693
  lora_dropout,
694
  model_name
695
  ]),
696
- outputs=training_status
697
  )
698
 
699
  # controlled by JS, shows the confirm_abort_button
 
269
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
270
  ):
271
  try:
272
+ if not should_training_progress_track_tqdm:
273
+ progress(0, desc="Preparing train data...")
274
+
275
  clear_cache()
276
  # If model has been used in inference, we need to unload it first.
277
  # Otherwise, we'll get a 'Function MmBackward0 returned an invalid
 
376
  time.sleep(2)
377
  return message
378
 
379
+ if not should_training_progress_track_tqdm:
380
+ progress(0, desc="Preparing model for training...")
381
+
382
  log_history = []
383
 
384
  class UiTrainerCallback(TrainerCallback):
 
425
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
426
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
427
 
428
+ output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
429
+ if not os.path.exists(output_dir):
430
+ os.makedirs(output_dir)
431
+
432
+ with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
433
+ dataset_name = "N/A (from text input)"
434
+ if load_dataset_from == "Data Dir":
435
+ dataset_name = dataset_from_data_dir
436
+
437
+ info = {
438
+ 'base_model': Global.base_model,
439
+ 'prompt_template': template,
440
+ 'dataset_name': dataset_name,
441
+ 'dataset_rows': len(train_data),
442
+ }
443
+ json.dump(info, info_json_file, indent=2)
444
+
445
+ if not should_training_progress_track_tqdm:
446
+ progress(0, desc="Train starting...")
447
+
448
+ train_output = Global.train_fn(
449
  base_model, # base_model
450
  tokenizer, # tokenizer
451
+ output_dir, # output_dir
 
452
  train_data,
453
  # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
454
  micro_batch_size, # micro_batch_size
 
470
  logs_str = "\n".join([json.dumps(log)
471
  for log in log_history]) or "None"
472
 
473
+ result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
474
  print(result_message)
475
+ clear_cache()
476
  return result_message
477
 
478
  except Exception as e:
479
+ raise gr.Error(f"{e} (To dismiss this error, click the 'Abort' button)")
480
 
481
 
482
  def do_abort_training():
 
701
  elem_id="finetune_confirm_stop_btn"
702
  )
703
 
704
+ train_output = gr.Text(
705
+ "Training results will be shown here.",
706
+ label="Train Output",
707
  elem_id="finetune_training_status")
708
 
709
  train_progress = train_btn.click(
 
719
  lora_dropout,
720
  model_name
721
  ]),
722
+ outputs=train_output
723
  )
724
 
725
  # controlled by JS, shows the confirm_abort_button
llama_lora/ui/inference_ui.py CHANGED
@@ -11,13 +11,15 @@ from ..models import get_base_model, get_model_with_lora, get_tokenizer, get_dev
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
- get_path_of_available_lora_model)
 
15
  from ..utils.prompter import Prompter
16
  from ..utils.callbacks import Iteratorize, Stream
17
 
18
  device = get_device()
19
 
20
  default_show_raw = True
 
21
 
22
 
23
  def do_inference(
@@ -36,12 +38,23 @@ def do_inference(
36
  progress=gr.Progress(track_tqdm=True),
37
  ):
38
  try:
 
 
 
 
 
 
 
 
 
39
  variables = [variable_0, variable_1, variable_2, variable_3,
40
  variable_4, variable_5, variable_6, variable_7]
41
  prompter = Prompter(prompt_template)
42
  prompt = prompter.generate_prompt(variables)
43
 
44
- if lora_model_name is not None and "/" not in lora_model_name and lora_model_name != "None":
 
 
45
  path_of_available_lora_model = get_path_of_available_lora_model(
46
  lora_model_name)
47
  if path_of_available_lora_model:
@@ -66,16 +79,24 @@ def do_inference(
66
  yield out
67
 
68
  for partial_sentence in word_generator(message):
69
- yield partial_sentence, json.dumps(list(range(len(partial_sentence.split()))), indent=2)
 
 
 
 
 
70
  time.sleep(0.05)
71
 
72
  return
73
  time.sleep(1)
74
- yield message, json.dumps(list(range(len(message.split()))), indent=2)
 
 
 
75
  return
76
 
77
  model = get_base_model()
78
- if not lora_model_name == "None" and lora_model_name is not None:
79
  model = get_model_with_lora(lora_model_name)
80
  tokenizer = get_tokenizer()
81
 
@@ -97,6 +118,19 @@ def do_inference(
97
  "max_new_tokens": max_new_tokens,
98
  }
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if stream_output:
101
  # Stream the reply 1 token at a time.
102
  # This is based on the trick of using 'stopping_criteria' to create an iterator,
@@ -128,29 +162,61 @@ def do_inference(
128
  raw_output = None
129
  if show_raw:
130
  raw_output = str(output)
131
- yield prompter.get_response(decoded_output), raw_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  return # early return for stream_output
133
 
134
  # Without streaming
135
  with torch.no_grad():
136
- generation_output = model.generate(
137
- input_ids=input_ids,
138
- generation_config=generation_config,
139
- return_dict_in_generate=True,
140
- output_scores=True,
141
- max_new_tokens=max_new_tokens,
142
- )
143
  s = generation_output.sequences[0]
144
  output = tokenizer.decode(s)
145
  raw_output = None
146
  if show_raw:
147
  raw_output = str(s)
148
- yield prompter.get_response(output), raw_output
 
 
 
 
 
 
 
 
149
 
150
  except Exception as e:
151
  raise gr.Error(e)
152
 
153
 
 
 
 
 
 
154
  def reload_selections(current_lora_model, current_prompt_template):
155
  available_template_names = get_available_template_names()
156
  available_template_names_with_none = available_template_names + ["None"]
@@ -172,7 +238,7 @@ def reload_selections(current_lora_model, current_prompt_template):
172
  gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
173
 
174
 
175
- def handle_prompt_template_change(prompt_template):
176
  prompter = Prompter(prompt_template)
177
  var_names = prompter.get_variable_names()
178
  human_var_names = [' '.join(word.capitalize()
@@ -182,7 +248,36 @@ def handle_prompt_template_change(prompt_template):
182
  while len(gr_updates) < 8:
183
  gr_updates.append(gr.Textbox.update(
184
  label="Not Used", visible=False))
185
- return gr_updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
 
188
  def update_prompt_preview(prompt_template,
@@ -200,12 +295,15 @@ def inference_ui():
200
 
201
  with gr.Blocks() as inference_ui_blocks:
202
  with gr.Row():
203
- lora_model = gr.Dropdown(
204
- label="LoRA Model",
205
- elem_id="inference_lora_model",
206
- value="tloen/alpaca-lora-7b",
207
- allow_custom_value=True,
208
- )
 
 
 
209
  prompt_template = gr.Dropdown(
210
  label="Prompt Template",
211
  elem_id="inference_prompt_template",
@@ -318,7 +416,7 @@ def inference_ui():
318
  with gr.Column(elem_id="inference_output_group_container"):
319
  with gr.Column(elem_id="inference_output_group"):
320
  inference_output = gr.Textbox(
321
- lines=12, label="Output", elem_id="inference_output")
322
  inference_output.style(show_copy_button=True)
323
  with gr.Accordion(
324
  "Raw Output",
@@ -346,10 +444,20 @@ def inference_ui():
346
  )
347
  things_that_might_timeout.append(reload_selections_event)
348
 
349
- prompt_template_change_event = prompt_template.change(fn=handle_prompt_template_change, inputs=[prompt_template], outputs=[
350
- variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
 
 
 
 
351
  things_that_might_timeout.append(prompt_template_change_event)
352
 
 
 
 
 
 
 
353
  generate_event = generate_btn.click(
354
  fn=do_inference,
355
  inputs=[
@@ -369,8 +477,12 @@ def inference_ui():
369
  outputs=[inference_output, inference_raw_output],
370
  api_name="inference"
371
  )
372
- stop_btn.click(fn=None, inputs=None, outputs=None,
373
- cancels=[generate_event])
 
 
 
 
374
 
375
  update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
376
  variable_0, variable_1, variable_2, variable_3,
@@ -543,9 +655,15 @@ def inference_ui():
543
  return function (...args) {
544
  const context = this;
545
  clearTimeout(timeout);
546
- timeout = setTimeout(() => {
 
 
 
 
 
547
  func.apply(context, args);
548
- }, wait);
 
549
  };
550
  }
551
 
@@ -580,5 +698,27 @@ def inference_ui():
580
  });
581
  }
582
  }, 100);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  }
584
  """)
 
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
+ get_path_of_available_lora_model,
15
+ get_info_of_available_lora_model)
16
  from ..utils.prompter import Prompter
17
  from ..utils.callbacks import Iteratorize, Stream
18
 
19
  device = get_device()
20
 
21
  default_show_raw = True
22
+ inference_output_lines = 12
23
 
24
 
25
  def do_inference(
 
38
  progress=gr.Progress(track_tqdm=True),
39
  ):
40
  try:
41
+ if Global.generation_force_stopped_at is not None:
42
+ required_elapsed_time_after_forced_stop = 1
43
+ current_unix_time = time.time()
44
+ remaining_time = required_elapsed_time_after_forced_stop - \
45
+ (current_unix_time - Global.generation_force_stopped_at)
46
+ if remaining_time > 0:
47
+ time.sleep(remaining_time)
48
+ Global.generation_force_stopped_at = None
49
+
50
  variables = [variable_0, variable_1, variable_2, variable_3,
51
  variable_4, variable_5, variable_6, variable_7]
52
  prompter = Prompter(prompt_template)
53
  prompt = prompter.generate_prompt(variables)
54
 
55
+ if not lora_model_name:
56
+ lora_model_name = "None"
57
+ if "/" not in lora_model_name and lora_model_name != "None":
58
  path_of_available_lora_model = get_path_of_available_lora_model(
59
  lora_model_name)
60
  if path_of_available_lora_model:
 
79
  yield out
80
 
81
  for partial_sentence in word_generator(message):
82
+ yield (
83
+ gr.Textbox.update(
84
+ value=partial_sentence, lines=inference_output_lines),
85
+ json.dumps(
86
+ list(range(len(partial_sentence.split()))), indent=2)
87
+ )
88
  time.sleep(0.05)
89
 
90
  return
91
  time.sleep(1)
92
+ yield (
93
+ gr.Textbox.update(value=message, lines=1), # TODO
94
+ json.dumps(list(range(len(message.split()))), indent=2)
95
+ )
96
  return
97
 
98
  model = get_base_model()
99
+ if lora_model_name != "None":
100
  model = get_model_with_lora(lora_model_name)
101
  tokenizer = get_tokenizer()
102
 
 
118
  "max_new_tokens": max_new_tokens,
119
  }
120
 
121
+ def ui_generation_stopping_criteria(input_ids, score, **kwargs):
122
+ if Global.should_stop_generating:
123
+ return True
124
+ return False
125
+
126
+ Global.should_stop_generating = False
127
+ generate_params.setdefault(
128
+ "stopping_criteria", transformers.StoppingCriteriaList()
129
+ )
130
+ generate_params["stopping_criteria"].append(
131
+ ui_generation_stopping_criteria
132
+ )
133
+
134
  if stream_output:
135
  # Stream the reply 1 token at a time.
136
  # This is based on the trick of using 'stopping_criteria' to create an iterator,
 
162
  raw_output = None
163
  if show_raw:
164
  raw_output = str(output)
165
+ response = prompter.get_response(decoded_output)
166
+
167
+ if Global.should_stop_generating:
168
+ return
169
+
170
+ yield (
171
+ gr.Textbox.update(
172
+ value=response, lines=inference_output_lines),
173
+ raw_output)
174
+
175
+ if Global.should_stop_generating:
176
+ # If the user stops the generation, and then clicks the
177
+ # generation button again, they may mysteriously landed
178
+ # here, in the previous, should-be-stopped generation
179
+ # function call, with the new generation function not be
180
+ # called at all. To workaround this, we yield a message
181
+ # and setting lines=1, and if the front-end JS detects
182
+ # that lines has been set to 1 (rows="1" in HTML),
183
+ # it will automatically click the generate button again
184
+ # (gr.Textbox.update() does not support updating
185
+ # elem_classes or elem_id).
186
+ # [WORKAROUND-UI01]
187
+ yield (
188
+ gr.Textbox.update(
189
+ value="Please retry", lines=1),
190
+ None)
191
  return # early return for stream_output
192
 
193
  # Without streaming
194
  with torch.no_grad():
195
+ generation_output = model.generate(**generate_params)
 
 
 
 
 
 
196
  s = generation_output.sequences[0]
197
  output = tokenizer.decode(s)
198
  raw_output = None
199
  if show_raw:
200
  raw_output = str(s)
201
+
202
+ response = prompter.get_response(output)
203
+ if Global.should_stop_generating:
204
+ return
205
+
206
+ yield (
207
+ gr.Textbox.update(value=response, lines=inference_output_lines),
208
+ raw_output)
209
+
210
 
211
  except Exception as e:
212
  raise gr.Error(e)
213
 
214
 
215
+ def handle_stop_generate():
216
+ Global.generation_force_stopped_at = time.time()
217
+ Global.should_stop_generating = True
218
+
219
+
220
  def reload_selections(current_lora_model, current_prompt_template):
221
  available_template_names = get_available_template_names()
222
  available_template_names_with_none = available_template_names + ["None"]
 
238
  gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
239
 
240
 
241
+ def handle_prompt_template_change(prompt_template, lora_model):
242
  prompter = Prompter(prompt_template)
243
  var_names = prompter.get_variable_names()
244
  human_var_names = [' '.join(word.capitalize()
 
248
  while len(gr_updates) < 8:
249
  gr_updates.append(gr.Textbox.update(
250
  label="Not Used", visible=False))
251
+
252
+ model_prompt_template_message_update = gr.Markdown.update(
253
+ "", visible=False)
254
+ lora_mode_info = get_info_of_available_lora_model(lora_model)
255
+ if lora_mode_info and isinstance(lora_mode_info, dict):
256
+ model_prompt_template = lora_mode_info.get("prompt_template")
257
+ if model_prompt_template and model_prompt_template != prompt_template:
258
+ model_prompt_template_message_update = gr.Markdown.update(
259
+ f"Trained with prompt template `{model_prompt_template}`", visible=True)
260
+
261
+ return [model_prompt_template_message_update] + gr_updates
262
+
263
+
264
+ def handle_lora_model_change(lora_model, prompt_template):
265
+ lora_mode_info = get_info_of_available_lora_model(lora_model)
266
+ if not lora_mode_info:
267
+ return gr.Markdown.update("", visible=False), prompt_template
268
+
269
+ if not isinstance(lora_mode_info, dict):
270
+ return gr.Markdown.update("", visible=False), prompt_template
271
+
272
+ model_prompt_template = lora_mode_info.get("prompt_template")
273
+ if not model_prompt_template:
274
+ return gr.Markdown.update("", visible=False), prompt_template
275
+
276
+ available_template_names = get_available_template_names()
277
+ if model_prompt_template in available_template_names:
278
+ return gr.Markdown.update("", visible=False), model_prompt_template
279
+
280
+ return gr.Markdown.update(f"Trained with prompt template `{model_prompt_template}`", visible=True), prompt_template
281
 
282
 
283
  def update_prompt_preview(prompt_template,
 
295
 
296
  with gr.Blocks() as inference_ui_blocks:
297
  with gr.Row():
298
+ with gr.Column(elem_id="inference_lora_model_group"):
299
+ model_prompt_template_message = gr.Markdown(
300
+ "", visible=False, elem_id="inference_lora_model_prompt_template_message")
301
+ lora_model = gr.Dropdown(
302
+ label="LoRA Model",
303
+ elem_id="inference_lora_model",
304
+ value="tloen/alpaca-lora-7b",
305
+ allow_custom_value=True,
306
+ )
307
  prompt_template = gr.Dropdown(
308
  label="Prompt Template",
309
  elem_id="inference_prompt_template",
 
416
  with gr.Column(elem_id="inference_output_group_container"):
417
  with gr.Column(elem_id="inference_output_group"):
418
  inference_output = gr.Textbox(
419
+ lines=inference_output_lines, label="Output", elem_id="inference_output")
420
  inference_output.style(show_copy_button=True)
421
  with gr.Accordion(
422
  "Raw Output",
 
444
  )
445
  things_that_might_timeout.append(reload_selections_event)
446
 
447
+ prompt_template_change_event = prompt_template.change(
448
+ fn=handle_prompt_template_change,
449
+ inputs=[prompt_template, lora_model],
450
+ outputs=[
451
+ model_prompt_template_message,
452
+ variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
453
  things_that_might_timeout.append(prompt_template_change_event)
454
 
455
+ lora_model_change_event = lora_model.change(
456
+ fn=handle_lora_model_change,
457
+ inputs=[lora_model, prompt_template],
458
+ outputs=[model_prompt_template_message, prompt_template])
459
+ things_that_might_timeout.append(lora_model_change_event)
460
+
461
  generate_event = generate_btn.click(
462
  fn=do_inference,
463
  inputs=[
 
477
  outputs=[inference_output, inference_raw_output],
478
  api_name="inference"
479
  )
480
+ stop_btn.click(
481
+ fn=handle_stop_generate,
482
+ inputs=None,
483
+ outputs=None,
484
+ cancels=[generate_event]
485
+ )
486
 
487
  update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
488
  variable_0, variable_1, variable_2, variable_3,
 
655
  return function (...args) {
656
  const context = this;
657
  clearTimeout(timeout);
658
+ const fn = () => {
659
+ if (document.querySelector('#inference_preview_prompt > .wrap:not(.hide)')) {
660
+ // Preview request is still loading, wait for 10ms and try again.
661
+ timeout = setTimeout(fn, 10);
662
+ return;
663
+ }
664
  func.apply(context, args);
665
+ };
666
+ timeout = setTimeout(fn, wait);
667
  };
668
  }
669
 
 
698
  });
699
  }
700
  }, 100);
701
+
702
+ // [WORKAROUND-UI01]
703
+ setTimeout(function () {
704
+ const inference_output_textarea = document.querySelector(
705
+ '#inference_output textarea'
706
+ );
707
+ if (!inference_output_textarea) return;
708
+ const observer = new MutationObserver(function () {
709
+ if (inference_output_textarea.getAttribute('rows') === '1') {
710
+ setTimeout(function () {
711
+ const inference_generate_btn = document.getElementById(
712
+ 'inference_generate_btn'
713
+ );
714
+ if (inference_generate_btn) inference_generate_btn.click();
715
+ }, 10);
716
+ }
717
+ });
718
+ observer.observe(inference_output_textarea, {
719
+ attributes: true,
720
+ attributeFilter: ['rows'],
721
+ });
722
+ }, 100);
723
  }
724
  """)
llama_lora/ui/main_page.py CHANGED
@@ -30,7 +30,7 @@ def main_page():
30
  tokenizer_ui()
31
  info = []
32
  if Global.version:
33
- info.append(f"LLaMA-LoRA `{Global.version}`")
34
  info.append(f"Base model: `{Global.base_model}`")
35
  if Global.ui_show_sys_info:
36
  info.append(f"Data dir: `{Global.data_dir}`")
@@ -134,6 +134,41 @@ def main_page_custom_css():
134
  /* text-transform: uppercase; */
135
  }
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  #inference_prompt_box > *:first-child {
138
  border-bottom-left-radius: 0;
139
  border-bottom-right-radius: 0;
@@ -266,12 +301,16 @@ def main_page_custom_css():
266
  }
267
 
268
  @media screen and (min-width: 640px) {
269
- #inference_lora_model, #finetune_template {
 
270
  border-top-right-radius: 0;
271
  border-bottom-right-radius: 0;
272
  border-right: 0;
273
  margin-right: -16px;
274
  }
 
 
 
275
 
276
  #inference_prompt_template {
277
  border-top-left-radius: 0;
@@ -301,7 +340,7 @@ def main_page_custom_css():
301
  height: 42px !important;
302
  min-width: 42px !important;
303
  width: 42px !important;
304
- z-index: 1;
305
  }
306
  }
307
 
 
30
  tokenizer_ui()
31
  info = []
32
  if Global.version:
33
+ info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
34
  info.append(f"Base model: `{Global.base_model}`")
35
  if Global.ui_show_sys_info:
36
  info.append(f"Data dir: `{Global.data_dir}`")
 
134
  /* text-transform: uppercase; */
135
  }
136
 
137
+ #inference_lora_model_group {
138
+ border-radius: var(--block-radius);
139
+ background: var(--block-background-fill);
140
+ }
141
+ #inference_lora_model_group #inference_lora_model {
142
+ background: transparent;
143
+ }
144
+ #inference_lora_model_prompt_template_message:not(.hidden) + #inference_lora_model {
145
+ padding-bottom: 28px;
146
+ }
147
+ #inference_lora_model_group > #inference_lora_model_prompt_template_message {
148
+ position: absolute;
149
+ bottom: 8px;
150
+ left: 20px;
151
+ z-index: 1;
152
+ font-size: 12px;
153
+ opacity: 0.7;
154
+ }
155
+ #inference_lora_model_group > #inference_lora_model_prompt_template_message p {
156
+ font-size: 12px;
157
+ }
158
+ #inference_lora_model_prompt_template_message > .wrap {
159
+ display: none;
160
+ }
161
+ #inference_lora_model > .wrap:first-child:not(.hide),
162
+ #inference_prompt_template > .wrap:first-child:not(.hide) {
163
+ opacity: 0.5;
164
+ }
165
+ #inference_lora_model_group, #inference_lora_model {
166
+ z-index: 60;
167
+ }
168
+ #inference_prompt_template {
169
+ z-index: 55;
170
+ }
171
+
172
  #inference_prompt_box > *:first-child {
173
  border-bottom-left-radius: 0;
174
  border-bottom-right-radius: 0;
 
301
  }
302
 
303
  @media screen and (min-width: 640px) {
304
+ #inference_lora_model, #inference_lora_model_group,
305
+ #finetune_template {
306
  border-top-right-radius: 0;
307
  border-bottom-right-radius: 0;
308
  border-right: 0;
309
  margin-right: -16px;
310
  }
311
+ #inference_lora_model_group #inference_lora_model {
312
+ box-shadow: var(--block-shadow);
313
+ }
314
 
315
  #inference_prompt_template {
316
  border-top-left-radius: 0;
 
340
  height: 42px !important;
341
  min-width: 42px !important;
342
  width: 42px !important;
343
+ z-index: 61;
344
  }
345
  }
346
 
llama_lora/utils/data.py CHANGED
@@ -52,6 +52,22 @@ def get_path_of_available_lora_model(name):
52
  return None
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def get_dataset_content(name):
56
  file_name = os.path.join(Global.data_dir, "datasets", name)
57
  if not os.path.exists(file_name):
 
52
  return None
53
 
54
 
55
+ def get_info_of_available_lora_model(name):
56
+ try:
57
+ if "/" in name:
58
+ return None
59
+ path_of_available_lora_model = get_path_of_available_lora_model(
60
+ name)
61
+ if not path_of_available_lora_model:
62
+ return None
63
+
64
+ with open(os.path.join(path_of_available_lora_model, "info.json"), "r") as json_file:
65
+ return json.load(json_file)
66
+
67
+ except Exception as e:
68
+ return None
69
+
70
+
71
  def get_dataset_content(name):
72
  file_name = os.path.join(Global.data_dir, "datasets", name)
73
  if not os.path.exists(file_name):
llama_lora/utils/lru_cache.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+
4
+ class LRUCache:
5
+ def __init__(self, capacity=5):
6
+ self.cache = OrderedDict()
7
+ self.capacity = capacity
8
+
9
+ def get(self, key):
10
+ if key in self.cache:
11
+ # Move the accessed item to the end of the OrderedDict
12
+ self.cache.move_to_end(key)
13
+ return self.cache[key]
14
+ return None
15
+
16
+ def set(self, key, value):
17
+ if key in self.cache:
18
+ # If the key already exists, update its value
19
+ self.cache[key] = value
20
+ else:
21
+ # If the cache has reached its capacity, remove the least recently used item
22
+ if len(self.cache) >= self.capacity:
23
+ self.cache.popitem(last=False)
24
+ self.cache[key] = value
25
+
26
+ def clear(self):
27
+ self.cache.clear()
requirements.lock.txt CHANGED
@@ -65,10 +65,10 @@ packaging==23.0
65
  pandas==2.0.0
66
  parso==0.8.3
67
  pathspec==0.11.1
68
- peft @ git+https://github.com/huggingface/peft.git@deff03f2c251534fffd2511fc2d440e84cc54b1b
69
  pexpect==4.8.0
70
  pickleshare==0.7.5
71
- Pillow==9.5.0
72
  pkgutil_resolve_name==1.3.10
73
  platformdirs==3.2.0
74
  pluggy==1.0.0
 
65
  pandas==2.0.0
66
  parso==0.8.3
67
  pathspec==0.11.1
68
+ peft @ git+https://github.com/huggingface/peft.git@382b178911edff38c1ff619bbac2ba556bd2276b
69
  pexpect==4.8.0
70
  pickleshare==0.7.5
71
+ Pillow==9.3.0
72
  pkgutil_resolve_name==1.3.10
73
  platformdirs==3.2.0
74
  pluggy==1.0.0
templates/user_and_ai.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "description": "Unhelpful AI assistant.",
3
+ "variables": ["instruction"],
4
+ "prompt": "### User:\n{instruction}\n\n### AI:\n",
5
+ "default": "prompt",
6
+ "response_split": "### AI:"
7
+ }