omitakahiro commited on
Commit
1803d3a
1 Parent(s): 8f53238

Delete notebooks/QLoRA.ipynb

Browse files
Files changed (1) hide show
  1. notebooks/QLoRA.ipynb +0 -253
notebooks/QLoRA.ipynb DELETED
@@ -1,253 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "1a884871-7a65-4501-9063-c85ad260d0da",
6
- "metadata": {},
7
- "source": [
8
- "このnotebookはstockmark/stockmark-13bのモデルをkunishou/databricks-dolly-15k-jaのデータセットを用いてQLoRA tuningするためのコードの例です。ここでは主に、T4やV100などのGPUメモリの少ないGPUを用いることを想定しています。A100またはH100のGPUを用いている方は、本レポジトリのLoRA tuningのサンプルをお試しください。\n",
9
- "\n",
10
- "- モデル:https://huggingface.co/stockmark/stockmark-13b\n",
11
- "- データ:https://github.com/kunishou/databricks-dolly-15k-ja\n",
12
- "\n",
13
- "以下の例では、学習から推論までの一連の流れを効率よく回すため、学習を0.1 epochのみ行います。T4 GPUで実行すると30分ほどかかります。\n",
14
- "\n",
15
- "また、ここで用いられているハイパーパラメータは最適化されたものではありませんので、必要に応じて調整してください。"
16
- ]
17
- },
18
- {
19
- "cell_type": "markdown",
20
- "id": "93b3f4b5-2825-4ef3-a0ee-7a60155aee5d",
21
- "metadata": {},
22
- "source": [
23
- "# 準備\n",
24
- "\n",
25
- "Google Colaboratoryで実行される方は、以下のライブラリをインストールしてください。"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": null,
31
- "id": "7e05d465-4851-46fe-984d-942cab5e09a2",
32
- "metadata": {},
33
- "outputs": [],
34
- "source": [
35
- "!python3 -m pip install transformers accelerate datasets peft bitsandbytes"
36
- ]
37
- },
38
- {
39
- "cell_type": "code",
40
- "execution_count": null,
41
- "id": "6a694ba9-a0fa-4f14-81cf-f35f683ba889",
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "import torch\n",
46
- "import datasets\n",
47
- "from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, BitsAndBytesConfig\n",
48
- "from peft import get_peft_model, LoraConfig, PeftModel, PeftConfig, prepare_model_for_kbit_training\n",
49
- "\n",
50
- "model_name = \"stockmark/stockmark-13b\"\n",
51
- "peft_model_name = \"stockmark-13b-adapter\"\n",
52
- "\n",
53
- "prompt_template = \"\"\"### Instruction:\n",
54
- "{instruction}\n",
55
- "\n",
56
- "### Input:\n",
57
- "{input}\n",
58
- "\n",
59
- "### Response:\n",
60
- "\"\"\"\n",
61
- "\n",
62
- "def encode(sample):\n",
63
- " prompt = prompt_template.format(instruction=sample[\"instruction\"], input=sample[\"input\"])\n",
64
- " target = sample[\"output\"]\n",
65
- " input_ids_prompt, input_ids_target = tokenizer([prompt, target]).input_ids\n",
66
- " input_ids_target = input_ids_target + [ tokenizer.eos_token_id ]\n",
67
- " input_ids = input_ids_prompt + input_ids_target\n",
68
- " labels = input_ids.copy()\n",
69
- " labels[:len(input_ids_prompt)] = [-100] * len(input_ids_prompt) # ignore label tokens in a prompt for loss calculation\n",
70
- " return {\"input_ids\": input_ids, \"labels\": labels}\n",
71
- "\n",
72
- "def get_collator(tokenizer, max_length):\n",
73
- " def collator(batch):\n",
74
- " batch = [{ key: value[:max_length] for key, value in sample.items() } for sample in batch ]\n",
75
- " batch = tokenizer.pad(batch)\n",
76
- " batch[\"labels\"] = [ e + [-100] * (len(batch[\"input_ids\"][0]) - len(e)) for e in batch[\"labels\"] ]\n",
77
- " batch = { key: torch.tensor(value) for key, value in batch.items() }\n",
78
- " return batch\n",
79
- "\n",
80
- " return collator\n",
81
- "\n",
82
- "bnb_config = BitsAndBytesConfig(\n",
83
- " load_in_4bit=True,\n",
84
- " bnb_4bit_use_double_quant=True,\n",
85
- " bnb_4bit_quant_type=\"nf4\",\n",
86
- " bnb_4bit_compute_dtype=torch.float16\n",
87
- ")"
88
- ]
89
- },
90
- {
91
- "cell_type": "markdown",
92
- "id": "51e6cfcf-1ac1-400e-a4bc-ea64375d0f9e",
93
- "metadata": {},
94
- "source": [
95
- "# データセットとモデルのロード"
96
- ]
97
- },
98
- {
99
- "cell_type": "code",
100
- "execution_count": null,
101
- "id": "3ac80067-4e60-46c4-90da-05647cf96ccd",
102
- "metadata": {},
103
- "outputs": [],
104
- "source": [
105
- "# load_tokenizer\n",
106
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
107
- "\n",
108
- "# prepare dataset\n",
109
- "dataset_name = \"kunishou/databricks-dolly-15k-ja\"\n",
110
- "dataset = datasets.load_dataset(dataset_name)\n",
111
- "dataset = dataset.map(encode)\n",
112
- "dataset = dataset[\"train\"].train_test_split(0.1)\n",
113
- "train_dataset = dataset[\"train\"]\n",
114
- "eval_dataset = dataset[\"test\"]\n",
115
- "\n",
116
- "# load model\n",
117
- "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", quantization_config=bnb_config, torch_dtype=torch.float16)\n",
118
- "\n",
119
- "peft_config = LoraConfig(\n",
120
- " task_type=\"CAUSAL_LM\",\n",
121
- " inference_mode=False,\n",
122
- " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
123
- " r=16,\n",
124
- " lora_alpha=32,\n",
125
- " lora_dropout=0.05\n",
126
- ")\n",
127
- "\n",
128
- "model = get_peft_model(model, peft_config)\n",
129
- "model.print_trainable_parameters()"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "id": "9b471da0-7fba-4127-8b07-22da4cbee6a9",
135
- "metadata": {},
136
- "source": [
137
- "# QLoRA Tuning"
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": null,
143
- "id": "b9bafa12-538c-4abb-b8b3-bffeb0990b46",
144
- "metadata": {},
145
- "outputs": [],
146
- "source": [
147
- "training_args = TrainingArguments(\n",
148
- " output_dir=\"./log_stockmark_13b\",\n",
149
- " learning_rate=2e-4,\n",
150
- " per_device_train_batch_size=1,\n",
151
- " gradient_accumulation_steps=16,\n",
152
- " per_device_eval_batch_size=1,\n",
153
- " num_train_epochs=0.1,\n",
154
- " logging_strategy='steps',\n",
155
- " logging_steps=10,\n",
156
- " save_strategy='epoch',\n",
157
- " evaluation_strategy='epoch',\n",
158
- " load_best_model_at_end=True,\n",
159
- " metric_for_best_model=\"eval_loss\",\n",
160
- " greater_is_better=False,\n",
161
- " save_total_limit=2\n",
162
- ")\n",
163
- "\n",
164
- "trainer = Trainer(\n",
165
- " model=model,\n",
166
- " args=training_args,\n",
167
- " train_dataset=train_dataset,\n",
168
- " eval_dataset=eval_dataset,\n",
169
- " data_collator=get_collator(tokenizer, 256)\n",
170
- ")\n",
171
- "\n",
172
- "# QLoRA tuning\n",
173
- "trainer.train()\n",
174
- "\n",
175
- "# save model\n",
176
- "model = trainer.model\n",
177
- "model.save_pretrained(peft_model_name)"
178
- ]
179
- },
180
- {
181
- "cell_type": "markdown",
182
- "id": "a3f80a8e-1ac2-4bdc-8232-fe0ee18ffff5",
183
- "metadata": {},
184
- "source": [
185
- "# 学習したモデルのロード(Optional)\n",
186
- "異なるセッションでモデルを読み込む場合、まず最初の準備のセクションのコードを実行して、このコードを実行してください。"
187
- ]
188
- },
189
- {
190
- "cell_type": "code",
191
- "execution_count": null,
192
- "id": "43241395-3035-4cb9-8c1c-45ffe8cd48be",
193
- "metadata": {},
194
- "outputs": [],
195
- "source": [
196
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
197
- "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", quantization_config=bnb_config, torch_dtype=torch.float16)\n",
198
- "model = PeftModel.from_pretrained(model, peft_model_name)"
199
- ]
200
- },
201
- {
202
- "cell_type": "markdown",
203
- "id": "2ce4db1f-9bad-4c8e-9c04-d1102b299f24",
204
- "metadata": {},
205
- "source": [
206
- "# 推論"
207
- ]
208
- },
209
- {
210
- "cell_type": "code",
211
- "execution_count": null,
212
- "id": "d7d6359b-e0ac-49df-a178-39bb9f79ca93",
213
- "metadata": {},
214
- "outputs": [],
215
- "source": [
216
- "prompt = prompt_template.format(instruction=\"自然言語処理とは?\", input=\"\")\n",
217
- "\n",
218
- "inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
219
- "with torch.no_grad():\n",
220
- " tokens = model.generate(\n",
221
- " **inputs,\n",
222
- " max_new_tokens=128,\n",
223
- " do_sample=True,\n",
224
- " temperature=0.7\n",
225
- " )\n",
226
- "\n",
227
- "output = tokenizer.decode(tokens[0], skip_special_tokens=True)\n",
228
- "print(output)"
229
- ]
230
- }
231
- ],
232
- "metadata": {
233
- "kernelspec": {
234
- "display_name": "Python 3",
235
- "language": "python",
236
- "name": "python3"
237
- },
238
- "language_info": {
239
- "codemirror_mode": {
240
- "name": "ipython",
241
- "version": 3
242
- },
243
- "file_extension": ".py",
244
- "mimetype": "text/x-python",
245
- "name": "python",
246
- "nbconvert_exporter": "python",
247
- "pygments_lexer": "ipython3",
248
- "version": "3.8.10"
249
- }
250
- },
251
- "nbformat": 4,
252
- "nbformat_minor": 5
253
- }