Spaces:
Running
Running
fix: no gradient checkpointing for new model
Browse files- tools/train/train.py +3 -1
tools/train/train.py
CHANGED
@@ -531,6 +531,9 @@ def main():
|
|
531 |
# Set up our new model config
|
532 |
if model_args.config_name:
|
533 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
|
|
|
|
|
|
534 |
else:
|
535 |
config = None
|
536 |
|
@@ -553,7 +556,6 @@ def main():
|
|
553 |
seed=training_args.seed_model,
|
554 |
dtype=getattr(jnp, model_args.dtype),
|
555 |
load_on_cpu=True,
|
556 |
-
gradient_checkpointing=False,
|
557 |
)
|
558 |
|
559 |
# update model config per training args
|
|
|
531 |
# Set up our new model config
|
532 |
if model_args.config_name:
|
533 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
534 |
+
# initializing params with gradient checkpointing creates issues
|
535 |
+
# we correctly set it later per training_args
|
536 |
+
config.gradient_checkpointing = False
|
537 |
else:
|
538 |
config = None
|
539 |
|
|
|
556 |
seed=training_args.seed_model,
|
557 |
dtype=getattr(jnp, model_args.dtype),
|
558 |
load_on_cpu=True,
|
|
|
559 |
)
|
560 |
|
561 |
# update model config per training args
|