Spaces:
Running
Running
feat(train): split artifact into model/state
Browse files- tools/train/train.py +92 -97
tools/train/train.py
CHANGED
@@ -88,6 +88,24 @@ class ModelArguments:
|
|
88 |
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
89 |
},
|
90 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
|
93 |
@dataclass
|
@@ -319,11 +337,6 @@ class TrainingArguments:
|
|
319 |
},
|
320 |
)
|
321 |
|
322 |
-
resume_from_checkpoint: Optional[str] = field(
|
323 |
-
default=None,
|
324 |
-
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
325 |
-
)
|
326 |
-
|
327 |
wandb_entity: Optional[str] = field(
|
328 |
default=None,
|
329 |
metadata={"help": "The wandb entity to use (for teams)."},
|
@@ -349,6 +362,8 @@ class TrainingArguments:
|
|
349 |
},
|
350 |
)
|
351 |
|
|
|
|
|
352 |
def __post_init__(self):
|
353 |
assert self.optim in [
|
354 |
"distributed_shampoo",
|
@@ -470,62 +485,40 @@ def main():
|
|
470 |
config=parser.parse_args(),
|
471 |
)
|
472 |
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
artifact_dir = artifact.download()
|
479 |
|
480 |
-
|
|
|
481 |
model = DalleBart.from_pretrained(
|
482 |
-
|
|
|
|
|
483 |
dtype=getattr(jnp, model_args.dtype),
|
484 |
abstract_init=True,
|
485 |
load_on_cpu=True,
|
486 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
|
488 |
-
|
|
|
489 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
490 |
-
|
491 |
-
use_fast=True,
|
492 |
)
|
493 |
-
|
494 |
else:
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
config = None
|
500 |
-
|
501 |
-
# Load or create new model
|
502 |
-
if model_args.model_name_or_path:
|
503 |
-
model = DalleBart.from_pretrained(
|
504 |
-
model_args.model_name_or_path,
|
505 |
-
config=config,
|
506 |
-
seed=training_args.seed_model,
|
507 |
-
dtype=getattr(jnp, model_args.dtype),
|
508 |
-
abstract_init=True,
|
509 |
-
load_on_cpu=True,
|
510 |
-
)
|
511 |
-
else:
|
512 |
-
model = DalleBart(
|
513 |
-
config,
|
514 |
-
seed=training_args.seed_model,
|
515 |
-
dtype=getattr(jnp, model_args.dtype),
|
516 |
-
load_on_cpu=True,
|
517 |
-
)
|
518 |
-
|
519 |
-
# Load tokenizer
|
520 |
-
if model_args.tokenizer_name is not None:
|
521 |
-
tokenizer = DalleBartTokenizer.from_pretrained(
|
522 |
-
model_args.tokenizer_name, use_fast=True
|
523 |
-
)
|
524 |
-
else:
|
525 |
-
tokenizer = DalleBartTokenizer.from_pretrained(
|
526 |
-
model_args.model_name_or_path,
|
527 |
-
use_fast=True,
|
528 |
-
)
|
529 |
|
530 |
# get PartitionSpec for model params (required to be a dict)
|
531 |
param_spec = set_partitions(model.params)
|
@@ -698,7 +691,7 @@ def main():
|
|
698 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
699 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
700 |
|
701 |
-
#
|
702 |
state_spec = TrainState(
|
703 |
params=param_spec,
|
704 |
opt_state=opt_state_spec,
|
@@ -713,7 +706,7 @@ def main():
|
|
713 |
|
714 |
# create training state
|
715 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
716 |
-
if
|
717 |
|
718 |
def init_state(params):
|
719 |
return TrainState.create(
|
@@ -731,6 +724,13 @@ def main():
|
|
731 |
)(model.params)
|
732 |
|
733 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
# restore opt_state
|
735 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
736 |
opt_state = from_bytes(opt_state_shape, f.read())
|
@@ -998,51 +998,46 @@ def main():
|
|
998 |
f,
|
999 |
)
|
1000 |
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
-
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
str(Path(training_args.output_dir) / "opt_state.msgpack")
|
1040 |
-
)
|
1041 |
-
artifact.add_file(
|
1042 |
-
str(Path(training_args.output_dir) / "training_state.json")
|
1043 |
)
|
1044 |
-
|
1045 |
-
wandb.run.log_artifact(artifact)
|
1046 |
|
1047 |
# init variables
|
1048 |
last_time = time.perf_counter()
|
|
|
88 |
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
89 |
},
|
90 |
)
|
91 |
+
restore_state: Optional[bool] = field(
|
92 |
+
default=False,
|
93 |
+
metadata={
|
94 |
+
"help": "Restore optimizer and training state associated with a wandb checkpoint."
|
95 |
+
},
|
96 |
+
)
|
97 |
+
|
98 |
+
state_artifact: str = field(init=False)
|
99 |
+
|
100 |
+
def __post_init__(self):
|
101 |
+
if self.restore_state:
|
102 |
+
assert (
|
103 |
+
"/model-" in self.model_name_or_path
|
104 |
+
), "Restoring state only available with W&B artifact reference"
|
105 |
+
self.state_artifact = self.model_name_or_path.replace(
|
106 |
+
"/model-", "/state-", 1
|
107 |
+
)
|
108 |
+
raise ValueError("Need a dataset repository or path.")
|
109 |
|
110 |
|
111 |
@dataclass
|
|
|
337 |
},
|
338 |
)
|
339 |
|
|
|
|
|
|
|
|
|
|
|
340 |
wandb_entity: Optional[str] = field(
|
341 |
default=None,
|
342 |
metadata={"help": "The wandb entity to use (for teams)."},
|
|
|
362 |
},
|
363 |
)
|
364 |
|
365 |
+
dp_devices: int = field(init=False)
|
366 |
+
|
367 |
def __post_init__(self):
|
368 |
assert self.optim in [
|
369 |
"distributed_shampoo",
|
|
|
485 |
config=parser.parse_args(),
|
486 |
)
|
487 |
|
488 |
+
# Set up our new model config
|
489 |
+
if model_args.config_name:
|
490 |
+
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
491 |
+
else:
|
492 |
+
config = None
|
|
|
493 |
|
494 |
+
# Load or create new model
|
495 |
+
if model_args.model_name_or_path:
|
496 |
model = DalleBart.from_pretrained(
|
497 |
+
model_args.model_name_or_path,
|
498 |
+
config=config,
|
499 |
+
seed=training_args.seed_model,
|
500 |
dtype=getattr(jnp, model_args.dtype),
|
501 |
abstract_init=True,
|
502 |
load_on_cpu=True,
|
503 |
)
|
504 |
+
else:
|
505 |
+
model = DalleBart(
|
506 |
+
config,
|
507 |
+
seed=training_args.seed_model,
|
508 |
+
dtype=getattr(jnp, model_args.dtype),
|
509 |
+
load_on_cpu=True,
|
510 |
+
)
|
511 |
|
512 |
+
# Load tokenizer
|
513 |
+
if model_args.tokenizer_name is not None:
|
514 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
515 |
+
model_args.tokenizer_name, use_fast=True
|
|
|
516 |
)
|
|
|
517 |
else:
|
518 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
519 |
+
model_args.model_name_or_path,
|
520 |
+
use_fast=True,
|
521 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
# get PartitionSpec for model params (required to be a dict)
|
524 |
param_spec = set_partitions(model.params)
|
|
|
691 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
692 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
693 |
|
694 |
+
# define state spec
|
695 |
state_spec = TrainState(
|
696 |
params=param_spec,
|
697 |
opt_state=opt_state_spec,
|
|
|
706 |
|
707 |
# create training state
|
708 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
709 |
+
if not model_args.restore_state:
|
710 |
|
711 |
def init_state(params):
|
712 |
return TrainState.create(
|
|
|
724 |
)(model.params)
|
725 |
|
726 |
else:
|
727 |
+
# get state files from artifact
|
728 |
+
if jax.process_index() == 0:
|
729 |
+
artifact = wandb.run.use_artifact(model_args.state_artifact)
|
730 |
+
else:
|
731 |
+
artifact = wandb.Api().artifact(model_args.state_artifact)
|
732 |
+
artifact_dir = artifact.download()
|
733 |
+
|
734 |
# restore opt_state
|
735 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
736 |
opt_state = from_bytes(opt_state_shape, f.read())
|
|
|
998 |
f,
|
999 |
)
|
1000 |
|
1001 |
+
# save to W&B
|
1002 |
+
if training_args.log_model:
|
1003 |
+
# save some space
|
1004 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
1005 |
+
c.cleanup(wandb.util.from_human_size("10GB"))
|
1006 |
+
|
1007 |
+
metadata = dict(state_dict)
|
1008 |
+
metadata["num_params"] = num_params
|
1009 |
+
if eval_metrics is not None:
|
1010 |
+
metadata["eval"] = eval_metrics
|
1011 |
+
|
1012 |
+
# create model artifact
|
1013 |
+
artifact = wandb.Artifact(
|
1014 |
+
name=f"model-{wandb.run.id}",
|
1015 |
+
type="DalleBart_model",
|
1016 |
+
metadata=metadata,
|
1017 |
+
)
|
1018 |
+
for filename in [
|
1019 |
+
"config.json",
|
1020 |
+
"flax_model.msgpack",
|
1021 |
+
"merges.txt",
|
1022 |
+
"special_tokens_map.json",
|
1023 |
+
"tokenizer.json",
|
1024 |
+
"tokenizer_config.json",
|
1025 |
+
"vocab.json",
|
1026 |
+
]:
|
1027 |
+
artifact.add_file(f"{Path(training_args.output_dir) / filename}")
|
1028 |
+
wandb.run.log_artifact(artifact)
|
1029 |
+
|
1030 |
+
# create state artifact
|
1031 |
+
artifact_state = wandb.Artifact(
|
1032 |
+
name=f"state-{wandb.run.id}",
|
1033 |
+
type="DalleBart_state",
|
1034 |
+
metadata=metadata,
|
1035 |
+
)
|
1036 |
+
for filename in ["opt_state.msgpack", "training_state.json"]:
|
1037 |
+
artifact_state.add_file(
|
1038 |
+
f"{Path(training_args.output_dir) / filename}"
|
|
|
|
|
|
|
|
|
1039 |
)
|
1040 |
+
wandb.run.log_artifact(artifact_state)
|
|
|
1041 |
|
1042 |
# init variables
|
1043 |
last_time = time.perf_counter()
|