Spaces:
Running
Running
feat(train): split artifact into model/state (#128)
Browse files- src/dalle_mini/text.py +3 -3
- tools/train/train.py +114 -122
src/dalle_mini/text.py
CHANGED
@@ -116,7 +116,7 @@ def remove_comma_numbers(t):
|
|
116 |
|
117 |
|
118 |
def pre_process_dot_numbers(t):
|
119 |
-
return re.sub("(\w)\.(\w)",
|
120 |
|
121 |
|
122 |
def post_process_dot_numbers(t):
|
@@ -126,7 +126,7 @@ def post_process_dot_numbers(t):
|
|
126 |
def pre_process_quotes(t):
|
127 |
# allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
|
128 |
return re.sub(
|
129 |
-
r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)",
|
130 |
)
|
131 |
|
132 |
|
@@ -135,7 +135,7 @@ def post_process_quotes(t):
|
|
135 |
|
136 |
|
137 |
def pre_process_dates(t):
|
138 |
-
return re.sub("(\d)/(\d)",
|
139 |
|
140 |
|
141 |
def post_process_dates(t):
|
|
|
116 |
|
117 |
|
118 |
def pre_process_dot_numbers(t):
|
119 |
+
return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
|
120 |
|
121 |
|
122 |
def post_process_dot_numbers(t):
|
|
|
126 |
def pre_process_quotes(t):
|
127 |
# allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
|
128 |
return re.sub(
|
129 |
+
r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t
|
130 |
)
|
131 |
|
132 |
|
|
|
135 |
|
136 |
|
137 |
def pre_process_dates(t):
|
138 |
+
return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
|
139 |
|
140 |
|
141 |
def post_process_dates(t):
|
tools/train/train.py
CHANGED
@@ -88,6 +88,23 @@ class ModelArguments:
|
|
88 |
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
89 |
},
|
90 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
|
93 |
@dataclass
|
@@ -319,11 +336,6 @@ class TrainingArguments:
|
|
319 |
},
|
320 |
)
|
321 |
|
322 |
-
resume_from_checkpoint: Optional[str] = field(
|
323 |
-
default=None,
|
324 |
-
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
325 |
-
)
|
326 |
-
|
327 |
wandb_entity: Optional[str] = field(
|
328 |
default=None,
|
329 |
metadata={"help": "The wandb entity to use (for teams)."},
|
@@ -349,6 +361,8 @@ class TrainingArguments:
|
|
349 |
},
|
350 |
)
|
351 |
|
|
|
|
|
352 |
def __post_init__(self):
|
353 |
assert self.optim in [
|
354 |
"distributed_shampoo",
|
@@ -470,62 +484,40 @@ def main():
|
|
470 |
config=parser.parse_args(),
|
471 |
)
|
472 |
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
artifact_dir = artifact.download()
|
479 |
|
480 |
-
|
|
|
481 |
model = DalleBart.from_pretrained(
|
482 |
-
|
|
|
|
|
483 |
dtype=getattr(jnp, model_args.dtype),
|
484 |
abstract_init=True,
|
485 |
load_on_cpu=True,
|
486 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
|
488 |
-
|
|
|
489 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
490 |
-
|
491 |
-
use_fast=True,
|
492 |
)
|
493 |
-
|
494 |
else:
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
config = None
|
500 |
-
|
501 |
-
# Load or create new model
|
502 |
-
if model_args.model_name_or_path:
|
503 |
-
model = DalleBart.from_pretrained(
|
504 |
-
model_args.model_name_or_path,
|
505 |
-
config=config,
|
506 |
-
seed=training_args.seed_model,
|
507 |
-
dtype=getattr(jnp, model_args.dtype),
|
508 |
-
abstract_init=True,
|
509 |
-
load_on_cpu=True,
|
510 |
-
)
|
511 |
-
else:
|
512 |
-
model = DalleBart(
|
513 |
-
config,
|
514 |
-
seed=training_args.seed_model,
|
515 |
-
dtype=getattr(jnp, model_args.dtype),
|
516 |
-
load_on_cpu=True,
|
517 |
-
)
|
518 |
-
|
519 |
-
# Load tokenizer
|
520 |
-
if model_args.tokenizer_name is not None:
|
521 |
-
tokenizer = DalleBartTokenizer.from_pretrained(
|
522 |
-
model_args.tokenizer_name, use_fast=True
|
523 |
-
)
|
524 |
-
else:
|
525 |
-
tokenizer = DalleBartTokenizer.from_pretrained(
|
526 |
-
model_args.model_name_or_path,
|
527 |
-
use_fast=True,
|
528 |
-
)
|
529 |
|
530 |
# get PartitionSpec for model params (required to be a dict)
|
531 |
param_spec = set_partitions(model.params)
|
@@ -655,30 +647,29 @@ def main():
|
|
655 |
|
656 |
# get PartitionSpec for optimizer state
|
657 |
def get_opt_state_spec_and_shape(param_spec):
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
)
|
678 |
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
|
683 |
elif training_args.optim == "distributed_shampoo":
|
684 |
opt_state_spec = opt_fn.pspec_fn(
|
@@ -686,7 +677,6 @@ def main():
|
|
686 |
params_partition_spec=param_spec,
|
687 |
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
688 |
)
|
689 |
-
opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
|
690 |
else:
|
691 |
raise NotImplementedError
|
692 |
return opt_state_spec, opt_state_shape
|
@@ -698,7 +688,7 @@ def main():
|
|
698 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
699 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
700 |
|
701 |
-
#
|
702 |
state_spec = TrainState(
|
703 |
params=param_spec,
|
704 |
opt_state=opt_state_spec,
|
@@ -713,7 +703,7 @@ def main():
|
|
713 |
|
714 |
# create training state
|
715 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
716 |
-
if
|
717 |
|
718 |
def init_state(params):
|
719 |
return TrainState.create(
|
@@ -731,6 +721,13 @@ def main():
|
|
731 |
)(model.params)
|
732 |
|
733 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
# restore opt_state
|
735 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
736 |
opt_state = from_bytes(opt_state_shape, f.read())
|
@@ -760,7 +757,7 @@ def main():
|
|
760 |
del opt_state
|
761 |
|
762 |
# free memory
|
763 |
-
del model._params
|
764 |
|
765 |
# define batch specs
|
766 |
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
@@ -998,51 +995,46 @@ def main():
|
|
998 |
f,
|
999 |
)
|
1000 |
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
-
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
str(Path(training_args.output_dir) / "opt_state.msgpack")
|
1040 |
-
)
|
1041 |
-
artifact.add_file(
|
1042 |
-
str(Path(training_args.output_dir) / "training_state.json")
|
1043 |
)
|
1044 |
-
|
1045 |
-
wandb.run.log_artifact(artifact)
|
1046 |
|
1047 |
# init variables
|
1048 |
last_time = time.perf_counter()
|
|
|
88 |
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
89 |
},
|
90 |
)
|
91 |
+
restore_state: Optional[bool] = field(
|
92 |
+
default=False,
|
93 |
+
metadata={
|
94 |
+
"help": "Restore optimizer and training state associated with a wandb checkpoint."
|
95 |
+
},
|
96 |
+
)
|
97 |
+
|
98 |
+
state_artifact: str = field(init=False)
|
99 |
+
|
100 |
+
def __post_init__(self):
|
101 |
+
if self.restore_state:
|
102 |
+
assert (
|
103 |
+
"/model-" in self.model_name_or_path
|
104 |
+
), "Restoring state only available with W&B artifact reference"
|
105 |
+
self.state_artifact = self.model_name_or_path.replace(
|
106 |
+
"/model-", "/state-", 1
|
107 |
+
)
|
108 |
|
109 |
|
110 |
@dataclass
|
|
|
336 |
},
|
337 |
)
|
338 |
|
|
|
|
|
|
|
|
|
|
|
339 |
wandb_entity: Optional[str] = field(
|
340 |
default=None,
|
341 |
metadata={"help": "The wandb entity to use (for teams)."},
|
|
|
361 |
},
|
362 |
)
|
363 |
|
364 |
+
dp_devices: int = field(init=False)
|
365 |
+
|
366 |
def __post_init__(self):
|
367 |
assert self.optim in [
|
368 |
"distributed_shampoo",
|
|
|
484 |
config=parser.parse_args(),
|
485 |
)
|
486 |
|
487 |
+
# Set up our new model config
|
488 |
+
if model_args.config_name:
|
489 |
+
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
490 |
+
else:
|
491 |
+
config = None
|
|
|
492 |
|
493 |
+
# Load or create new model
|
494 |
+
if model_args.model_name_or_path:
|
495 |
model = DalleBart.from_pretrained(
|
496 |
+
model_args.model_name_or_path,
|
497 |
+
config=config,
|
498 |
+
seed=training_args.seed_model,
|
499 |
dtype=getattr(jnp, model_args.dtype),
|
500 |
abstract_init=True,
|
501 |
load_on_cpu=True,
|
502 |
)
|
503 |
+
else:
|
504 |
+
model = DalleBart(
|
505 |
+
config,
|
506 |
+
seed=training_args.seed_model,
|
507 |
+
dtype=getattr(jnp, model_args.dtype),
|
508 |
+
load_on_cpu=True,
|
509 |
+
)
|
510 |
|
511 |
+
# Load tokenizer
|
512 |
+
if model_args.tokenizer_name is not None:
|
513 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
514 |
+
model_args.tokenizer_name, use_fast=True
|
|
|
515 |
)
|
|
|
516 |
else:
|
517 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
518 |
+
model_args.model_name_or_path,
|
519 |
+
use_fast=True,
|
520 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
|
522 |
# get PartitionSpec for model params (required to be a dict)
|
523 |
param_spec = set_partitions(model.params)
|
|
|
647 |
|
648 |
# get PartitionSpec for optimizer state
|
649 |
def get_opt_state_spec_and_shape(param_spec):
|
650 |
+
# get opt_state shape without actual init
|
651 |
+
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
652 |
+
|
653 |
+
if training_args.optim == "adam":
|
654 |
+
|
655 |
+
def _opt_state_spec_per_leaf(x):
|
656 |
+
if isinstance(x, FrozenDict):
|
657 |
+
# variables with same structure as params
|
658 |
+
return param_spec
|
659 |
+
else:
|
660 |
+
# other variables such as count
|
661 |
+
return None
|
662 |
+
|
663 |
+
opt_state_spec = jax.tree_map(
|
664 |
+
_opt_state_spec_per_leaf,
|
665 |
+
opt_state_shape,
|
666 |
+
# return None spec for empty elements
|
667 |
+
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
|
668 |
+
)
|
|
|
669 |
|
670 |
+
elif training_args.optim == "adafactor":
|
671 |
+
# factorized state must be replicated (rank different than params)
|
672 |
+
opt_state_spec = None
|
673 |
|
674 |
elif training_args.optim == "distributed_shampoo":
|
675 |
opt_state_spec = opt_fn.pspec_fn(
|
|
|
677 |
params_partition_spec=param_spec,
|
678 |
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
679 |
)
|
|
|
680 |
else:
|
681 |
raise NotImplementedError
|
682 |
return opt_state_spec, opt_state_shape
|
|
|
688 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
689 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
690 |
|
691 |
+
# define state spec
|
692 |
state_spec = TrainState(
|
693 |
params=param_spec,
|
694 |
opt_state=opt_state_spec,
|
|
|
703 |
|
704 |
# create training state
|
705 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
706 |
+
if not model_args.restore_state:
|
707 |
|
708 |
def init_state(params):
|
709 |
return TrainState.create(
|
|
|
721 |
)(model.params)
|
722 |
|
723 |
else:
|
724 |
+
# get state files from artifact
|
725 |
+
if jax.process_index() == 0:
|
726 |
+
artifact = wandb.run.use_artifact(model_args.state_artifact)
|
727 |
+
else:
|
728 |
+
artifact = wandb.Api().artifact(model_args.state_artifact)
|
729 |
+
artifact_dir = artifact.download()
|
730 |
+
|
731 |
# restore opt_state
|
732 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
733 |
opt_state = from_bytes(opt_state_shape, f.read())
|
|
|
757 |
del opt_state
|
758 |
|
759 |
# free memory
|
760 |
+
del model._params, opt_state_spec, opt_state_shape
|
761 |
|
762 |
# define batch specs
|
763 |
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
|
|
995 |
f,
|
996 |
)
|
997 |
|
998 |
+
# save to W&B
|
999 |
+
if training_args.log_model:
|
1000 |
+
# save some space
|
1001 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
1002 |
+
c.cleanup(wandb.util.from_human_size("10GB"))
|
1003 |
+
|
1004 |
+
metadata = dict(state_dict)
|
1005 |
+
metadata["num_params"] = num_params
|
1006 |
+
if eval_metrics is not None:
|
1007 |
+
metadata["eval"] = eval_metrics
|
1008 |
+
|
1009 |
+
# create model artifact
|
1010 |
+
artifact = wandb.Artifact(
|
1011 |
+
name=f"model-{wandb.run.id}",
|
1012 |
+
type="DalleBart_model",
|
1013 |
+
metadata=metadata,
|
1014 |
+
)
|
1015 |
+
for filename in [
|
1016 |
+
"config.json",
|
1017 |
+
"flax_model.msgpack",
|
1018 |
+
"merges.txt",
|
1019 |
+
"special_tokens_map.json",
|
1020 |
+
"tokenizer.json",
|
1021 |
+
"tokenizer_config.json",
|
1022 |
+
"vocab.json",
|
1023 |
+
]:
|
1024 |
+
artifact.add_file(f"{Path(training_args.output_dir) / filename}")
|
1025 |
+
wandb.run.log_artifact(artifact)
|
1026 |
+
|
1027 |
+
# create state artifact
|
1028 |
+
artifact_state = wandb.Artifact(
|
1029 |
+
name=f"state-{wandb.run.id}",
|
1030 |
+
type="DalleBart_state",
|
1031 |
+
metadata=metadata,
|
1032 |
+
)
|
1033 |
+
for filename in ["opt_state.msgpack", "training_state.json"]:
|
1034 |
+
artifact_state.add_file(
|
1035 |
+
f"{Path(training_args.output_dir) / filename}"
|
|
|
|
|
|
|
|
|
1036 |
)
|
1037 |
+
wandb.run.log_artifact(artifact_state)
|
|
|
1038 |
|
1039 |
# init variables
|
1040 |
last_time = time.perf_counter()
|