Spaces:
Running
Running
feat(train): handle distributed_shampoo in pjit
Browse files- tools/train/train.py +40 -36
tools/train/train.py
CHANGED
@@ -25,7 +25,7 @@ import sys
|
|
25 |
import time
|
26 |
from dataclasses import asdict, dataclass, field
|
27 |
from pathlib import Path
|
28 |
-
from typing import Callable, Optional
|
29 |
|
30 |
import datasets
|
31 |
import jax
|
@@ -36,7 +36,7 @@ import transformers
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
-
from flax.core.frozen_dict import
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot, stack_forest
|
@@ -523,6 +523,12 @@ def main():
|
|
523 |
use_fast=True,
|
524 |
)
|
525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
# Preprocessing the datasets.
|
527 |
# We need to normalize and tokenize inputs and targets.
|
528 |
|
@@ -620,6 +626,13 @@ def main():
|
|
620 |
precision=jax.lax.Precision.HIGHEST,
|
621 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
622 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
|
624 |
elif training_args.optim == "adam":
|
625 |
optimizer = optax.adamw(
|
@@ -636,43 +649,40 @@ def main():
|
|
636 |
clipping_threshold=training_args.max_grad_norm,
|
637 |
)
|
638 |
|
639 |
-
# get PartitionSpec for model params
|
640 |
-
param_spec = set_partitions(model.params)
|
641 |
-
|
642 |
# get PartitionSpec for optimizer state
|
643 |
def get_opt_state_spec_and_shape(param_spec):
|
644 |
-
if training_args.optim
|
645 |
# get opt_state shape without actual init
|
646 |
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
647 |
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
|
|
|
|
662 |
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
|
667 |
elif training_args.optim == "distributed_shampoo":
|
668 |
-
|
669 |
-
_opt_state = optimizer.init(model.params)
|
670 |
-
opt_state_spec = _opt_state.pspec_fn(
|
671 |
params=model.params,
|
672 |
-
params_partition_spec=
|
673 |
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
674 |
)
|
675 |
-
opt_state_shape =
|
676 |
else:
|
677 |
raise NotImplementedError
|
678 |
return opt_state_spec, opt_state_shape
|
@@ -714,18 +724,12 @@ def main():
|
|
714 |
in_axis_resources=(param_spec,),
|
715 |
out_axis_resources=state_spec,
|
716 |
donate_argnums=(0,),
|
717 |
-
)(
|
718 |
|
719 |
else:
|
720 |
# restore opt_state
|
721 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
722 |
opt_state = from_bytes(opt_state_shape, f.read())
|
723 |
-
# need to freeze dict for pjit
|
724 |
-
opt_state = jax.tree_map(
|
725 |
-
lambda x: freeze(x) if isinstance(x, dict) else x,
|
726 |
-
opt_state,
|
727 |
-
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
728 |
-
)
|
729 |
|
730 |
# restore other attributes
|
731 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
@@ -746,7 +750,7 @@ def main():
|
|
746 |
in_axis_resources=(param_spec, opt_state_spec),
|
747 |
out_axis_resources=state_spec,
|
748 |
donate_argnums=(0, 1),
|
749 |
-
)(
|
750 |
|
751 |
# remove opt_state from CPU
|
752 |
del opt_state
|
|
|
25 |
import time
|
26 |
from dataclasses import asdict, dataclass, field
|
27 |
from pathlib import Path
|
28 |
+
from typing import Any, Callable, NamedTuple, Optional
|
29 |
|
30 |
import datasets
|
31 |
import jax
|
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
+
from flax.core.frozen_dict import FrozenDict, freeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot, stack_forest
|
|
|
523 |
use_fast=True,
|
524 |
)
|
525 |
|
526 |
+
# get PartitionSpec for model params (required to be a dict)
|
527 |
+
param_spec = set_partitions(model.params)
|
528 |
+
|
529 |
+
# convert params to frozen dict
|
530 |
+
model._params = freeze(model.params)
|
531 |
+
|
532 |
# Preprocessing the datasets.
|
533 |
# We need to normalize and tokenize inputs and targets.
|
534 |
|
|
|
626 |
precision=jax.lax.Precision.HIGHEST,
|
627 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
628 |
)
|
629 |
+
# get the real optimizer and helper functions
|
630 |
+
update_fn = optimizer.update
|
631 |
+
optimizer = optimizer.init(model.params)
|
632 |
+
opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
|
633 |
+
optimizer.pspec_fn, optimizer.shape_and_dtype_fn
|
634 |
+
)
|
635 |
+
optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
|
636 |
|
637 |
elif training_args.optim == "adam":
|
638 |
optimizer = optax.adamw(
|
|
|
649 |
clipping_threshold=training_args.max_grad_norm,
|
650 |
)
|
651 |
|
|
|
|
|
|
|
652 |
# get PartitionSpec for optimizer state
|
653 |
def get_opt_state_spec_and_shape(param_spec):
|
654 |
+
if training_args.optim in ["adam", "adafactor"]:
|
655 |
# get opt_state shape without actual init
|
656 |
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
657 |
|
658 |
+
if training_args.optim == "adam":
|
659 |
+
|
660 |
+
def _opt_state_spec_per_leaf(x):
|
661 |
+
if isinstance(x, FrozenDict):
|
662 |
+
# variables with same structure as params
|
663 |
+
return param_spec
|
664 |
+
else:
|
665 |
+
# other variables such as count
|
666 |
+
return None
|
667 |
+
|
668 |
+
opt_state_spec = jax.tree_map(
|
669 |
+
_opt_state_spec_per_leaf,
|
670 |
+
opt_state_shape,
|
671 |
+
# return None spec for empty elements
|
672 |
+
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
|
673 |
+
)
|
674 |
|
675 |
+
elif training_args.optim == "adafactor":
|
676 |
+
# factorized state must be replicated (rank different than params)
|
677 |
+
opt_state_spec = None
|
678 |
|
679 |
elif training_args.optim == "distributed_shampoo":
|
680 |
+
opt_state_spec = opt_fn.pspec_fn(
|
|
|
|
|
681 |
params=model.params,
|
682 |
+
params_partition_spec=param_spec,
|
683 |
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
684 |
)
|
685 |
+
opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
|
686 |
else:
|
687 |
raise NotImplementedError
|
688 |
return opt_state_spec, opt_state_shape
|
|
|
724 |
in_axis_resources=(param_spec,),
|
725 |
out_axis_resources=state_spec,
|
726 |
donate_argnums=(0,),
|
727 |
+
)(model.params)
|
728 |
|
729 |
else:
|
730 |
# restore opt_state
|
731 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
732 |
opt_state = from_bytes(opt_state_shape, f.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
733 |
|
734 |
# restore other attributes
|
735 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
|
|
750 |
in_axis_resources=(param_spec, opt_state_spec),
|
751 |
out_axis_resources=state_spec,
|
752 |
donate_argnums=(0, 1),
|
753 |
+
)(model.params, opt_state)
|
754 |
|
755 |
# remove opt_state from CPU
|
756 |
del opt_state
|