Jen Ben Arye commited on
Commit
71053f2
·
1 Parent(s): c8a2d4e

updated kto pipeline to work with general dataset

Browse files
Files changed (1) hide show
  1. kto_pipeline.py +52 -44
kto_pipeline.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
  from dataclasses import dataclass
3
  from accelerate import PartialState
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
- from trl import KTOConfig, KTOTrainer, get_peft_config
6
  from kto_dataset_processor import process_dataset_ultrafeedback
7
  from datetime import datetime
8
  import wandb
@@ -12,24 +12,33 @@ import wandb
12
  ####################################
13
 
14
  @dataclass
15
- class Config:
16
  """
17
  Configuration for the script.
18
  """
19
- # Dataset settings
20
- process_dataset_func: callable = process_dataset_ultrafeedback # Dataset processing function
 
21
 
22
- # Model settings
23
- model_name: str = "HuggingFaceH4/zephyr-7b-beta" # Pretrained model name or path
 
 
 
 
24
  use_peft: bool = True
25
  lora_target_modules: str = "all-linear"
26
  lora_r: int = 16
27
  lora_alpha: int = 16
28
 
29
- # Training settings
30
- output_dir: str = f"kto_{model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
 
 
 
 
31
  num_train_epochs: int = 1
32
- per_device_train_batch_size: int = 4
33
  learning_rate: float = 5e-7
34
  lr_scheduler_type: str = "cosine"
35
  gradient_accumulation_steps: int = 1
@@ -39,28 +48,30 @@ class Config:
39
  bf16: bool = True
40
  logging_first_step: bool = True
41
 
42
- # Checkpoint and hub settings
43
- checkpoint_path: str = None
44
- push_to_hub: bool = False
45
 
46
- # Initialize the unified configuration
47
- config = Config()
 
 
 
48
 
49
  ####################################
50
  # HELPER FUNCTIONS
51
  ####################################
52
 
53
- def load_model_and_tokenizer(config):
54
  """
55
  Load a model and tokenizer from a specified path.
56
  """
57
  model = AutoModelForCausalLM.from_pretrained(
58
- config.model_name,
 
59
  torch_dtype=torch.float16,
60
  device_map="auto"
61
  )
62
  tokenizer = AutoTokenizer.from_pretrained(
63
- config.model_name
 
64
  )
65
 
66
  # Set pad token if missing
@@ -79,13 +90,13 @@ def main():
79
 
80
  # Load models and tokenizer
81
  print("Loading models and tokenizer...")
82
- model, tokenizer = load_model_and_tokenizer(config)
83
- ref_model, _ = load_model_and_tokenizer(config)
84
  print("Models and tokenizer loaded.")
85
 
86
- # Load and process datasets using the specified function
87
  print("Processing dataset...")
88
- dataset = config.process_dataset_func()
89
  print("Dataset processed.")
90
 
91
  # Initialize trainer
@@ -93,28 +104,11 @@ def main():
93
  trainer = KTOTrainer(
94
  model=model,
95
  ref_model=ref_model,
96
- args=KTOConfig(
97
- output_dir=config.output_dir,
98
- num_train_epochs=config.num_train_epochs,
99
- per_device_train_batch_size=config.per_device_train_batch_size,
100
- learning_rate=config.learning_rate,
101
- lr_scheduler_type=config.lr_scheduler_type,
102
- gradient_accumulation_steps=config.gradient_accumulation_steps,
103
- logging_steps=config.logging_steps,
104
- eval_steps=config.eval_steps,
105
- warmup_ratio=config.warmup_ratio,
106
- bf16=config.bf16,
107
- logging_first_step=config.logging_first_step,
108
- ),
109
  train_dataset=dataset["train"],
110
  eval_dataset=dataset["test"],
111
  tokenizer=tokenizer,
112
- peft_config=get_peft_config({
113
- "use_peft": config.use_peft,
114
- "lora_target_modules": config.lora_target_modules,
115
- "lora_r": config.lora_r,
116
- "lora_alpha": config.lora_alpha,
117
- }),
118
  )
119
 
120
  # Training
@@ -130,11 +124,25 @@ def main():
130
  trainer.save_metrics("eval", metrics)
131
 
132
  # Log metrics to wandb
