Jen Ben Arye commited on
Commit
0e83f57
·
1 Parent(s): 6498d39

updated kto pipeline to work with general dataset

Browse files
Files changed (1) hide show
  1. kto_pipeline.py +55 -171
kto_pipeline.py CHANGED
@@ -1,37 +1,35 @@
1
  import torch
2
  from dataclasses import dataclass
3
  from accelerate import PartialState
4
- from datasets import load_dataset, DatasetDict, Dataset
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6
- from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
7
- from dataloaders.data_loader import get_oasst
8
- from pdb import set_trace as st
9
  import wandb
10
 
11
-
12
  ####################################
13
  # CONFIGURATION
14
  ####################################
15
 
16
  @dataclass
17
- class ScriptArguments:
18
  """
19
  Configuration for the script.
20
  """
21
- dataset_name: str = "OpenAssistant/oasst1" # Dataset name or path
22
- output_dir: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs" # Output directory
23
- pretrained_model_name: str = "mistralai/Mistral-7B-v0.1" # Pretrained model name or path
24
- checkpoint_path: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs" # Checkpoint path
25
- push_to_hub: bool = False # Whether to push the model to the Hugging Face hub
26
 
27
- @dataclass
28
- class TrainingArguments(KTOConfig):
29
- """
30
- Configuration for the KTO trainer.
31
- """
32
- output_dir: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs"
33
- num_train_epochs: int = 2 # did 1 epoch, then maybe try 2 epochs
34
- per_device_train_batch_size: int = 4 # 4 is the highes that runs well.
 
 
 
35
  learning_rate: float = 5e-7
36
  lr_scheduler_type: str = "cosine"
37
  gradient_accumulation_steps: int = 1
@@ -41,152 +39,36 @@ class TrainingArguments(KTOConfig):
41
  bf16: bool = True
42
  logging_first_step: bool = True
43
 
44
- @dataclass
45
- class ModelArguments(ModelConfig):
46
- """
47
- Configuration for the model.
48
- """
49
- model_name_or_path: str = "mistralai/Mistral-7B-v0.1"
50
- use_peft: bool = True
51
- lora_target_modules: str = "all-linear"
52
- lora_r: int = 16
53
- lora_alpha: int = 16
54
 
55
- # Initialize configurations
56
- script_args = ScriptArguments()
57
- training_args = TrainingArguments(output_dir=script_args.output_dir)
58
- model_args = ModelArguments(model_name_or_path=script_args.pretrained_model_name)
59
 
60
  ####################################
61
  # HELPER FUNCTIONS
62
  ####################################
63
 
64
- def load_model_and_tokenizer(model_args):
65
  """
66
  Load a model and tokenizer from a specified path.
67
  """
68
  model = AutoModelForCausalLM.from_pretrained(
69
- model_args.model_name_or_path,
70
- trust_remote_code=model_args.trust_remote_code,
71
  torch_dtype=torch.float16,
72
  device_map="auto"
73
  )
74
  tokenizer = AutoTokenizer.from_pretrained(
75
- model_args.model_name_or_path,
76
- trust_remote_code=model_args.trust_remote_code
77
  )
78
 
79
  # Set pad token if missing
80
  if tokenizer.pad_token is None:
81
  tokenizer.pad_token = tokenizer.eos_token
82
 
83
- # Setup chat format if not present
84
- if tokenizer.chat_template is None:
85
- model, tokenizer = setup_chat_format(model, tokenizer)
86
-
87
  return model, tokenizer
88
 
