Upload prompt_tune_phi3.ipynb with huggingface_hub
Browse files- prompt_tune_phi3.ipynb +122 -71
prompt_tune_phi3.ipynb
CHANGED
@@ -36,13 +36,13 @@
|
|
36 |
},
|
37 |
{
|
38 |
"cell_type": "code",
|
39 |
-
"execution_count":
|
40 |
"id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b",
|
41 |
"metadata": {},
|
42 |
"outputs": [],
|
43 |
"source": [
|
44 |
"from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
|
45 |
-
"from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
|
46 |
"import torch\n",
|
47 |
"from datasets import load_dataset\n",
|
48 |
"import os\n",
|
@@ -54,27 +54,42 @@
|
|
54 |
},
|
55 |
{
|
56 |
"cell_type": "code",
|
57 |
-
"execution_count":
|
58 |
"id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da",
|
59 |
"metadata": {},
|
60 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
"source": [
|
62 |
"notebook_login()"
|
63 |
]
|
64 |
},
|
65 |
{
|
66 |
"cell_type": "code",
|
67 |
-
"execution_count":
|
68 |
"id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2",
|
69 |
"metadata": {},
|
70 |
"outputs": [
|
71 |
{
|
72 |
"data": {
|
73 |
"text/plain": [
|
74 |
-
"CommitInfo(commit_url='https://huggingface.co/Granther/prompt-tuned-phi3/commit/
|
75 |
]
|
76 |
},
|
77 |
-
"execution_count":
|
78 |
"metadata": {},
|
79 |
"output_type": "execute_result"
|
80 |
}
|
@@ -90,7 +105,7 @@
|
|
90 |
},
|
91 |
{
|
92 |
"cell_type": "code",
|
93 |
-
"execution_count":
|
94 |
"id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
|
95 |
"metadata": {},
|
96 |
"outputs": [],
|
@@ -103,7 +118,7 @@
|
|
103 |
" peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n",
|
104 |
" task_type=TaskType.CAUSAL_LM, # config task\n",
|
105 |
" prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n",
|
106 |
-
" num_virtual_tokens=
|
107 |
" prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
|
108 |
" tokenizer_name_or_path=model_id\n",
|
109 |
")\n",
|
@@ -123,7 +138,7 @@
|
|
123 |
},
|
124 |
{
|
125 |
"cell_type": "code",
|
126 |
-
"execution_count":
|
127 |
"id": "6f677839-ef23-428a-bcfe-f596590804ca",
|
128 |
"metadata": {},
|
129 |
"outputs": [],
|
@@ -133,7 +148,7 @@
|
|
133 |
},
|
134 |
{
|
135 |
"cell_type": "code",
|
136 |
-
"execution_count":
|
137 |
"id": "c0c05613-7941-4959-ada9-49ed1093bec4",
|
138 |
"metadata": {},
|
139 |
"outputs": [
|
@@ -143,7 +158,7 @@
|
|
143 |
"['Unlabeled', 'complaint', 'no complaint']"
|
144 |
]
|
145 |
},
|
146 |
-
"execution_count":
|
147 |
"metadata": {},
|
148 |
"output_type": "execute_result"
|
149 |
}
|
@@ -155,24 +170,10 @@
|
|
155 |
},
|
156 |
{
|
157 |
"cell_type": "code",
|
158 |
-
"execution_count":
|
159 |
"id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
|
160 |
"metadata": {},
|
161 |
"outputs": [
|
162 |
-
{
|
163 |
-
"data": {
|
164 |
-
"application/vnd.jupyter.widget-view+json": {
|
165 |
-
"model_id": "11da1eb81527428a95c41816f5bf459f",
|
166 |
-
"version_major": 2,
|
167 |
-
"version_minor": 0
|
168 |
-
},
|
169 |
-
"text/plain": [
|
170 |
-
"Map (num_proc=10): 0%| | 0/3399 [00:00<?, ? examples/s]"
|
171 |
-
]
|
172 |
-
},
|
173 |
-
"metadata": {},
|
174 |
-
"output_type": "display_data"
|
175 |
-
},
|
176 |
{
|
177 |
"data": {
|
178 |
"text/plain": [
|
@@ -182,7 +183,7 @@
|
|
182 |
" 'text_label': 'no complaint'}"
|
183 |
]
|
184 |
},
|
185 |
-
"execution_count":
|
186 |
"metadata": {},
|
187 |
"output_type": "execute_result"
|
188 |
}
|
@@ -201,7 +202,7 @@
|
|
201 |
},
|
202 |
{
|
203 |
"cell_type": "code",
|
204 |
-
"execution_count":
|
205 |
"id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
|
206 |
"metadata": {},
|
207 |
"outputs": [
|
@@ -218,7 +219,7 @@
|
|
218 |
"[1, 853, 29880, 24025]"
|
219 |
]
|
220 |
},
|
221 |
-
"execution_count":
|
222 |
"metadata": {},
|
223 |
"output_type": "execute_result"
|
224 |
}
|
@@ -250,7 +251,7 @@
|
|
250 |
},
|
251 |
{
|
252 |
"cell_type": "code",
|
253 |
-
"execution_count":
|
254 |
"id": "03f05467-dce3-4e42-ab3b-c39ba620e164",
|
255 |
"metadata": {},
|
256 |
"outputs": [],
|
@@ -291,14 +292,14 @@
|
|
291 |
},
|
292 |
{
|
293 |
"cell_type": "code",
|
294 |
-
"execution_count":
|
295 |
"id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec",
|
296 |
"metadata": {},
|
297 |
"outputs": [
|
298 |
{
|
299 |
"data": {
|
300 |
"application/vnd.jupyter.widget-view+json": {
|
301 |
-
"model_id": "
|
302 |
"version_major": 2,
|
303 |
"version_minor": 0
|
304 |
},
|
@@ -312,7 +313,7 @@
|
|
312 |
{
|
313 |
"data": {
|
314 |
"application/vnd.jupyter.widget-view+json": {
|
315 |
-
"model_id": "
|
316 |
"version_major": 2,
|
317 |
"version_minor": 0
|
318 |
},
|
@@ -337,7 +338,7 @@
|
|
337 |
},
|
338 |
{
|
339 |
"cell_type": "code",
|
340 |
-
"execution_count":
|
341 |
"id": "40cea6bc-e898-4d86-a6bf-5afc3a647e07",
|
342 |
"metadata": {},
|
343 |
"outputs": [],
|
@@ -362,14 +363,21 @@
|
|
362 |
},
|
363 |
{
|
364 |
"cell_type": "code",
|
365 |
-
"execution_count":
|
366 |
"id": "a4c529e4-d8ae-42b2-a658-f76d183bb264",
|
367 |
"metadata": {},
|
368 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
{
|
370 |
"data": {
|
371 |
"application/vnd.jupyter.widget-view+json": {
|
372 |
-
"model_id": "
|
373 |
"version_major": 2,
|
374 |
"version_minor": 0
|
375 |
},
|
@@ -391,7 +399,7 @@
|
|
391 |
"name": "stdout",
|
392 |
"output_type": "stream",
|
393 |
"text": [
|
394 |
-
"trainable params:
|
395 |
"None\n"
|
396 |
]
|
397 |
}
|
@@ -406,7 +414,7 @@
|
|
406 |
},
|
407 |
{
|
408 |
"cell_type": "code",
|
409 |
-
"execution_count":
|
410 |
"id": "3289e4e3-9b9a-4256-921b-5df21d18344e",
|
411 |
"metadata": {},
|
412 |
"outputs": [],
|
@@ -421,7 +429,7 @@
|
|
421 |
},
|
422 |
{
|
423 |
"cell_type": "code",
|
424 |
-
"execution_count":
|
425 |
"id": "e7939d75-c6b9-47a8-b1a3-88f7c33ff121",
|
426 |
"metadata": {},
|
427 |
"outputs": [
|
@@ -429,8 +437,9 @@
|
|
429 |
"name": "stderr",
|
430 |
"output_type": "stream",
|
431 |
"text": [
|
432 |
-
"
|
433 |
-
"100%|ββββββββββ|
|
|
|
434 |
]
|
435 |
},
|
436 |
{
|
@@ -444,8 +453,8 @@
|
|
444 |
"name": "stderr",
|
445 |
"output_type": "stream",
|
446 |
"text": [
|
447 |
-
"100%|ββββββββββ| 7/7 [00:00<00:00,
|
448 |
-
"100%|ββββββββββ| 425/425 [00:
|
449 |
]
|
450 |
},
|
451 |
{
|
@@ -459,8 +468,8 @@
|
|
459 |
"name": "stderr",
|
460 |
"output_type": "stream",
|
461 |
"text": [
|
462 |
-
"100%|ββββββββββ| 7/7 [00:00<00:00,
|
463 |
-
"100%|ββββββββββ| 425/425 [00:
|
464 |
]
|
465 |
},
|
466 |
{
|
@@ -474,8 +483,8 @@
|
|
474 |
"name": "stderr",
|
475 |
"output_type": "stream",
|
476 |
"text": [
|
477 |
-
"100%|ββββββββββ| 7/7 [00:00<00:00,
|
478 |
-
"100%|ββββββββββ| 425/425 [00:
|
479 |
]
|
480 |
},
|
481 |
{
|
@@ -489,8 +498,8 @@
|
|
489 |
"name": "stderr",
|
490 |
"output_type": "stream",
|
491 |
"text": [
|
492 |
-
"100%|ββββββββββ| 7/7 [00:00<00:00,
|
493 |
-
"100%|ββββββββββ| 425/425 [00:
|
494 |
]
|
495 |
},
|
496 |
{
|
@@ -546,7 +555,7 @@
|
|
546 |
},
|
547 |
{
|
548 |
"cell_type": "code",
|
549 |
-
"execution_count":
|
550 |
"id": "806d36f8-499e-4af8-b717-68e5d849866d",
|
551 |
"metadata": {},
|
552 |
"outputs": [],
|
@@ -556,14 +565,14 @@
|
|
556 |
},
|
557 |
{
|
558 |
"cell_type": "code",
|
559 |
-
"execution_count":
|
560 |
-
"id": "
|
561 |
"metadata": {},
|
562 |
"outputs": [
|
563 |
{
|
564 |
"data": {
|
565 |
"application/vnd.jupyter.widget-view+json": {
|
566 |
-
"model_id": "
|
567 |
"version_major": 2,
|
568 |
"version_minor": 0
|
569 |
},
|
@@ -573,48 +582,90 @@
|
|
573 |
},
|
574 |
"metadata": {},
|
575 |
"output_type": "display_data"
|
576 |
-
},
|
577 |
-
{
|
578 |
-
"name": "stderr",
|
579 |
-
"output_type": "stream",
|
580 |
-
"text": [
|
581 |
-
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
582 |
-
]
|
583 |
}
|
584 |
],
|
585 |
"source": [
|
586 |
-
"from
|
587 |
-
"
|
588 |
-
"
|
|
|
|
|
|
|
|
|
|
|
589 |
]
|
590 |
},
|
591 |
{
|
592 |
"cell_type": "code",
|
593 |
-
"execution_count":
|
594 |
-
"id": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
"metadata": {},
|
596 |
"outputs": [
|
597 |
{
|
598 |
"name": "stderr",
|
599 |
"output_type": "stream",
|
600 |
"text": [
|
601 |
-
"
|
602 |
-
"
|
603 |
]
|
604 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
605 |
{
|
606 |
"data": {
|
607 |
"text/plain": [
|
608 |
-
"[{'generated_text':
|
609 |
]
|
610 |
},
|
611 |
-
"execution_count":
|
612 |
"metadata": {},
|
613 |
"output_type": "execute_result"
|
614 |
}
|
615 |
],
|
616 |
"source": [
|
617 |
-
"pipe(\"@
|
618 |
]
|
619 |
},
|
620 |
{
|
|
|
36 |
},
|
37 |
{
|
38 |
"cell_type": "code",
|
39 |
+
"execution_count": 3,
|
40 |
"id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b",
|
41 |
"metadata": {},
|
42 |
"outputs": [],
|
43 |
"source": [
|
44 |
"from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
|
45 |
+
"from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType, PeftConfig\n",
|
46 |
"import torch\n",
|
47 |
"from datasets import load_dataset\n",
|
48 |
"import os\n",
|
|
|
54 |
},
|
55 |
{
|
56 |
"cell_type": "code",
|
57 |
+
"execution_count": 17,
|
58 |
"id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da",
|
59 |
"metadata": {},
|
60 |
+
"outputs": [
|
61 |
+
{
|
62 |
+
"data": {
|
63 |
+
"application/vnd.jupyter.widget-view+json": {
|
64 |
+
"model_id": "7f03fcf3844743fcb41f8bfc9c6c9b70",
|
65 |
+
"version_major": 2,
|
66 |
+
"version_minor": 0
|
67 |
+
},
|
68 |
+
"text/plain": [
|
69 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svβ¦"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
"metadata": {},
|
73 |
+
"output_type": "display_data"
|
74 |
+
}
|
75 |
+
],
|
76 |
"source": [
|
77 |
"notebook_login()"
|
78 |
]
|
79 |
},
|
80 |
{
|
81 |
"cell_type": "code",
|
82 |
+
"execution_count": 3,
|
83 |
"id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2",
|
84 |
"metadata": {},
|
85 |
"outputs": [
|
86 |
{
|
87 |
"data": {
|
88 |
"text/plain": [
|
89 |
+
"CommitInfo(commit_url='https://huggingface.co/Granther/prompt-tuned-phi3/commit/ab5911db092a8e53ea24c33f170e8013a8b172aa', commit_message='Upload prompt_tune_phi3.ipynb with huggingface_hub', commit_description='', oid='ab5911db092a8e53ea24c33f170e8013a8b172aa', pr_url=None, pr_revision=None, pr_num=None)"
|
90 |
]
|
91 |
},
|
92 |
+
"execution_count": 3,
|
93 |
"metadata": {},
|
94 |
"output_type": "execute_result"
|
95 |
}
|
|
|
105 |
},
|
106 |
{
|
107 |
"cell_type": "code",
|
108 |
+
"execution_count": 4,
|
109 |
"id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
|
110 |
"metadata": {},
|
111 |
"outputs": [],
|
|
|
118 |
" peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n",
|
119 |
" task_type=TaskType.CAUSAL_LM, # config task\n",
|
120 |
" prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n",
|
121 |
+
" num_virtual_tokens=100, # x times the number of hidden transformer layers\n",
|
122 |
" prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
|
123 |
" tokenizer_name_or_path=model_id\n",
|
124 |
")\n",
|
|
|
138 |
},
|
139 |
{
|
140 |
"cell_type": "code",
|
141 |
+
"execution_count": 5,
|
142 |
"id": "6f677839-ef23-428a-bcfe-f596590804ca",
|
143 |
"metadata": {},
|
144 |
"outputs": [],
|
|
|
148 |
},
|
149 |
{
|
150 |
"cell_type": "code",
|
151 |
+
"execution_count": 11,
|
152 |
"id": "c0c05613-7941-4959-ada9-49ed1093bec4",
|
153 |
"metadata": {},
|
154 |
"outputs": [
|
|
|
158 |
"['Unlabeled', 'complaint', 'no complaint']"
|
159 |
]
|
160 |
},
|
161 |
+
"execution_count": 11,
|
162 |
"metadata": {},
|
163 |
"output_type": "execute_result"
|
164 |
}
|
|
|
170 |
},
|
171 |
{
|
172 |
"cell_type": "code",
|
173 |
+
"execution_count": 7,
|
174 |
"id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
|
175 |
"metadata": {},
|
176 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
{
|
178 |
"data": {
|
179 |
"text/plain": [
|
|
|
183 |
" 'text_label': 'no complaint'}"
|
184 |
]
|
185 |
},
|
186 |
+
"execution_count": 7,
|
187 |
"metadata": {},
|
188 |
"output_type": "execute_result"
|
189 |
}
|
|
|
202 |
},
|
203 |
{
|
204 |
"cell_type": "code",
|
205 |
+
"execution_count": 8,
|
206 |
"id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
|
207 |
"metadata": {},
|
208 |
"outputs": [
|
|
|
219 |
"[1, 853, 29880, 24025]"
|
220 |
]
|
221 |
},
|
222 |
+
"execution_count": 8,
|
223 |
"metadata": {},
|
224 |
"output_type": "execute_result"
|
225 |
}
|
|
|
251 |
},
|
252 |
{
|
253 |
"cell_type": "code",
|
254 |
+
"execution_count": 14,
|
255 |
"id": "03f05467-dce3-4e42-ab3b-c39ba620e164",
|
256 |
"metadata": {},
|
257 |
"outputs": [],
|
|
|
292 |
},
|
293 |
{
|
294 |
"cell_type": "code",
|
295 |
+
"execution_count": 15,
|
296 |
"id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec",
|
297 |
"metadata": {},
|
298 |
"outputs": [
|
299 |
{
|
300 |
"data": {
|
301 |
"application/vnd.jupyter.widget-view+json": {
|
302 |
+
"model_id": "5494bc1fbce24646b61e60e119ae1cb2",
|
303 |
"version_major": 2,
|
304 |
"version_minor": 0
|
305 |
},
|
|
|
313 |
{
|
314 |
"data": {
|
315 |
"application/vnd.jupyter.widget-view+json": {
|
316 |
+
"model_id": "857675d314254672964cafc522e3869f",
|
317 |
"version_major": 2,
|
318 |
"version_minor": 0
|
319 |
},
|
|
|
338 |
},
|
339 |
{
|
340 |
"cell_type": "code",
|
341 |
+
"execution_count": 16,
|
342 |
"id": "40cea6bc-e898-4d86-a6bf-5afc3a647e07",
|
343 |
"metadata": {},
|
344 |
"outputs": [],
|
|
|
363 |
},
|
364 |
{
|
365 |
"cell_type": "code",
|
366 |
+
"execution_count": 17,
|
367 |
"id": "a4c529e4-d8ae-42b2-a658-f76d183bb264",
|
368 |
"metadata": {},
|
369 |
"outputs": [
|
370 |
+
{
|
371 |
+
"name": "stderr",
|
372 |
+
"output_type": "stream",
|
373 |
+
"text": [
|
374 |
+
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
|
375 |
+
]
|
376 |
+
},
|
377 |
{
|
378 |
"data": {
|
379 |
"application/vnd.jupyter.widget-view+json": {
|
380 |
+
"model_id": "1d09f75f23894968a6acd482a53fc92b",
|
381 |
"version_major": 2,
|
382 |
"version_minor": 0
|
383 |
},
|
|
|
399 |
"name": "stdout",
|
400 |
"output_type": "stream",
|
401 |
"text": [
|
402 |
+
"trainable params: 307,200 || all params: 3,821,386,752 || trainable%: 0.0080\n",
|
403 |
"None\n"
|
404 |
]
|
405 |
}
|
|
|
414 |
},
|
415 |
{
|
416 |
"cell_type": "code",
|
417 |
+
"execution_count": 18,
|
418 |
"id": "3289e4e3-9b9a-4256-921b-5df21d18344e",
|
419 |
"metadata": {},
|
420 |
"outputs": [],
|
|
|
429 |
},
|
430 |
{
|
431 |
"cell_type": "code",
|
432 |
+
"execution_count": 19,
|
433 |
"id": "e7939d75-c6b9-47a8-b1a3-88f7c33ff121",
|
434 |
"metadata": {},
|
435 |
"outputs": [
|
|
|
437 |
"name": "stderr",
|
438 |
"output_type": "stream",
|
439 |
"text": [
|
440 |
+
" 0%| | 0/7 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
|
441 |
+
"100%|ββββββββββ| 7/7 [00:01<00:00, 5.36it/s]\n",
|
442 |
+
"100%|ββββββββββ| 425/425 [00:29<00:00, 14.23it/s]\n"
|
443 |
]
|
444 |
},
|
445 |
{
|
|
|
453 |
"name": "stderr",
|
454 |
"output_type": "stream",
|
455 |
"text": [
|
456 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 7.66it/s]\n",
|
457 |
+
"100%|ββββββββββ| 425/425 [00:29<00:00, 14.26it/s]\n"
|
458 |
]
|
459 |
},
|
460 |
{
|
|
|
468 |
"name": "stderr",
|
469 |
"output_type": "stream",
|
470 |
"text": [
|
471 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 7.76it/s]\n",
|
472 |
+
"100%|ββββββββββ| 425/425 [00:29<00:00, 14.25it/s]\n"
|
473 |
]
|
474 |
},
|
475 |
{
|
|
|
483 |
"name": "stderr",
|
484 |
"output_type": "stream",
|
485 |
"text": [
|
486 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 7.72it/s]\n",
|
487 |
+
"100%|ββββββββββ| 425/425 [00:29<00:00, 14.24it/s]\n"
|
488 |
]
|
489 |
},
|
490 |
{
|
|
|
498 |
"name": "stderr",
|
499 |
"output_type": "stream",
|
500 |
"text": [
|
501 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 7.77it/s]\n",
|
502 |
+
"100%|ββββββββββ| 425/425 [00:29<00:00, 14.18it/s]"
|
503 |
]
|
504 |
},
|
505 |
{
|
|
|
555 |
},
|
556 |
{
|
557 |
"cell_type": "code",
|
558 |
+
"execution_count": 20,
|
559 |
"id": "806d36f8-499e-4af8-b717-68e5d849866d",
|
560 |
"metadata": {},
|
561 |
"outputs": [],
|
|
|
565 |
},
|
566 |
{
|
567 |
"cell_type": "code",
|
568 |
+
"execution_count": 10,
|
569 |
+
"id": "cff41965-fa71-420b-80d8-ce597510f1d3",
|
570 |
"metadata": {},
|
571 |
"outputs": [
|
572 |
{
|
573 |
"data": {
|
574 |
"application/vnd.jupyter.widget-view+json": {
|
575 |
+
"model_id": "821777d6daa442c7a5779f3aff695739",
|
576 |
"version_major": 2,
|
577 |
"version_minor": 0
|
578 |
},
|
|
|
582 |
},
|
583 |
"metadata": {},
|
584 |
"output_type": "display_data"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
}
|
586 |
],
|
587 |
"source": [
|
588 |
+
"from peft import PeftModel, PeftConfig\n",
|
589 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer \n",
|
590 |
+
"\n",
|
591 |
+
"#tokenizer = AutoTokenizer.from_pretrained('model')\n",
|
592 |
+
"\n",
|
593 |
+
"config = PeftConfig.from_pretrained('model')\n",
|
594 |
+
"model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
|
595 |
+
"model = PeftModel.from_pretrained(model, 'model')"
|
596 |
]
|
597 |
},
|
598 |
{
|
599 |
"cell_type": "code",
|
600 |
+
"execution_count": 11,
|
601 |
+
"id": "d8a432c9-9ddb-4bb7-a7f0-c4cadd612535",
|
602 |
+
"metadata": {},
|
603 |
+
"outputs": [],
|
604 |
+
"source": [
|
605 |
+
"inputs = tokenizer(\n",
|
606 |
+
" f'{text_col} : {\"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\"} Label : ',\n",
|
607 |
+
" return_tensors=\"pt\",\n",
|
608 |
+
")"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"cell_type": "code",
|
613 |
+
"execution_count": 15,
|
614 |
+
"id": "66cfaab3-dc63-4a1e-ab4d-2a687695993d",
|
615 |
"metadata": {},
|
616 |
"outputs": [
|
617 |
{
|
618 |
"name": "stderr",
|
619 |
"output_type": "stream",
|
620 |
"text": [
|
621 |
+
"/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1249: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
|
622 |
+
" warnings.warn(\n"
|
623 |
]
|
624 |
},
|
625 |
+
{
|
626 |
+
"ename": "ValueError",
|
627 |
+
"evalue": "Input length of input_ids is 32, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.",
|
628 |
+
"output_type": "error",
|
629 |
+
"traceback": [
|
630 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
631 |
+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
632 |
+
"Cell \u001b[0;32mIn[15], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 4\u001b[0m inputs \u001b[38;5;241m=\u001b[39m {k: v\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m inputs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m----> 5\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mattention_mask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
|
633 |
+
"File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1493\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.generate\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1491\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_model\u001b[38;5;241m.\u001b[39mgenerate(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1492\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1493\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1494\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 1495\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_model\u001b[38;5;241m.\u001b[39mprepare_inputs_for_generation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_model_prepare_inputs_for_generation\n",
|
634 |
+
"File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
635 |
+
"File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1786\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1783\u001b[0m model_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpast_key_values\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m DynamicCache\u001b[38;5;241m.\u001b[39mfrom_legacy_cache(past)\n\u001b[1;32m 1784\u001b[0m use_dynamic_cache_by_default \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m-> 1786\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_generated_length\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_ids_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhas_default_max_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1788\u001b[0m \u001b[38;5;66;03m# 7. determine generation mode\u001b[39;00m\n\u001b[1;32m 1789\u001b[0m generation_mode \u001b[38;5;241m=\u001b[39m generation_config\u001b[38;5;241m.\u001b[39mget_generation_mode(assistant_model)\n",
|
636 |
+
"File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1257\u001b[0m, in \u001b[0;36mGenerationMixin._validate_generated_length\u001b[0;34m(self, generation_config, input_ids_length, has_default_max_length)\u001b[0m\n\u001b[1;32m 1255\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m input_ids_length \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m generation_config\u001b[38;5;241m.\u001b[39mmax_length:\n\u001b[1;32m 1256\u001b[0m input_ids_string \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdecoder_input_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m-> 1257\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1258\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput length of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minput_ids_string\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minput_ids_length\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but `max_length` is set to\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1259\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgeneration_config\u001b[38;5;241m.\u001b[39mmax_length\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This can lead to unexpected behavior. You should consider\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1260\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m increasing `max_length` or, better yet, setting `max_new_tokens`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1261\u001b[0m )\n\u001b[1;32m 1263\u001b[0m \u001b[38;5;66;03m# 2. Min length warnings due to unfeasible parameter combinations\u001b[39;00m\n\u001b[1;32m 1264\u001b[0m min_length_error_suffix \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Generation will stop at the defined maximum length. You should decrease the minimum length and/or \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mincrease the maximum length.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1267\u001b[0m )\n",
|
637 |
+
"\u001b[0;31mValueError\u001b[0m: Input length of input_ids is 32, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`."
|
638 |
+
]
|
639 |
+
}
|
640 |
+
],
|
641 |
+
"source": [
|
642 |
+
"model.to(device)\n",
|
643 |
+
"\n",
|
644 |
+
"with torch.no_grad():\n",
|
645 |
+
" inputs = {k: v.to(device) for k, v in inputs.items()}\n",
|
646 |
+
" out = model.generate(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"])#, max_new_tokens=10) #, eos_token_id=3)\n",
|
647 |
+
" #print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
|
648 |
+
]
|
649 |
+
},
|
650 |
+
{
|
651 |
+
"cell_type": "code",
|
652 |
+
"execution_count": 24,
|
653 |
+
"id": "26438301-3601-44f4-bbe4-3c573a1c28be",
|
654 |
+
"metadata": {},
|
655 |
+
"outputs": [
|
656 |
{
|
657 |
"data": {
|
658 |
"text/plain": [
|
659 |
+
"[{'generated_text': '@HMRCcustomers No this is my first job and I am not sure what to do. I have been told that I need to register with HMRC but I am not sure how to do this. Can you please help me?\\n\\n### response\\nTo register with HMRC for your first job, you need to complete a Self Assessment tax return if you are self-employed or have income to report. For employees, you may need to complete'}]"
|
660 |
]
|
661 |
},
|
662 |
+
"execution_count": 24,
|
663 |
"metadata": {},
|
664 |
"output_type": "execute_result"
|
665 |
}
|
666 |
],
|
667 |
"source": [
|
668 |
+
"pipe(\"@HMRCcustomers No this is my first job\")"
|
669 |
]
|
670 |
},
|
671 |
{
|