133
- wandb.log(metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # Save model and optionally push to hub
136
- trainer.save_model(config.output_dir)
137
- if config.push_to_hub:
138
  trainer.push_to_hub()
139
 
140
  print("Process completed.")
 
1
  import torch
2
  from dataclasses import dataclass
3
  from accelerate import PartialState
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
5
+ from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config
6
  from kto_dataset_processor import process_dataset_ultrafeedback
7
  from datetime import datetime
8
  import wandb
 
12
  ####################################
13
 
14
  @dataclass
15
+ class ScriptArguments:
16
  """
17
  Configuration for the script.
18
  """
19
+ process_dataset_func: callable = process_dataset_ultrafeedback # process_dataset function from kto_dataset_processor.py
20
+ checkpoint_path: str = None # Checkpoint path
21
+ push_to_hub: bool = False # Whether to push the model to the Hugging Face hub
22
 
23
+ @dataclass
24
+ class ModelArguments(ModelConfig):
25
+ """
26
+ Configuration for the model.
27
+ """
28
+ model_name: str = "HuggingFaceH4/zephyr-7b-beta"
29
  use_peft: bool = True
30
  lora_target_modules: str = "all-linear"
31
  lora_r: int = 16
32
  lora_alpha: int = 16
33
 
34
+ @dataclass
35
+ class TrainingArguments(KTOConfig):
36
+ """
37
+ Configuration for the KTO trainer.
38
+ """
39
+ output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
40
  num_train_epochs: int = 1
41
+ per_device_train_batch_size: int = 4 # Highest that runs well
42
  learning_rate: float = 5e-7
43
  lr_scheduler_type: str = "cosine"
44
  gradient_accumulation_steps: int = 1
 
48
  bf16: bool = True
49
  logging_first_step: bool = True
50
 
 
 
 
51
 
52
+
53
+ # Initialize configurations
54
+ script_args = ScriptArguments()
55
+ training_args = TrainingArguments()
56
+ model_args = ModelArguments()
57
 
58
  ####################################
59
  # HELPER FUNCTIONS
60
  ####################################
61
 
62
+ def load_model_and_tokenizer(model_args):
63
  """
64
  Load a model and tokenizer from a specified path.
65
  """
66
  model = AutoModelForCausalLM.from_pretrained(
67
+ model_args.model_name,
68
+ trust_remote_code=model_args.trust_remote_code,
69
  torch_dtype=torch.float16,
70
  device_map="auto"
71
  )
72
  tokenizer = AutoTokenizer.from_pretrained(
73
+ model_args.model_name,
74
+ trust_remote_code=model_args.trust_remote_code
75
  )
76
 
77
  # Set pad token if missing
 
90
 
91
  # Load models and tokenizer
92
  print("Loading models and tokenizer...")
93
+ model, tokenizer = load_model_and_tokenizer(model_args)
94
+ ref_model, _ = load_model_and_tokenizer(model_args)
95
  print("Models and tokenizer loaded.")
96
 
97
+ # Load and process datasets using external function
98
  print("Processing dataset...")
99
+ dataset = process_dataset_ultrafeedback()
100
  print("Dataset processed.")
101
 
102
  # Initialize trainer
 
104
  trainer = KTOTrainer(
105
  model=model,
106
  ref_model=ref_model,
107
+ args=training_args,
 
 
 
 
 
 
 
 
 
 
 
 
108
  train_dataset=dataset["train"],
109
  eval_dataset=dataset["test"],
110
  tokenizer=tokenizer,
111
+ peft_config=get_peft_config(model_args),
 
 
 
 
 
112
  )
113
 
114
  # Training
 
124
  trainer.save_metrics("eval", metrics)
125
 
126
  # Log metrics to wandb
127
+ wandb.log({
128
+ "epoch": metrics.get("epoch"),
129
+ "grad_norm": metrics.get("grad_norm"),
130
+ "kl": metrics.get("kl"),
131
+ "learning_rate": metrics.get("learning_rate"),
132
+ "logits/chosen": metrics.get("logits/chosen"),
133
+ "logits/rejected": metrics.get("logits/rejected"),
134
+ "logps/chosen": metrics.get("logps/chosen"),
135
+ "logps/rejected": metrics.get("logps/rejected"),
136
+ "loss": metrics.get("loss"),
137
+ "rewards/chosen": metrics.get("rewards/chosen"),
138
+ "rewards/margins": metrics.get("rewards/margins"),
139
+ "rewards/rejected": metrics.get("rewards/rejected"),
140
+ "step": metrics.get("step")
141
+ })
142
 
143
  # Save model and optionally push to hub
144
+ trainer.save_model(training_args.output_dir)
145
+ if script_args.push_to_hub:
146
  trainer.push_to_hub()
147
 
148
  print("Process completed.")