89
-
90
-
91
- def load_and_format_oasst_dataset(tokenizer):
92
- """
93
- Load, process, and format the OpenAssistant dataset into DPO-compatible format.
94
-
95
- Args:
96
- split (str): The dataset split to load ('train' or 'test').
97
- tokenizer (AutoTokenizer): Tokenizer to apply chat templates.
98
- num_proc (int, optional): Number of processes for parallel processing.
99
-
100
- Returns:
101
- Dataset: Processed and formatted dataset.
102
- """
103
-
104
- # Load oasst dataset
105
- train_dataset = get_oasst(split='train')
106
-
107
- # Initialize lists for DPO dataset
108
- dpo_train_data = {
109
- "prompt": [],
110
- "chosen": [],
111
- "rejected": []
112
- }
113
-
114
- # Process the dataset
115
- for prompt, key in train_dataset.data.items(): # Iterate over dataset
116
- if hasattr(key, "pairs") and key.pairs: # Check if pairs exist
117
- for i, j in key.pairs: # Process each preference pair
118
- # Add prompt and corresponding chosen/rejected completions
119
- dpo_train_data["prompt"].append(key.prompt)
120
- dpo_train_data["chosen"].append(key.generations[i]) # Chosen generation
121
- dpo_train_data["rejected"].append(key.generations[j]) # Rejected generation
122
-
123
- # Convert DPO data into a Dataset
124
- dpo_train_dataset = Dataset.from_dict(dpo_train_data)
125
-
126
- # Wrap it in a DatasetDict
127
- dataset_dict = DatasetDict({
128
- "train": dpo_train_dataset
129
- })
130
-
131
-
132
- test_dataset = get_oasst(split='test')
133
-
134
- dpo_test_data = {
135
- "prompt": [],
136
- "chosen": [],
137
- "rejected": []
138
- }
139
-
140
- for prompt, key in test_dataset.data.items(): # Iterate over dataset
141
- if hasattr(key, "pairs") and key.pairs: # Check if pairs exist
142
- for i, j in key.pairs: # Process each preference pair
143
- # Add prompt and corresponding chosen/rejected completions
144
- dpo_test_data["prompt"].append(key.prompt)
145
- dpo_test_data["chosen"].append(key.generations[i]) # Chosen generation
146
- dpo_test_data["rejected"].append(key.generations[j]) # Rejected generation
147
-
148
- dpo_test_dataset = Dataset.from_dict(dpo_test_data)
149
- dataset_dict["test"] = dpo_test_dataset
150
-
151
- # If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
152
- dataset_dict = maybe_unpair_preference_dataset(dataset_dict, num_proc=training_args.dataset_num_proc)
153
- print(f'loaded dataset')
154
-
155
-
156
- # Apply chat template
157
- def format_dataset(example):
158
- # Ensure prompt is in the correct structure
159
- if isinstance(example["prompt"], str):
160
- example["prompt"] = [{"role": "user", "content": example["prompt"]}]
161
- elif isinstance(example["prompt"], list):
162
- # If it's already a list, ensure each element has the "role" and "content" keys
163
- for item in example["prompt"]:
164
- if "role" not in item or "content" not in item:
165
- raise ValueError(f"Each item in 'prompt' must have 'role' and 'content': {item}")
166
-
167
- # Ensure completion is in the correct structure
168
- if isinstance(example["completion"], str):
169
- example["completion"] = [{"role": "assistant", "content": example["completion"]}] # Wrap as a list of dictionaries
170
- elif isinstance(example["completion"], list):
171
- # If it's already a list, ensure each element has the "role" and "content" keys
172
- for item in example["completion"]:
173
- if "role" not in item or "content" not in item:
174
- raise ValueError(f"Each item in 'completion' must have 'role' and 'content': {item}")
175
-
176
- # Now apply the chat template
177
- example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
178
- example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
179
-
180
- return example
181
-
182
-
183
- # Compute that only on the main process for faster data processing.
184
- # see: https://github.com/huggingface/trl/pull/1255
185
- with PartialState().local_main_process_first():
186
- dataset = dataset_dict.map(format_dataset, num_proc=training_args.dataset_num_proc)
187
-
188
- return dataset
189
-
190
  ####################################
191
  # MAIN LOGIC
192
  ####################################
@@ -197,26 +79,42 @@ def main():
197
 
198
  # Load models and tokenizer
199
  print("Loading models and tokenizer...")
200
- model, tokenizer = load_model_and_tokenizer(model_args)
201
- ref_model, _ = load_model_and_tokenizer(model_args)
202
  print("Models and tokenizer loaded.")
203
 
204
- # Load and process datasets
205
- print("Loading, processing, and formatting dataset...")
206
- dataset = load_and_format_oasst_dataset(
207
- tokenizer=tokenizer,
208
- )
209
 
210
  # Initialize trainer
211
  print("Initializing trainer...")
212
  trainer = KTOTrainer(
213
  model=model,
214
  ref_model=ref_model,
215
- args=training_args,
 
 
 
 
 
 
 
 
 
 
 
 
216
  train_dataset=dataset["train"],
217
  eval_dataset=dataset["test"],
218
  tokenizer=tokenizer,
219
- peft_config=get_peft_config(model_args),
 
 
 
 
 
220
  )
221
 
222
  # Training
@@ -232,26 +130,12 @@ def main():
232
  trainer.save_metrics("eval", metrics)
233
 
234
  # Log metrics to wandb
235
- wandb.log({
236
- "epoch": metrics.get("epoch"),
237
- "grad_norm": metrics.get("grad_norm"),
238
- "kl": metrics.get("kl"),
239
- "learning_rate": metrics.get("learning_rate"),
240
- "logits/chosen": metrics.get("logits/chosen"),
241
- "logits/rejected": metrics.get("logits/rejected"),
242
- "logps/chosen": metrics.get("logps/chosen"),
243
- "logps/rejected": metrics.get("logps/rejected"),
244
- "loss": metrics.get("loss"),
245
- "rewards/chosen": metrics.get("rewards/chosen"),
246
- "rewards/margins": metrics.get("rewards/margins"),
247
- "rewards/rejected": metrics.get("rewards/rejected"),
248
- "step": metrics.get("step")
249
- })
250
 
