fearlessdots commited on
Commit
7c7feff
1 Parent(s): f35ba3f

Upload Merge_LoRA_And_Push_To_Hub.ipynb

Browse files
Files changed (1) hide show
  1. Merge_LoRA_And_Push_To_Hub.ipynb +341 -0
Merge_LoRA_And_Push_To_Hub.ipynb ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "57b52683-2ad6-4670-a36a-a5cd7d3ca00d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# Based on: <https://www.datacamp.com/tutorial/fine-tuning-llama-2>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "77796674-8a83-4ce1-b275-0f681591a647",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import os\n",
21
+ "import time\n",
22
+ "import torch\n",
23
+ "from datasets import load_dataset\n",
24
+ "from transformers import (\n",
25
+ " AutoModelForCausalLM,\n",
26
+ " AutoTokenizer,\n",
27
+ " BitsAndBytesConfig,\n",
28
+ " TrainingArguments,\n",
29
+ " pipeline,\n",
30
+ " logging,\n",
31
+ ")\n",
32
+ "from peft import LoraConfig\n",
33
+ "from trl import SFTTrainer"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "c7cb8b2e-6019-4872-8ba1-99242354b761",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "# Model from Hugging Face hub\n",
44
+ "base_model = \"failspy/Phi-3-mini-128k-instruct-abliterated-v3\"\n",
45
+ "\n",
46
+ "# New instruction dataset\n",
47
+ "instruct_dataset = \"NobodyExistsOnTheInternet/ToxicQAFinal\"\n",
48
+ "\n",
49
+ "# Fine-tuned model\n",
50
+ "new_model = \"Ophiuchus-mini-128k-v0.1\""
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "a1aefbfc-215e-41b8-b3fa-b0c5db62ebd0",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "dataset = load_dataset(instruct_dataset, split=\"train\")"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "id": "dcf420d2-5bd7-4049-ba06-3ba5ff90ddd2",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "compute_dtype = getattr(torch, \"float16\")\n",
71
+ "\n",
72
+ "quant_config = BitsAndBytesConfig(\n",
73
+ " load_in_4bit=True,\n",
74
+ " bnb_4bit_quant_type=\"fp4\",\n",
75
+ " bnb_4bit_compute_dtype=compute_dtype,\n",
76
+ " bnb_4bit_use_double_quant=False,\n",
77
+ ")"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "5a5ab3dc-68aa-41e7-b724-6c4e0544beca",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "model = AutoModelForCausalLM.from_pretrained(\n",
88
+ " base_model,\n",
89
+ " quantization_config=quant_config,\n",
90
+ " device_map={\"\": 0}\n",
91
+ ")\n",
92
+ "model.config.use_cache = False\n",
93
+ "model.config.pretraining_tp = 1"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "5db8846f-0af4-4f04-8fa6-273656de4397",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)\n",
104
+ "tokenizer.pad_token = tokenizer.eos_token\n",
105
+ "tokenizer.padding_side = \"right\""
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "e7ab69cb-1f99-46e3-a17b-e33565d11679",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "peft_params = LoraConfig(\n",
116
+ " lora_alpha=64,\n",
117
+ " lora_dropout=0.05,\n",
118
+ " r=128,\n",
119
+ " bias=\"none\",\n",
120
+ " task_type=\"CAUSAL_LM\",\n",
121
+ " target_modules=\"all-linear\"\n",
122
+ ")"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "id": "e6bf2e24-a15e-4568-b76d-43541c6bdeae",
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "training_params = TrainingArguments(\n",
133
+ " output_dir=\"./mnt/ft_results\", # change this accordingly\n",
134
+ " num_train_epochs=1,\n",
135
+ " per_device_train_batch_size=1,\n",
136
+ " gradient_accumulation_steps=4,\n",
137
+ " optim=\"adamw_bnb_8bit\",\n",
138
+ " save_steps=25,\n",
139
+ " logging_steps=25,\n",
140
+ " learning_rate=2e-4,\n",
141
+ " weight_decay=0.001,\n",
142
+ " fp16=False,\n",
143
+ " bf16=False,\n",
144
+ " max_grad_norm=0.3,\n",
145
+ " max_steps=-1,\n",
146
+ " warmup_ratio=0.03,\n",
147
+ " group_by_length=True,\n",
148
+ " lr_scheduler_type=\"constant\",\n",
149
+ " report_to=\"tensorboard\",\n",
150
+ ")"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "id": "c4e2a9a1-db0f-46ff-b477-aae422badada",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "def formatting_prompts_func(example):\n",
161
+ " output_texts = []\n",
162
+ " for conv in example['conversations']:\n",
163
+ " ## For Llama-3:\n",
164
+ " #text = f\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n{conv[0]['value']}<|eot_id|><|start_header_id|>user<|end_header_id|>\\n{conv[1]['value']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n{conv[2]['value']}<|eot_id|>\"\"\"\n",
165
+ " ## For WizardLM-2:\n",
166
+ " #text = f\"\"\"{conv[0]['value']} USER: {conv[1]['value']} ASSISTANT: {conv[2]['value']}</s>\"\"\"\n",
167
+ " ## For Phi-3:\n",
168
+ " #text = f\"\"\"<|system|>\\n{conv[0]['value']}<|end|>\\n<|user|>\\n{conv[1]['value']}<|end|>\\n<|assistant|>\\n{conv[2]['value']}<|end|>\"\"\"\n",
169
+ "\n",
170
+ " output_texts.append(text)\n",
171
+ " return output_texts"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "id": "1c9dc3f1-999e-4a16-a29a-1752c08306d3",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "trainer = SFTTrainer(\n",
182
+ " model=model,\n",
183
+ " train_dataset=dataset,\n",
184
+ " peft_config=peft_params,\n",
185
+ " max_seq_length=None,\n",
186
+ " tokenizer=tokenizer,\n",
187
+ " args=training_params,\n",
188
+ " packing=False,\n",
189
+ " formatting_func=formatting_prompts_func\n",
190
+ ")"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "id": "86d66c37-b963-42cb-afa1-f3999ae0216d",
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": [
200
+ "trainer.train()"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "49f3eff6-4795-4b44-b506-cd49ec068986",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "trainer.model.save_pretrained(new_model)"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "id": "519f6339-101a-4015-96b2-c1b54f8e1fa7",
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "trainer.tokenizer.save_pretrained(new_model)"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "id": "257095fb-f597-4a90-ac88-f44443d7af29",
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "def create_message_template(user_message):\n",
231
+ " ## For Llama-3:\n",
232
+ " #return f\"\"\"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\\n{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\"\"\"\n",
233
+ " ## For WizardLM-2:\n",
234
+ " #return f\"\"\"USER: {user_message} ASSISTANT:\"\"\"\n",
235
+ " ## For Phi-3:\n",
236
+ " #return f\"\"\"<|user|>\\n{user_message}<|end|>\\n<|assistant|>\\n\"\"\""
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "63cb7f11-0fd1-46b7-bca7-a3fb55aa3669",
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "prompt = \"Ask something here.\"\n",
247
+ "\n",
248
+ "messages = create_message_template(prompt)\n",
249
+ "\n",
250
+ "messages"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "id": "103d75b9-9ed5-4724-a1e0-2026cd7e08fd",
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "pipe = pipeline(task=\"text-generation\", model=model, tokenizer=tokenizer, max_length=4000)\n",
261
+ "result = pipe(messages)\n",
262
+ "print(result[0]['generated_text'])"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "id": "ea3dbc94-1e9a-4f37-b7a2-7378c15442a3",
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": [
272
+ "from huggingface_hub import login\n",
273
+ "from huggingface_hub import HfApi\n",
274
+ "\n",
275
+ "login()\n",
276
+ "api = HfApi()"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": null,
282
+ "id": "8b172584-478b-479e-ba95-e5071f6ffc40",
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "trainer.model.push_to_hub(\"fearlessdots/Ophiuchus-mini-128k-v0.1-LoRA\")"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "id": "b8a7901a-53cc-4c0b-90ba-2f19f1abe7ac",
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "def upload_files(path):\n",
297
+ " api.upload_file(\n",
298
+ " path_or_fileobj=path,\n",
299
+ " repo_id=\"fearlessdots/Ophiuchus-mini-128k-v0.1-LoRA\",\n",
300
+ " path_in_repo=f\"{path.split('/')[-1]}\",\n",
301
+ " repo_type=\"model\"\n",
302
+ " )"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "id": "ee851d0c-9968-43f7-9b5a-ed14b4dc0066",
309
+ "metadata": {},
310
+ "outputs": [],
311
+ "source": [
312
+ "# Upload files to LoRA repo\n",
313
+ "upload_files(\"/home/ubuntu/Llama-3-8B-Alpha-Centauri-v0.1/tokenizer_config.json\")\n",
314
+ "upload_files(\"/home/ubuntu/Llama-3-8B-Alpha-Centauri-v0.1/tokenizer.json\")\n",
315
+ "upload_files(\"/home/ubuntu/Llama-3-8B-Alpha-Centauri-v0.1/tokenizer.model\") # Only for models that contain this file. Llama-3 does not.\n",
316
+ "upload_files(\"/home/ubuntu/Llama-3-8B-Alpha-Centauri-v0.1/special_tokens_map.json\")"
317
+ ]
318
+ }
319
+ ],
320
+ "metadata": {
321
+ "kernelspec": {
322
+ "display_name": "Python 3 (ipykernel)",
323
+ "language": "python",
324
+ "name": "python3"
325
+ },
326
+ "language_info": {
327
+ "codemirror_mode": {
328
+ "name": "ipython",
329
+ "version": 3
330
+ },
331
+ "file_extension": ".py",
332
+ "mimetype": "text/x-python",
333
+ "name": "python",
334
+ "nbconvert_exporter": "python",
335
+ "pygments_lexer": "ipython3",
336
+ "version": "3.10.12"
337
+ }
338
+ },
339
+ "nbformat": 4,
340
+ "nbformat_minor": 5
341
+ }