Spaces:
Runtime error
Runtime error
Merge branch 'dev-2'
Browse files- LLaMA_LoRA.ipynb +5 -5
- README.md +10 -10
- app.py +5 -0
- llama_lora/globals.py +7 -1
- llama_lora/lib/finetune.py +21 -2
- llama_lora/models.py +15 -4
- llama_lora/ui/finetune_ui.py +35 -9
- llama_lora/ui/inference_ui.py +169 -29
- llama_lora/ui/main_page.py +42 -3
- llama_lora/utils/data.py +16 -0
- llama_lora/utils/lru_cache.py +27 -0
- requirements.lock.txt +2 -2
- templates/user_and_ai.json +7 -0
LLaMA_LoRA.ipynb
CHANGED
@@ -27,13 +27,13 @@
|
|
27 |
"colab_type": "text"
|
28 |
},
|
29 |
"source": [
|
30 |
-
"<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
31 |
]
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "markdown",
|
35 |
"source": [
|
36 |
-
"# π¦ποΈ LLaMA-LoRA\n",
|
37 |
"\n",
|
38 |
"TL;DR: **Runtime > Run All** (`β/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
|
39 |
],
|
@@ -72,9 +72,9 @@
|
|
72 |
"# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
|
73 |
"# @markdown Project settings.\n",
|
74 |
"\n",
|
75 |
-
"# @markdown The URL of the LLaMA-LoRA project<br> (default: `https://github.com/zetavg/
|
76 |
-
"llama_lora_project_url = \"https://github.com/zetavg/
|
77 |
-
"# @markdown The branch to use for LLaMA-LoRA project:\n",
|
78 |
"llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
|
79 |
"\n",
|
80 |
"# # @markdown Forces the local directory to be updated by the remote branch:\n",
|
|
|
27 |
"colab_type": "text"
|
28 |
},
|
29 |
"source": [
|
30 |
+
"<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
31 |
]
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "markdown",
|
35 |
"source": [
|
36 |
+
"# π¦ποΈ LLaMA-LoRA Tuner\n",
|
37 |
"\n",
|
38 |
"TL;DR: **Runtime > Run All** (`β/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
|
39 |
],
|
|
|
72 |
"# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
|
73 |
"# @markdown Project settings.\n",
|
74 |
"\n",
|
75 |
+
"# @markdown The URL of the LLaMA-LoRA-Tuner project<br> (default: `https://github.com/zetavg/LLaMA-LoRA-Tuner.git`):\n",
|
76 |
+
"llama_lora_project_url = \"https://github.com/zetavg/LLaMA-LoRA-Tuner.git\" # @param {type:\"string\"}\n",
|
77 |
+
"# @markdown The branch to use for LLaMA-LoRA-Tuner project:\n",
|
78 |
"llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
|
79 |
"\n",
|
80 |
"# # @markdown Forces the local directory to be updated by the remote branch:\n",
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
# π¦ποΈ LLaMA-LoRA
|
2 |
|
3 |
-
<a href="https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
4 |
|
5 |
Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) easy.
|
6 |
|
@@ -27,7 +27,7 @@ There are various ways to run this app:
|
|
27 |
|
28 |
### Run On Google Colab
|
29 |
|
30 |
-
Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`β/Ctrl+F9`).
|
31 |
|
32 |
You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
|
33 |
|
@@ -38,7 +38,7 @@ After approximately 5 minutes of running, you will see the public URL in the out
|
|
38 |
After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
|
39 |
|
40 |
```yaml
|
41 |
-
# llama-lora-
|
42 |
|
43 |
resources:
|
44 |
accelerators: A10:1 # 1x NVIDIA A10 GPU, about US$ 0.6 / hr on Lambda Cloud.
|
@@ -49,13 +49,13 @@ file_mounts:
|
|
49 |
# (to store train datasets trained models)
|
50 |
# See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
|
51 |
/data:
|
52 |
-
name: llama-lora-
|
53 |
store: s3 # Could be either of [s3, gcs]
|
54 |
mode: MOUNT
|
55 |
|
56 |
-
# Clone the LLaMA-LoRA repo and install its dependencies.
|
57 |
setup: |
|
58 |
-
git clone https://github.com/zetavg/LLaMA-LoRA.git llama_lora
|
59 |
cd llama_lora && pip install -r requirements.lock.txt
|
60 |
cd ..
|
61 |
echo 'Dependencies installed.'
|
@@ -69,7 +69,7 @@ run: |
|
|
69 |
Then launch a cluster to run the task:
|
70 |
|
71 |
```
|
72 |
-
sky launch -c llama-lora-
|
73 |
```
|
74 |
|
75 |
`-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
|
@@ -86,8 +86,8 @@ When you are done, run `sky stop <cluster_name>` to stop the cluster. To termina
|
|
86 |
<summary>Prepare environment with conda</summary>
|
87 |
|
88 |
```bash
|
89 |
-
conda create -y python=3.8 -n llama-lora-
|
90 |
-
conda activate llama-lora-
|
91 |
```
|
92 |
</details>
|
93 |
|
|
|
1 |
+
# π¦ποΈ LLaMA-LoRA Tuner
|
2 |
|
3 |
+
<a href="https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
4 |
|
5 |
Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) easy.
|
6 |
|
|
|
27 |
|
28 |
### Run On Google Colab
|
29 |
|
30 |
+
Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`β/Ctrl+F9`).
|
31 |
|
32 |
You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
|
33 |
|
|
|
38 |
After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
|
39 |
|
40 |
```yaml
|
41 |
+
# llama-lora-tuner.yaml
|
42 |
|
43 |
resources:
|
44 |
accelerators: A10:1 # 1x NVIDIA A10 GPU, about US$ 0.6 / hr on Lambda Cloud.
|
|
|
49 |
# (to store train datasets trained models)
|
50 |
# See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
|
51 |
/data:
|
52 |
+
name: llama-lora-tuner-data # Make sure this name is unique or you own this bucket. If it does not exists, SkyPilot will try to create a bucket with this name.
|
53 |
store: s3 # Could be either of [s3, gcs]
|
54 |
mode: MOUNT
|
55 |
|
56 |
+
# Clone the LLaMA-LoRA Tuner repo and install its dependencies.
|
57 |
setup: |
|
58 |
+
git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora
|
59 |
cd llama_lora && pip install -r requirements.lock.txt
|
60 |
cd ..
|
61 |
echo 'Dependencies installed.'
|
|
|
69 |
Then launch a cluster to run the task:
|
70 |
|
71 |
```
|
72 |
+
sky launch -c llama-lora-tuner llama-lora-tuner.yaml
|
73 |
```
|
74 |
|
75 |
`-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
|
|
|
86 |
<summary>Prepare environment with conda</summary>
|
87 |
|
88 |
```bash
|
89 |
+
conda create -y python=3.8 -n llama-lora-tuner
|
90 |
+
conda activate llama-lora-tuner
|
91 |
```
|
92 |
</details>
|
93 |
|
app.py
CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
|
|
7 |
from llama_lora.globals import Global
|
8 |
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
9 |
from llama_lora.utils.data import init_data_dir
|
|
|
10 |
|
11 |
|
12 |
def main(
|
@@ -16,6 +17,7 @@ def main(
|
|
16 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
17 |
server_name: str = "127.0.0.1",
|
18 |
share: bool = False,
|
|
|
19 |
ui_show_sys_info: bool = True,
|
20 |
ui_dev_mode: bool = False,
|
21 |
):
|
@@ -39,6 +41,9 @@ def main(
|
|
39 |
os.makedirs(data_dir, exist_ok=True)
|
40 |
init_data_dir()
|
41 |
|
|
|
|
|
|
|
42 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
43 |
main_page()
|
44 |
|
|
|
7 |
from llama_lora.globals import Global
|
8 |
from llama_lora.ui.main_page import main_page, get_page_title, main_page_custom_css
|
9 |
from llama_lora.utils.data import init_data_dir
|
10 |
+
from llama_lora.models import load_base_model
|
11 |
|
12 |
|
13 |
def main(
|
|
|
17 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
18 |
server_name: str = "127.0.0.1",
|
19 |
share: bool = False,
|
20 |
+
skip_loading_base_model: bool = False,
|
21 |
ui_show_sys_info: bool = True,
|
22 |
ui_dev_mode: bool = False,
|
23 |
):
|
|
|
41 |
os.makedirs(data_dir, exist_ok=True)
|
42 |
init_data_dir()
|
43 |
|
44 |
+
if not skip_loading_base_model:
|
45 |
+
load_base_model()
|
46 |
+
|
47 |
with gr.Blocks(title=get_page_title(), css=main_page_custom_css()) as demo:
|
48 |
main_page()
|
49 |
|
llama_lora/globals.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6 |
from numba import cuda
|
7 |
import nvidia_smi
|
8 |
|
|
|
9 |
from .lib.finetune import train
|
10 |
|
11 |
|
@@ -25,8 +26,13 @@ class Global:
|
|
25 |
# Training Control
|
26 |
should_stop_training = False
|
27 |
|
|
|
|
|
|
|
|
|
28 |
# Model related
|
29 |
model_has_been_used = False
|
|
|
30 |
|
31 |
# GPU Info
|
32 |
gpu_cc = None # GPU compute capability
|
@@ -35,7 +41,7 @@ class Global:
|
|
35 |
gpu_total_memory = None
|
36 |
|
37 |
# UI related
|
38 |
-
ui_title: str = "LLaMA-LoRA"
|
39 |
ui_emoji: str = "π¦ποΈ"
|
40 |
ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
|
41 |
ui_show_sys_info: bool = True
|
|
|
6 |
from numba import cuda
|
7 |
import nvidia_smi
|
8 |
|
9 |
+
from .utils.lru_cache import LRUCache
|
10 |
from .lib.finetune import train
|
11 |
|
12 |
|
|
|
26 |
# Training Control
|
27 |
should_stop_training = False
|
28 |
|
29 |
+
# Generation Control
|
30 |
+
should_stop_generating = False
|
31 |
+
generation_force_stopped_at = None
|
32 |
+
|
33 |
# Model related
|
34 |
model_has_been_used = False
|
35 |
+
cached_lora_models = LRUCache(10)
|
36 |
|
37 |
# GPU Info
|
38 |
gpu_cc = None # GPU compute capability
|
|
|
41 |
gpu_total_memory = None
|
42 |
|
43 |
# UI related
|
44 |
+
ui_title: str = "LLaMA-LoRA Tuner"
|
45 |
ui_emoji: str = "π¦ποΈ"
|
46 |
ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
|
47 |
ui_show_sys_info: bool = True
|
llama_lora/lib/finetune.py
CHANGED
@@ -2,6 +2,8 @@ import os
|
|
2 |
import sys
|
3 |
from typing import Any, List
|
4 |
|
|
|
|
|
5 |
import fire
|
6 |
import torch
|
7 |
import transformers
|
@@ -47,6 +49,10 @@ def train(
|
|
47 |
# logging
|
48 |
callbacks: List[Any] = []
|
49 |
):
|
|
|
|
|
|
|
|
|
50 |
device_map = "auto"
|
51 |
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
52 |
ddp = world_size != 1
|
@@ -202,6 +208,12 @@ def train(
|
|
202 |
),
|
203 |
callbacks=callbacks,
|
204 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
model.config.use_cache = False
|
206 |
|
207 |
old_state_dict = model.state_dict
|
@@ -214,9 +226,16 @@ def train(
|
|
214 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
215 |
model = torch.compile(model)
|
216 |
|
217 |
-
|
218 |
|
219 |
model.save_pretrained(output_dir)
|
220 |
print(f"Model saved to {output_dir}.")
|
221 |
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import sys
|
3 |
from typing import Any, List
|
4 |
|
5 |
+
import json
|
6 |
+
|
7 |
import fire
|
8 |
import torch
|
9 |
import transformers
|
|
|
49 |
# logging
|
50 |
callbacks: List[Any] = []
|
51 |
):
|
52 |
+
if os.path.exists(output_dir):
|
53 |
+
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
54 |
+
raise ValueError(f"The output directory already exists and is not empty. ({output_dir})")
|
55 |
+
|
56 |
device_map = "auto"
|
57 |
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
58 |
ddp = world_size != 1
|
|
|
208 |
),
|
209 |
callbacks=callbacks,
|
210 |
)
|
211 |
+
|
212 |
+
if not os.path.exists(output_dir):
|
213 |
+
os.makedirs(output_dir)
|
214 |
+
with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
|
215 |
+
json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
|
216 |
+
|
217 |
model.config.use_cache = False
|
218 |
|
219 |
old_state_dict = model.state_dict
|
|
|
226 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
227 |
model = torch.compile(model)
|
228 |
|
229 |
+
train_output = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
230 |
|
231 |
model.save_pretrained(output_dir)
|
232 |
print(f"Model saved to {output_dir}.")
|
233 |
|
234 |
+
with open(os.path.join(output_dir, "trainer_log_history.jsonl"), 'w') as trainer_log_history_jsonl_file:
|
235 |
+
trainer_log_history = "\n".join([json.dumps(line) for line in trainer.state.log_history])
|
236 |
+
trainer_log_history_jsonl_file.write(trainer_log_history)
|
237 |
+
|
238 |
+
with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
|
239 |
+
json.dump(train_output, train_output_json_file, indent=2)
|
240 |
+
|
241 |
+
return train_output
|
llama_lora/models.py
CHANGED
@@ -31,27 +31,32 @@ def get_base_model():
|
|
31 |
return Global.loaded_base_model
|
32 |
|
33 |
|
34 |
-
def get_model_with_lora(
|
35 |
Global.model_has_been_used = True
|
36 |
|
|
|
|
|
|
|
|
|
|
|
37 |
if device == "cuda":
|
38 |
model = PeftModel.from_pretrained(
|
39 |
get_base_model(),
|
40 |
-
|
41 |
torch_dtype=torch.float16,
|
42 |
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
43 |
)
|
44 |
elif device == "mps":
|
45 |
model = PeftModel.from_pretrained(
|
46 |
get_base_model(),
|
47 |
-
|
48 |
device_map={"": device},
|
49 |
torch_dtype=torch.float16,
|
50 |
)
|
51 |
else:
|
52 |
model = PeftModel.from_pretrained(
|
53 |
get_base_model(),
|
54 |
-
|
55 |
device_map={"": device},
|
56 |
)
|
57 |
|
@@ -65,6 +70,10 @@ def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
|
|
65 |
model.eval()
|
66 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
67 |
model = torch.compile(model)
|
|
|
|
|
|
|
|
|
68 |
return model
|
69 |
|
70 |
|
@@ -121,6 +130,8 @@ def unload_models():
|
|
121 |
del Global.loaded_tokenizer
|
122 |
Global.loaded_tokenizer = None
|
123 |
|
|
|
|
|
124 |
clear_cache()
|
125 |
|
126 |
Global.model_has_been_used = False
|
|
|
31 |
return Global.loaded_base_model
|
32 |
|
33 |
|
34 |
+
def get_model_with_lora(lora_weights_name_or_path: str = "tloen/alpaca-lora-7b"):
|
35 |
Global.model_has_been_used = True
|
36 |
|
37 |
+
if Global.cached_lora_models:
|
38 |
+
model_from_cache = Global.cached_lora_models.get(lora_weights_name_or_path)
|
39 |
+
if model_from_cache:
|
40 |
+
return model_from_cache
|
41 |
+
|
42 |
if device == "cuda":
|
43 |
model = PeftModel.from_pretrained(
|
44 |
get_base_model(),
|
45 |
+
lora_weights_name_or_path,
|
46 |
torch_dtype=torch.float16,
|
47 |
device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
|
48 |
)
|
49 |
elif device == "mps":
|
50 |
model = PeftModel.from_pretrained(
|
51 |
get_base_model(),
|
52 |
+
lora_weights_name_or_path,
|
53 |
device_map={"": device},
|
54 |
torch_dtype=torch.float16,
|
55 |
)
|
56 |
else:
|
57 |
model = PeftModel.from_pretrained(
|
58 |
get_base_model(),
|
59 |
+
lora_weights_name_or_path,
|
60 |
device_map={"": device},
|
61 |
)
|
62 |
|
|
|
70 |
model.eval()
|
71 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
72 |
model = torch.compile(model)
|
73 |
+
|
74 |
+
if Global.cached_lora_models:
|
75 |
+
Global.cached_lora_models.set(lora_weights_name_or_path, model)
|
76 |
+
|
77 |
return model
|
78 |
|
79 |
|
|
|
130 |
del Global.loaded_tokenizer
|
131 |
Global.loaded_tokenizer = None
|
132 |
|
133 |
+
Global.cached_lora_models.clear()
|
134 |
+
|
135 |
clear_cache()
|
136 |
|
137 |
Global.model_has_been_used = False
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -269,6 +269,9 @@ def do_train(
|
|
269 |
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
270 |
):
|
271 |
try:
|
|
|
|
|
|
|
272 |
clear_cache()
|
273 |
# If model has been used in inference, we need to unload it first.
|
274 |
# Otherwise, we'll get a 'Function MmBackward0 returned an invalid
|
@@ -373,6 +376,9 @@ Train data (first 10):
|
|
373 |
time.sleep(2)
|
374 |
return message
|
375 |
|
|
|
|
|
|
|
376 |
log_history = []
|
377 |
|
378 |
class UiTrainerCallback(TrainerCallback):
|
@@ -419,11 +425,30 @@ Train data (first 10):
|
|
419 |
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
420 |
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
421 |
|
422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
base_model, # base_model
|
424 |
tokenizer, # tokenizer
|
425 |
-
|
426 |
-
model_name), # output_dir
|
427 |
train_data,
|
428 |
# 128, # batch_size (is not used, use gradient_accumulation_steps instead)
|
429 |
micro_batch_size, # micro_batch_size
|
@@ -445,12 +470,13 @@ Train data (first 10):
|
|
445 |
logs_str = "\n".join([json.dumps(log)
|
446 |
for log in log_history]) or "None"
|
447 |
|
448 |
-
result_message = f"Training ended:\n{str(
|
449 |
print(result_message)
|
|
|
450 |
return result_message
|
451 |
|
452 |
except Exception as e:
|
453 |
-
raise gr.Error(e)
|
454 |
|
455 |
|
456 |
def do_abort_training():
|
@@ -675,9 +701,9 @@ def finetune_ui():
|
|
675 |
elem_id="finetune_confirm_stop_btn"
|
676 |
)
|
677 |
|
678 |
-
|
679 |
-
"Training
|
680 |
-
label="
|
681 |
elem_id="finetune_training_status")
|
682 |
|
683 |
train_progress = train_btn.click(
|
@@ -693,7 +719,7 @@ def finetune_ui():
|
|
693 |
lora_dropout,
|
694 |
model_name
|
695 |
]),
|
696 |
-
outputs=
|
697 |
)
|
698 |
|
699 |
# controlled by JS, shows the confirm_abort_button
|
|
|
269 |
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
270 |
):
|
271 |
try:
|
272 |
+
if not should_training_progress_track_tqdm:
|
273 |
+
progress(0, desc="Preparing train data...")
|
274 |
+
|
275 |
clear_cache()
|
276 |
# If model has been used in inference, we need to unload it first.
|
277 |
# Otherwise, we'll get a 'Function MmBackward0 returned an invalid
|
|
|
376 |
time.sleep(2)
|
377 |
return message
|
378 |
|
379 |
+
if not should_training_progress_track_tqdm:
|
380 |
+
progress(0, desc="Preparing model for training...")
|
381 |
+
|
382 |
log_history = []
|
383 |
|
384 |
class UiTrainerCallback(TrainerCallback):
|
|
|
425 |
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
426 |
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
427 |
|
428 |
+
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
429 |
+
if not os.path.exists(output_dir):
|
430 |
+
os.makedirs(output_dir)
|
431 |
+
|
432 |
+
with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
|
433 |
+
dataset_name = "N/A (from text input)"
|
434 |
+
if load_dataset_from == "Data Dir":
|
435 |
+
dataset_name = dataset_from_data_dir
|
436 |
+
|
437 |
+
info = {
|
438 |
+
'base_model': Global.base_model,
|
439 |
+
'prompt_template': template,
|
440 |
+
'dataset_name': dataset_name,
|
441 |
+
'dataset_rows': len(train_data),
|
442 |
+
}
|
443 |
+
json.dump(info, info_json_file, indent=2)
|
444 |
+
|
445 |
+
if not should_training_progress_track_tqdm:
|
446 |
+
progress(0, desc="Train starting...")
|
447 |
+
|
448 |
+
train_output = Global.train_fn(
|
449 |
base_model, # base_model
|
450 |
tokenizer, # tokenizer
|
451 |
+
output_dir, # output_dir
|
|
|
452 |
train_data,
|
453 |
# 128, # batch_size (is not used, use gradient_accumulation_steps instead)
|
454 |
micro_batch_size, # micro_batch_size
|
|
|
470 |
logs_str = "\n".join([json.dumps(log)
|
471 |
for log in log_history]) or "None"
|
472 |
|
473 |
+
result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
|
474 |
print(result_message)
|
475 |
+
clear_cache()
|
476 |
return result_message
|
477 |
|
478 |
except Exception as e:
|
479 |
+
raise gr.Error(f"{e} (To dismiss this error, click the 'Abort' button)")
|
480 |
|
481 |
|
482 |
def do_abort_training():
|
|
|
701 |
elem_id="finetune_confirm_stop_btn"
|
702 |
)
|
703 |
|
704 |
+
train_output = gr.Text(
|
705 |
+
"Training results will be shown here.",
|
706 |
+
label="Train Output",
|
707 |
elem_id="finetune_training_status")
|
708 |
|
709 |
train_progress = train_btn.click(
|
|
|
719 |
lora_dropout,
|
720 |
model_name
|
721 |
]),
|
722 |
+
outputs=train_output
|
723 |
)
|
724 |
|
725 |
# controlled by JS, shows the confirm_abort_button
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -11,13 +11,15 @@ from ..models import get_base_model, get_model_with_lora, get_tokenizer, get_dev
|
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
14 |
-
get_path_of_available_lora_model
|
|
|
15 |
from ..utils.prompter import Prompter
|
16 |
from ..utils.callbacks import Iteratorize, Stream
|
17 |
|
18 |
device = get_device()
|
19 |
|
20 |
default_show_raw = True
|
|
|
21 |
|
22 |
|
23 |
def do_inference(
|
@@ -36,12 +38,23 @@ def do_inference(
|
|
36 |
progress=gr.Progress(track_tqdm=True),
|
37 |
):
|
38 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
variables = [variable_0, variable_1, variable_2, variable_3,
|
40 |
variable_4, variable_5, variable_6, variable_7]
|
41 |
prompter = Prompter(prompt_template)
|
42 |
prompt = prompter.generate_prompt(variables)
|
43 |
|
44 |
-
if
|
|
|
|
|
45 |
path_of_available_lora_model = get_path_of_available_lora_model(
|
46 |
lora_model_name)
|
47 |
if path_of_available_lora_model:
|
@@ -66,16 +79,24 @@ def do_inference(
|
|
66 |
yield out
|
67 |
|
68 |
for partial_sentence in word_generator(message):
|
69 |
-
yield
|
|
|
|
|
|
|
|
|
|
|
70 |
time.sleep(0.05)
|
71 |
|
72 |
return
|
73 |
time.sleep(1)
|
74 |
-
yield
|
|
|
|
|
|
|
75 |
return
|
76 |
|
77 |
model = get_base_model()
|
78 |
-
if
|
79 |
model = get_model_with_lora(lora_model_name)
|
80 |
tokenizer = get_tokenizer()
|
81 |
|
@@ -97,6 +118,19 @@ def do_inference(
|
|
97 |
"max_new_tokens": max_new_tokens,
|
98 |
}
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
if stream_output:
|
101 |
# Stream the reply 1 token at a time.
|
102 |
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
@@ -128,29 +162,61 @@ def do_inference(
|
|
128 |
raw_output = None
|
129 |
if show_raw:
|
130 |
raw_output = str(output)
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
return # early return for stream_output
|
133 |
|
134 |
# Without streaming
|
135 |
with torch.no_grad():
|
136 |
-
generation_output = model.generate(
|
137 |
-
input_ids=input_ids,
|
138 |
-
generation_config=generation_config,
|
139 |
-
return_dict_in_generate=True,
|
140 |
-
output_scores=True,
|
141 |
-
max_new_tokens=max_new_tokens,
|
142 |
-
)
|
143 |
s = generation_output.sequences[0]
|
144 |
output = tokenizer.decode(s)
|
145 |
raw_output = None
|
146 |
if show_raw:
|
147 |
raw_output = str(s)
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
except Exception as e:
|
151 |
raise gr.Error(e)
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
|
|
154 |
def reload_selections(current_lora_model, current_prompt_template):
|
155 |
available_template_names = get_available_template_names()
|
156 |
available_template_names_with_none = available_template_names + ["None"]
|
@@ -172,7 +238,7 @@ def reload_selections(current_lora_model, current_prompt_template):
|
|
172 |
gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
|
173 |
|
174 |
|
175 |
-
def handle_prompt_template_change(prompt_template):
|
176 |
prompter = Prompter(prompt_template)
|
177 |
var_names = prompter.get_variable_names()
|
178 |
human_var_names = [' '.join(word.capitalize()
|
@@ -182,7 +248,36 @@ def handle_prompt_template_change(prompt_template):
|
|
182 |
while len(gr_updates) < 8:
|
183 |
gr_updates.append(gr.Textbox.update(
|
184 |
label="Not Used", visible=False))
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
|
188 |
def update_prompt_preview(prompt_template,
|
@@ -200,12 +295,15 @@ def inference_ui():
|
|
200 |
|
201 |
with gr.Blocks() as inference_ui_blocks:
|
202 |
with gr.Row():
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
209 |
prompt_template = gr.Dropdown(
|
210 |
label="Prompt Template",
|
211 |
elem_id="inference_prompt_template",
|
@@ -318,7 +416,7 @@ def inference_ui():
|
|
318 |
with gr.Column(elem_id="inference_output_group_container"):
|
319 |
with gr.Column(elem_id="inference_output_group"):
|
320 |
inference_output = gr.Textbox(
|
321 |
-
lines=
|
322 |
inference_output.style(show_copy_button=True)
|
323 |
with gr.Accordion(
|
324 |
"Raw Output",
|
@@ -346,10 +444,20 @@ def inference_ui():
|
|
346 |
)
|
347 |
things_that_might_timeout.append(reload_selections_event)
|
348 |
|
349 |
-
prompt_template_change_event = prompt_template.change(
|
350 |
-
|
|
|
|
|
|
|
|
|
351 |
things_that_might_timeout.append(prompt_template_change_event)
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
generate_event = generate_btn.click(
|
354 |
fn=do_inference,
|
355 |
inputs=[
|
@@ -369,8 +477,12 @@ def inference_ui():
|
|
369 |
outputs=[inference_output, inference_raw_output],
|
370 |
api_name="inference"
|
371 |
)
|
372 |
-
stop_btn.click(
|
373 |
-
|
|
|
|
|
|
|
|
|
374 |
|
375 |
update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
|
376 |
variable_0, variable_1, variable_2, variable_3,
|
@@ -543,9 +655,15 @@ def inference_ui():
|
|
543 |
return function (...args) {
|
544 |
const context = this;
|
545 |
clearTimeout(timeout);
|
546 |
-
|
|
|
|
|
|
|
|
|
|
|
547 |
func.apply(context, args);
|
548 |
-
}
|
|
|
549 |
};
|
550 |
}
|
551 |
|
@@ -580,5 +698,27 @@ def inference_ui():
|
|
580 |
});
|
581 |
}
|
582 |
}, 100);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
}
|
584 |
""")
|
|
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
14 |
+
get_path_of_available_lora_model,
|
15 |
+
get_info_of_available_lora_model)
|
16 |
from ..utils.prompter import Prompter
|
17 |
from ..utils.callbacks import Iteratorize, Stream
|
18 |
|
19 |
device = get_device()
|
20 |
|
21 |
default_show_raw = True
|
22 |
+
inference_output_lines = 12
|
23 |
|
24 |
|
25 |
def do_inference(
|
|
|
38 |
progress=gr.Progress(track_tqdm=True),
|
39 |
):
|
40 |
try:
|
41 |
+
if Global.generation_force_stopped_at is not None:
|
42 |
+
required_elapsed_time_after_forced_stop = 1
|
43 |
+
current_unix_time = time.time()
|
44 |
+
remaining_time = required_elapsed_time_after_forced_stop - \
|
45 |
+
(current_unix_time - Global.generation_force_stopped_at)
|
46 |
+
if remaining_time > 0:
|
47 |
+
time.sleep(remaining_time)
|
48 |
+
Global.generation_force_stopped_at = None
|
49 |
+
|
50 |
variables = [variable_0, variable_1, variable_2, variable_3,
|
51 |
variable_4, variable_5, variable_6, variable_7]
|
52 |
prompter = Prompter(prompt_template)
|
53 |
prompt = prompter.generate_prompt(variables)
|
54 |
|
55 |
+
if not lora_model_name:
|
56 |
+
lora_model_name = "None"
|
57 |
+
if "/" not in lora_model_name and lora_model_name != "None":
|
58 |
path_of_available_lora_model = get_path_of_available_lora_model(
|
59 |
lora_model_name)
|
60 |
if path_of_available_lora_model:
|
|
|
79 |
yield out
|
80 |
|
81 |
for partial_sentence in word_generator(message):
|
82 |
+
yield (
|
83 |
+
gr.Textbox.update(
|
84 |
+
value=partial_sentence, lines=inference_output_lines),
|
85 |
+
json.dumps(
|
86 |
+
list(range(len(partial_sentence.split()))), indent=2)
|
87 |
+
)
|
88 |
time.sleep(0.05)
|
89 |
|
90 |
return
|
91 |
time.sleep(1)
|
92 |
+
yield (
|
93 |
+
gr.Textbox.update(value=message, lines=1), # TODO
|
94 |
+
json.dumps(list(range(len(message.split()))), indent=2)
|
95 |
+
)
|
96 |
return
|
97 |
|
98 |
model = get_base_model()
|
99 |
+
if lora_model_name != "None":
|
100 |
model = get_model_with_lora(lora_model_name)
|
101 |
tokenizer = get_tokenizer()
|
102 |
|
|
|
118 |
"max_new_tokens": max_new_tokens,
|
119 |
}
|
120 |
|
121 |
+
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
122 |
+
if Global.should_stop_generating:
|
123 |
+
return True
|
124 |
+
return False
|
125 |
+
|
126 |
+
Global.should_stop_generating = False
|
127 |
+
generate_params.setdefault(
|
128 |
+
"stopping_criteria", transformers.StoppingCriteriaList()
|
129 |
+
)
|
130 |
+
generate_params["stopping_criteria"].append(
|
131 |
+
ui_generation_stopping_criteria
|
132 |
+
)
|
133 |
+
|
134 |
if stream_output:
|
135 |
# Stream the reply 1 token at a time.
|
136 |
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
|
|
162 |
raw_output = None
|
163 |
if show_raw:
|
164 |
raw_output = str(output)
|
165 |
+
response = prompter.get_response(decoded_output)
|
166 |
+
|
167 |
+
if Global.should_stop_generating:
|
168 |
+
return
|
169 |
+
|
170 |
+
yield (
|
171 |
+
gr.Textbox.update(
|
172 |
+
value=response, lines=inference_output_lines),
|
173 |
+
raw_output)
|
174 |
+
|
175 |
+
if Global.should_stop_generating:
|
176 |
+
# If the user stops the generation, and then clicks the
|
177 |
+
# generation button again, they may mysteriously landed
|
178 |
+
# here, in the previous, should-be-stopped generation
|
179 |
+
# function call, with the new generation function not be
|
180 |
+
# called at all. To workaround this, we yield a message
|
181 |
+
# and setting lines=1, and if the front-end JS detects
|
182 |
+
# that lines has been set to 1 (rows="1" in HTML),
|
183 |
+
# it will automatically click the generate button again
|
184 |
+
# (gr.Textbox.update() does not support updating
|
185 |
+
# elem_classes or elem_id).
|
186 |
+
# [WORKAROUND-UI01]
|
187 |
+
yield (
|
188 |
+
gr.Textbox.update(
|
189 |
+
value="Please retry", lines=1),
|
190 |
+
None)
|
191 |
return # early return for stream_output
|
192 |
|
193 |
# Without streaming
|
194 |
with torch.no_grad():
|
195 |
+
generation_output = model.generate(**generate_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
s = generation_output.sequences[0]
|
197 |
output = tokenizer.decode(s)
|
198 |
raw_output = None
|
199 |
if show_raw:
|
200 |
raw_output = str(s)
|
201 |
+
|
202 |
+
response = prompter.get_response(output)
|
203 |
+
if Global.should_stop_generating:
|
204 |
+
return
|
205 |
+
|
206 |
+
yield (
|
207 |
+
gr.Textbox.update(value=response, lines=inference_output_lines),
|
208 |
+
raw_output)
|
209 |
+
|
210 |
|
211 |
except Exception as e:
|
212 |
raise gr.Error(e)
|
213 |
|
214 |
|
215 |
+
def handle_stop_generate():
|
216 |
+
Global.generation_force_stopped_at = time.time()
|
217 |
+
Global.should_stop_generating = True
|
218 |
+
|
219 |
+
|
220 |
def reload_selections(current_lora_model, current_prompt_template):
|
221 |
available_template_names = get_available_template_names()
|
222 |
available_template_names_with_none = available_template_names + ["None"]
|
|
|
238 |
gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
|
239 |
|
240 |
|
241 |
+
def handle_prompt_template_change(prompt_template, lora_model):
|
242 |
prompter = Prompter(prompt_template)
|
243 |
var_names = prompter.get_variable_names()
|
244 |
human_var_names = [' '.join(word.capitalize()
|
|
|
248 |
while len(gr_updates) < 8:
|
249 |
gr_updates.append(gr.Textbox.update(
|
250 |
label="Not Used", visible=False))
|
251 |
+
|
252 |
+
model_prompt_template_message_update = gr.Markdown.update(
|
253 |
+
"", visible=False)
|
254 |
+
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
255 |
+
if lora_mode_info and isinstance(lora_mode_info, dict):
|
256 |
+
model_prompt_template = lora_mode_info.get("prompt_template")
|
257 |
+
if model_prompt_template and model_prompt_template != prompt_template:
|
258 |
+
model_prompt_template_message_update = gr.Markdown.update(
|
259 |
+
f"Trained with prompt template `{model_prompt_template}`", visible=True)
|
260 |
+
|
261 |
+
return [model_prompt_template_message_update] + gr_updates
|
262 |
+
|
263 |
+
|
264 |
+
def handle_lora_model_change(lora_model, prompt_template):
|
265 |
+
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
266 |
+
if not lora_mode_info:
|
267 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
268 |
+
|
269 |
+
if not isinstance(lora_mode_info, dict):
|
270 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
271 |
+
|
272 |
+
model_prompt_template = lora_mode_info.get("prompt_template")
|
273 |
+
if not model_prompt_template:
|
274 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
275 |
+
|
276 |
+
available_template_names = get_available_template_names()
|
277 |
+
if model_prompt_template in available_template_names:
|
278 |
+
return gr.Markdown.update("", visible=False), model_prompt_template
|
279 |
+
|
280 |
+
return gr.Markdown.update(f"Trained with prompt template `{model_prompt_template}`", visible=True), prompt_template
|
281 |
|
282 |
|
283 |
def update_prompt_preview(prompt_template,
|
|
|
295 |
|
296 |
with gr.Blocks() as inference_ui_blocks:
|
297 |
with gr.Row():
|
298 |
+
with gr.Column(elem_id="inference_lora_model_group"):
|
299 |
+
model_prompt_template_message = gr.Markdown(
|
300 |
+
"", visible=False, elem_id="inference_lora_model_prompt_template_message")
|
301 |
+
lora_model = gr.Dropdown(
|
302 |
+
label="LoRA Model",
|
303 |
+
elem_id="inference_lora_model",
|
304 |
+
value="tloen/alpaca-lora-7b",
|
305 |
+
allow_custom_value=True,
|
306 |
+
)
|
307 |
prompt_template = gr.Dropdown(
|
308 |
label="Prompt Template",
|
309 |
elem_id="inference_prompt_template",
|
|
|
416 |
with gr.Column(elem_id="inference_output_group_container"):
|
417 |
with gr.Column(elem_id="inference_output_group"):
|
418 |
inference_output = gr.Textbox(
|
419 |
+
lines=inference_output_lines, label="Output", elem_id="inference_output")
|
420 |
inference_output.style(show_copy_button=True)
|
421 |
with gr.Accordion(
|
422 |
"Raw Output",
|
|
|
444 |
)
|
445 |
things_that_might_timeout.append(reload_selections_event)
|
446 |
|
447 |
+
prompt_template_change_event = prompt_template.change(
|
448 |
+
fn=handle_prompt_template_change,
|
449 |
+
inputs=[prompt_template, lora_model],
|
450 |
+
outputs=[
|
451 |
+
model_prompt_template_message,
|
452 |
+
variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
|
453 |
things_that_might_timeout.append(prompt_template_change_event)
|
454 |
|
455 |
+
lora_model_change_event = lora_model.change(
|
456 |
+
fn=handle_lora_model_change,
|
457 |
+
inputs=[lora_model, prompt_template],
|
458 |
+
outputs=[model_prompt_template_message, prompt_template])
|
459 |
+
things_that_might_timeout.append(lora_model_change_event)
|
460 |
+
|
461 |
generate_event = generate_btn.click(
|
462 |
fn=do_inference,
|
463 |
inputs=[
|
|
|
477 |
outputs=[inference_output, inference_raw_output],
|
478 |
api_name="inference"
|
479 |
)
|
480 |
+
stop_btn.click(
|
481 |
+
fn=handle_stop_generate,
|
482 |
+
inputs=None,
|
483 |
+
outputs=None,
|
484 |
+
cancels=[generate_event]
|
485 |
+
)
|
486 |
|
487 |
update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
|
488 |
variable_0, variable_1, variable_2, variable_3,
|
|
|
655 |
return function (...args) {
|
656 |
const context = this;
|
657 |
clearTimeout(timeout);
|
658 |
+
const fn = () => {
|
659 |
+
if (document.querySelector('#inference_preview_prompt > .wrap:not(.hide)')) {
|
660 |
+
// Preview request is still loading, wait for 10ms and try again.
|
661 |
+
timeout = setTimeout(fn, 10);
|
662 |
+
return;
|
663 |
+
}
|
664 |
func.apply(context, args);
|
665 |
+
};
|
666 |
+
timeout = setTimeout(fn, wait);
|
667 |
};
|
668 |
}
|
669 |
|
|
|
698 |
});
|
699 |
}
|
700 |
}, 100);
|
701 |
+
|
702 |
+
// [WORKAROUND-UI01]
|
703 |
+
setTimeout(function () {
|
704 |
+
const inference_output_textarea = document.querySelector(
|
705 |
+
'#inference_output textarea'
|
706 |
+
);
|
707 |
+
if (!inference_output_textarea) return;
|
708 |
+
const observer = new MutationObserver(function () {
|
709 |
+
if (inference_output_textarea.getAttribute('rows') === '1') {
|
710 |
+
setTimeout(function () {
|
711 |
+
const inference_generate_btn = document.getElementById(
|
712 |
+
'inference_generate_btn'
|
713 |
+
);
|
714 |
+
if (inference_generate_btn) inference_generate_btn.click();
|
715 |
+
}, 10);
|
716 |
+
}
|
717 |
+
});
|
718 |
+
observer.observe(inference_output_textarea, {
|
719 |
+
attributes: true,
|
720 |
+
attributeFilter: ['rows'],
|
721 |
+
});
|
722 |
+
}, 100);
|
723 |
}
|
724 |
""")
|
llama_lora/ui/main_page.py
CHANGED
@@ -30,7 +30,7 @@ def main_page():
|
|
30 |
tokenizer_ui()
|
31 |
info = []
|
32 |
if Global.version:
|
33 |
-
info.append(f"LLaMA-LoRA `{Global.version}`")
|
34 |
info.append(f"Base model: `{Global.base_model}`")
|
35 |
if Global.ui_show_sys_info:
|
36 |
info.append(f"Data dir: `{Global.data_dir}`")
|
@@ -134,6 +134,41 @@ def main_page_custom_css():
|
|
134 |
/* text-transform: uppercase; */
|
135 |
}
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
#inference_prompt_box > *:first-child {
|
138 |
border-bottom-left-radius: 0;
|
139 |
border-bottom-right-radius: 0;
|
@@ -266,12 +301,16 @@ def main_page_custom_css():
|
|
266 |
}
|
267 |
|
268 |
@media screen and (min-width: 640px) {
|
269 |
-
#inference_lora_model, #
|
|
|
270 |
border-top-right-radius: 0;
|
271 |
border-bottom-right-radius: 0;
|
272 |
border-right: 0;
|
273 |
margin-right: -16px;
|
274 |
}
|
|
|
|
|
|
|
275 |
|
276 |
#inference_prompt_template {
|
277 |
border-top-left-radius: 0;
|
@@ -301,7 +340,7 @@ def main_page_custom_css():
|
|
301 |
height: 42px !important;
|
302 |
min-width: 42px !important;
|
303 |
width: 42px !important;
|
304 |
-
z-index:
|
305 |
}
|
306 |
}
|
307 |
|
|
|
30 |
tokenizer_ui()
|
31 |
info = []
|
32 |
if Global.version:
|
33 |
+
info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
|
34 |
info.append(f"Base model: `{Global.base_model}`")
|
35 |
if Global.ui_show_sys_info:
|
36 |
info.append(f"Data dir: `{Global.data_dir}`")
|
|
|
134 |
/* text-transform: uppercase; */
|
135 |
}
|
136 |
|
137 |
+
#inference_lora_model_group {
|
138 |
+
border-radius: var(--block-radius);
|
139 |
+
background: var(--block-background-fill);
|
140 |
+
}
|
141 |
+
#inference_lora_model_group #inference_lora_model {
|
142 |
+
background: transparent;
|
143 |
+
}
|
144 |
+
#inference_lora_model_prompt_template_message:not(.hidden) + #inference_lora_model {
|
145 |
+
padding-bottom: 28px;
|
146 |
+
}
|
147 |
+
#inference_lora_model_group > #inference_lora_model_prompt_template_message {
|
148 |
+
position: absolute;
|
149 |
+
bottom: 8px;
|
150 |
+
left: 20px;
|
151 |
+
z-index: 1;
|
152 |
+
font-size: 12px;
|
153 |
+
opacity: 0.7;
|
154 |
+
}
|
155 |
+
#inference_lora_model_group > #inference_lora_model_prompt_template_message p {
|
156 |
+
font-size: 12px;
|
157 |
+
}
|
158 |
+
#inference_lora_model_prompt_template_message > .wrap {
|
159 |
+
display: none;
|
160 |
+
}
|
161 |
+
#inference_lora_model > .wrap:first-child:not(.hide),
|
162 |
+
#inference_prompt_template > .wrap:first-child:not(.hide) {
|
163 |
+
opacity: 0.5;
|
164 |
+
}
|
165 |
+
#inference_lora_model_group, #inference_lora_model {
|
166 |
+
z-index: 60;
|
167 |
+
}
|
168 |
+
#inference_prompt_template {
|
169 |
+
z-index: 55;
|
170 |
+
}
|
171 |
+
|
172 |
#inference_prompt_box > *:first-child {
|
173 |
border-bottom-left-radius: 0;
|
174 |
border-bottom-right-radius: 0;
|
|
|
301 |
}
|
302 |
|
303 |
@media screen and (min-width: 640px) {
|
304 |
+
#inference_lora_model, #inference_lora_model_group,
|
305 |
+
#finetune_template {
|
306 |
border-top-right-radius: 0;
|
307 |
border-bottom-right-radius: 0;
|
308 |
border-right: 0;
|
309 |
margin-right: -16px;
|
310 |
}
|
311 |
+
#inference_lora_model_group #inference_lora_model {
|
312 |
+
box-shadow: var(--block-shadow);
|
313 |
+
}
|
314 |
|
315 |
#inference_prompt_template {
|
316 |
border-top-left-radius: 0;
|
|
|
340 |
height: 42px !important;
|
341 |
min-width: 42px !important;
|
342 |
width: 42px !important;
|
343 |
+
z-index: 61;
|
344 |
}
|
345 |
}
|
346 |
|
llama_lora/utils/data.py
CHANGED
@@ -52,6 +52,22 @@ def get_path_of_available_lora_model(name):
|
|
52 |
return None
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def get_dataset_content(name):
|
56 |
file_name = os.path.join(Global.data_dir, "datasets", name)
|
57 |
if not os.path.exists(file_name):
|
|
|
52 |
return None
|
53 |
|
54 |
|
55 |
+
def get_info_of_available_lora_model(name):
|
56 |
+
try:
|
57 |
+
if "/" in name:
|
58 |
+
return None
|
59 |
+
path_of_available_lora_model = get_path_of_available_lora_model(
|
60 |
+
name)
|
61 |
+
if not path_of_available_lora_model:
|
62 |
+
return None
|
63 |
+
|
64 |
+
with open(os.path.join(path_of_available_lora_model, "info.json"), "r") as json_file:
|
65 |
+
return json.load(json_file)
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
return None
|
69 |
+
|
70 |
+
|
71 |
def get_dataset_content(name):
|
72 |
file_name = os.path.join(Global.data_dir, "datasets", name)
|
73 |
if not os.path.exists(file_name):
|
llama_lora/utils/lru_cache.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
|
4 |
+
class LRUCache:
|
5 |
+
def __init__(self, capacity=5):
|
6 |
+
self.cache = OrderedDict()
|
7 |
+
self.capacity = capacity
|
8 |
+
|
9 |
+
def get(self, key):
|
10 |
+
if key in self.cache:
|
11 |
+
# Move the accessed item to the end of the OrderedDict
|
12 |
+
self.cache.move_to_end(key)
|
13 |
+
return self.cache[key]
|
14 |
+
return None
|
15 |
+
|
16 |
+
def set(self, key, value):
|
17 |
+
if key in self.cache:
|
18 |
+
# If the key already exists, update its value
|
19 |
+
self.cache[key] = value
|
20 |
+
else:
|
21 |
+
# If the cache has reached its capacity, remove the least recently used item
|
22 |
+
if len(self.cache) >= self.capacity:
|
23 |
+
self.cache.popitem(last=False)
|
24 |
+
self.cache[key] = value
|
25 |
+
|
26 |
+
def clear(self):
|
27 |
+
self.cache.clear()
|
requirements.lock.txt
CHANGED
@@ -65,10 +65,10 @@ packaging==23.0
|
|
65 |
pandas==2.0.0
|
66 |
parso==0.8.3
|
67 |
pathspec==0.11.1
|
68 |
-
peft @ git+https://github.com/huggingface/peft.git@
|
69 |
pexpect==4.8.0
|
70 |
pickleshare==0.7.5
|
71 |
-
Pillow==9.
|
72 |
pkgutil_resolve_name==1.3.10
|
73 |
platformdirs==3.2.0
|
74 |
pluggy==1.0.0
|
|
|
65 |
pandas==2.0.0
|
66 |
parso==0.8.3
|
67 |
pathspec==0.11.1
|
68 |
+
peft @ git+https://github.com/huggingface/peft.git@382b178911edff38c1ff619bbac2ba556bd2276b
|
69 |
pexpect==4.8.0
|
70 |
pickleshare==0.7.5
|
71 |
+
Pillow==9.3.0
|
72 |
pkgutil_resolve_name==1.3.10
|
73 |
platformdirs==3.2.0
|
74 |
pluggy==1.0.0
|
templates/user_and_ai.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"description": "Unhelpful AI assistant.",
|
3 |
+
"variables": ["instruction"],
|
4 |
+
"prompt": "### User:\n{instruction}\n\n### AI:\n",
|
5 |
+
"default": "prompt",
|
6 |
+
"response_split": "### AI:"
|
7 |
+
}
|