Jen Ben Arye commited on
Commit
6498d39
·
1 Parent(s): 7ee0f15
Files changed (1) hide show
  1. kto_pipeline.py +226 -80
kto_pipeline.py CHANGED
@@ -1,116 +1,262 @@
1
  import torch
2
  from dataclasses import dataclass
3
-
4
  from accelerate import PartialState
5
- from datasets import load_dataset, DatasetDict
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
7
-
8
  from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
 
 
 
9
 
10
 
 
 
 
11
 
12
- # Define and parse arguments.
13
  @dataclass
14
  class ScriptArguments:
15
  """
16
- The arguments for the KTO training script.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
 
19
- dataset_name: str = "trl-lib/kto-mix-14k"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
21
 
22
- # Initialize the arguments directly
23
- script_args = ScriptArguments(
24
- dataset_name="trl-lib/kto-mix-14k"
25
- )
26
 
27
- training_args = KTOConfig(
28
- output_dir="/raid/lingo/jen_ben/HF-RLHF/kto_nov_2", # MODFIFY
29
- num_train_epochs=100,
30
- per_device_train_batch_size=4,
31
- learning_rate=5e-7,
32
- lr_scheduler_type="cosine",
33
- gradient_accumulation_steps=8,
34
- logging_steps=10,
35
- eval_steps=500,
36
- warmup_ratio=0.1,
37
- bf16=True,
38
- logging_first_step=True
39
- )
40
 
41
- model_args = ModelConfig(
42
- model_name_or_path="trl-lib/qwen1.5-1.8b-sft",
43
- # any additional model-specific arguments
44
- )
45
 
46
- # Load a pretrained model
47
- model = AutoModelForCausalLM.from_pretrained(
48
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
49
- )
50
- ref_model = AutoModelForCausalLM.from_pretrained(
51
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
52
- )
53
- print(f'loaded model')
 
 
54
 
55
- # load a tokenaizer
56
- tokenizer = AutoTokenizer.from_pretrained(
57
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
58
- )
 
 
 
 
59
 
60
- if tokenizer.pad_token is None:
61
- tokenizer.pad_token = tokenizer.eos_token
 
62
 
63
- # If we are aligning a base model, we use ChatML as the default template
64
- if tokenizer.chat_template is None:
65
- model, tokenizer = setup_chat_format(model, tokenizer)
66
- print(f'loaded tokenizer')
67
 
68
- # Load the dataset
69
- dataset = load_dataset(script_args.dataset_name)
70
 
71
- # If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
72
- dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)
73
- print(f'loaded dataset')
 
74
 
 
75
 
76
- # Apply chat template
77
- def format_dataset(example):
78
- example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
79
- example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
80
- return example
81
 
 
 
 
82
 
83
- # Compute that only on the main process for faster data processing.
84
- # see: https://github.com/huggingface/trl/pull/1255
85
- with PartialState().local_main_process_first():
86
- dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc)
 
87
 
 
 
 
 
 
88
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Initialize the KTO trainer
91
- trainer = KTOTrainer(
92
- model,
93
- ref_model,
94
- args=training_args,
95
- train_dataset=dataset["train"],
96
- eval_dataset=dataset["test"],
97
- tokenizer=tokenizer,
98
- peft_config=get_peft_config(model_args),
99
- )
100
 
101
- print(f'start training')
 
 
 
 
 
102
 
103
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- print(f'finished training')
 
 
 
106
 
107
- metrics = trainer.evaluate()
108
- print(f'metrics: \n {metrics}')
109
- trainer.log_metrics("eval", metrics)
110
- trainer.save_metrics("eval", metrics)
111
 
 
 
112
 
113
- # Save and push to hub
114
- trainer.save_model(training_args.output_dir)
115
- if training_args.push_to_hub:
116
- trainer.push_to_hub(dataset_name=script_args.dataset_name)
 
1
  import torch
2
  from dataclasses import dataclass
 
3
  from accelerate import PartialState
4
+ from datasets import load_dataset, DatasetDict, Dataset
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
 
6
  from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
7
+ from dataloaders.data_loader import get_oasst
8
+ from pdb import set_trace as st
9
+ import wandb
10
 
11
 
12
+ ####################################
13
+ # CONFIGURATION
14
+ ####################################
15
 
 
16
  @dataclass
17
  class ScriptArguments:
18
  """
