Spaces:
Running
Running
feat: refactor TrainingArguments
Browse files- tools/train/train.py +63 -42
tools/train/train.py
CHANGED
@@ -65,7 +65,7 @@ class ModelArguments:
|
|
65 |
config_name: Optional[str] = field(
|
66 |
default=None,
|
67 |
metadata={
|
68 |
-
"help": "Pretrained config name or path if not the same as
|
69 |
},
|
70 |
)
|
71 |
tokenizer_name: Optional[str] = field(
|
@@ -77,7 +77,7 @@ class ModelArguments:
|
|
77 |
dtype: Optional[str] = field(
|
78 |
default="float32",
|
79 |
metadata={
|
80 |
-
"help": "Floating-point format in which the
|
81 |
},
|
82 |
)
|
83 |
|
@@ -106,11 +106,15 @@ class DataTrainingArguments:
|
|
106 |
)
|
107 |
train_file: Optional[str] = field(
|
108 |
default=None,
|
109 |
-
metadata={
|
|
|
|
|
110 |
)
|
111 |
validation_file: Optional[str] = field(
|
112 |
default=None,
|
113 |
-
metadata={
|
|
|
|
|
114 |
)
|
115 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
116 |
streaming: Optional[bool] = field(
|
@@ -132,15 +136,13 @@ class DataTrainingArguments:
|
|
132 |
max_train_samples: Optional[int] = field(
|
133 |
default=None,
|
134 |
metadata={
|
135 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples
|
136 |
-
"value if set."
|
137 |
},
|
138 |
)
|
139 |
max_eval_samples: Optional[int] = field(
|
140 |
default=None,
|
141 |
metadata={
|
142 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples
|
143 |
-
"value if set."
|
144 |
},
|
145 |
)
|
146 |
preprocessing_num_workers: Optional[int] = field(
|
@@ -191,42 +193,42 @@ class TrainingArguments:
|
|
191 |
|
192 |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
193 |
do_eval: bool = field(
|
194 |
-
default=False, metadata={"help": "Whether to run eval on the
|
195 |
)
|
196 |
|
197 |
per_device_train_batch_size: int = field(
|
198 |
-
default=8, metadata={"help": "Batch size per GPU/TPU
|
199 |
)
|
200 |
per_device_eval_batch_size: int = field(
|
201 |
-
default=8, metadata={"help": "Batch size per GPU/TPU
|
202 |
)
|
203 |
|
204 |
gradient_accumulation_steps: int = field(
|
205 |
default=1,
|
206 |
metadata={
|
207 |
-
"help": "Number of updates steps to accumulate before performing
|
208 |
},
|
209 |
)
|
210 |
|
211 |
learning_rate: float = field(
|
212 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
213 |
)
|
214 |
-
|
215 |
-
default=
|
216 |
-
metadata={
|
217 |
-
|
218 |
-
|
219 |
-
default=False,
|
220 |
-
metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
|
221 |
)
|
222 |
weight_decay: float = field(
|
223 |
default=None, metadata={"help": "Weight decay if we apply some."}
|
224 |
)
|
225 |
-
|
226 |
-
default=0.9,
|
|
|
227 |
)
|
228 |
-
|
229 |
-
default=0.999,
|
|
|
230 |
)
|
231 |
adam_epsilon: float = field(
|
232 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
@@ -234,6 +236,16 @@ class TrainingArguments:
|
|
234 |
max_grad_norm: float = field(
|
235 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
236 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
use_decay: bool = field(
|
238 |
default=False,
|
239 |
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
@@ -272,6 +284,13 @@ class TrainingArguments:
|
|
272 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
273 |
)
|
274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
class TrainState(train_state.TrainState):
|
277 |
dropout_rng: jnp.ndarray = None
|
@@ -551,29 +570,22 @@ def main():
|
|
551 |
return traverse_util.unflatten_dict(flat_mask)
|
552 |
|
553 |
# create adam optimizer
|
554 |
-
if training_args.
|
555 |
-
# We use the default parameters here to initialize adafactor,
|
556 |
-
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
557 |
-
optimizer = optax.adafactor(
|
558 |
-
learning_rate=learning_rate_fn,
|
559 |
-
weight_decay_rate=training_args.weight_decay,
|
560 |
-
weight_decay_mask=decay_mask_fn,
|
561 |
-
clipping_threshold=training_args.max_grad_norm,
|
562 |
-
)
|
563 |
-
elif training_args.distributed_shampoo:
|
564 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
565 |
# Notes:
|
566 |
-
# - mask for weight decay is not implemented
|
567 |
optimizer = distributed_shampoo(
|
568 |
learning_rate_fn,
|
569 |
block_size=1024, # recommended default for large LM is 1536
|
570 |
-
beta1=
|
571 |
-
beta2=
|
572 |
diagonal_epsilon=1e-10,
|
573 |
matrix_epsilon=1e-8,
|
574 |
-
weight_decay=
|
|
|
|
|
575 |
start_preconditioning_step=1001,
|
576 |
-
preconditioning_compute_steps=
|
577 |
statistics_compute_steps=1,
|
578 |
best_effort_shape_interpretation=True,
|
579 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
@@ -585,20 +597,29 @@ def main():
|
|
585 |
skip_preconditioning_dim_size_gt=4096,
|
586 |
clip_by_scaled_gradient_norm=None,
|
587 |
precision=jax.lax.Precision.HIGHEST,
|
588 |
-
best_effort_memory_usage_reduction=
|
589 |
)
|
590 |
|
591 |
-
|
592 |
optimizer = optax.adamw(
|
593 |
learning_rate=learning_rate_fn,
|
594 |
-
b1=training_args.
|
595 |
-
b2=training_args.
|
596 |
eps=training_args.adam_epsilon,
|
597 |
weight_decay=training_args.weight_decay
|
598 |
if training_args.weight_decay is not None
|
599 |
else 0.0,
|
600 |
mask=decay_mask_fn,
|
601 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
602 |
|
603 |
# add gradient accumulation
|
604 |
if training_args.gradient_accumulation_steps > 1:
|
|
|
65 |
config_name: Optional[str] = field(
|
66 |
default=None,
|
67 |
metadata={
|
68 |
+
"help": "Pretrained config name or path if not the same as model_name_or_path"
|
69 |
},
|
70 |
)
|
71 |
tokenizer_name: Optional[str] = field(
|
|
|
77 |
dtype: Optional[str] = field(
|
78 |
default="float32",
|
79 |
metadata={
|
80 |
+
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
81 |
},
|
82 |
)
|
83 |
|
|
|
106 |
)
|
107 |
train_file: Optional[str] = field(
|
108 |
default=None,
|
109 |
+
metadata={
|
110 |
+
"help": "The input training data file (glob & braceexpand acceptable)."
|
111 |
+
},
|
112 |
)
|
113 |
validation_file: Optional[str] = field(
|
114 |
default=None,
|
115 |
+
metadata={
|
116 |
+
"help": "An optional input evaluation data file (glob & braceexpand acceptable)."
|
117 |
+
},
|
118 |
)
|
119 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
120 |
streaming: Optional[bool] = field(
|
|
|
136 |
max_train_samples: Optional[int] = field(
|
137 |
default=None,
|
138 |
metadata={
|
139 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples."
|
|
|
140 |
},
|
141 |
)
|
142 |
max_eval_samples: Optional[int] = field(
|
143 |
default=None,
|
144 |
metadata={
|
145 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
|
|
|
146 |
},
|
147 |
)
|
148 |
preprocessing_num_workers: Optional[int] = field(
|
|
|
193 |
|
194 |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
195 |
do_eval: bool = field(
|
196 |
+
default=False, metadata={"help": "Whether to run eval on the validation set."}
|
197 |
)
|
198 |
|
199 |
per_device_train_batch_size: int = field(
|
200 |
+
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
|
201 |
)
|
202 |
per_device_eval_batch_size: int = field(
|
203 |
+
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for evaluation."}
|
204 |
)
|
205 |
|
206 |
gradient_accumulation_steps: int = field(
|
207 |
default=1,
|
208 |
metadata={
|
209 |
+
"help": "Number of updates steps to accumulate before performing an update pass."
|
210 |
},
|
211 |
)
|
212 |
|
213 |
learning_rate: float = field(
|
214 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
215 |
)
|
216 |
+
optim: str = field(
|
217 |
+
default="distributed_shampoo",
|
218 |
+
metadata={
|
219 |
+
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
220 |
+
},
|
|
|
|
|
221 |
)
|
222 |
weight_decay: float = field(
|
223 |
default=None, metadata={"help": "Weight decay if we apply some."}
|
224 |
)
|
225 |
+
beta1: float = field(
|
226 |
+
default=0.9,
|
227 |
+
metadata={"help": "Beta1 for adam & distributed_shampoo optimizers"},
|
228 |
)
|
229 |
+
beta2: float = field(
|
230 |
+
default=0.999,
|
231 |
+
metadata={"help": "Beta2 for adam & distributed_shampoo optimizers"},
|
232 |
)
|
233 |
adam_epsilon: float = field(
|
234 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
|
236 |
max_grad_norm: float = field(
|
237 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
238 |
)
|
239 |
+
preconditioning_compute_steps: int = field(
|
240 |
+
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
241 |
+
)
|
242 |
+
optim_quantized: bool = field(
|
243 |
+
default=False,
|
244 |
+
metadat={
|
245 |
+
"help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
|
246 |
+
},
|
247 |
+
)
|
248 |
+
|
249 |
use_decay: bool = field(
|
250 |
default=False,
|
251 |
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
|
|
284 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
285 |
)
|
286 |
|
287 |
+
def __post_init__(self):
|
288 |
+
assert self.optim in [
|
289 |
+
"distributed_shampoo",
|
290 |
+
"adam",
|
291 |
+
"adafactor",
|
292 |
+
], f"Selected optimizer not supported: {self.optim}"
|
293 |
+
|
294 |
|
295 |
class TrainState(train_state.TrainState):
|
296 |
dropout_rng: jnp.ndarray = None
|
|
|
570 |
return traverse_util.unflatten_dict(flat_mask)
|
571 |
|
572 |
# create adam optimizer
|
573 |
+
if training_args.optim == "distributed_shampoo":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
575 |
# Notes:
|
576 |
+
# - mask for weight decay is not implemented
|
577 |
optimizer = distributed_shampoo(
|
578 |
learning_rate_fn,
|
579 |
block_size=1024, # recommended default for large LM is 1536
|
580 |
+
beta1=training_args.beta1,
|
581 |
+
beta2=training_args.beta2,
|
582 |
diagonal_epsilon=1e-10,
|
583 |
matrix_epsilon=1e-8,
|
584 |
+
weight_decay=training_args.weight_decay
|
585 |
+
if training_args.weight_decay is not None
|
586 |
+
else 0.0,
|
587 |
start_preconditioning_step=1001,
|
588 |
+
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
589 |
statistics_compute_steps=1,
|
590 |
best_effort_shape_interpretation=True,
|
591 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
|
|
597 |
skip_preconditioning_dim_size_gt=4096,
|
598 |
clip_by_scaled_gradient_norm=None,
|
599 |
precision=jax.lax.Precision.HIGHEST,
|
600 |
+
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
601 |
)
|
602 |
|
603 |
+
elif training_args.optim == "adam":
|
604 |
optimizer = optax.adamw(
|
605 |
learning_rate=learning_rate_fn,
|
606 |
+
b1=training_args.beta1,
|
607 |
+
b2=training_args.beta2,
|
608 |
eps=training_args.adam_epsilon,
|
609 |
weight_decay=training_args.weight_decay
|
610 |
if training_args.weight_decay is not None
|
611 |
else 0.0,
|
612 |
mask=decay_mask_fn,
|
613 |
)
|
614 |
+
elif training_args.optim == "adafactor":
|
615 |
+
# We use the default parameters here to initialize adafactor,
|
616 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
617 |
+
optimizer = optax.adafactor(
|
618 |
+
learning_rate=learning_rate_fn,
|
619 |
+
weight_decay_rate=training_args.weight_decay,
|
620 |
+
weight_decay_mask=decay_mask_fn,
|
621 |
+
clipping_threshold=training_args.max_grad_norm,
|
622 |
+
)
|
623 |
|
624 |
# add gradient accumulation
|
625 |
if training_args.gradient_accumulation_steps > 1:
|