Spaces:
Runtime error
Runtime error
Merge branch 'dev-2'
Browse files- LLaMA_LoRA.ipynb +3 -37
- README.md +94 -0
- llama_lora/globals.py +63 -0
- llama_lora/models.py +9 -5
- llama_lora/ui/finetune_ui.py +64 -10
- llama_lora/ui/inference_ui.py +24 -3
- requirements.txt +2 -0
LLaMA_LoRA.ipynb
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
"provenance": [],
|
7 |
"private_outputs": true,
|
8 |
"toc_visible": true,
|
9 |
-
"authorship_tag": "
|
10 |
"include_colab_link": true
|
11 |
},
|
12 |
"kernelspec": {
|
@@ -34,7 +34,8 @@
|
|
34 |
"cell_type": "markdown",
|
35 |
"source": [
|
36 |
"# π¦ποΈ LLaMA-LoRA\n",
|
37 |
-
"\n",
|
|
|
38 |
],
|
39 |
"metadata": {
|
40 |
"id": "bb4nzBvLfZUj"
|
@@ -309,41 +310,6 @@
|
|
309 |
},
|
310 |
"execution_count": null,
|
311 |
"outputs": []
|
312 |
-
},
|
313 |
-
{
|
314 |
-
"cell_type": "markdown",
|
315 |
-
"source": [
|
316 |
-
"# Reset"
|
317 |
-
],
|
318 |
-
"metadata": {
|
319 |
-
"id": "RW09SrCZpqpa"
|
320 |
-
}
|
321 |
-
},
|
322 |
-
{
|
323 |
-
"cell_type": "code",
|
324 |
-
"source": [
|
325 |
-
"# @title Kill Session { display-mode: \"form\" }\n",
|
326 |
-
"# @markdown If you ran out of runtime resources, you can **check the following \n",
|
327 |
-
"# @markdown checkbox and run this code cell to kill the runtime session** while\n",
|
328 |
-
"# @markdown preserving your downloaded data.\n",
|
329 |
-
"do_kill_session = False # @param {type:\"boolean\"}\n",
|
330 |
-
"# @markdown You will need to re-run this notebook from start after doing this.\n",
|
331 |
-
"#\n",
|
332 |
-
"# @markdown All data that are saved to disk, including Python dependencies, base\n",
|
333 |
-
"# @markdown models will all be preserved, so the second run will be much faster.\n",
|
334 |
-
"\n",
|
335 |
-
"import os\n",
|
336 |
-
"def kill_session():\n",
|
337 |
-
" os.kill(os.getpid(), 9)\n",
|
338 |
-
"\n",
|
339 |
-
"if do_kill_session:\n",
|
340 |
-
" kill_session()"
|
341 |
-
],
|
342 |
-
"metadata": {
|
343 |
-
"id": "bM4sY2tVps8U"
|
344 |
-
},
|
345 |
-
"execution_count": null,
|
346 |
-
"outputs": []
|
347 |
}
|
348 |
]
|
349 |
}
|
|
|
6 |
"provenance": [],
|
7 |
"private_outputs": true,
|
8 |
"toc_visible": true,
|
9 |
+
"authorship_tag": "ABX9TyMHMc4PwWLbRlhFol+WRzoT",
|
10 |
"include_colab_link": true
|
11 |
},
|
12 |
"kernelspec": {
|
|
|
34 |
"cell_type": "markdown",
|
35 |
"source": [
|
36 |
"# π¦ποΈ LLaMA-LoRA\n",
|
37 |
+
"\n",
|
38 |
+
"TL;DR: **Runtime > Run All** (`β/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
|
39 |
],
|
40 |
"metadata": {
|
41 |
"id": "bb4nzBvLfZUj"
|
|
|
310 |
},
|
311 |
"execution_count": null,
|
312 |
"outputs": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
}
|
314 |
]
|
315 |
}
|
README.md
CHANGED
@@ -15,6 +15,100 @@ Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) e
|
|
15 |
* Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
## Acknowledgements
|
19 |
|
20 |
* https://github.com/tloen/alpaca-lora
|
|
|
15 |
* Supports Stanford Alpaca [seed_tasks](https://github.com/tatsu-lab/stanford_alpaca/blob/main/seed_tasks.jsonl), [alpaca_data](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json) and [OpenAI "prompt"-"completion"](https://platform.openai.com/docs/guides/fine-tuning/data-formatting) format.
|
16 |
|
17 |
|
18 |
+
## How to Start
|
19 |
+
|
20 |
+
There are various ways to run this app:
|
21 |
+
|
22 |
+
* **[Run on Google Colab](#run-on-google-colab)**: The simplest way to get started, all you need is a Google account. Standard (free) GPU runtime is sufficient to run generation and training with micro batch size of 8. However, the text generation and training is much slower than on other cloud services, and Colab might terminate the execution in inactivity while running long tasks.
|
23 |
+
* **[Run on a cloud service via SkyPilot](#run-on-a-cloud-service-via-skypilot)**: If you have a cloud service (Lambda Labs, GCP, AWS, or Azure) account, you can use SkyPilot to run the app on a cloud service. A cloud bucket can be mounted to preserve your data.
|
24 |
+
* **[Run locally](#run-locally)**: Depends on the hardware you have.
|
25 |
+
|
26 |
+
### Run On Google Colab
|
27 |
+
|
28 |
+
Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`β/Ctrl+F9`).
|
29 |
+
|
30 |
+
You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
|
31 |
+
|
32 |
+
After approximately 5 minutes of running, you will see the public URL in the output of the "Launch"/"Start Gradio UI π" section (like `Running on public URL: https://xxxx.gradio.live`). Open the URL in your browser to use the app.
|
33 |
+
|
34 |
+
### Run on a cloud service via SkyPilot
|
35 |
+
|
36 |
+
After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
|
37 |
+
|
38 |
+
```yaml
|
39 |
+
# llama-lora-multitool.yaml
|
40 |
+
|
41 |
+
resources:
|
42 |
+
accelerators: A10:1 # 1x NVIDIA A10 GPU
|
43 |
+
cloud: lambda # Optional; if left out, SkyPilot will automatically pick the cheapest cloud.
|
44 |
+
|
45 |
+
file_mounts:
|
46 |
+
# Mount a presisted cloud storage that will be used as the data directory.
|
47 |
+
# (to store train datasets trained models)
|
48 |
+
# See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
|
49 |
+
/data:
|
50 |
+
name: llama-lora-multitool-data # Make sure this name is unique or you own this bucket. If it does not exists, SkyPilot will try to create a bucket with this name.
|
51 |
+
store: gcs # Could be either of [s3, gcs]
|
52 |
+
mode: MOUNT
|
53 |
+
|
54 |
+
# Clone the LLaMA-LoRA repo and install its dependencies.
|
55 |
+
setup: |
|
56 |
+
git clone https://github.com/zetavg/LLaMA-LoRA.git llama_lora
|
57 |
+
cd llama_lora && pip install -r requirements.txt
|
58 |
+
cd ..
|
59 |
+
echo 'Dependencies installed.'
|
60 |
+
|
61 |
+
# Start the app.
|
62 |
+
run: |
|
63 |
+
echo 'Starting...'
|
64 |
+
python llama_lora/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
|
65 |
+
```
|
66 |
+
|
67 |
+
Then launch a cluster to run the task:
|
68 |
+
|
69 |
+
```
|
70 |
+
sky launch -c llama-lora-multitool llama-lora-multitool.yaml
|
71 |
+
```
|
72 |
+
|
73 |
+
`-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
|
74 |
+
|
75 |
+
You will see the public URL of the app in the terminal. Open the URL in your browser to use the app.
|
76 |
+
|
77 |
+
Note that exiting `sky launch` will only exit log streaming and will not stop the task. You can use `sky queue --skip-finished` to see the status of running or pending tasks, `sky logs <cluster_name> <job_id>` connect back to log streaming, and `sky cancel <cluster_name> <job_id>` to stop a task.
|
78 |
+
|
79 |
+
When you are done, run `sky stop <cluster_name>` to stop the cluster. To terminate a cluster instead, run `sky down <cluster_name>`.
|
80 |
+
|
81 |
+
### Run locally
|
82 |
+
|
83 |
+
<details>
|
84 |
+
<summary>Prepare environment with conda</summary>
|
85 |
+
|
86 |
+
```bash
|
87 |
+
conda create -y -n llama-lora-multitool python=3.8
|
88 |
+
conda activate llama-lora-multitool
|
89 |
+
```
|
90 |
+
</details>
|
91 |
+
|
92 |
+
```bash
|
93 |
+
pip install -r requirements.txt
|
94 |
+
python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share
|
95 |
+
```
|
96 |
+
|
97 |
+
You will see the local and public URLs of the app in the terminal. Open the URL in your browser to use the app.
|
98 |
+
|
99 |
+
For more options, see `python app.py --help`.
|
100 |
+
|
101 |
+
<details>
|
102 |
+
<summary>UI development mode</summary>
|
103 |
+
|
104 |
+
To test the UI without loading the language model, use the `--ui_dev_mode` flag:
|
105 |
+
|
106 |
+
```bash
|
107 |
+
python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share --ui_dev_mode
|
108 |
+
```
|
109 |
+
</details>
|
110 |
+
|
111 |
+
|
112 |
## Acknowledgements
|
113 |
|
114 |
* https://github.com/tloen/alpaca-lora
|
llama_lora/globals.py
CHANGED
@@ -3,6 +3,9 @@ import subprocess
|
|
3 |
|
4 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
|
|
|
|
|
|
|
6 |
from .lib.finetune import train
|
7 |
|
8 |
|
@@ -25,6 +28,12 @@ class Global:
|
|
25 |
# Model related
|
26 |
model_has_been_used = False
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# UI related
|
29 |
ui_title: str = "LLaMA-LoRA"
|
30 |
ui_emoji: str = "π¦ποΈ"
|
@@ -60,3 +69,57 @@ commit_hash = get_git_commit_hash()
|
|
60 |
|
61 |
if commit_hash:
|
62 |
Global.version = commit_hash[:8]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
|
6 |
+
from numba import cuda
|
7 |
+
import nvidia_smi
|
8 |
+
|
9 |
from .lib.finetune import train
|
10 |
|
11 |
|
|
|
28 |
# Model related
|
29 |
model_has_been_used = False
|
30 |
|
31 |
+
# GPU Info
|
32 |
+
gpu_cc = None # GPU compute capability
|
33 |
+
gpu_sms = None # GPU total number of SMs
|
34 |
+
gpu_total_cores = None # GPU total cores
|
35 |
+
gpu_total_memory = None
|
36 |
+
|
37 |
# UI related
|
38 |
ui_title: str = "LLaMA-LoRA"
|
39 |
ui_emoji: str = "π¦ποΈ"
|
|
|
69 |
|
70 |
if commit_hash:
|
71 |
Global.version = commit_hash[:8]
|
72 |
+
|
73 |
+
|
74 |
+
def load_gpu_info():
|
75 |
+
try:
|
76 |
+
cc_cores_per_SM_dict = {
|
77 |
+
(2, 0): 32,
|
78 |
+
(2, 1): 48,
|
79 |
+
(3, 0): 192,
|
80 |
+
(3, 5): 192,
|
81 |
+
(3, 7): 192,
|
82 |
+
(5, 0): 128,
|
83 |
+
(5, 2): 128,
|
84 |
+
(6, 0): 64,
|
85 |
+
(6, 1): 128,
|
86 |
+
(7, 0): 64,
|
87 |
+
(7, 5): 64,
|
88 |
+
(8, 0): 64,
|
89 |
+
(8, 6): 128,
|
90 |
+
(8, 9): 128,
|
91 |
+
(9, 0): 128
|
92 |
+
}
|
93 |
+
# the above dictionary should result in a value of "None" if a cc match
|
94 |
+
# is not found. The dictionary needs to be extended as new devices become
|
95 |
+
# available, and currently does not account for all Jetson devices
|
96 |
+
device = cuda.get_current_device()
|
97 |
+
device_sms = getattr(device, 'MULTIPROCESSOR_COUNT')
|
98 |
+
device_cc = device.compute_capability
|
99 |
+
cores_per_sm = cc_cores_per_SM_dict.get(device_cc)
|
100 |
+
total_cores = cores_per_sm*device_sms
|
101 |
+
print("GPU compute capability: ", device_cc)
|
102 |
+
print("GPU total number of SMs: ", device_sms)
|
103 |
+
print("GPU total cores: ", total_cores)
|
104 |
+
Global.gpu_cc = device_cc
|
105 |
+
Global.gpu_sms = device_sms
|
106 |
+
Global.gpu_total_cores = total_cores
|
107 |
+
|
108 |
+
nvidia_smi.nvmlInit()
|
109 |
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
|
110 |
+
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
|
111 |
+
total_memory = info.total
|
112 |
+
|
113 |
+
total_memory_mb = total_memory / (1024 ** 2)
|
114 |
+
total_memory_gb = total_memory / (1024 ** 3)
|
115 |
+
|
116 |
+
# Print the memory size
|
117 |
+
print(
|
118 |
+
f"GPU total memory: {total_memory} bytes ({total_memory_mb:.2f} MB) ({total_memory_gb:.2f} GB)")
|
119 |
+
Global.gpu_total_memory = total_memory
|
120 |
+
|
121 |
+
except Exception as e:
|
122 |
+
print(f"Notice: cannot get GPU info: {e}")
|
123 |
+
|
124 |
+
|
125 |
+
load_gpu_info()
|
llama_lora/models.py
CHANGED
@@ -102,6 +102,14 @@ def load_base_model():
|
|
102 |
)
|
103 |
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def unload_models():
|
106 |
del Global.loaded_base_model
|
107 |
Global.loaded_base_model = None
|
@@ -109,11 +117,7 @@ def unload_models():
|
|
109 |
del Global.loaded_tokenizer
|
110 |
Global.loaded_tokenizer = None
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
# if not shared.args.cpu: # will not be running on CPUs anyway
|
115 |
-
with torch.no_grad():
|
116 |
-
torch.cuda.empty_cache()
|
117 |
|
118 |
Global.model_has_been_used = False
|
119 |
|
|
|
102 |
)
|
103 |
|
104 |
|
105 |
+
def clear_cache():
|
106 |
+
gc.collect()
|
107 |
+
|
108 |
+
# if not shared.args.cpu: # will not be running on CPUs anyway
|
109 |
+
with torch.no_grad():
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
|
112 |
+
|
113 |
def unload_models():
|
114 |
del Global.loaded_base_model
|
115 |
Global.loaded_base_model = None
|
|
|
117 |
del Global.loaded_tokenizer
|
118 |
Global.loaded_tokenizer = None
|
119 |
|
120 |
+
clear_cache()
|
|
|
|
|
|
|
|
|
121 |
|
122 |
Global.model_has_been_used = False
|
123 |
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -9,7 +9,9 @@ from random_word import RandomWords
|
|
9 |
from transformers import TrainerCallback
|
10 |
|
11 |
from ..globals import Global
|
12 |
-
from ..models import
|
|
|
|
|
13 |
from ..utils.data import (
|
14 |
get_available_template_names,
|
15 |
get_available_dataset_names,
|
@@ -238,6 +240,12 @@ def parse_plain_text_input(
|
|
238 |
return result
|
239 |
|
240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
def do_train(
|
242 |
# Dataset
|
243 |
template,
|
@@ -258,9 +266,15 @@ def do_train(
|
|
258 |
lora_alpha,
|
259 |
lora_dropout,
|
260 |
model_name,
|
261 |
-
progress=gr.Progress(track_tqdm=
|
262 |
):
|
263 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
prompter = Prompter(template)
|
265 |
variable_names = prompter.get_variable_names()
|
266 |
|
@@ -312,7 +326,32 @@ def do_train(
|
|
312 |
'completion': d['output']}
|
313 |
for d in data]
|
314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
if Global.ui_dev_mode:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
message = f"""Currently in UI dev mode, not doing the actual training.
|
317 |
|
318 |
Train options: {json.dumps({
|
@@ -368,16 +407,21 @@ Train data (first 10):
|
|
368 |
|
369 |
training_callbacks = [UiTrainerCallback]
|
370 |
|
371 |
-
|
372 |
-
|
373 |
-
#
|
374 |
unload_models_if_already_used()
|
|
|
375 |
|
376 |
-
|
|
|
|
|
|
|
|
|
377 |
|
378 |
results = Global.train_fn(
|
379 |
-
|
380 |
-
|
381 |
os.path.join(Global.data_dir, "lora_models",
|
382 |
model_name), # output_dir
|
383 |
train_data,
|
@@ -398,7 +442,8 @@ Train data (first 10):
|
|
398 |
training_callbacks # callbacks
|
399 |
)
|
400 |
|
401 |
-
logs_str = "\n".join([json.dumps(log)
|
|
|
402 |
|
403 |
result_message = f"Training ended:\n{str(results)}\n\nLogs:\n{logs_str}"
|
404 |
print(result_message)
|
@@ -557,9 +602,18 @@ def finetune_ui():
|
|
557 |
)
|
558 |
|
559 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
with gr.Column():
|
561 |
micro_batch_size = gr.Slider(
|
562 |
-
minimum=1, maximum=100, step=1, value=
|
563 |
label="Micro Batch Size",
|
564 |
info="The number of examples in each mini-batch for gradient computation. A smaller micro_batch_size reduces memory usage but may increase training time."
|
565 |
)
|
|
|
9 |
from transformers import TrainerCallback
|
10 |
|
11 |
from ..globals import Global
|
12 |
+
from ..models import (
|
13 |
+
get_base_model, get_tokenizer,
|
14 |
+
clear_cache, unload_models_if_already_used)
|
15 |
from ..utils.data import (
|
16 |
get_available_template_names,
|
17 |
get_available_dataset_names,
|
|
|
240 |
return result
|
241 |
|
242 |
|
243 |
+
should_training_progress_track_tqdm = True
|
244 |
+
|
245 |
+
if Global.gpu_total_cores is not None and Global.gpu_total_cores > 2560:
|
246 |
+
should_training_progress_track_tqdm = False
|
247 |
+
|
248 |
+
|
249 |
def do_train(
|
250 |
# Dataset
|
251 |
template,
|
|
|
266 |
lora_alpha,
|
267 |
lora_dropout,
|
268 |
model_name,
|
269 |
+
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
270 |
):
|
271 |
try:
|
272 |
+
clear_cache()
|
273 |
+
# If model has been used in inference, we need to unload it first.
|
274 |
+
# Otherwise, we'll get a 'Function MmBackward0 returned an invalid
|
275 |
+
# gradient at index 1 - expected device meta but got cuda:0' error.
|
276 |
+
unload_models_if_already_used()
|
277 |
+
|
278 |
prompter = Prompter(template)
|
279 |
variable_names = prompter.get_variable_names()
|
280 |
|
|
|
326 |
'completion': d['output']}
|
327 |
for d in data]
|
328 |
|
329 |
+
def get_progress_text(epoch, epochs, last_loss):
|
330 |
+
progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
|
331 |
+
if last_loss is not None:
|
332 |
+
progress_detail += f", Loss: {last_loss:.4f}"
|
333 |
+
return f"Training... ({progress_detail})"
|
334 |
+
|
335 |
if Global.ui_dev_mode:
|
336 |
+
Global.should_stop_training = False
|
337 |
+
|
338 |
+
for i in range(300):
|
339 |
+
if (Global.should_stop_training):
|
340 |
+
return
|
341 |
+
epochs = 3
|
342 |
+
epoch = i / 100
|
343 |
+
last_loss = None
|
344 |
+
if (i > 20):
|
345 |
+
last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
|
346 |
+
|
347 |
+
progress(
|
348 |
+
(i, 300),
|
349 |
+
desc="(Simulate) " +
|
350 |
+
get_progress_text(epoch, epochs, last_loss)
|
351 |
+
)
|
352 |
+
|
353 |
+
time.sleep(0.1)
|
354 |
+
|
355 |
message = f"""Currently in UI dev mode, not doing the actual training.
|
356 |
|
357 |
Train options: {json.dumps({
|
|
|
407 |
|
408 |
training_callbacks = [UiTrainerCallback]
|
409 |
|
410 |
+
Global.should_stop_training = False
|
411 |
+
|
412 |
+
# Do this again right before training to make sure the model is not used in inference.
|
413 |
unload_models_if_already_used()
|
414 |
+
clear_cache()
|
415 |
|
416 |
+
base_model = get_base_model()
|
417 |
+
tokenizer = get_tokenizer()
|
418 |
+
|
419 |
+
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
420 |
+
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
421 |
|
422 |
results = Global.train_fn(
|
423 |
+
base_model, # base_model
|
424 |
+
tokenizer, # tokenizer
|
425 |
os.path.join(Global.data_dir, "lora_models",
|
426 |
model_name), # output_dir
|
427 |
train_data,
|
|
|
442 |
training_callbacks # callbacks
|
443 |
)
|
444 |
|
445 |
+
logs_str = "\n".join([json.dumps(log)
|
446 |
+
for log in log_history]) or "None"
|
447 |
|
448 |
result_message = f"Training ended:\n{str(results)}\n\nLogs:\n{logs_str}"
|
449 |
print(result_message)
|
|
|
602 |
)
|
603 |
|
604 |
with gr.Row():
|
605 |
+
micro_batch_size_default_value = 1
|
606 |
+
|
607 |
+
if Global.gpu_total_cores is not None and Global.gpu_total_memory is not None:
|
608 |
+
memory_per_core = Global.gpu_total_memory / Global.gpu_total_cores
|
609 |
+
if memory_per_core >= 6291456:
|
610 |
+
micro_batch_size_default_value = 8
|
611 |
+
elif memory_per_core >= 4000000: # ?
|
612 |
+
micro_batch_size_default_value = 4
|
613 |
+
|
614 |
with gr.Column():
|
615 |
micro_batch_size = gr.Slider(
|
616 |
+
minimum=1, maximum=100, step=1, value=micro_batch_size_default_value,
|
617 |
label="Micro Batch Size",
|
618 |
info="The number of examples in each mini-batch for gradient computation. A smaller micro_batch_size reduces memory usage but may increase training time."
|
619 |
)
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import time
|
|
|
3 |
|
4 |
import torch
|
5 |
import transformers
|
@@ -47,10 +48,30 @@ def do_inference(
|
|
47 |
lora_model_name = path_of_available_lora_model
|
48 |
|
49 |
if Global.ui_dev_mode:
|
50 |
-
message = f"
|
51 |
print(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
time.sleep(1)
|
53 |
-
yield message,
|
54 |
return
|
55 |
|
56 |
model = get_base_model()
|
@@ -224,7 +245,7 @@ def inference_ui():
|
|
224 |
preview_prompt = gr.Textbox(
|
225 |
show_label=False, interactive=False, elem_id="inference_preview_prompt")
|
226 |
update_prompt_preview_btn = gr.Button(
|
227 |
-
"β»", elem_id="inference_update_prompt_preview_btn"
|
228 |
update_prompt_preview_btn.style(size="sm")
|
229 |
|
230 |
# with gr.Column():
|
|
|
1 |
import gradio as gr
|
2 |
import time
|
3 |
+
import json
|
4 |
|
5 |
import torch
|
6 |
import transformers
|
|
|
48 |
lora_model_name = path_of_available_lora_model
|
49 |
|
50 |
if Global.ui_dev_mode:
|
51 |
+
message = f"Hi, Iβm currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {Global.base_model}\nLoRA model: {lora_model_name}\n\nThe following text is your prompt:\n\n{prompt}"
|
52 |
print(message)
|
53 |
+
|
54 |
+
if stream_output:
|
55 |
+
def word_generator(sentence):
|
56 |
+
lines = message.split('\n')
|
57 |
+
out = ""
|
58 |
+
for line in lines:
|
59 |
+
words = line.split(' ')
|
60 |
+
for i in range(len(words)):
|
61 |
+
if out:
|
62 |
+
out += ' '
|
63 |
+
out += words[i]
|
64 |
+
yield out
|
65 |
+
out += "\n"
|
66 |
+
yield out
|
67 |
+
|
68 |
+
for partial_sentence in word_generator(message):
|
69 |
+
yield partial_sentence, json.dumps(list(range(len(partial_sentence.split()))), indent=2)
|
70 |
+
time.sleep(0.05)
|
71 |
+
|
72 |
+
return
|
73 |
time.sleep(1)
|
74 |
+
yield message, json.dumps(list(range(len(message.split()))), indent=2)
|
75 |
return
|
76 |
|
77 |
model = get_base_model()
|
|
|
245 |
preview_prompt = gr.Textbox(
|
246 |
show_label=False, interactive=False, elem_id="inference_preview_prompt")
|
247 |
update_prompt_preview_btn = gr.Button(
|
248 |
+
"β»", elem_id="inference_update_prompt_preview_btn")
|
249 |
update_prompt_preview_btn.style(size="sm")
|
250 |
|
251 |
# with gr.Column():
|
requirements.txt
CHANGED
@@ -7,6 +7,8 @@ datasets
|
|
7 |
fire
|
8 |
git+https://github.com/huggingface/peft.git
|
9 |
git+https://github.com/huggingface/transformers.git
|
|
|
|
|
10 |
gradio
|
11 |
loralib
|
12 |
sentencepiece
|
|
|
7 |
fire
|
8 |
git+https://github.com/huggingface/peft.git
|
9 |
git+https://github.com/huggingface/transformers.git
|
10 |
+
numba
|
11 |
+
nvidia-ml-py3
|
12 |
gradio
|
13 |
loralib
|
14 |
sentencepiece
|