Riddhi Bhagwat
commited on
Commit
·
fd59c75
1
Parent(s):
3f8b25a
Add files via upload
Browse filesinital commit; moving files from old repo to organization
- README.md +2 -0
- dataset_training.ipynb +398 -0
- kto_quickstart.ipynb +590 -0
- trl_rlhf_data.py +97 -0
README.md
CHANGED
@@ -6,3 +6,5 @@
|
|
6 |
This code repository (or "repo") is designed to demonstrate the best GitHub has to offer with the least amount of noise.
|
7 |
|
8 |
The repo includes an `index.html` file (so it can render a web page), two GitHub Actions workflows, and a CSS stylesheet dependency.
|
|
|
|
|
|
6 |
This code repository (or "repo") is designed to demonstrate the best GitHub has to offer with the least amount of noise.
|
7 |
|
8 |
The repo includes an `index.html` file (so it can render a web page), two GitHub Actions workflows, and a CSS stylesheet dependency.
|
9 |
+
# Model-Improvement-Platform-With-RLHF
|
10 |
+
Platform being developed at MIT in collaboration with HuggingFace. Aimed at improving performance of existing Large Language Models through real time human feedback loop.
|
dataset_training.ipynb
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 43,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"#dependencies:\n",
|
10 |
+
"import pandas as pd\n",
|
11 |
+
"\n",
|
12 |
+
"import torch\n",
|
13 |
+
"from transformers import GPT2Tokenizer\n",
|
14 |
+
"\n",
|
15 |
+
"from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "code",
|
20 |
+
"execution_count": 44,
|
21 |
+
"metadata": {},
|
22 |
+
"outputs": [
|
23 |
+
{
|
24 |
+
"data": {
|
25 |
+
"application/vnd.jupyter.widget-view+json": {
|
26 |
+
"model_id": "b8a22b8d60c0417eafbf554832398287",
|
27 |
+
"version_major": 2,
|
28 |
+
"version_minor": 0
|
29 |
+
},
|
30 |
+
"text/plain": [
|
31 |
+
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
"metadata": {},
|
35 |
+
"output_type": "display_data"
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"data": {
|
39 |
+
"application/vnd.jupyter.widget-view+json": {
|
40 |
+
"model_id": "b83d2624c2b14986a8297821460225ab",
|
41 |
+
"version_major": 2,
|
42 |
+
"version_minor": 0
|
43 |
+
},
|
44 |
+
"text/plain": [
|
45 |
+
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
"metadata": {},
|
49 |
+
"output_type": "display_data"
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"data": {
|
53 |
+
"application/vnd.jupyter.widget-view+json": {
|
54 |
+
"model_id": "b4304c0f48cb472589b5e80d3a42cba2",
|
55 |
+
"version_major": 2,
|
56 |
+
"version_minor": 0
|
57 |
+
},
|
58 |
+
"text/plain": [
|
59 |
+
"Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
"metadata": {},
|
63 |
+
"output_type": "display_data"
|
64 |
+
}
|
65 |
+
],
|
66 |
+
"source": [
|
67 |
+
"#loading datasets:\n",
|
68 |
+
"from datasets import load_dataset\n",
|
69 |
+
"\n",
|
70 |
+
"ds = load_dataset(\"stanfordnlp/SHP\", split='train')"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": 45,
|
76 |
+
"metadata": {},
|
77 |
+
"outputs": [
|
78 |
+
{
|
79 |
+
"name": "stdout",
|
80 |
+
"output_type": "stream",
|
81 |
+
"text": [
|
82 |
+
"Index(['post_id', 'domain', 'upvote_ratio', 'history', 'c_root_id_A',\n",
|
83 |
+
" 'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B', 'score_A',\n",
|
84 |
+
" 'score_B', 'human_ref_A', 'human_ref_B', 'labels', 'seconds_difference',\n",
|
85 |
+
" 'score_ratio'],\n",
|
86 |
+
" dtype='object')\n"
|
87 |
+
]
|
88 |
+
}
|
89 |
+
],
|
90 |
+
"source": [
|
91 |
+
"df = ds.to_pandas()\n",
|
92 |
+
"print(df.columns)\n"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": 46,
|
98 |
+
"metadata": {},
|
99 |
+
"outputs": [
|
100 |
+
{
|
101 |
+
"data": {
|
102 |
+
"text/html": [
|
103 |
+
"<div>\n",
|
104 |
+
"<style scoped>\n",
|
105 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
106 |
+
" vertical-align: middle;\n",
|
107 |
+
" }\n",
|
108 |
+
"\n",
|
109 |
+
" .dataframe tbody tr th {\n",
|
110 |
+
" vertical-align: top;\n",
|
111 |
+
" }\n",
|
112 |
+
"\n",
|
113 |
+
" .dataframe thead th {\n",
|
114 |
+
" text-align: right;\n",
|
115 |
+
" }\n",
|
116 |
+
"</style>\n",
|
117 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
118 |
+
" <thead>\n",
|
119 |
+
" <tr style=\"text-align: right;\">\n",
|
120 |
+
" <th></th>\n",
|
121 |
+
" <th>upvote_ratio</th>\n",
|
122 |
+
" <th>history</th>\n",
|
123 |
+
" <th>score_A</th>\n",
|
124 |
+
" <th>score_B</th>\n",
|
125 |
+
" <th>human_ref_A</th>\n",
|
126 |
+
" <th>human_ref_B</th>\n",
|
127 |
+
" <th>labels</th>\n",
|
128 |
+
" <th>score_ratio</th>\n",
|
129 |
+
" </tr>\n",
|
130 |
+
" </thead>\n",
|
131 |
+
" <tbody>\n",
|
132 |
+
" <tr>\n",
|
133 |
+
" <th>0</th>\n",
|
134 |
+
" <td>0.99</td>\n",
|
135 |
+
" <td>In an interview right before receiving the 201...</td>\n",
|
136 |
+
" <td>52</td>\n",
|
137 |
+
" <td>54</td>\n",
|
138 |
+
" <td>Currently wrapping up my PhD. There is a stark...</td>\n",
|
139 |
+
" <td>It’s ironic to me that research has shown that...</td>\n",
|
140 |
+
" <td>0</td>\n",
|
141 |
+
" <td>1.038462</td>\n",
|
142 |
+
" </tr>\n",
|
143 |
+
" <tr>\n",
|
144 |
+
" <th>1</th>\n",
|
145 |
+
" <td>0.95</td>\n",
|
146 |
+
" <td>If any professor is reading this: please do no...</td>\n",
|
147 |
+
" <td>5</td>\n",
|
148 |
+
" <td>17</td>\n",
|
149 |
+
" <td>And when your teacher doesn't listen or pay at...</td>\n",
|
150 |
+
" <td>I'm pretty strict on time, to the point where ...</td>\n",
|
151 |
+
" <td>0</td>\n",
|
152 |
+
" <td>3.400000</td>\n",
|
153 |
+
" </tr>\n",
|
154 |
+
" <tr>\n",
|
155 |
+
" <th>2</th>\n",
|
156 |
+
" <td>0.95</td>\n",
|
157 |
+
" <td>If any professor is reading this: please do no...</td>\n",
|
158 |
+
" <td>5</td>\n",
|
159 |
+
" <td>7</td>\n",
|
160 |
+
" <td>Profs can be oblivious? What’s new!</td>\n",
|
161 |
+
" <td>This sounds like a problem with a specific pro...</td>\n",
|
162 |
+
" <td>0</td>\n",
|
163 |
+
" <td>1.400000</td>\n",
|
164 |
+
" </tr>\n",
|
165 |
+
" <tr>\n",
|
166 |
+
" <th>3</th>\n",
|
167 |
+
" <td>0.95</td>\n",
|
168 |
+
" <td>If any professor is reading this: please do no...</td>\n",
|
169 |
+
" <td>7</td>\n",
|
170 |
+
" <td>5</td>\n",
|
171 |
+
" <td>This sounds like a problem with a specific pro...</td>\n",
|
172 |
+
" <td>And when your teacher doesn't listen or pay at...</td>\n",
|
173 |
+
" <td>1</td>\n",
|
174 |
+
" <td>1.400000</td>\n",
|
175 |
+
" </tr>\n",
|
176 |
+
" <tr>\n",
|
177 |
+
" <th>4</th>\n",
|
178 |
+
" <td>0.95</td>\n",
|
179 |
+
" <td>If any professor is reading this: please do no...</td>\n",
|
180 |
+
" <td>6</td>\n",
|
181 |
+
" <td>7</td>\n",
|
182 |
+
" <td>This would be totally unacceptable in my class...</td>\n",
|
183 |
+
" <td>This sounds like a problem with a specific pro...</td>\n",
|
184 |
+
" <td>0</td>\n",
|
185 |
+
" <td>1.166667</td>\n",
|
186 |
+
" </tr>\n",
|
187 |
+
" <tr>\n",
|
188 |
+
" <th>...</th>\n",
|
189 |
+
" <td>...</td>\n",
|
190 |
+
" <td>...</td>\n",
|
191 |
+
" <td>...</td>\n",
|
192 |
+
" <td>...</td>\n",
|
193 |
+
" <td>...</td>\n",
|
194 |
+
" <td>...</td>\n",
|
195 |
+
" <td>...</td>\n",
|
196 |
+
" <td>...</td>\n",
|
197 |
+
" </tr>\n",
|
198 |
+
" <tr>\n",
|
199 |
+
" <th>348713</th>\n",
|
200 |
+
" <td>0.94</td>\n",
|
201 |
+
" <td>Can I get in trouble for giving my neighbor hi...</td>\n",
|
202 |
+
" <td>7</td>\n",
|
203 |
+
" <td>25</td>\n",
|
204 |
+
" <td>Just put up a fence. Legally he isn't responsi...</td>\n",
|
205 |
+
" <td>Whatever you do, don't cut his trees down.</td>\n",
|
206 |
+
" <td>0</td>\n",
|
207 |
+
" <td>3.571429</td>\n",
|
208 |
+
" </tr>\n",
|
209 |
+
" <tr>\n",
|
210 |
+
" <th>348714</th>\n",
|
211 |
+
" <td>0.94</td>\n",
|
212 |
+
" <td>Can I get in trouble for giving my neighbor hi...</td>\n",
|
213 |
+
" <td>2</td>\n",
|
214 |
+
" <td>25</td>\n",
|
215 |
+
" <td>If OP pays someone to clean his yard, and then...</td>\n",
|
216 |
+
" <td>Whatever you do, don't cut his trees down.</td>\n",
|
217 |
+
" <td>0</td>\n",
|
218 |
+
" <td>12.500000</td>\n",
|
219 |
+
" </tr>\n",
|
220 |
+
" <tr>\n",
|
221 |
+
" <th>348715</th>\n",
|
222 |
+
" <td>0.94</td>\n",
|
223 |
+
" <td>Can I get in trouble for giving my neighbor hi...</td>\n",
|
224 |
+
" <td>9</td>\n",
|
225 |
+
" <td>7</td>\n",
|
226 |
+
" <td>My observation is that both of you are idiots...</td>\n",
|
227 |
+
" <td>Are you Rand Paul's neighbor? https://www.gq....</td>\n",
|
228 |
+
" <td>1</td>\n",
|
229 |
+
" <td>1.285714</td>\n",
|
230 |
+
" </tr>\n",
|
231 |
+
" <tr>\n",
|
232 |
+
" <th>348716</th>\n",
|
233 |
+
" <td>0.94</td>\n",
|
234 |
+
" <td>Can I get in trouble for giving my neighbor hi...</td>\n",
|
235 |
+
" <td>9</td>\n",
|
236 |
+
" <td>7</td>\n",
|
237 |
+
" <td>My observation is that both of you are idiots...</td>\n",
|
238 |
+
" <td>Just put up a fence. Legally he isn't responsi...</td>\n",
|
239 |
+
" <td>1</td>\n",
|
240 |
+
" <td>1.285714</td>\n",
|
241 |
+
" </tr>\n",
|
242 |
+
" <tr>\n",
|
243 |
+
" <th>348717</th>\n",
|
244 |
+
" <td>0.94</td>\n",
|
245 |
+
" <td>Can I get in trouble for giving my neighbor hi...</td>\n",
|
246 |
+
" <td>7</td>\n",
|
247 |
+
" <td>2</td>\n",
|
248 |
+
" <td>Capture his acts on camera. Collect and bag l...</td>\n",
|
249 |
+
" <td>If OP pays someone to clean his yard, and then...</td>\n",
|
250 |
+
" <td>1</td>\n",
|
251 |
+
" <td>3.500000</td>\n",
|
252 |
+
" </tr>\n",
|
253 |
+
" </tbody>\n",
|
254 |
+
"</table>\n",
|
255 |
+
"<p>348718 rows × 8 columns</p>\n",
|
256 |
+
"</div>"
|
257 |
+
],
|
258 |
+
"text/plain": [
|
259 |
+
" upvote_ratio history \\\n",
|
260 |
+
"0 0.99 In an interview right before receiving the 201... \n",
|
261 |
+
"1 0.95 If any professor is reading this: please do no... \n",
|
262 |
+
"2 0.95 If any professor is reading this: please do no... \n",
|
263 |
+
"3 0.95 If any professor is reading this: please do no... \n",
|
264 |
+
"4 0.95 If any professor is reading this: please do no... \n",
|
265 |
+
"... ... ... \n",
|
266 |
+
"348713 0.94 Can I get in trouble for giving my neighbor hi... \n",
|
267 |
+
"348714 0.94 Can I get in trouble for giving my neighbor hi... \n",
|
268 |
+
"348715 0.94 Can I get in trouble for giving my neighbor hi... \n",
|
269 |
+
"348716 0.94 Can I get in trouble for giving my neighbor hi... \n",
|
270 |
+
"348717 0.94 Can I get in trouble for giving my neighbor hi... \n",
|
271 |
+
"\n",
|
272 |
+
" score_A score_B human_ref_A \\\n",
|
273 |
+
"0 52 54 Currently wrapping up my PhD. There is a stark... \n",
|
274 |
+
"1 5 17 And when your teacher doesn't listen or pay at... \n",
|
275 |
+
"2 5 7 Profs can be oblivious? What’s new! \n",
|
276 |
+
"3 7 5 This sounds like a problem with a specific pro... \n",
|
277 |
+
"4 6 7 This would be totally unacceptable in my class... \n",
|
278 |
+
"... ... ... ... \n",
|
279 |
+
"348713 7 25 Just put up a fence. Legally he isn't responsi... \n",
|
280 |
+
"348714 2 25 If OP pays someone to clean his yard, and then... \n",
|
281 |
+
"348715 9 7 My observation is that both of you are idiots... \n",
|
282 |
+
"348716 9 7 My observation is that both of you are idiots... \n",
|
283 |
+
"348717 7 2 Capture his acts on camera. Collect and bag l... \n",
|
284 |
+
"\n",
|
285 |
+
" human_ref_B labels score_ratio \n",
|
286 |
+
"0 It’s ironic to me that research has shown that... 0 1.038462 \n",
|
287 |
+
"1 I'm pretty strict on time, to the point where ... 0 3.400000 \n",
|
288 |
+
"2 This sounds like a problem with a specific pro... 0 1.400000 \n",
|
289 |
+
"3 And when your teacher doesn't listen or pay at... 1 1.400000 \n",
|
290 |
+
"4 This sounds like a problem with a specific pro... 0 1.166667 \n",
|
291 |
+
"... ... ... ... \n",
|
292 |
+
"348713 Whatever you do, don't cut his trees down. 0 3.571429 \n",
|
293 |
+
"348714 Whatever you do, don't cut his trees down. 0 12.500000 \n",
|
294 |
+
"348715 Are you Rand Paul's neighbor? https://www.gq.... 1 1.285714 \n",
|
295 |
+
"348716 Just put up a fence. Legally he isn't responsi... 1 1.285714 \n",
|
296 |
+
"348717 If OP pays someone to clean his yard, and then... 1 3.500000 \n",
|
297 |
+
"\n",
|
298 |
+
"[348718 rows x 8 columns]"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
"execution_count": 46,
|
302 |
+
"metadata": {},
|
303 |
+
"output_type": "execute_result"
|
304 |
+
}
|
305 |
+
],
|
306 |
+
"source": [
|
307 |
+
"# df['response_length'] = df['history'].apply(len)\n",
|
308 |
+
"# df['label'] = df['response_length'].apply(lambda x: 'long' if x > 100 else 'short')\n",
|
309 |
+
"df.drop(columns=['post_id', 'domain', 'c_root_id_A', 'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B', 'seconds_difference'])"
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "code",
|
314 |
+
"execution_count": 47,
|
315 |
+
"metadata": {},
|
316 |
+
"outputs": [
|
317 |
+
{
|
318 |
+
"name": "stderr",
|
319 |
+
"output_type": "stream",
|
320 |
+
"text": [
|
321 |
+
"/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
|
322 |
+
" warnings.warn(\n"
|
323 |
+
]
|
324 |
+
}
|
325 |
+
],
|
326 |
+
"source": [
|
327 |
+
"model = AutoModelForCausalLMWithValueHead.from_pretrained(\"gpt2\")\n",
|
328 |
+
"ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(\"gpt2\")\n",
|
329 |
+
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
|
330 |
+
"tokenizer.pad_token = tokenizer.eos_token"
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "code",
|
335 |
+
"execution_count": 48,
|
336 |
+
"metadata": {},
|
337 |
+
"outputs": [],
|
338 |
+
"source": [
|
339 |
+
"from trl_rlhf_data import runner, ScriptArguments\n",
|
340 |
+
"import re\n",
|
341 |
+
"from dataclasses import dataclass\n",
|
342 |
+
"from typing import Dict, List, Optional\n",
|
343 |
+
"\n",
|
344 |
+
"from datasets import load_dataset\n",
|
345 |
+
"from transformers import HfArgumentParser"
|
346 |
+
]
|
347 |
+
},
|
348 |
+
{
|
349 |
+
"cell_type": "code",
|
350 |
+
"execution_count": 49,
|
351 |
+
"metadata": {},
|
352 |
+
"outputs": [
|
353 |
+
{
|
354 |
+
"ename": "TypeError",
|
355 |
+
"evalue": "runner() takes 0 positional arguments but 1 was given",
|
356 |
+
"output_type": "error",
|
357 |
+
"traceback": [
|
358 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
359 |
+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
360 |
+
"Cell \u001b[0;32mIn[49], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mrunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mScriptArguments\u001b[49m\u001b[43m)\u001b[49m\n",
|
361 |
+
"\u001b[0;31mTypeError\u001b[0m: runner() takes 0 positional arguments but 1 was given"
|
362 |
+
]
|
363 |
+
}
|
364 |
+
],
|
365 |
+
"source": [
|
366 |
+
"dataset = runner(ScriptArguments)"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"cell_type": "code",
|
371 |
+
"execution_count": null,
|
372 |
+
"metadata": {},
|
373 |
+
"outputs": [],
|
374 |
+
"source": []
|
375 |
+
}
|
376 |
+
],
|
377 |
+
"metadata": {
|
378 |
+
"kernelspec": {
|
379 |
+
"display_name": "Python 3",
|
380 |
+
"language": "python",
|
381 |
+
"name": "python3"
|
382 |
+
},
|
383 |
+
"language_info": {
|
384 |
+
"codemirror_mode": {
|
385 |
+
"name": "ipython",
|
386 |
+
"version": 3
|
387 |
+
},
|
388 |
+
"file_extension": ".py",
|
389 |
+
"mimetype": "text/x-python",
|
390 |
+
"name": "python",
|
391 |
+
"nbconvert_exporter": "python",
|
392 |
+
"pygments_lexer": "ipython3",
|
393 |
+
"version": "3.10.13"
|
394 |
+
}
|
395 |
+
},
|
396 |
+
"nbformat": 4,
|
397 |
+
"nbformat_minor": 2
|
398 |
+
}
|
kto_quickstart.ipynb
ADDED
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# .KTO Example"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 2,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"from dataclasses import dataclass\n",
|
24 |
+
"\n",
|
25 |
+
"from accelerate import PartialState\n",
|
26 |
+
"from datasets import load_dataset\n",
|
27 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser\n",
|
28 |
+
"\n",
|
29 |
+
"from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 3,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"# Define and parse arguments.\n",
|
39 |
+
"@dataclass\n",
|
40 |
+
"class ScriptArguments:\n",
|
41 |
+
" \"\"\"\n",
|
42 |
+
" The arguments for the KTO training script.\n",
|
43 |
+
" \"\"\"\n",
|
44 |
+
"\n",
|
45 |
+
" dataset_name: str = \"trl-lib/kto-mix-14k\"\n",
|
46 |
+
"\n",
|
47 |
+
"\n",
|
48 |
+
"# Initialize the arguments directly\n",
|
49 |
+
"script_args = ScriptArguments(\n",
|
50 |
+
" dataset_name=\"trl-lib/kto-mix-14k\"\n",
|
51 |
+
")\n",
|
52 |
+
"\n",
|
53 |
+
"training_args = KTOConfig(\n",
|
54 |
+
" output_dir=\"kto-aligned-model\",\n",
|
55 |
+
" num_train_epochs=1,\n",
|
56 |
+
" per_device_train_batch_size=16,\n",
|
57 |
+
" learning_rate=5e-7,\n",
|
58 |
+
" lr_scheduler_type=\"cosine\",\n",
|
59 |
+
" gradient_accumulation_steps=1,\n",
|
60 |
+
" logging_steps=10,\n",
|
61 |
+
" eval_steps=500,\n",
|
62 |
+
" warmup_ratio=0.1,\n",
|
63 |
+
" bf16=True,\n",
|
64 |
+
" logging_first_step=True\n",
|
65 |
+
")\n",
|
66 |
+
"\n",
|
67 |
+
"model_args = ModelConfig(\n",
|
68 |
+
" model_name_or_path=\"trl-lib/qwen1.5-1.8b-sft\",\n",
|
69 |
+
" # any additional model-specific arguments\n",
|
70 |
+
")"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "markdown",
|
75 |
+
"metadata": {},
|
76 |
+
"source": [
|
77 |
+
"- @dataclass makes it easier to create classes that only contain data, making your argument definitions compact, easier to read, and automatically initialized without the need to write a custom __init__ method.\n",
|
78 |
+
"- @dataclass is used here to define a structure for the arguments that you are going to pass to the training script:\n",
|
79 |
+
"- You define a simple data structure (ScriptArguments) with a list of variables (e.g., dataset_name).\n",
|
80 |
+
"- You can quickly create instances of this structure (script_args = ScriptArguments(...)) without manually writing the initializer.\n"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 4,
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [
|
88 |
+
{
|
89 |
+
"data": {
|
90 |
+
"application/vnd.jupyter.widget-view+json": {
|
91 |
+
"model_id": "194616275edb45c5a41065cd24d32510",
|
92 |
+
"version_major": 2,
|
93 |
+
"version_minor": 0
|
94 |
+
},
|
95 |
+
"text/plain": [
|
96 |
+
"config.json: 0%| | 0.00/702 [00:00<?, ?B/s]"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
"metadata": {},
|
100 |
+
"output_type": "display_data"
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"data": {
|
104 |
+
"application/vnd.jupyter.widget-view+json": {
|
105 |
+
"model_id": "487b6524aee9484ea889b896dae886d9",
|
106 |
+
"version_major": 2,
|
107 |
+
"version_minor": 0
|
108 |
+
},
|
109 |
+
"text/plain": [
|
110 |
+
"model.safetensors: 0%| | 0.00/3.67G [00:00<?, ?B/s]"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
"metadata": {},
|
114 |
+
"output_type": "display_data"
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"data": {
|
118 |
+
"application/vnd.jupyter.widget-view+json": {
|
119 |
+
"model_id": "dc564aa7d2704c7baca796e3a4bd6bd5",
|
120 |
+
"version_major": 2,
|
121 |
+
"version_minor": 0
|
122 |
+
},
|
123 |
+
"text/plain": [
|
124 |
+
"generation_config.json: 0%| | 0.00/117 [00:00<?, ?B/s]"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
"metadata": {},
|
128 |
+
"output_type": "display_data"
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"data": {
|
132 |
+
"application/vnd.jupyter.widget-view+json": {
|
133 |
+
"model_id": "2fd83e4e335e4b1fa014c7bb71990d3b",
|
134 |
+
"version_major": 2,
|
135 |
+
"version_minor": 0
|
136 |
+
},
|
137 |
+
"text/plain": [
|
138 |
+
"tokenizer_config.json: 0%| | 0.00/1.17k [00:00<?, ?B/s]"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
"metadata": {},
|
142 |
+
"output_type": "display_data"
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"data": {
|
146 |
+
"application/vnd.jupyter.widget-view+json": {
|
147 |
+
"model_id": "cb5d6cc62c5b4a79a2e72d68d003fac3",
|
148 |
+
"version_major": 2,
|
149 |
+
"version_minor": 0
|
150 |
+
},
|
151 |
+
"text/plain": [
|
152 |
+
"vocab.json: 0%| | 0.00/2.78M [00:00<?, ?B/s]"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
"metadata": {},
|
156 |
+
"output_type": "display_data"
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"data": {
|
160 |
+
"application/vnd.jupyter.widget-view+json": {
|
161 |
+
"model_id": "59bd030296c44f9eb74d110e91bebdbe",
|
162 |
+
"version_major": 2,
|
163 |
+
"version_minor": 0
|
164 |
+
},
|
165 |
+
"text/plain": [
|
166 |
+
"merges.txt: 0%| | 0.00/1.67M [00:00<?, ?B/s]"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
"metadata": {},
|
170 |
+
"output_type": "display_data"
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"data": {
|
174 |
+
"application/vnd.jupyter.widget-view+json": {
|
175 |
+
"model_id": "1e6be00a1a8740d08016c438bfc3c9ea",
|
176 |
+
"version_major": 2,
|
177 |
+
"version_minor": 0
|
178 |
+
},
|
179 |
+
"text/plain": [
|
180 |
+
"tokenizer.json: 0%| | 0.00/7.03M [00:00<?, ?B/s]"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
"metadata": {},
|
184 |
+
"output_type": "display_data"
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"data": {
|
188 |
+
"application/vnd.jupyter.widget-view+json": {
|
189 |
+
"model_id": "d54439e2d7d0400a8498f3f80a8df94a",
|
190 |
+
"version_major": 2,
|
191 |
+
"version_minor": 0
|
192 |
+
},
|
193 |
+
"text/plain": [
|
194 |
+
"added_tokens.json: 0%| | 0.00/80.0 [00:00<?, ?B/s]"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
"metadata": {},
|
198 |
+
"output_type": "display_data"
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"data": {
|
202 |
+
"application/vnd.jupyter.widget-view+json": {
|
203 |
+
"model_id": "4583e80e1c534c2aaaec54cbe22fe987",
|
204 |
+
"version_major": 2,
|
205 |
+
"version_minor": 0
|
206 |
+
},
|
207 |
+
"text/plain": [
|
208 |
+
"special_tokens_map.json: 0%| | 0.00/419 [00:00<?, ?B/s]"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
"metadata": {},
|
212 |
+
"output_type": "display_data"
|
213 |
+
}
|
214 |
+
],
|
215 |
+
"source": [
|
216 |
+
"# Load a pretrained model\n",
|
217 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
218 |
+
" model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code\n",
|
219 |
+
")\n",
|
220 |
+
"ref_model = AutoModelForCausalLM.from_pretrained(\n",
|
221 |
+
" model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code\n",
|
222 |
+
")\n",
|
223 |
+
"\n",
|
224 |
+
"# load a tokenaizer\n",
|
225 |
+
"tokenizer = AutoTokenizer.from_pretrained(\n",
|
226 |
+
" model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code\n",
|
227 |
+
")\n",
|
228 |
+
"if tokenizer.pad_token is None:\n",
|
229 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
230 |
+
"\n",
|
231 |
+
"# If we are aligning a base model, we use ChatML as the default template\n",
|
232 |
+
"if tokenizer.chat_template is None:\n",
|
233 |
+
" model, tokenizer = setup_chat_format(model, tokenizer)"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": 5,
|
239 |
+
"metadata": {},
|
240 |
+
"outputs": [
|
241 |
+
{
|
242 |
+
"data": {
|
243 |
+
"application/vnd.jupyter.widget-view+json": {
|
244 |
+
"model_id": "edc71904a99c485e9ff32d6c4740249d",
|
245 |
+
"version_major": 2,
|
246 |
+
"version_minor": 0
|
247 |
+
},
|
248 |
+
"text/plain": [
|
249 |
+
"README.md: 0%| | 0.00/814 [00:00<?, ?B/s]"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
"metadata": {},
|
253 |
+
"output_type": "display_data"
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"data": {
|
257 |
+
"application/vnd.jupyter.widget-view+json": {
|
258 |
+
"model_id": "27ac71372f8b493bbb8833148d381f75",
|
259 |
+
"version_major": 2,
|
260 |
+
"version_minor": 0
|
261 |
+
},
|
262 |
+
"text/plain": [
|
263 |
+
"train-00000-of-00001.parquet: 0%| | 0.00/16.3M [00:00<?, ?B/s]"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
"metadata": {},
|
267 |
+
"output_type": "display_data"
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"data": {
|
271 |
+
"application/vnd.jupyter.widget-view+json": {
|
272 |
+
"model_id": "f7ca9416d7c643ceb00109f8ce9a512f",
|
273 |
+
"version_major": 2,
|
274 |
+
"version_minor": 0
|
275 |
+
},
|
276 |
+
"text/plain": [
|
277 |
+
"test-00000-of-00001.parquet: 0%| | 0.00/1.81M [00:00<?, ?B/s]"
|
278 |
+
]
|
279 |
+
},
|
280 |
+
"metadata": {},
|
281 |
+
"output_type": "display_data"
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"data": {
|
285 |
+
"application/vnd.jupyter.widget-view+json": {
|
286 |
+
"model_id": "34b0aa59e9474cb29a7d38956bcac892",
|
287 |
+
"version_major": 2,
|
288 |
+
"version_minor": 0
|
289 |
+
},
|
290 |
+
"text/plain": [
|
291 |
+
"Generating train split: 0%| | 0/13500 [00:00<?, ? examples/s]"
|
292 |
+
]
|
293 |
+
},
|
294 |
+
"metadata": {},
|
295 |
+
"output_type": "display_data"
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"data": {
|
299 |
+
"application/vnd.jupyter.widget-view+json": {
|
300 |
+
"model_id": "111f8817e354479ea2c99838d91bdcae",
|
301 |
+
"version_major": 2,
|
302 |
+
"version_minor": 0
|
303 |
+
},
|
304 |
+
"text/plain": [
|
305 |
+
"Generating test split: 0%| | 0/1500 [00:00<?, ? examples/s]"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
"metadata": {},
|
309 |
+
"output_type": "display_data"
|
310 |
+
}
|
311 |
+
],
|
312 |
+
"source": [
|
313 |
+
"# Load the dataset\n",
|
314 |
+
"dataset = load_dataset(script_args.dataset_name)\n",
|
315 |
+
"\n",
|
316 |
+
"# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)\n",
|
317 |
+
"dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)"
|
318 |
+
]
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"cell_type": "code",
|
322 |
+
"execution_count": 6,
|
323 |
+
"metadata": {},
|
324 |
+
"outputs": [],
|
325 |
+
"source": [
|
326 |
+
"# Apply chat template\n",
|
327 |
+
"def format_dataset(example):\n",
|
328 |
+
" example[\"prompt\"] = tokenizer.apply_chat_template(example[\"prompt\"], tokenize=False)\n",
|
329 |
+
" example[\"completion\"] = tokenizer.apply_chat_template(example[\"completion\"], tokenize=False)\n",
|
330 |
+
" return example"
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "code",
|
335 |
+
"execution_count": 7,
|
336 |
+
"metadata": {},
|
337 |
+
"outputs": [
|
338 |
+
{
|
339 |
+
"data": {
|
340 |
+
"application/vnd.jupyter.widget-view+json": {
|
341 |
+
"model_id": "e1ec07668de94a1580a72a208fc90c47",
|
342 |
+
"version_major": 2,
|
343 |
+
"version_minor": 0
|
344 |
+
},
|
345 |
+
"text/plain": [
|
346 |
+
"Map: 0%| | 0/13500 [00:00<?, ? examples/s]"
|
347 |
+
]
|
348 |
+
},
|
349 |
+
"metadata": {},
|
350 |
+
"output_type": "display_data"
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"data": {
|
354 |
+
"application/vnd.jupyter.widget-view+json": {
|
355 |
+
"model_id": "f411d87acc1840a4a5650565cab06018",
|
356 |
+
"version_major": 2,
|
357 |
+
"version_minor": 0
|
358 |
+
},
|
359 |
+
"text/plain": [
|
360 |
+
"Map: 0%| | 0/1500 [00:00<?, ? examples/s]"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
"metadata": {},
|
364 |
+
"output_type": "display_data"
|
365 |
+
}
|
366 |
+
],
|
367 |
+
"source": [
|
368 |
+
"# Compute that only on the main process for faster data processing.\n",
|
369 |
+
"# see: https://github.com/huggingface/trl/pull/1255\n",
|
370 |
+
"with PartialState().local_main_process_first():\n",
|
371 |
+
" dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc)\n"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"cell_type": "code",
|
376 |
+
"execution_count": 8,
|
377 |
+
"metadata": {},
|
378 |
+
"outputs": [
|
379 |
+
{
|
380 |
+
"name": "stderr",
|
381 |
+
"output_type": "stream",
|
382 |
+
"text": [
|
383 |
+
"/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:466: UserWarning: When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init it will be set to `512` by default, but you should do it yourself in the future.\n",
|
384 |
+
" warnings.warn(\n",
|
385 |
+
"/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:476: UserWarning: When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init it will be set to `128` by default, but you should do it yourself in the future.\n",
|
386 |
+
" warnings.warn(\n",
|
387 |
+
"/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:506: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig we have set it for you, but you should do it yourself in the future.\n",
|
388 |
+
" warnings.warn(\n"
|
389 |
+
]
|
390 |
+
},
|
391 |
+
{
|
392 |
+
"data": {
|
393 |
+
"application/vnd.jupyter.widget-view+json": {
|
394 |
+
"model_id": "a0546a58479a4cf3ae8c14ead7d2f21a",
|
395 |
+
"version_major": 2,
|
396 |
+
"version_minor": 0
|
397 |
+
},
|
398 |
+
"text/plain": [
|
399 |
+
"Tokenizing train dataset: 0%| | 0/13500 [00:00<?, ? examples/s]"
|
400 |
+
]
|
401 |
+
},
|
402 |
+
"metadata": {},
|
403 |
+
"output_type": "display_data"
|
404 |
+
},
|
405 |
+
{
|
406 |
+
"data": {
|
407 |
+
"application/vnd.jupyter.widget-view+json": {
|
408 |
+
"model_id": "93957abece9b440181de6bcd3d7ac9d4",
|
409 |
+
"version_major": 2,
|
410 |
+
"version_minor": 0
|
411 |
+
},
|
412 |
+
"text/plain": [
|
413 |
+
"Processing tokenized train dataset: 0%| | 0/13500 [00:00<?, ? examples/s]"
|
414 |
+
]
|
415 |
+
},
|
416 |
+
"metadata": {},
|
417 |
+
"output_type": "display_data"
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"data": {
|
421 |
+
"application/vnd.jupyter.widget-view+json": {
|
422 |
+
"model_id": "6307b14ccc3b455db5610bf269f054eb",
|
423 |
+
"version_major": 2,
|
424 |
+
"version_minor": 0
|
425 |
+
},
|
426 |
+
"text/plain": [
|
427 |
+
"Tokenizing eval dataset: 0%| | 0/1500 [00:00<?, ? examples/s]"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
"metadata": {},
|
431 |
+
"output_type": "display_data"
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"data": {
|
435 |
+
"application/vnd.jupyter.widget-view+json": {
|
436 |
+
"model_id": "8d8ae59b643b4ba99258f2e579e24614",
|
437 |
+
"version_major": 2,
|
438 |
+
"version_minor": 0
|
439 |
+
},
|
440 |
+
"text/plain": [
|
441 |
+
"Processing tokenized eval dataset: 0%| | 0/1500 [00:00<?, ? examples/s]"
|
442 |
+
]
|
443 |
+
},
|
444 |
+
"metadata": {},
|
445 |
+
"output_type": "display_data"
|
446 |
+
},
|
447 |
+
{
|
448 |
+
"data": {
|
449 |
+
"application/vnd.jupyter.widget-view+json": {
|
450 |
+
"model_id": "4a3a9f330cc849339c6bd99fffae40ff",
|
451 |
+
"version_major": 2,
|
452 |
+
"version_minor": 0
|
453 |
+
},
|
454 |
+
"text/plain": [
|
455 |
+
"Extracting KL train dataset: 0%| | 0/13500 [00:00<?, ? examples/s]"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
"metadata": {},
|
459 |
+
"output_type": "display_data"
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"data": {
|
463 |
+
"application/vnd.jupyter.widget-view+json": {
|
464 |
+
"model_id": "0c77beb7b355417c91b5b1b974f01a22",
|
465 |
+
"version_major": 2,
|
466 |
+
"version_minor": 0
|
467 |
+
},
|
468 |
+
"text/plain": [
|
469 |
+
"Processing tokenized train KL dataset: 0%| | 0/13500 [00:00<?, ? examples/s]"
|
470 |
+
]
|
471 |
+
},
|
472 |
+
"metadata": {},
|
473 |
+
"output_type": "display_data"
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"data": {
|
477 |
+
"application/vnd.jupyter.widget-view+json": {
|
478 |
+
"model_id": "9da59bf2f1e14a6aad34be6b6dcd56c4",
|
479 |
+
"version_major": 2,
|
480 |
+
"version_minor": 0
|
481 |
+
},
|
482 |
+
"text/plain": [
|
483 |
+
"Extracting eval KL dataset: 0%| | 0/1500 [00:00<?, ? examples/s]"
|
484 |
+
]
|
485 |
+
},
|
486 |
+
"metadata": {},
|
487 |
+
"output_type": "display_data"
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"data": {
|
491 |
+
"application/vnd.jupyter.widget-view+json": {
|
492 |
+
"model_id": "0a8772792384450394df61957660aa56",
|
493 |
+
"version_major": 2,
|
494 |
+
"version_minor": 0
|
495 |
+
},
|
496 |
+
"text/plain": [
|
497 |
+
"Processing tokenized eval KL dataset: 0%| | 0/1500 [00:00<?, ? examples/s]"
|
498 |
+
]
|
499 |
+
},
|
500 |
+
"metadata": {},
|
501 |
+
"output_type": "display_data"
|
502 |
+
},
|
503 |
+
{
|
504 |
+
"data": {
|
505 |
+
"application/vnd.jupyter.widget-view+json": {
|
506 |
+
"model_id": "8afe00509fe14b16b163b38b8774a4c6",
|
507 |
+
"version_major": 2,
|
508 |
+
"version_minor": 0
|
509 |
+
},
|
510 |
+
"text/plain": [
|
511 |
+
" 0%| | 0/844 [00:00<?, ?it/s]"
|
512 |
+
]
|
513 |
+
},
|
514 |
+
"metadata": {},
|
515 |
+
"output_type": "display_data"
|
516 |
+
},
|
517 |
+
{
|
518 |
+
"ename": "RuntimeError",
|
519 |
+
"evalue": "MPS backend out of memory (MPS allocated: 17.37 GB, other allocations: 664.64 MB, max allowed: 18.13 GB). Tried to allocate 172.34 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).",
|
520 |
+
"output_type": "error",
|
521 |
+
"traceback": [
|
522 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
523 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
524 |
+
"Cell \u001b[0;32mIn[8], line 13\u001b[0m\n\u001b[1;32m 2\u001b[0m trainer \u001b[38;5;241m=\u001b[39m KTOTrainer(\n\u001b[1;32m 3\u001b[0m model,\n\u001b[1;32m 4\u001b[0m ref_model,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m peft_config\u001b[38;5;241m=\u001b[39mget_peft_config(model_args),\n\u001b[1;32m 10\u001b[0m )\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# Train and push the model to the Hub\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# Save and push to hub\u001b[39;00m\n\u001b[1;32m 16\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model(training_args\u001b[38;5;241m.\u001b[39moutput_dir)\n",
|
525 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/trainer.py:2052\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2050\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 2051\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2052\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2053\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2054\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2055\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2056\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2057\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
526 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/trainer.py:2388\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2385\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2387\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2388\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2390\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2391\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2392\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2393\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2394\u001b[0m ):\n\u001b[1;32m 2395\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2396\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
|
527 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/trainer.py:3485\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 3484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3485\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3487\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m inputs\n\u001b[1;32m 3488\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 3489\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mtorch_empty_cache_steps \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 3490\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mtorch_empty_cache_steps \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 3491\u001b[0m ):\n",
|
528 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:1237\u001b[0m, in \u001b[0;36mKTOTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 1234\u001b[0m compute_loss_context_manager \u001b[38;5;241m=\u001b[39m amp\u001b[38;5;241m.\u001b[39mautocast(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_peft_has_been_casted_to_bf16 \u001b[38;5;28;01melse\u001b[39;00m nullcontext()\n\u001b[1;32m 1236\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m compute_loss_context_manager:\n\u001b[0;32m-> 1237\u001b[0m loss, metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_batch_loss_metrics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1239\u001b[0m \u001b[38;5;66;03m# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:\u001b[39;00m\n\u001b[1;32m 1240\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n",
|
529 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:1143\u001b[0m, in \u001b[0;36mKTOTrainer.get_batch_loss_metrics\u001b[0;34m(self, model, batch)\u001b[0m\n\u001b[1;32m 1140\u001b[0m metrics \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 1141\u001b[0m batch \u001b[38;5;241m=\u001b[39m {k: (v\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdevice) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(v, torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;28;01melse\u001b[39;00m v) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m batch\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m-> 1143\u001b[0m forward_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1144\u001b[0m (\n\u001b[1;32m 1145\u001b[0m policy_chosen_logps,\n\u001b[1;32m 1146\u001b[0m policy_rejected_logps,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1149\u001b[0m policy_KL_logps,\n\u001b[1;32m 1150\u001b[0m ) \u001b[38;5;241m=\u001b[39m forward_output[:\u001b[38;5;241m5\u001b[39m]\n\u001b[1;32m 1151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maux_loss_enabled:\n",
|
530 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:1002\u001b[0m, in \u001b[0;36mKTOTrainer.forward\u001b[0;34m(self, model, batch)\u001b[0m\n\u001b[1;32m 988\u001b[0m KL_model_kwargs \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 989\u001b[0m {\n\u001b[1;32m 990\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m: batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mKL_prompt_input_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 999\u001b[0m }\n\u001b[1;32m 1000\u001b[0m )\n\u001b[1;32m 1001\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1002\u001b[0m KL_logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1003\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mKL_model_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1004\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlogits\n\u001b[1;32m 1006\u001b[0m KL_logps \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_batch_logps(\n\u001b[1;32m 1007\u001b[0m KL_logits,\n\u001b[1;32m 1008\u001b[0m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mKL_completion_labels\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1011\u001b[0m label_pad_token_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabel_pad_token_id,\n\u001b[1;32m 1012\u001b[0m )\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
531 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
532 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
533 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1167\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)\u001b[0m\n\u001b[1;32m 1164\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1167\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1168\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1169\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1170\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1171\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1172\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1173\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1174\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1175\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1176\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1177\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1178\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1180\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m labels \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n",
|
534 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
535 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
536 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:976\u001b[0m, in \u001b[0;36mQwen2Model.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 964\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 965\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 966\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 973\u001b[0m position_embeddings,\n\u001b[1;32m 974\u001b[0m )\n\u001b[1;32m 975\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 976\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 978\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 980\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 981\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 982\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 984\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 985\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 987\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 989\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
|
537 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
538 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
539 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:717\u001b[0m, in \u001b[0;36mQwen2DecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 715\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 716\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 717\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 718\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 720\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n",
|
540 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
541 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
542 |
+
"File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:276\u001b[0m, in \u001b[0;36mQwen2MLP.forward\u001b[0;34m(self, hidden_state)\u001b[0m\n\u001b[1;32m 275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, hidden_state):\n\u001b[0;32m--> 276\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdown_proj(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mact_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgate_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_state\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mup_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_state\u001b[49m\u001b[43m)\u001b[49m)\n",
|
543 |
+
"\u001b[0;31mRuntimeError\u001b[0m: MPS backend out of memory (MPS allocated: 17.37 GB, other allocations: 664.64 MB, max allowed: 18.13 GB). Tried to allocate 172.34 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."
|
544 |
+
]
|
545 |
+
}
|
546 |
+
],
|
547 |
+
"source": [
|
548 |
+
"# Initialize the KTO trainer\n",
|
549 |
+
"trainer = KTOTrainer(\n",
|
550 |
+
" model,\n",
|
551 |
+
" ref_model,\n",
|
552 |
+
" args=training_args,\n",
|
553 |
+
" train_dataset=dataset[\"train\"],\n",
|
554 |
+
" eval_dataset=dataset[\"test\"],\n",
|
555 |
+
" tokenizer=tokenizer,\n",
|
556 |
+
" peft_config=get_peft_config(model_args),\n",
|
557 |
+
")\n",
|
558 |
+
"\n",
|
559 |
+
"# Train and push the model to the Hub\n",
|
560 |
+
"trainer.train()\n",
|
561 |
+
"\n",
|
562 |
+
"# Save and push to hub\n",
|
563 |
+
"trainer.save_model(training_args.output_dir)\n",
|
564 |
+
"if training_args.push_to_hub:\n",
|
565 |
+
" trainer.push_to_hub()"
|
566 |
+
]
|
567 |
+
}
|
568 |
+
],
|
569 |
+
"metadata": {
|
570 |
+
"kernelspec": {
|
571 |
+
"display_name": "rlhf",
|
572 |
+
"language": "python",
|
573 |
+
"name": "python3"
|
574 |
+
},
|
575 |
+
"language_info": {
|
576 |
+
"codemirror_mode": {
|
577 |
+
"name": "ipython",
|
578 |
+
"version": 3
|
579 |
+
},
|
580 |
+
"file_extension": ".py",
|
581 |
+
"mimetype": "text/x-python",
|
582 |
+
"name": "python",
|
583 |
+
"nbconvert_exporter": "python",
|
584 |
+
"pygments_lexer": "ipython3",
|
585 |
+
"version": "3.10.13"
|
586 |
+
}
|
587 |
+
},
|
588 |
+
"nbformat": 4,
|
589 |
+
"nbformat_minor": 2
|
590 |
+
}
|
trl_rlhf_data.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Dict, List, Optional
|
18 |
+
|
19 |
+
from datasets import load_dataset
|
20 |
+
from transformers import HfArgumentParser
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ScriptArguments:
|
25 |
+
r"""
|
26 |
+
Arguments for the script.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
30 |
+
Whether to push the dataset to the Hugging Face Hub.
|
31 |
+
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
|
32 |
+
Hugging Face repository ID to push the dataset to.
|
33 |
+
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
34 |
+
Number of workers to use for dataset processing.
|
35 |
+
"""
|
36 |
+
|
37 |
+
push_to_hub: bool = False
|
38 |
+
repo_id: str = "trl-lib/hh-rlhf-helpful-base"
|
39 |
+
dataset_num_proc: Optional[int] = None
|
40 |
+
|
41 |
+
|
42 |
+
def common_start(str1: str, str2: str) -> str:
|
43 |
+
# Zip the two strings and iterate over them together
|
44 |
+
common_chars = []
|
45 |
+
for c1, c2 in zip(str1, str2):
|
46 |
+
if c1 == c2:
|
47 |
+
common_chars.append(c1)
|
48 |
+
else:
|
49 |
+
break
|
50 |
+
# Join the common characters and return as a string
|
51 |
+
return "".join(common_chars)
|
52 |
+
|
53 |
+
|
54 |
+
def extract_dialogue(example: str) -> List[Dict[str, str]]:
|
55 |
+
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
|
56 |
+
prompt_text = common_start(example["chosen"], example["rejected"])
|
57 |
+
|
58 |
+
# The chosen and rejected may share a common start, so we need to remove the common part
|
59 |
+
if not prompt_text.endswith("\n\nAssistant: "):
|
60 |
+
prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: "
|
61 |
+
|
62 |
+
# Extract the chosen and rejected lines
|
63 |
+
chosen_line = example["chosen"][len(prompt_text) :]
|
64 |
+
rejected_line = example["rejected"][len(prompt_text) :]
|
65 |
+
|
66 |
+
# Remove the generation prompt ("\n\nAssistant: ") from the prompt
|
67 |
+
prompt_text = prompt_text[: -len("\n\nAssistant: ")]
|
68 |
+
|
69 |
+
# Split the string at every occurrence of "Human: " or "Assistant: "
|
70 |
+
prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text)
|
71 |
+
|
72 |
+
# Remove the first element as it's empty
|
73 |
+
prompt_lines = prompt_lines[1:]
|
74 |
+
|
75 |
+
prompt = []
|
76 |
+
for idx in range(0, len(prompt_lines), 2):
|
77 |
+
role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant"
|
78 |
+
content = prompt_lines[idx + 1]
|
79 |
+
prompt.append({"role": role, "content": content})
|
80 |
+
|
81 |
+
# Remove the prompt from the chosen and rejected dialogues
|
82 |
+
chosen = [{"role": "assitant", "content": chosen_line}]
|
83 |
+
rejected = [{"role": "assistant", "content": rejected_line}]
|
84 |
+
|
85 |
+
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
|
86 |
+
|
87 |
+
|
88 |
+
def runner(arguments):
|
89 |
+
parser = HfArgumentParser(arguments)
|
90 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
91 |
+
|
92 |
+
dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
|
93 |
+
dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc)
|
94 |
+
return
|
95 |
+
|
96 |
+
# if script_args.push_to_hub:
|
97 |
+
# dataset.push_to_hub(script_args.repo_id)
|