19
+ Configuration for the script.
20
+ """
21
+ dataset_name: str = "OpenAssistant/oasst1" # Dataset name or path
22
+ output_dir: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs" # Output directory
23
+ pretrained_model_name: str = "mistralai/Mistral-7B-v0.1" # Pretrained model name or path
24
+ checkpoint_path: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs" # Checkpoint path
25
+ push_to_hub: bool = False # Whether to push the model to the Hugging Face hub
26
+
27
+ @dataclass
28
+ class TrainingArguments(KTOConfig):
29
+ """
30
+ Configuration for the KTO trainer.
31
+ """
32
+ output_dir: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs"
33
+ num_train_epochs: int = 2 # did 1 epoch, then maybe try 2 epochs
34
+ per_device_train_batch_size: int = 4 # 4 is the highes that runs well.
35
+ learning_rate: float = 5e-7
36
+ lr_scheduler_type: str = "cosine"
37
+ gradient_accumulation_steps: int = 1
38
+ logging_steps: int = 10
39
+ eval_steps: int = 500
40
+ warmup_ratio: float = 0.1
41
+ bf16: bool = True
42
+ logging_first_step: bool = True
43
+
44
+ @dataclass
45
+ class ModelArguments(ModelConfig):
46
+ """
47
+ Configuration for the model.
48
+ """
49
+ model_name_or_path: str = "mistralai/Mistral-7B-v0.1"
50
+ use_peft: bool = True
51
+ lora_target_modules: str = "all-linear"
52
+ lora_r: int = 16
53
+ lora_alpha: int = 16
54
+
55
+ # Initialize configurations
56
+ script_args = ScriptArguments()
57
+ training_args = TrainingArguments(output_dir=script_args.output_dir)
58
+ model_args = ModelArguments(model_name_or_path=script_args.pretrained_model_name)
59
+
60
+ ####################################
61
+ # HELPER FUNCTIONS
62
+ ####################################
63
+
64
+ def load_model_and_tokenizer(model_args):
65
+ """
66
+ Load a model and tokenizer from a specified path.
67
+ """
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_args.model_name_or_path,
70
+ trust_remote_code=model_args.trust_remote_code,
71
+ torch_dtype=torch.float16,
72
+ device_map="auto"
73
+ )
74
+ tokenizer = AutoTokenizer.from_pretrained(
75
+ model_args.model_name_or_path,
76
+ trust_remote_code=model_args.trust_remote_code
77
+ )
78
+
79
+ # Set pad token if missing
80
+ if tokenizer.pad_token is None:
81
+ tokenizer.pad_token = tokenizer.eos_token
82
+
83
+ # Setup chat format if not present
84
+ if tokenizer.chat_template is None:
85
+ model, tokenizer = setup_chat_format(model, tokenizer)
86
+
87
+ return model, tokenizer
88
+
89
+
90
+
91
+ def load_and_format_oasst_dataset(tokenizer):
92
+ """
93
+ Load, process, and format the OpenAssistant dataset into DPO-compatible format.
94
+
95
+ Args:
96
+ split (str): The dataset split to load ('train' or 'test').
97
+ tokenizer (AutoTokenizer): Tokenizer to apply chat templates.
98
+ num_proc (int, optional): Number of processes for parallel processing.
99
+
100
+ Returns:
101
+ Dataset: Processed and formatted dataset.
102
  """
103
 
104
+ # Load oasst dataset
105
+ train_dataset = get_oasst(split='train')
106
+
107
+ # Initialize lists for DPO dataset
108
+ dpo_train_data = {
109
+ "prompt": [],
110
+ "chosen": [],
111
+ "rejected": []
112
+ }
113
+
114
+ # Process the dataset
115
+ for prompt, key in train_dataset.data.items(): # Iterate over dataset
116
+ if hasattr(key, "pairs") and key.pairs: # Check if pairs exist
117
+ for i, j in key.pairs: # Process each preference pair
118
+ # Add prompt and corresponding chosen/rejected completions
119
+ dpo_train_data["prompt"].append(key.prompt)
120
+ dpo_train_data["chosen"].append(key.generations[i]) # Chosen generation
121
+ dpo_train_data["rejected"].append(key.generations[j]) # Rejected generation
122
+
123
+ # Convert DPO data into a Dataset
124
+ dpo_train_dataset = Dataset.from_dict(dpo_train_data)
125
+
126
+ # Wrap it in a DatasetDict
127
+ dataset_dict = DatasetDict({
128
+ "train": dpo_train_dataset
129
+ })
130
+
131
+
132
+ test_dataset = get_oasst(split='test')
133
+
134
+ dpo_test_data = {
135
+ "prompt": [],
136
+ "chosen": [],
137
+ "rejected": []
138
+ }
139
 
140
+ for prompt, key in test_dataset.data.items(): # Iterate over dataset
141
+ if hasattr(key, "pairs") and key.pairs: # Check if pairs exist
142
+ for i, j in key.pairs: # Process each preference pair
143
+ # Add prompt and corresponding chosen/rejected completions
144
+ dpo_test_data["prompt"].append(key.prompt)
145
+ dpo_test_data["chosen"].append(key.generations[i]) # Chosen generation
146
+ dpo_test_data["rejected"].append(key.generations[j]) # Rejected generation
147
 
148
+ dpo_test_dataset = Dataset.from_dict(dpo_test_data)
149
+ dataset_dict["test"] = dpo_test_dataset
 
 
150
 
151
+ # If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
152
+ dataset_dict = maybe_unpair_preference_dataset(dataset_dict, num_proc=training_args.dataset_num_proc)
153
+ print(f'loaded dataset')
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
 
155
 
156
+ # Apply chat template
157
+ def format_dataset(example):
158
+ # Ensure prompt is in the correct structure
159
+ if isinstance(example["prompt"], str):
160
+ example["prompt"] = [{"role": "user", "content": example["prompt"]}]
161
+ elif isinstance(example["prompt"], list):
162
+ # If it's already a list, ensure each element has the "role" and "content" keys
163
+ for item in example["prompt"]:
164
+ if "role" not in item or "content" not in item:
165
+ raise ValueError(f"Each item in 'prompt' must have 'role' and 'content': {item}")
166
 
167
+ # Ensure completion is in the correct structure
168
+ if isinstance(example["completion"], str):
169
+ example["completion"] = [{"role": "assistant", "content": example["completion"]}] # Wrap as a list of dictionaries
170
+ elif isinstance(example["completion"], list):
171
+ # If it's already a list, ensure each element has the "role" and "content" keys
172
+ for item in example["completion"]:
173
+ if "role" not in item or "content" not in item:
174
+ raise ValueError(f"Each item in 'completion' must have 'role' and 'content': {item}")
175
 
176
+ # Now apply the chat template
177
+ example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
178
+ example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
179
 
180
+ return example
 
 
 
181
 
 
 
182
 
183
+ # Compute that only on the main process for faster data processing.
184
+ # see: https://github.com/huggingface/trl/pull/1255
185
+ with PartialState().local_main_process_first():
186
+ dataset = dataset_dict.map(format_dataset, num_proc=training_args.dataset_num_proc)
187
 
188
+ return dataset
189
 
190
+ ####################################
191
+ # MAIN LOGIC
192
+ ####################################
 
 
193
 
194
+ def main():
195
+ # Initialize wandb
196
+ wandb.init(project="kto")
197
 
198
+ # Load models and tokenizer
199
+ print("Loading models and tokenizer...")
200
+ model, tokenizer = load_model_and_tokenizer(model_args)
201
+ ref_model, _ = load_model_and_tokenizer(model_args)
202
+ print("Models and tokenizer loaded.")
203
 
204
+ # Load and process datasets
205
+ print("Loading, processing, and formatting dataset...")
206
+ dataset = load_and_format_oasst_dataset(
207
+ tokenizer=tokenizer,
208
+ )
209
 
210
+ # Initialize trainer
211
+ print("Initializing trainer...")
212
+ trainer = KTOTrainer(
213
+ model=model,
214
+ ref_model=ref_model,
215
+ args=training_args,
216
+ train_dataset=dataset["train"],
217
+ eval_dataset=dataset["test"],
218
+ tokenizer=tokenizer,
219
+ peft_config=get_peft_config(model_args),
220
+ )
221
 
222
+ # Training
223
+ print("Starting training...")
224
+ trainer.train()
225
+ print("Training completed.")
 
 
 
 
 
 
226
 
227
+ # Evaluation
228
+ print("Evaluating model...")
229
+ metrics = trainer.evaluate()
230
+ print(f"Metrics: {metrics}")
231
+ trainer.log_metrics("eval", metrics)
232
+ trainer.save_metrics("eval", metrics)
233
 
234
+ # Log metrics to wandb
235
+ wandb.log({
236
+ "epoch": metrics.get("epoch"),
237
+ "grad_norm": metrics.get("grad_norm"),
238
+ "kl": metrics.get("kl"),
239
+ "learning_rate": metrics.get("learning_rate"),
240
+ "logits/chosen": metrics.get("logits/chosen"),
241
+ "logits/rejected": metrics.get("logits/rejected"),
242
+ "logps/chosen": metrics.get("logps/chosen"),
243
+ "logps/rejected": metrics.get("logps/rejected"),
244
+ "loss": metrics.get("loss"),
245
+ "rewards/chosen": metrics.get("rewards/chosen"),
246
+ "rewards/margins": metrics.get("rewards/margins"),
247
+ "rewards/rejected": metrics.get("rewards/rejected"),
248
+ "step": metrics.get("step")
249
+ })
250
 
251
+ # Save model and optionally push to hub
252
+ trainer.save_model(training_args.output_dir)
253
+ if script_args.push_to_hub:
254
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
255
 
256
+ print("Process completed.")
 
 
 
257
 
258
+ # Finish wandb run
259
+ wandb.finish()
260
 
261
+ if __name__ == "__main__":
262
+ main()