251
  # Save model and optionally push to hub
252
- trainer.save_model(training_args.output_dir)
253
- if script_args.push_to_hub:
254
- trainer.push_to_hub(dataset_name=script_args.dataset_name)
255
 
256
  print("Process completed.")
257
 
 
1
  import torch
2
  from dataclasses import dataclass
3
  from accelerate import PartialState
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from trl import KTOConfig, KTOTrainer, get_peft_config
6
+ from kto_dataset_processor import process_dataset_ultrafeedback
7
+ from datetime import datetime
 
8
  import wandb
9
 
 
10
  ####################################
11
  # CONFIGURATION
12
  ####################################
13
 
14
  @dataclass
15
+ class Config:
16
  """
17
  Configuration for the script.
18
  """
19
+ # Dataset settings
20
+ process_dataset_func: callable = process_dataset_ultrafeedback # Dataset processing function
 
 
 
21
 
22
+ # Model settings
23
+ model_name: str = "HuggingFaceH4/zephyr-7b-beta" # Pretrained model name or path
24
+ use_peft: bool = True
25
+ lora_target_modules: str = "all-linear"
26
+ lora_r: int = 16
27
+ lora_alpha: int = 16
28
+
29
+ # Training settings
30
+ output_dir: str = f"kto_{model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
31
+ num_train_epochs: int = 1
32
+ per_device_train_batch_size: int = 4
33
  learning_rate: float = 5e-7
34
  lr_scheduler_type: str = "cosine"
35
  gradient_accumulation_steps: int = 1
 
39
  bf16: bool = True
40
  logging_first_step: bool = True
41
 
42
+ # Checkpoint and hub settings
43
+ checkpoint_path: str = None
44
+ push_to_hub: bool = False
 
 
 
 
 
 
 
45
 
46
+ # Initialize the unified configuration
47
+ config = Config()
 
 
48
 
49
  ####################################
50
  # HELPER FUNCTIONS
51
  ####################################
52
 
53
+ def load_model_and_tokenizer(config):
54
  """
55
  Load a model and tokenizer from a specified path.
56
  """
57
  model = AutoModelForCausalLM.from_pretrained(
58
+ config.model_name,
 
59
  torch_dtype=torch.float16,
60
  device_map="auto"
61
  )
62
  tokenizer = AutoTokenizer.from_pretrained(
63
+ config.model_name
 
64
  )
65
 
66
  # Set pad token if missing
67
  if tokenizer.pad_token is None:
68
  tokenizer.pad_token = tokenizer.eos_token
69
 
 
 
 
 
70
  return model, tokenizer
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  ####################################
73
  # MAIN LOGIC
74
  ####################################
 
79
 
80
  # Load models and tokenizer
81
  print("Loading models and tokenizer...")
82
+ model, tokenizer = load_model_and_tokenizer(config)
83
+ ref_model, _ = load_model_and_tokenizer(config)
84
  print("Models and tokenizer loaded.")
85
 
86
+ # Load and process datasets using the specified function
87
+ print("Processing dataset...")
88
+ dataset = config.process_dataset_func()
89
+ print("Dataset processed.")
 
90
 
91
  # Initialize trainer
92
  print("Initializing trainer...")
93
  trainer = KTOTrainer(
94
  model=model,
95
  ref_model=ref_model,
96
+ args=KTOConfig(
97
+ output_dir=config.output_dir,
98
+ num_train_epochs=config.num_train_epochs,
99
+ per_device_train_batch_size=config.per_device_train_batch_size,
100
+ learning_rate=config.learning_rate,
101
+ lr_scheduler_type=config.lr_scheduler_type,
102
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
103
+ logging_steps=config.logging_steps,
104
+ eval_steps=config.eval_steps,
105
+ warmup_ratio=config.warmup_ratio,
106
+ bf16=config.bf16,
107
+ logging_first_step=config.logging_first_step,
108
+ ),
109
  train_dataset=dataset["train"],
110
  eval_dataset=dataset["test"],
111
  tokenizer=tokenizer,
112
+ peft_config=get_peft_config({
113
+ "use_peft": config.use_peft,
114
+ "lora_target_modules": config.lora_target_modules,
115
+ "lora_r": config.lora_r,
116
+ "lora_alpha": config.lora_alpha,
117
+ }),
118
  )
119
 
120
  # Training
 
130
  trainer.save_metrics("eval", metrics)
131
 
132
  # Log metrics to wandb
133
+ wandb.log(metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # Save model and optionally push to hub
136
+ trainer.save_model(config.output_dir)
137
+ if config.push_to_hub:
138
+ trainer.push_to_hub()
139
 
140
  print("Process completed.")
141