Riddhi Bhagwat commited on
Commit
fd59c75
·
1 Parent(s): 3f8b25a

Add files via upload

Browse files

inital commit; moving files from old repo to organization

Files changed (4) hide show
  1. README.md +2 -0
  2. dataset_training.ipynb +398 -0
  3. kto_quickstart.ipynb +590 -0
  4. trl_rlhf_data.py +97 -0
README.md CHANGED
@@ -6,3 +6,5 @@
6
  This code repository (or "repo") is designed to demonstrate the best GitHub has to offer with the least amount of noise.
7
 
8
  The repo includes an `index.html` file (so it can render a web page), two GitHub Actions workflows, and a CSS stylesheet dependency.
 
 
 
6
  This code repository (or "repo") is designed to demonstrate the best GitHub has to offer with the least amount of noise.
7
 
8
  The repo includes an `index.html` file (so it can render a web page), two GitHub Actions workflows, and a CSS stylesheet dependency.
9
+ # Model-Improvement-Platform-With-RLHF
10
+ Platform being developed at MIT in collaboration with HuggingFace. Aimed at improving performance of existing Large Language Models through real time human feedback loop.
dataset_training.ipynb ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 43,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "#dependencies:\n",
10
+ "import pandas as pd\n",
11
+ "\n",
12
+ "import torch\n",
13
+ "from transformers import GPT2Tokenizer\n",
14
+ "\n",
15
+ "from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 44,
21
+ "metadata": {},
22
+ "outputs": [
23
+ {
24
+ "data": {
25
+ "application/vnd.jupyter.widget-view+json": {
26
+ "model_id": "b8a22b8d60c0417eafbf554832398287",
27
+ "version_major": 2,
28
+ "version_minor": 0
29
+ },
30
+ "text/plain": [
31
+ "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
32
+ ]
33
+ },
34
+ "metadata": {},
35
+ "output_type": "display_data"
36
+ },
37
+ {
38
+ "data": {
39
+ "application/vnd.jupyter.widget-view+json": {
40
+ "model_id": "b83d2624c2b14986a8297821460225ab",
41
+ "version_major": 2,
42
+ "version_minor": 0
43
+ },
44
+ "text/plain": [
45
+ "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
46
+ ]
47
+ },
48
+ "metadata": {},
49
+ "output_type": "display_data"
50
+ },
51
+ {
52
+ "data": {
53
+ "application/vnd.jupyter.widget-view+json": {
54
+ "model_id": "b4304c0f48cb472589b5e80d3a42cba2",
55
+ "version_major": 2,
56
+ "version_minor": 0
57
+ },
58
+ "text/plain": [
59
+ "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
60
+ ]
61
+ },
62
+ "metadata": {},
63
+ "output_type": "display_data"
64
+ }
65
+ ],
66
+ "source": [
67
+ "#loading datasets:\n",
68
+ "from datasets import load_dataset\n",
69
+ "\n",
70
+ "ds = load_dataset(\"stanfordnlp/SHP\", split='train')"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 45,
76
+ "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "name": "stdout",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "Index(['post_id', 'domain', 'upvote_ratio', 'history', 'c_root_id_A',\n",
83
+ " 'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B', 'score_A',\n",
84
+ " 'score_B', 'human_ref_A', 'human_ref_B', 'labels', 'seconds_difference',\n",
85
+ " 'score_ratio'],\n",
86
+ " dtype='object')\n"
87
+ ]
88
+ }
89
+ ],
90
+ "source": [
91
+ "df = ds.to_pandas()\n",
92
+ "print(df.columns)\n"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 46,
98
+ "metadata": {},
99
+ "outputs": [
100
+ {
101
+ "data": {
102
+ "text/html": [
103
+ "<div>\n",
104
+ "<style scoped>\n",
105
+ " .dataframe tbody tr th:only-of-type {\n",
106
+ " vertical-align: middle;\n",
107
+ " }\n",
108
+ "\n",
109
+ " .dataframe tbody tr th {\n",
110
+ " vertical-align: top;\n",
111
+ " }\n",
112
+ "\n",
113
+ " .dataframe thead th {\n",
114
+ " text-align: right;\n",
115
+ " }\n",
116
+ "</style>\n",
117
+ "<table border=\"1\" class=\"dataframe\">\n",
118
+ " <thead>\n",
119
+ " <tr style=\"text-align: right;\">\n",
120
+ " <th></th>\n",
121
+ " <th>upvote_ratio</th>\n",
122
+ " <th>history</th>\n",
123
+ " <th>score_A</th>\n",
124
+ " <th>score_B</th>\n",
125
+ " <th>human_ref_A</th>\n",
126
+ " <th>human_ref_B</th>\n",
127
+ " <th>labels</th>\n",
128
+ " <th>score_ratio</th>\n",
129
+ " </tr>\n",
130
+ " </thead>\n",
131
+ " <tbody>\n",
132
+ " <tr>\n",
133
+ " <th>0</th>\n",
134
+ " <td>0.99</td>\n",
135
+ " <td>In an interview right before receiving the 201...</td>\n",
136
+ " <td>52</td>\n",
137
+ " <td>54</td>\n",
138
+ " <td>Currently wrapping up my PhD. There is a stark...</td>\n",
139
+ " <td>It’s ironic to me that research has shown that...</td>\n",
140
+ " <td>0</td>\n",
141
+ " <td>1.038462</td>\n",
142
+ " </tr>\n",
143
+ " <tr>\n",
144
+ " <th>1</th>\n",
145
+ " <td>0.95</td>\n",
146
+ " <td>If any professor is reading this: please do no...</td>\n",
147
+ " <td>5</td>\n",
148
+ " <td>17</td>\n",
149
+ " <td>And when your teacher doesn't listen or pay at...</td>\n",
150
+ " <td>I'm pretty strict on time, to the point where ...</td>\n",
151
+ " <td>0</td>\n",
152
+ " <td>3.400000</td>\n",
153
+ " </tr>\n",
154
+ " <tr>\n",
155
+ " <th>2</th>\n",
156
+ " <td>0.95</td>\n",
157
+ " <td>If any professor is reading this: please do no...</td>\n",
158
+ " <td>5</td>\n",
159
+ " <td>7</td>\n",
160
+ " <td>Profs can be oblivious? What’s new!</td>\n",
161
+ " <td>This sounds like a problem with a specific pro...</td>\n",
162
+ " <td>0</td>\n",
163
+ " <td>1.400000</td>\n",
164
+ " </tr>\n",
165
+ " <tr>\n",
166
+ " <th>3</th>\n",
167
+ " <td>0.95</td>\n",
168
+ " <td>If any professor is reading this: please do no...</td>\n",
169
+ " <td>7</td>\n",
170
+ " <td>5</td>\n",
171
+ " <td>This sounds like a problem with a specific pro...</td>\n",
172
+ " <td>And when your teacher doesn't listen or pay at...</td>\n",
173
+ " <td>1</td>\n",
174
+ " <td>1.400000</td>\n",
175
+ " </tr>\n",
176
+ " <tr>\n",
177
+ " <th>4</th>\n",
178
+ " <td>0.95</td>\n",
179
+ " <td>If any professor is reading this: please do no...</td>\n",
180
+ " <td>6</td>\n",
181
+ " <td>7</td>\n",
182
+ " <td>This would be totally unacceptable in my class...</td>\n",
183
+ " <td>This sounds like a problem with a specific pro...</td>\n",
184
+ " <td>0</td>\n",
185
+ " <td>1.166667</td>\n",
186
+ " </tr>\n",
187
+ " <tr>\n",
188
+ " <th>...</th>\n",
189
+ " <td>...</td>\n",
190
+ " <td>...</td>\n",
191
+ " <td>...</td>\n",
192
+ " <td>...</td>\n",
193
+ " <td>...</td>\n",
194
+ " <td>...</td>\n",
195
+ " <td>...</td>\n",
196
+ " <td>...</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <th>348713</th>\n",
200
+ " <td>0.94</td>\n",
201
+ " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
202
+ " <td>7</td>\n",
203
+ " <td>25</td>\n",
204
+ " <td>Just put up a fence. Legally he isn't responsi...</td>\n",
205
+ " <td>Whatever you do, don't cut his trees down.</td>\n",
206
+ " <td>0</td>\n",
207
+ " <td>3.571429</td>\n",
208
+ " </tr>\n",
209
+ " <tr>\n",
210
+ " <th>348714</th>\n",
211
+ " <td>0.94</td>\n",
212
+ " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
213
+ " <td>2</td>\n",
214
+ " <td>25</td>\n",
215
+ " <td>If OP pays someone to clean his yard, and then...</td>\n",
216
+ " <td>Whatever you do, don't cut his trees down.</td>\n",
217
+ " <td>0</td>\n",
218
+ " <td>12.500000</td>\n",
219
+ " </tr>\n",
220
+ " <tr>\n",
221
+ " <th>348715</th>\n",
222
+ " <td>0.94</td>\n",
223
+ " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
224
+ " <td>9</td>\n",
225
+ " <td>7</td>\n",
226
+ " <td>My observation is that both of you are idiots...</td>\n",
227
+ " <td>Are you Rand Paul's neighbor? https://www.gq....</td>\n",
228
+ " <td>1</td>\n",
229
+ " <td>1.285714</td>\n",
230
+ " </tr>\n",
231
+ " <tr>\n",
232
+ " <th>348716</th>\n",
233
+ " <td>0.94</td>\n",
234
+ " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
235
+ " <td>9</td>\n",
236
+ " <td>7</td>\n",
237
+ " <td>My observation is that both of you are idiots...</td>\n",
238
+ " <td>Just put up a fence. Legally he isn't responsi...</td>\n",
239
+ " <td>1</td>\n",
240
+ " <td>1.285714</td>\n",
241
+ " </tr>\n",
242
+ " <tr>\n",
243
+ " <th>348717</th>\n",
244
+ " <td>0.94</td>\n",
245
+ " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
246
+ " <td>7</td>\n",
247
+ " <td>2</td>\n",
248
+ " <td>Capture his acts on camera. Collect and bag l...</td>\n",
249
+ " <td>If OP pays someone to clean his yard, and then...</td>\n",
250
+ " <td>1</td>\n",
251
+ " <td>3.500000</td>\n",
252
+ " </tr>\n",
253
+ " </tbody>\n",
254
+ "</table>\n",
255
+ "<p>348718 rows × 8 columns</p>\n",
256
+ "</div>"
257
+ ],
258
+ "text/plain": [
259
+ " upvote_ratio history \\\n",
260
+ "0 0.99 In an interview right before receiving the 201... \n",
261
+ "1 0.95 If any professor is reading this: please do no... \n",
262
+ "2 0.95 If any professor is reading this: please do no... \n",
263
+ "3 0.95 If any professor is reading this: please do no... \n",
264
+ "4 0.95 If any professor is reading this: please do no... \n",
265
+ "... ... ... \n",
266
+ "348713 0.94 Can I get in trouble for giving my neighbor hi... \n",
267
+ "348714 0.94 Can I get in trouble for giving my neighbor hi... \n",
268
+ "348715 0.94 Can I get in trouble for giving my neighbor hi... \n",
269
+ "348716 0.94 Can I get in trouble for giving my neighbor hi... \n",
270
+ "348717 0.94 Can I get in trouble for giving my neighbor hi... \n",
271
+ "\n",
272
+ " score_A score_B human_ref_A \\\n",
273
+ "0 52 54 Currently wrapping up my PhD. There is a stark... \n",
274
+ "1 5 17 And when your teacher doesn't listen or pay at... \n",
275
+ "2 5 7 Profs can be oblivious? What’s new! \n",
276
+ "3 7 5 This sounds like a problem with a specific pro... \n",
277
+ "4 6 7 This would be totally unacceptable in my class... \n",
278
+ "... ... ... ... \n",
279
+ "348713 7 25 Just put up a fence. Legally he isn't responsi... \n",
280
+ "348714 2 25 If OP pays someone to clean his yard, and then... \n",
281
+ "348715 9 7 My observation is that both of you are idiots... \n",
282
+ "348716 9 7 My observation is that both of you are idiots... \n",
283
+ "348717 7 2 Capture his acts on camera. Collect and bag l... \n",
284
+ "\n",
285
+ " human_ref_B labels score_ratio \n",
286
+ "0 It’s ironic to me that research has shown that... 0 1.038462 \n",
287
+ "1 I'm pretty strict on time, to the point where ... 0 3.400000 \n",
288
+ "2 This sounds like a problem with a specific pro... 0 1.400000 \n",
289
+ "3 And when your teacher doesn't listen or pay at... 1 1.400000 \n",
290
+ "4 This sounds like a problem with a specific pro... 0 1.166667 \n",
291
+ "... ... ... ... \n",
292
+ "348713 Whatever you do, don't cut his trees down. 0 3.571429 \n",
293
+ "348714 Whatever you do, don't cut his trees down. 0 12.500000 \n",
294
+ "348715 Are you Rand Paul's neighbor? https://www.gq.... 1 1.285714 \n",
295
+ "348716 Just put up a fence. Legally he isn't responsi... 1 1.285714 \n",
296
+ "348717 If OP pays someone to clean his yard, and then... 1 3.500000 \n",
297
+ "\n",
298
+ "[348718 rows x 8 columns]"
299
+ ]
300
+ },
301
+ "execution_count": 46,
302
+ "metadata": {},
303
+ "output_type": "execute_result"
304
+ }
305
+ ],
306
+ "source": [
307
+ "# df['response_length'] = df['history'].apply(len)\n",
308
+ "# df['label'] = df['response_length'].apply(lambda x: 'long' if x > 100 else 'short')\n",
309
+ "df.drop(columns=['post_id', 'domain', 'c_root_id_A', 'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B', 'seconds_difference'])"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 47,
315
+ "metadata": {},
316
+ "outputs": [
317
+ {
318
+ "name": "stderr",
319
+ "output_type": "stream",
320
+ "text": [
321
+ "/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
322
+ " warnings.warn(\n"
323
+ ]
324
+ }
325
+ ],
326
+ "source": [
327
+ "model = AutoModelForCausalLMWithValueHead.from_pretrained(\"gpt2\")\n",
328
+ "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(\"gpt2\")\n",
329
+ "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
330
+ "tokenizer.pad_token = tokenizer.eos_token"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 48,
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "from trl_rlhf_data import runner, ScriptArguments\n",
340
+ "import re\n",
341
+ "from dataclasses import dataclass\n",
342
+ "from typing import Dict, List, Optional\n",
343
+ "\n",
344
+ "from datasets import load_dataset\n",
345
+ "from transformers import HfArgumentParser"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "execution_count": 49,
351
+ "metadata": {},
352
+ "outputs": [
353
+ {
354
+ "ename": "TypeError",
355
+ "evalue": "runner() takes 0 positional arguments but 1 was given",
356
+ "output_type": "error",
357
+ "traceback": [
358
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
359
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
360
+ "Cell \u001b[0;32mIn[49], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mrunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mScriptArguments\u001b[49m\u001b[43m)\u001b[49m\n",
361
+ "\u001b[0;31mTypeError\u001b[0m: runner() takes 0 positional arguments but 1 was given"
362
+ ]
363
+ }
364
+ ],
365
+ "source": [
366
+ "dataset = runner(ScriptArguments)"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": []
375
+ }
376
+ ],
377
+ "metadata": {
378
+ "kernelspec": {
379
+ "display_name": "Python 3",
380
+ "language": "python",
381
+ "name": "python3"
382
+ },
383
+ "language_info": {
384
+ "codemirror_mode": {
385
+ "name": "ipython",
386
+ "version": 3
387
+ },
388
+ "file_extension": ".py",
389
+ "mimetype": "text/x-python",
390
+ "name": "python",
391
+ "nbconvert_exporter": "python",
392
+ "pygments_lexer": "ipython3",
393
+ "version": "3.10.13"
394
+ }
395
+ },
396
+ "nbformat": 4,
397
+ "nbformat_minor": 2
398
+ }
kto_quickstart.ipynb ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# .KTO Example"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "from dataclasses import dataclass\n",
24
+ "\n",
25
+ "from accelerate import PartialState\n",
26
+ "from datasets import load_dataset\n",
27
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser\n",
28
+ "\n",
29
+ "from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 3,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "# Define and parse arguments.\n",
39
+ "@dataclass\n",
40
+ "class ScriptArguments:\n",
41
+ " \"\"\"\n",
42
+ " The arguments for the KTO training script.\n",
43
+ " \"\"\"\n",
44
+ "\n",
45
+ " dataset_name: str = \"trl-lib/kto-mix-14k\"\n",
46
+ "\n",
47
+ "\n",
48
+ "# Initialize the arguments directly\n",
49
+ "script_args = ScriptArguments(\n",
50
+ " dataset_name=\"trl-lib/kto-mix-14k\"\n",
51
+ ")\n",
52
+ "\n",
53
+ "training_args = KTOConfig(\n",
54
+ " output_dir=\"kto-aligned-model\",\n",
55
+ " num_train_epochs=1,\n",
56
+ " per_device_train_batch_size=16,\n",
57
+ " learning_rate=5e-7,\n",
58
+ " lr_scheduler_type=\"cosine\",\n",
59
+ " gradient_accumulation_steps=1,\n",
60
+ " logging_steps=10,\n",
61
+ " eval_steps=500,\n",
62
+ " warmup_ratio=0.1,\n",
63
+ " bf16=True,\n",
64
+ " logging_first_step=True\n",
65
+ ")\n",
66
+ "\n",
67
+ "model_args = ModelConfig(\n",
68
+ " model_name_or_path=\"trl-lib/qwen1.5-1.8b-sft\",\n",
69
+ " # any additional model-specific arguments\n",
70
+ ")"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {},
76
+ "source": [
77
+ "- @dataclass makes it easier to create classes that only contain data, making your argument definitions compact, easier to read, and automatically initialized without the need to write a custom __init__ method.\n",
78
+ "- @dataclass is used here to define a structure for the arguments that you are going to pass to the training script:\n",
79
+ "- You define a simple data structure (ScriptArguments) with a list of variables (e.g., dataset_name).\n",
80
+ "- You can quickly create instances of this structure (script_args = ScriptArguments(...)) without manually writing the initializer.\n"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 4,
86
+ "metadata": {},
87
+ "outputs": [
88
+ {
89
+ "data": {
90
+ "application/vnd.jupyter.widget-view+json": {
91
+ "model_id": "194616275edb45c5a41065cd24d32510",
92
+ "version_major": 2,
93
+ "version_minor": 0
94
+ },
95
+ "text/plain": [
96
+ "config.json: 0%| | 0.00/702 [00:00<?, ?B/s]"
97
+ ]
98
+ },
99
+ "metadata": {},
100
+ "output_type": "display_data"
101
+ },
102
+ {
103
+ "data": {
104
+ "application/vnd.jupyter.widget-view+json": {
105
+ "model_id": "487b6524aee9484ea889b896dae886d9",
106
+ "version_major": 2,
107
+ "version_minor": 0
108
+ },
109
+ "text/plain": [
110
+ "model.safetensors: 0%| | 0.00/3.67G [00:00<?, ?B/s]"
111
+ ]
112
+ },
113
+ "metadata": {},
114
+ "output_type": "display_data"
115
+ },
116
+ {
117
+ "data": {
118
+ "application/vnd.jupyter.widget-view+json": {
119
+ "model_id": "dc564aa7d2704c7baca796e3a4bd6bd5",
120
+ "version_major": 2,
121
+ "version_minor": 0
122
+ },
123
+ "text/plain": [
124
+ "generation_config.json: 0%| | 0.00/117 [00:00<?, ?B/s]"
125
+ ]
126
+ },
127
+ "metadata": {},
128
+ "output_type": "display_data"
129
+ },
130
+ {
131
+ "data": {
132
+ "application/vnd.jupyter.widget-view+json": {
133
+ "model_id": "2fd83e4e335e4b1fa014c7bb71990d3b",
134
+ "version_major": 2,
135
+ "version_minor": 0
136
+ },
137
+ "text/plain": [
138
+ "tokenizer_config.json: 0%| | 0.00/1.17k [00:00<?, ?B/s]"
139
+ ]
140
+ },
141
+ "metadata": {},
142
+ "output_type": "display_data"
143
+ },
144
+ {
145
+ "data": {
146
+ "application/vnd.jupyter.widget-view+json": {
147
+ "model_id": "cb5d6cc62c5b4a79a2e72d68d003fac3",
148
+ "version_major": 2,
149
+ "version_minor": 0
150
+ },
151
+ "text/plain": [
152
+ "vocab.json: 0%| | 0.00/2.78M [00:00<?, ?B/s]"
153
+ ]
154
+ },
155
+ "metadata": {},
156
+ "output_type": "display_data"
157
+ },
158
+ {
159
+ "data": {
160
+ "application/vnd.jupyter.widget-view+json": {
161
+ "model_id": "59bd030296c44f9eb74d110e91bebdbe",
162
+ "version_major": 2,
163
+ "version_minor": 0
164
+ },
165
+ "text/plain": [
166
+ "merges.txt: 0%| | 0.00/1.67M [00:00<?, ?B/s]"
167
+ ]
168
+ },
169
+ "metadata": {},
170
+ "output_type": "display_data"
171
+ },
172
+ {
173
+ "data": {
174
+ "application/vnd.jupyter.widget-view+json": {
175
+ "model_id": "1e6be00a1a8740d08016c438bfc3c9ea",
176
+ "version_major": 2,
177
+ "version_minor": 0
178
+ },
179
+ "text/plain": [
180
+ "tokenizer.json: 0%| | 0.00/7.03M [00:00<?, ?B/s]"
181
+ ]
182
+ },
183
+ "metadata": {},
184
+ "output_type": "display_data"
185
+ },
186
+ {
187
+ "data": {
188
+ "application/vnd.jupyter.widget-view+json": {
189
+ "model_id": "d54439e2d7d0400a8498f3f80a8df94a",
190
+ "version_major": 2,
191
+ "version_minor": 0
192
+ },
193
+ "text/plain": [
194
+ "added_tokens.json: 0%| | 0.00/80.0 [00:00<?, ?B/s]"
195
+ ]
196
+ },
197
+ "metadata": {},
198
+ "output_type": "display_data"
199
+ },
200
+ {
201
+ "data": {
202
+ "application/vnd.jupyter.widget-view+json": {
203
+ "model_id": "4583e80e1c534c2aaaec54cbe22fe987",
204
+ "version_major": 2,
205
+ "version_minor": 0
206
+ },
207
+ "text/plain": [
208
+ "special_tokens_map.json: 0%| | 0.00/419 [00:00<?, ?B/s]"
209
+ ]
210
+ },
211
+ "metadata": {},
212
+ "output_type": "display_data"
213
+ }
214
+ ],
215
+ "source": [
216
+ "# Load a pretrained model\n",
217
+ "model = AutoModelForCausalLM.from_pretrained(\n",
218
+ " model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code\n",
219
+ ")\n",
220
+ "ref_model = AutoModelForCausalLM.from_pretrained(\n",
221
+ " model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code\n",
222
+ ")\n",
223
+ "\n",
224
+ "# load a tokenaizer\n",
225
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
226
+ " model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code\n",
227
+ ")\n",
228
+ "if tokenizer.pad_token is None:\n",
229
+ " tokenizer.pad_token = tokenizer.eos_token\n",
230
+ "\n",
231
+ "# If we are aligning a base model, we use ChatML as the default template\n",
232
+ "if tokenizer.chat_template is None:\n",
233
+ " model, tokenizer = setup_chat_format(model, tokenizer)"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 5,
239
+ "metadata": {},
240
+ "outputs": [
241
+ {
242
+ "data": {
243
+ "application/vnd.jupyter.widget-view+json": {
244
+ "model_id": "edc71904a99c485e9ff32d6c4740249d",
245
+ "version_major": 2,
246
+ "version_minor": 0
247
+ },
248
+ "text/plain": [
249
+ "README.md: 0%| | 0.00/814 [00:00<?, ?B/s]"
250
+ ]
251
+ },
252
+ "metadata": {},
253
+ "output_type": "display_data"
254
+ },
255
+ {
256
+ "data": {
257
+ "application/vnd.jupyter.widget-view+json": {
258
+ "model_id": "27ac71372f8b493bbb8833148d381f75",
259
+ "version_major": 2,
260
+ "version_minor": 0
261
+ },
262
+ "text/plain": [
263
+ "train-00000-of-00001.parquet: 0%| | 0.00/16.3M [00:00<?, ?B/s]"
264
+ ]
265
+ },
266
+ "metadata": {},
267
+ "output_type": "display_data"
268
+ },
269
+ {
270
+ "data": {
271
+ "application/vnd.jupyter.widget-view+json": {
272
+ "model_id": "f7ca9416d7c643ceb00109f8ce9a512f",
273
+ "version_major": 2,
274
+ "version_minor": 0
275
+ },
276
+ "text/plain": [
277
+ "test-00000-of-00001.parquet: 0%| | 0.00/1.81M [00:00<?, ?B/s]"
278
+ ]
279
+ },
280
+ "metadata": {},
281
+ "output_type": "display_data"
282
+ },
283
+ {
284
+ "data": {
285
+ "application/vnd.jupyter.widget-view+json": {
286
+ "model_id": "34b0aa59e9474cb29a7d38956bcac892",
287
+ "version_major": 2,
288
+ "version_minor": 0
289
+ },
290
+ "text/plain": [
291
+ "Generating train split: 0%| | 0/13500 [00:00<?, ? examples/s]"
292
+ ]
293
+ },
294
+ "metadata": {},
295
+ "output_type": "display_data"
296
+ },
297
+ {
298
+ "data": {
299
+ "application/vnd.jupyter.widget-view+json": {
300
+ "model_id": "111f8817e354479ea2c99838d91bdcae",
301
+ "version_major": 2,
302
+ "version_minor": 0
303
+ },
304
+ "text/plain": [
305
+ "Generating test split: 0%| | 0/1500 [00:00<?, ? examples/s]"
306
+ ]
307
+ },
308
+ "metadata": {},
309
+ "output_type": "display_data"
310
+ }
311
+ ],
312
+ "source": [
313
+ "# Load the dataset\n",
314
+ "dataset = load_dataset(script_args.dataset_name)\n",
315
+ "\n",
316
+ "# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)\n",
317
+ "dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": 6,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "# Apply chat template\n",
327
+ "def format_dataset(example):\n",
328
+ " example[\"prompt\"] = tokenizer.apply_chat_template(example[\"prompt\"], tokenize=False)\n",
329
+ " example[\"completion\"] = tokenizer.apply_chat_template(example[\"completion\"], tokenize=False)\n",
330
+ " return example"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 7,
336
+ "metadata": {},
337
+ "outputs": [
338
+ {
339
+ "data": {
340
+ "application/vnd.jupyter.widget-view+json": {
341
+ "model_id": "e1ec07668de94a1580a72a208fc90c47",
342
+ "version_major": 2,
343
+ "version_minor": 0
344
+ },
345
+ "text/plain": [
346
+ "Map: 0%| | 0/13500 [00:00<?, ? examples/s]"
347
+ ]
348
+ },
349
+ "metadata": {},
350
+ "output_type": "display_data"
351
+ },
352
+ {
353
+ "data": {
354
+ "application/vnd.jupyter.widget-view+json": {
355
+ "model_id": "f411d87acc1840a4a5650565cab06018",
356
+ "version_major": 2,
357
+ "version_minor": 0
358
+ },
359
+ "text/plain": [
360
+ "Map: 0%| | 0/1500 [00:00<?, ? examples/s]"
361
+ ]
362
+ },
363
+ "metadata": {},
364
+ "output_type": "display_data"
365
+ }
366
+ ],
367
+ "source": [
368
+ "# Compute that only on the main process for faster data processing.\n",
369
+ "# see: https://github.com/huggingface/trl/pull/1255\n",
370
+ "with PartialState().local_main_process_first():\n",
371
+ " dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc)\n"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": 8,
377
+ "metadata": {},
378
+ "outputs": [
379
+ {
380
+ "name": "stderr",
381
+ "output_type": "stream",
382
+ "text": [
383
+ "/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:466: UserWarning: When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init it will be set to `512` by default, but you should do it yourself in the future.\n",
384
+ " warnings.warn(\n",
385
+ "/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:476: UserWarning: When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init it will be set to `128` by default, but you should do it yourself in the future.\n",
386
+ " warnings.warn(\n",
387
+ "/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:506: UserWarning: When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig we have set it for you, but you should do it yourself in the future.\n",
388
+ " warnings.warn(\n"
389
+ ]
390
+ },
391
+ {
392
+ "data": {
393
+ "application/vnd.jupyter.widget-view+json": {
394
+ "model_id": "a0546a58479a4cf3ae8c14ead7d2f21a",
395
+ "version_major": 2,
396
+ "version_minor": 0
397
+ },
398
+ "text/plain": [
399
+ "Tokenizing train dataset: 0%| | 0/13500 [00:00<?, ? examples/s]"
400
+ ]
401
+ },
402
+ "metadata": {},
403
+ "output_type": "display_data"
404
+ },
405
+ {
406
+ "data": {
407
+ "application/vnd.jupyter.widget-view+json": {
408
+ "model_id": "93957abece9b440181de6bcd3d7ac9d4",
409
+ "version_major": 2,
410
+ "version_minor": 0
411
+ },
412
+ "text/plain": [
413
+ "Processing tokenized train dataset: 0%| | 0/13500 [00:00<?, ? examples/s]"
414
+ ]
415
+ },
416
+ "metadata": {},
417
+ "output_type": "display_data"
418
+ },
419
+ {
420
+ "data": {
421
+ "application/vnd.jupyter.widget-view+json": {
422
+ "model_id": "6307b14ccc3b455db5610bf269f054eb",
423
+ "version_major": 2,
424
+ "version_minor": 0
425
+ },
426
+ "text/plain": [
427
+ "Tokenizing eval dataset: 0%| | 0/1500 [00:00<?, ? examples/s]"
428
+ ]
429
+ },
430
+ "metadata": {},
431
+ "output_type": "display_data"
432
+ },
433
+ {
434
+ "data": {
435
+ "application/vnd.jupyter.widget-view+json": {
436
+ "model_id": "8d8ae59b643b4ba99258f2e579e24614",
437
+ "version_major": 2,
438
+ "version_minor": 0
439
+ },
440
+ "text/plain": [
441
+ "Processing tokenized eval dataset: 0%| | 0/1500 [00:00<?, ? examples/s]"
442
+ ]
443
+ },
444
+ "metadata": {},
445
+ "output_type": "display_data"
446
+ },
447
+ {
448
+ "data": {
449
+ "application/vnd.jupyter.widget-view+json": {
450
+ "model_id": "4a3a9f330cc849339c6bd99fffae40ff",
451
+ "version_major": 2,
452
+ "version_minor": 0
453
+ },
454
+ "text/plain": [
455
+ "Extracting KL train dataset: 0%| | 0/13500 [00:00<?, ? examples/s]"
456
+ ]
457
+ },
458
+ "metadata": {},
459
+ "output_type": "display_data"
460
+ },
461
+ {
462
+ "data": {
463
+ "application/vnd.jupyter.widget-view+json": {
464
+ "model_id": "0c77beb7b355417c91b5b1b974f01a22",
465
+ "version_major": 2,
466
+ "version_minor": 0
467
+ },
468
+ "text/plain": [
469
+ "Processing tokenized train KL dataset: 0%| | 0/13500 [00:00<?, ? examples/s]"
470
+ ]
471
+ },
472
+ "metadata": {},
473
+ "output_type": "display_data"
474
+ },
475
+ {
476
+ "data": {
477
+ "application/vnd.jupyter.widget-view+json": {
478
+ "model_id": "9da59bf2f1e14a6aad34be6b6dcd56c4",
479
+ "version_major": 2,
480
+ "version_minor": 0
481
+ },
482
+ "text/plain": [
483
+ "Extracting eval KL dataset: 0%| | 0/1500 [00:00<?, ? examples/s]"
484
+ ]
485
+ },
486
+ "metadata": {},
487
+ "output_type": "display_data"
488
+ },
489
+ {
490
+ "data": {
491
+ "application/vnd.jupyter.widget-view+json": {
492
+ "model_id": "0a8772792384450394df61957660aa56",
493
+ "version_major": 2,
494
+ "version_minor": 0
495
+ },
496
+ "text/plain": [
497
+ "Processing tokenized eval KL dataset: 0%| | 0/1500 [00:00<?, ? examples/s]"
498
+ ]
499
+ },
500
+ "metadata": {},
501
+ "output_type": "display_data"
502
+ },
503
+ {
504
+ "data": {
505
+ "application/vnd.jupyter.widget-view+json": {
506
+ "model_id": "8afe00509fe14b16b163b38b8774a4c6",
507
+ "version_major": 2,
508
+ "version_minor": 0
509
+ },
510
+ "text/plain": [
511
+ " 0%| | 0/844 [00:00<?, ?it/s]"
512
+ ]
513
+ },
514
+ "metadata": {},
515
+ "output_type": "display_data"
516
+ },
517
+ {
518
+ "ename": "RuntimeError",
519
+ "evalue": "MPS backend out of memory (MPS allocated: 17.37 GB, other allocations: 664.64 MB, max allowed: 18.13 GB). Tried to allocate 172.34 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).",
520
+ "output_type": "error",
521
+ "traceback": [
522
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
523
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
524
+ "Cell \u001b[0;32mIn[8], line 13\u001b[0m\n\u001b[1;32m 2\u001b[0m trainer \u001b[38;5;241m=\u001b[39m KTOTrainer(\n\u001b[1;32m 3\u001b[0m model,\n\u001b[1;32m 4\u001b[0m ref_model,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m peft_config\u001b[38;5;241m=\u001b[39mget_peft_config(model_args),\n\u001b[1;32m 10\u001b[0m )\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# Train and push the model to the Hub\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# Save and push to hub\u001b[39;00m\n\u001b[1;32m 16\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model(training_args\u001b[38;5;241m.\u001b[39moutput_dir)\n",
525
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/trainer.py:2052\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2050\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 2051\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2052\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2053\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2054\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2055\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2056\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2057\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
526
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/trainer.py:2388\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2385\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2387\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2388\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2390\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2391\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2392\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2393\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2394\u001b[0m ):\n\u001b[1;32m 2395\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2396\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
527
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/trainer.py:3485\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 3484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3485\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3487\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m inputs\n\u001b[1;32m 3488\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 3489\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mtorch_empty_cache_steps \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 3490\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mtorch_empty_cache_steps \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 3491\u001b[0m ):\n",
528
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:1237\u001b[0m, in \u001b[0;36mKTOTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 1234\u001b[0m compute_loss_context_manager \u001b[38;5;241m=\u001b[39m amp\u001b[38;5;241m.\u001b[39mautocast(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_peft_has_been_casted_to_bf16 \u001b[38;5;28;01melse\u001b[39;00m nullcontext()\n\u001b[1;32m 1236\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m compute_loss_context_manager:\n\u001b[0;32m-> 1237\u001b[0m loss, metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_batch_loss_metrics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1239\u001b[0m \u001b[38;5;66;03m# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:\u001b[39;00m\n\u001b[1;32m 1240\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n",
529
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:1143\u001b[0m, in \u001b[0;36mKTOTrainer.get_batch_loss_metrics\u001b[0;34m(self, model, batch)\u001b[0m\n\u001b[1;32m 1140\u001b[0m metrics \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 1141\u001b[0m batch \u001b[38;5;241m=\u001b[39m {k: (v\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdevice) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(v, torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;28;01melse\u001b[39;00m v) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m batch\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m-> 1143\u001b[0m forward_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1144\u001b[0m (\n\u001b[1;32m 1145\u001b[0m policy_chosen_logps,\n\u001b[1;32m 1146\u001b[0m policy_rejected_logps,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1149\u001b[0m policy_KL_logps,\n\u001b[1;32m 1150\u001b[0m ) \u001b[38;5;241m=\u001b[39m forward_output[:\u001b[38;5;241m5\u001b[39m]\n\u001b[1;32m 1151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maux_loss_enabled:\n",
530
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:1002\u001b[0m, in \u001b[0;36mKTOTrainer.forward\u001b[0;34m(self, model, batch)\u001b[0m\n\u001b[1;32m 988\u001b[0m KL_model_kwargs \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 989\u001b[0m {\n\u001b[1;32m 990\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m: batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mKL_prompt_input_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 999\u001b[0m }\n\u001b[1;32m 1000\u001b[0m )\n\u001b[1;32m 1001\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1002\u001b[0m KL_logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1003\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mKL_model_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1004\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlogits\n\u001b[1;32m 1006\u001b[0m KL_logps \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_batch_logps(\n\u001b[1;32m 1007\u001b[0m KL_logits,\n\u001b[1;32m 1008\u001b[0m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mKL_completion_labels\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1011\u001b[0m label_pad_token_id\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlabel_pad_token_id,\n\u001b[1;32m 1012\u001b[0m )\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
531
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
532
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
533
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1167\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)\u001b[0m\n\u001b[1;32m 1164\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1167\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1168\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1169\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1170\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1171\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1172\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1173\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1174\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1175\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1176\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1177\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1178\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1180\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m labels \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n",
534
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
535
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
536
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:976\u001b[0m, in \u001b[0;36mQwen2Model.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 964\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 965\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 966\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 973\u001b[0m position_embeddings,\n\u001b[1;32m 974\u001b[0m )\n\u001b[1;32m 975\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 976\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 977\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 978\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 980\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 981\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 982\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 984\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 985\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 987\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 989\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
537
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
538
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
539
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:717\u001b[0m, in \u001b[0;36mQwen2DecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 715\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 716\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 717\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 718\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 720\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n",
540
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
541
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
542
+ "File \u001b[0;32m~/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:276\u001b[0m, in \u001b[0;36mQwen2MLP.forward\u001b[0;34m(self, hidden_state)\u001b[0m\n\u001b[1;32m 275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, hidden_state):\n\u001b[0;32m--> 276\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdown_proj(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mact_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgate_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_state\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mup_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_state\u001b[49m\u001b[43m)\u001b[49m)\n",
543
+ "\u001b[0;31mRuntimeError\u001b[0m: MPS backend out of memory (MPS allocated: 17.37 GB, other allocations: 664.64 MB, max allowed: 18.13 GB). Tried to allocate 172.34 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."
544
+ ]
545
+ }
546
+ ],
547
+ "source": [
548
+ "# Initialize the KTO trainer\n",
549
+ "trainer = KTOTrainer(\n",
550
+ " model,\n",
551
+ " ref_model,\n",
552
+ " args=training_args,\n",
553
+ " train_dataset=dataset[\"train\"],\n",
554
+ " eval_dataset=dataset[\"test\"],\n",
555
+ " tokenizer=tokenizer,\n",
556
+ " peft_config=get_peft_config(model_args),\n",
557
+ ")\n",
558
+ "\n",
559
+ "# Train and push the model to the Hub\n",
560
+ "trainer.train()\n",
561
+ "\n",
562
+ "# Save and push to hub\n",
563
+ "trainer.save_model(training_args.output_dir)\n",
564
+ "if training_args.push_to_hub:\n",
565
+ " trainer.push_to_hub()"
566
+ ]
567
+ }
568
+ ],
569
+ "metadata": {
570
+ "kernelspec": {
571
+ "display_name": "rlhf",
572
+ "language": "python",
573
+ "name": "python3"
574
+ },
575
+ "language_info": {
576
+ "codemirror_mode": {
577
+ "name": "ipython",
578
+ "version": 3
579
+ },
580
+ "file_extension": ".py",
581
+ "mimetype": "text/x-python",
582
+ "name": "python",
583
+ "nbconvert_exporter": "python",
584
+ "pygments_lexer": "ipython3",
585
+ "version": "3.10.13"
586
+ }
587
+ },
588
+ "nbformat": 4,
589
+ "nbformat_minor": 2
590
+ }
trl_rlhf_data.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from typing import Dict, List, Optional
18
+
19
+ from datasets import load_dataset
20
+ from transformers import HfArgumentParser
21
+
22
+
23
+ @dataclass
24
+ class ScriptArguments:
25
+ r"""
26
+ Arguments for the script.
27
+
28
+ Args:
29
+ push_to_hub (`bool`, *optional*, defaults to `False`):
30
+ Whether to push the dataset to the Hugging Face Hub.
31
+ repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
32
+ Hugging Face repository ID to push the dataset to.
33
+ dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
34
+ Number of workers to use for dataset processing.
35
+ """
36
+
37
+ push_to_hub: bool = False
38
+ repo_id: str = "trl-lib/hh-rlhf-helpful-base"
39
+ dataset_num_proc: Optional[int] = None
40
+
41
+
42
+ def common_start(str1: str, str2: str) -> str:
43
+ # Zip the two strings and iterate over them together
44
+ common_chars = []
45
+ for c1, c2 in zip(str1, str2):
46
+ if c1 == c2:
47
+ common_chars.append(c1)
48
+ else:
49
+ break
50
+ # Join the common characters and return as a string
51
+ return "".join(common_chars)
52
+
53
+
54
+ def extract_dialogue(example: str) -> List[Dict[str, str]]:
55
+ # Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
56
+ prompt_text = common_start(example["chosen"], example["rejected"])
57
+
58
+ # The chosen and rejected may share a common start, so we need to remove the common part
59
+ if not prompt_text.endswith("\n\nAssistant: "):
60
+ prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: "
61
+
62
+ # Extract the chosen and rejected lines
63
+ chosen_line = example["chosen"][len(prompt_text) :]
64
+ rejected_line = example["rejected"][len(prompt_text) :]
65
+
66
+ # Remove the generation prompt ("\n\nAssistant: ") from the prompt
67
+ prompt_text = prompt_text[: -len("\n\nAssistant: ")]
68
+
69
+ # Split the string at every occurrence of "Human: " or "Assistant: "
70
+ prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text)
71
+
72
+ # Remove the first element as it's empty
73
+ prompt_lines = prompt_lines[1:]
74
+
75
+ prompt = []
76
+ for idx in range(0, len(prompt_lines), 2):
77
+ role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant"
78
+ content = prompt_lines[idx + 1]
79
+ prompt.append({"role": role, "content": content})
80
+
81
+ # Remove the prompt from the chosen and rejected dialogues
82
+ chosen = [{"role": "assitant", "content": chosen_line}]
83
+ rejected = [{"role": "assistant", "content": rejected_line}]
84
+
85
+ return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
86
+
87
+
88
+ def runner(arguments):
89
+ parser = HfArgumentParser(arguments)
90
+ script_args = parser.parse_args_into_dataclasses()[0]
91
+
92
+ dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
93
+ dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc)
94
+ return
95
+
96
+ # if script_args.push_to_hub:
97
+ # dataset.push_to_hub(script_args.repo_id)