Jen Ben Arye
commited on
Commit
·
0e83f57
1
Parent(s):
6498d39
updated kto pipeline to work with general dataset
Browse files- kto_pipeline.py +55 -171
kto_pipeline.py
CHANGED
@@ -1,37 +1,35 @@
|
|
1 |
import torch
|
2 |
from dataclasses import dataclass
|
3 |
from accelerate import PartialState
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from
|
7 |
-
from
|
8 |
-
from pdb import set_trace as st
|
9 |
import wandb
|
10 |
|
11 |
-
|
12 |
####################################
|
13 |
# CONFIGURATION
|
14 |
####################################
|
15 |
|
16 |
@dataclass
|
17 |
-
class
|
18 |
"""
|
19 |
Configuration for the script.
|
20 |
"""
|
21 |
-
|
22 |
-
|
23 |
-
pretrained_model_name: str = "mistralai/Mistral-7B-v0.1" # Pretrained model name or path
|
24 |
-
checkpoint_path: str = "/raid/lingo/jen_ben/HF-RLHF/kto_nov_24_2_epochs" # Checkpoint path
|
25 |
-
push_to_hub: bool = False # Whether to push the model to the Hugging Face hub
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
learning_rate: float = 5e-7
|
36 |
lr_scheduler_type: str = "cosine"
|
37 |
gradient_accumulation_steps: int = 1
|
@@ -41,152 +39,36 @@ class TrainingArguments(KTOConfig):
|
|
41 |
bf16: bool = True
|
42 |
logging_first_step: bool = True
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
Configuration for the model.
|
48 |
-
"""
|
49 |
-
model_name_or_path: str = "mistralai/Mistral-7B-v0.1"
|
50 |
-
use_peft: bool = True
|
51 |
-
lora_target_modules: str = "all-linear"
|
52 |
-
lora_r: int = 16
|
53 |
-
lora_alpha: int = 16
|
54 |
|
55 |
-
# Initialize
|
56 |
-
|
57 |
-
training_args = TrainingArguments(output_dir=script_args.output_dir)
|
58 |
-
model_args = ModelArguments(model_name_or_path=script_args.pretrained_model_name)
|
59 |
|
60 |
####################################
|
61 |
# HELPER FUNCTIONS
|
62 |
####################################
|
63 |
|
64 |
-
def load_model_and_tokenizer(
|
65 |
"""
|
66 |
Load a model and tokenizer from a specified path.
|
67 |
"""
|
68 |
model = AutoModelForCausalLM.from_pretrained(
|
69 |
-
|
70 |
-
trust_remote_code=model_args.trust_remote_code,
|
71 |
torch_dtype=torch.float16,
|
72 |
device_map="auto"
|
73 |
)
|
74 |
tokenizer = AutoTokenizer.from_pretrained(
|
75 |
-
|
76 |
-
trust_remote_code=model_args.trust_remote_code
|
77 |
)
|
78 |
|
79 |
# Set pad token if missing
|
80 |
if tokenizer.pad_token is None:
|
81 |
tokenizer.pad_token = tokenizer.eos_token
|
82 |
|
83 |
-
# Setup chat format if not present
|
84 |
-
if tokenizer.chat_template is None:
|
85 |
-
model, tokenizer = setup_chat_format(model, tokenizer)
|
86 |
-
|
87 |
return model, tokenizer
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
def load_and_format_oasst_dataset(tokenizer):
|
92 |
-
"""
|
93 |
-
Load, process, and format the OpenAssistant dataset into DPO-compatible format.
|
94 |
-
|
95 |
-
Args:
|
96 |
-
split (str): The dataset split to load ('train' or 'test').
|
97 |
-
tokenizer (AutoTokenizer): Tokenizer to apply chat templates.
|
98 |
-
num_proc (int, optional): Number of processes for parallel processing.
|
99 |
-
|
100 |
-
Returns:
|
101 |
-
Dataset: Processed and formatted dataset.
|
102 |
-
"""
|
103 |
-
|
104 |
-
# Load oasst dataset
|
105 |
-
train_dataset = get_oasst(split='train')
|
106 |
-
|
107 |
-
# Initialize lists for DPO dataset
|
108 |
-
dpo_train_data = {
|
109 |
-
"prompt": [],
|
110 |
-
"chosen": [],
|
111 |
-
"rejected": []
|
112 |
-
}
|
113 |
-
|
114 |
-
# Process the dataset
|
115 |
-
for prompt, key in train_dataset.data.items(): # Iterate over dataset
|
116 |
-
if hasattr(key, "pairs") and key.pairs: # Check if pairs exist
|
117 |
-
for i, j in key.pairs: # Process each preference pair
|
118 |
-
# Add prompt and corresponding chosen/rejected completions
|
119 |
-
dpo_train_data["prompt"].append(key.prompt)
|
120 |
-
dpo_train_data["chosen"].append(key.generations[i]) # Chosen generation
|
121 |
-
dpo_train_data["rejected"].append(key.generations[j]) # Rejected generation
|
122 |
-
|
123 |
-
# Convert DPO data into a Dataset
|
124 |
-
dpo_train_dataset = Dataset.from_dict(dpo_train_data)
|
125 |
-
|
126 |
-
# Wrap it in a DatasetDict
|
127 |
-
dataset_dict = DatasetDict({
|
128 |
-
"train": dpo_train_dataset
|
129 |
-
})
|
130 |
-
|
131 |
-
|
132 |
-
test_dataset = get_oasst(split='test')
|
133 |
-
|
134 |
-
dpo_test_data = {
|
135 |
-
"prompt": [],
|
136 |
-
"chosen": [],
|
137 |
-
"rejected": []
|
138 |
-
}
|
139 |
-
|
140 |
-
for prompt, key in test_dataset.data.items(): # Iterate over dataset
|
141 |
-
if hasattr(key, "pairs") and key.pairs: # Check if pairs exist
|
142 |
-
for i, j in key.pairs: # Process each preference pair
|
143 |
-
# Add prompt and corresponding chosen/rejected completions
|
144 |
-
dpo_test_data["prompt"].append(key.prompt)
|
145 |
-
dpo_test_data["chosen"].append(key.generations[i]) # Chosen generation
|
146 |
-
dpo_test_data["rejected"].append(key.generations[j]) # Rejected generation
|
147 |
-
|
148 |
-
dpo_test_dataset = Dataset.from_dict(dpo_test_data)
|
149 |
-
dataset_dict["test"] = dpo_test_dataset
|
150 |
-
|
151 |
-
# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
|
152 |
-
dataset_dict = maybe_unpair_preference_dataset(dataset_dict, num_proc=training_args.dataset_num_proc)
|
153 |
-
print(f'loaded dataset')
|
154 |
-
|
155 |
-
|
156 |
-
# Apply chat template
|
157 |
-
def format_dataset(example):
|
158 |
-
# Ensure prompt is in the correct structure
|
159 |
-
if isinstance(example["prompt"], str):
|
160 |
-
example["prompt"] = [{"role": "user", "content": example["prompt"]}]
|
161 |
-
elif isinstance(example["prompt"], list):
|
162 |
-
# If it's already a list, ensure each element has the "role" and "content" keys
|
163 |
-
for item in example["prompt"]:
|
164 |
-
if "role" not in item or "content" not in item:
|
165 |
-
raise ValueError(f"Each item in 'prompt' must have 'role' and 'content': {item}")
|
166 |
-
|
167 |
-
# Ensure completion is in the correct structure
|
168 |
-
if isinstance(example["completion"], str):
|
169 |
-
example["completion"] = [{"role": "assistant", "content": example["completion"]}] # Wrap as a list of dictionaries
|
170 |
-
elif isinstance(example["completion"], list):
|
171 |
-
# If it's already a list, ensure each element has the "role" and "content" keys
|
172 |
-
for item in example["completion"]:
|
173 |
-
if "role" not in item or "content" not in item:
|
174 |
-
raise ValueError(f"Each item in 'completion' must have 'role' and 'content': {item}")
|
175 |
-
|
176 |
-
# Now apply the chat template
|
177 |
-
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
|
178 |
-
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
|
179 |
-
|
180 |
-
return example
|
181 |
-
|
182 |
-
|
183 |
-
# Compute that only on the main process for faster data processing.
|
184 |
-
# see: https://github.com/huggingface/trl/pull/1255
|
185 |
-
with PartialState().local_main_process_first():
|
186 |
-
dataset = dataset_dict.map(format_dataset, num_proc=training_args.dataset_num_proc)
|
187 |
-
|
188 |
-
return dataset
|
189 |
-
|
190 |
####################################
|
191 |
# MAIN LOGIC
|
192 |
####################################
|
@@ -197,26 +79,42 @@ def main():
|
|
197 |
|
198 |
# Load models and tokenizer
|
199 |
print("Loading models and tokenizer...")
|
200 |
-
model, tokenizer = load_model_and_tokenizer(
|
201 |
-
ref_model, _ = load_model_and_tokenizer(
|
202 |
print("Models and tokenizer loaded.")
|
203 |
|
204 |
-
# Load and process datasets
|
205 |
-
print("
|
206 |
-
dataset =
|
207 |
-
|
208 |
-
)
|
209 |
|
210 |
# Initialize trainer
|
211 |
print("Initializing trainer...")
|
212 |
trainer = KTOTrainer(
|
213 |
model=model,
|
214 |
ref_model=ref_model,
|
215 |
-
args=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
train_dataset=dataset["train"],
|
217 |
eval_dataset=dataset["test"],
|
218 |
tokenizer=tokenizer,
|
219 |
-
peft_config=get_peft_config(
|
|
|
|
|
|
|
|
|
|
|
220 |
)
|
221 |
|
222 |
# Training
|
@@ -232,26 +130,12 @@ def main():
|
|
232 |
trainer.save_metrics("eval", metrics)
|
233 |
|
234 |
# Log metrics to wandb
|
235 |
-
wandb.log(
|
236 |
-
"epoch": metrics.get("epoch"),
|
237 |
-
"grad_norm": metrics.get("grad_norm"),
|
238 |
-
"kl": metrics.get("kl"),
|
239 |
-
"learning_rate": metrics.get("learning_rate"),
|
240 |
-
"logits/chosen": metrics.get("logits/chosen"),
|
241 |
-
"logits/rejected": metrics.get("logits/rejected"),
|
242 |
-
"logps/chosen": metrics.get("logps/chosen"),
|
243 |
-
"logps/rejected": metrics.get("logps/rejected"),
|
244 |
-
"loss": metrics.get("loss"),
|
245 |
-
"rewards/chosen": metrics.get("rewards/chosen"),
|
246 |
-
"rewards/margins": metrics.get("rewards/margins"),
|
247 |
-
"rewards/rejected": metrics.get("rewards/rejected"),
|
248 |
-
"step": metrics.get("step")
|
249 |
-
})
|
250 |
|
251 |
# Save model and optionally push to hub
|
252 |
-
trainer.save_model(
|
253 |
-
if
|
254 |
-
trainer.push_to_hub(
|
255 |
|
256 |
print("Process completed.")
|
257 |
|
|
|
1 |
import torch
|
2 |
from dataclasses import dataclass
|
3 |
from accelerate import PartialState
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
from trl import KTOConfig, KTOTrainer, get_peft_config
|
6 |
+
from kto_dataset_processor import process_dataset_ultrafeedback
|
7 |
+
from datetime import datetime
|
|
|
8 |
import wandb
|
9 |
|
|
|
10 |
####################################
|
11 |
# CONFIGURATION
|
12 |
####################################
|
13 |
|
14 |
@dataclass
|
15 |
+
class Config:
|
16 |
"""
|
17 |
Configuration for the script.
|
18 |
"""
|
19 |
+
# Dataset settings
|
20 |
+
process_dataset_func: callable = process_dataset_ultrafeedback # Dataset processing function
|
|
|
|
|
|
|
21 |
|
22 |
+
# Model settings
|
23 |
+
model_name: str = "HuggingFaceH4/zephyr-7b-beta" # Pretrained model name or path
|
24 |
+
use_peft: bool = True
|
25 |
+
lora_target_modules: str = "all-linear"
|
26 |
+
lora_r: int = 16
|
27 |
+
lora_alpha: int = 16
|
28 |
+
|
29 |
+
# Training settings
|
30 |
+
output_dir: str = f"kto_{model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
|
31 |
+
num_train_epochs: int = 1
|
32 |
+
per_device_train_batch_size: int = 4
|
33 |
learning_rate: float = 5e-7
|
34 |
lr_scheduler_type: str = "cosine"
|
35 |
gradient_accumulation_steps: int = 1
|
|
|
39 |
bf16: bool = True
|
40 |
logging_first_step: bool = True
|
41 |
|
42 |
+
# Checkpoint and hub settings
|
43 |
+
checkpoint_path: str = None
|
44 |
+
push_to_hub: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
# Initialize the unified configuration
|
47 |
+
config = Config()
|
|
|
|
|
48 |
|
49 |
####################################
|
50 |
# HELPER FUNCTIONS
|
51 |
####################################
|
52 |
|
53 |
+
def load_model_and_tokenizer(config):
|
54 |
"""
|
55 |
Load a model and tokenizer from a specified path.
|
56 |
"""
|
57 |
model = AutoModelForCausalLM.from_pretrained(
|
58 |
+
config.model_name,
|
|
|
59 |
torch_dtype=torch.float16,
|
60 |
device_map="auto"
|
61 |
)
|
62 |
tokenizer = AutoTokenizer.from_pretrained(
|
63 |
+
config.model_name
|
|
|
64 |
)
|
65 |
|
66 |
# Set pad token if missing
|
67 |
if tokenizer.pad_token is None:
|
68 |
tokenizer.pad_token = tokenizer.eos_token
|
69 |
|
|
|
|
|
|
|
|
|
70 |
return model, tokenizer
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
####################################
|
73 |
# MAIN LOGIC
|
74 |
####################################
|
|
|
79 |
|
80 |
# Load models and tokenizer
|
81 |
print("Loading models and tokenizer...")
|
82 |
+
model, tokenizer = load_model_and_tokenizer(config)
|
83 |
+
ref_model, _ = load_model_and_tokenizer(config)
|
84 |
print("Models and tokenizer loaded.")
|
85 |
|
86 |
+
# Load and process datasets using the specified function
|
87 |
+
print("Processing dataset...")
|
88 |
+
dataset = config.process_dataset_func()
|
89 |
+
print("Dataset processed.")
|
|
|
90 |
|
91 |
# Initialize trainer
|
92 |
print("Initializing trainer...")
|
93 |
trainer = KTOTrainer(
|
94 |
model=model,
|
95 |
ref_model=ref_model,
|
96 |
+
args=KTOConfig(
|
97 |
+
output_dir=config.output_dir,
|
98 |
+
num_train_epochs=config.num_train_epochs,
|
99 |
+
per_device_train_batch_size=config.per_device_train_batch_size,
|
100 |
+
learning_rate=config.learning_rate,
|
101 |
+
lr_scheduler_type=config.lr_scheduler_type,
|
102 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
103 |
+
logging_steps=config.logging_steps,
|
104 |
+
eval_steps=config.eval_steps,
|
105 |
+
warmup_ratio=config.warmup_ratio,
|
106 |
+
bf16=config.bf16,
|
107 |
+
logging_first_step=config.logging_first_step,
|
108 |
+
),
|
109 |
train_dataset=dataset["train"],
|
110 |
eval_dataset=dataset["test"],
|
111 |
tokenizer=tokenizer,
|
112 |
+
peft_config=get_peft_config({
|
113 |
+
"use_peft": config.use_peft,
|
114 |
+
"lora_target_modules": config.lora_target_modules,
|
115 |
+
"lora_r": config.lora_r,
|
116 |
+
"lora_alpha": config.lora_alpha,
|
117 |
+
}),
|
118 |
)
|
119 |
|
120 |
# Training
|
|
|
130 |
trainer.save_metrics("eval", metrics)
|
131 |
|
132 |
# Log metrics to wandb
|
133 |
+
wandb.log(metrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
# Save model and optionally push to hub
|
136 |
+
trainer.save_model(config.output_dir)
|
137 |
+
if config.push_to_hub:
|
138 |
+
trainer.push_to_hub()
|
139 |
|
140 |
print("Process completed.")
|
141 |
|