boris commited on
Commit
fa5b058
1 Parent(s): 605df32

feat(train): split artifact into model/state

Browse files
Files changed (1) hide show
  1. tools/train/train.py +92 -97
tools/train/train.py CHANGED
@@ -88,6 +88,24 @@ class ModelArguments:
88
  "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
89
  },
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  @dataclass
@@ -319,11 +337,6 @@ class TrainingArguments:
319
  },
320
  )
321
 
322
- resume_from_checkpoint: Optional[str] = field(
323
- default=None,
324
- metadata={"help": "Reference to a wandb artifact for resuming training."},
325
- )
326
-
327
  wandb_entity: Optional[str] = field(
328
  default=None,
329
  metadata={"help": "The wandb entity to use (for teams)."},
@@ -349,6 +362,8 @@ class TrainingArguments:
349
  },
350
  )
351
 
 
 
352
  def __post_init__(self):
353
  assert self.optim in [
354
  "distributed_shampoo",
@@ -470,62 +485,40 @@ def main():
470
  config=parser.parse_args(),
471
  )
472
 
473
- if training_args.resume_from_checkpoint is not None:
474
- if jax.process_index() == 0:
475
- artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
476
- else:
477
- artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
478
- artifact_dir = artifact.download()
479
 
480
- # load model
 
481
  model = DalleBart.from_pretrained(
482
- artifact_dir,
 
 
483
  dtype=getattr(jnp, model_args.dtype),
484
  abstract_init=True,
485
  load_on_cpu=True,
486
  )
 
 
 
 
 
 
 
487
 
488
- # load tokenizer
 
489
  tokenizer = DalleBartTokenizer.from_pretrained(
490
- artifact_dir,
491
- use_fast=True,
492
  )
493
-
494
  else:
495
- # Set up our new model config
496
- if model_args.config_name:
497
- config = DalleBartConfig.from_pretrained(model_args.config_name)
498
- else:
499
- config = None
500
-
501
- # Load or create new model
502
- if model_args.model_name_or_path:
503
- model = DalleBart.from_pretrained(
504
- model_args.model_name_or_path,
505
- config=config,
506
- seed=training_args.seed_model,
507
- dtype=getattr(jnp, model_args.dtype),
508
- abstract_init=True,
509
- load_on_cpu=True,
510
- )
511
- else:
512
- model = DalleBart(
513
- config,
514
- seed=training_args.seed_model,
515
- dtype=getattr(jnp, model_args.dtype),
516
- load_on_cpu=True,
517
- )
518
-
519
- # Load tokenizer
520
- if model_args.tokenizer_name is not None:
521
- tokenizer = DalleBartTokenizer.from_pretrained(
522
- model_args.tokenizer_name, use_fast=True
523
- )
524
- else:
525
- tokenizer = DalleBartTokenizer.from_pretrained(
526
- model_args.model_name_or_path,
527
- use_fast=True,
528
- )
529
 
530
  # get PartitionSpec for model params (required to be a dict)
531
  param_spec = set_partitions(model.params)
@@ -698,7 +691,7 @@ def main():
698
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
699
  mesh = maps.Mesh(devices, ("batch", "mp"))
700
 
701
- # Create state spec
702
  state_spec = TrainState(
703
  params=param_spec,
704
  opt_state=opt_state_spec,
@@ -713,7 +706,7 @@ def main():
713
 
714
  # create training state
715
  with maps.mesh(mesh.devices, mesh.axis_names):
716
- if training_args.resume_from_checkpoint is None:
717
 
718
  def init_state(params):
719
  return TrainState.create(
@@ -731,6 +724,13 @@ def main():
731
  )(model.params)
732
 
733
  else:
 
 
 
 
 
 
 
734
  # restore opt_state
735
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
736
  opt_state = from_bytes(opt_state_shape, f.read())
@@ -998,51 +998,46 @@ def main():
998
  f,
999
  )
1000
 
1001
- if jax.process_index() == 0:
1002
- # save to W&B
1003
- if training_args.log_model:
1004
- # save some space
1005
- c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1006
- c.cleanup(wandb.util.from_human_size("10GB"))
1007
-
1008
- metadata = dict(state_dict)
1009
- metadata["num_params"] = num_params
1010
- if eval_metrics is not None:
1011
- metadata["eval"] = eval_metrics
1012
- artifact = wandb.Artifact(
1013
- name=f"model-{wandb.run.id}",
1014
- type="bart_model",
1015
- metadata=metadata,
1016
- )
1017
- artifact.add_file(
1018
- str(Path(training_args.output_dir) / "flax_model.msgpack")
1019
- )
1020
- artifact.add_file(
1021
- str(Path(training_args.output_dir) / "config.json")
1022
- )
1023
- artifact.add_file(
1024
- str(Path(training_args.output_dir) / "tokenizer.json")
1025
- )
1026
- artifact.add_file(
1027
- str(Path(training_args.output_dir) / "tokenizer_config.json")
1028
- )
1029
- artifact.add_file(
1030
- str(Path(training_args.output_dir) / "vocab.json")
1031
- )
1032
- artifact.add_file(
1033
- str(Path(training_args.output_dir) / "merges.txt")
1034
- )
1035
- artifact.add_file(
1036
- str(Path(training_args.output_dir) / "special_tokens_map.json")
1037
- )
1038
- artifact.add_file(
1039
- str(Path(training_args.output_dir) / "opt_state.msgpack")
1040
- )
1041
- artifact.add_file(
1042
- str(Path(training_args.output_dir) / "training_state.json")
1043
  )
1044
-
1045
- wandb.run.log_artifact(artifact)
1046
 
1047
  # init variables
1048
  last_time = time.perf_counter()
 
88
  "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
89
  },
90
  )
91
+ restore_state: Optional[bool] = field(
92
+ default=False,
93
+ metadata={
94
+ "help": "Restore optimizer and training state associated with a wandb checkpoint."
95
+ },
96
+ )
97
+
98
+ state_artifact: str = field(init=False)
99
+
100
+ def __post_init__(self):
101
+ if self.restore_state:
102
+ assert (
103
+ "/model-" in self.model_name_or_path
104
+ ), "Restoring state only available with W&B artifact reference"
105
+ self.state_artifact = self.model_name_or_path.replace(
106
+ "/model-", "/state-", 1
107
+ )
108
+ raise ValueError("Need a dataset repository or path.")
109
 
110
 
111
  @dataclass
 
337
  },
338
  )
339
 
 
 
 
 
 
340
  wandb_entity: Optional[str] = field(
341
  default=None,
342
  metadata={"help": "The wandb entity to use (for teams)."},
 
362
  },
363
  )
364
 
365
+ dp_devices: int = field(init=False)
366
+
367
  def __post_init__(self):
368
  assert self.optim in [
369
  "distributed_shampoo",
 
485
  config=parser.parse_args(),
486
  )
487
 
488
+ # Set up our new model config
489
+ if model_args.config_name:
490
+ config = DalleBartConfig.from_pretrained(model_args.config_name)
491
+ else:
492
+ config = None
 
493
 
494
+ # Load or create new model
495
+ if model_args.model_name_or_path:
496
  model = DalleBart.from_pretrained(
497
+ model_args.model_name_or_path,
498
+ config=config,
499
+ seed=training_args.seed_model,
500
  dtype=getattr(jnp, model_args.dtype),
501
  abstract_init=True,
502
  load_on_cpu=True,
503
  )
504
+ else:
505
+ model = DalleBart(
506
+ config,
507
+ seed=training_args.seed_model,
508
+ dtype=getattr(jnp, model_args.dtype),
509
+ load_on_cpu=True,
510
+ )
511
 
512
+ # Load tokenizer
513
+ if model_args.tokenizer_name is not None:
514
  tokenizer = DalleBartTokenizer.from_pretrained(
515
+ model_args.tokenizer_name, use_fast=True
 
516
  )
 
517
  else:
518
+ tokenizer = DalleBartTokenizer.from_pretrained(
519
+ model_args.model_name_or_path,
520
+ use_fast=True,
521
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  # get PartitionSpec for model params (required to be a dict)
524
  param_spec = set_partitions(model.params)
 
691
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
692
  mesh = maps.Mesh(devices, ("batch", "mp"))
693
 
694
+ # define state spec
695
  state_spec = TrainState(
696
  params=param_spec,
697
  opt_state=opt_state_spec,
 
706
 
707
  # create training state
708
  with maps.mesh(mesh.devices, mesh.axis_names):
709
+ if not model_args.restore_state:
710
 
711
  def init_state(params):
712
  return TrainState.create(
 
724
  )(model.params)
725
 
726
  else:
727
+ # get state files from artifact
728
+ if jax.process_index() == 0:
729
+ artifact = wandb.run.use_artifact(model_args.state_artifact)
730
+ else:
731
+ artifact = wandb.Api().artifact(model_args.state_artifact)
732
+ artifact_dir = artifact.download()
733
+
734
  # restore opt_state
735
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
736
  opt_state = from_bytes(opt_state_shape, f.read())
 
998
  f,
999
  )
1000
 
1001
+ # save to W&B
1002
+ if training_args.log_model:
1003
+ # save some space
1004
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1005
+ c.cleanup(wandb.util.from_human_size("10GB"))
1006
+
1007
+ metadata = dict(state_dict)
1008
+ metadata["num_params"] = num_params
1009
+ if eval_metrics is not None:
1010
+ metadata["eval"] = eval_metrics
1011
+
1012
+ # create model artifact
1013
+ artifact = wandb.Artifact(
1014
+ name=f"model-{wandb.run.id}",
1015
+ type="DalleBart_model",
1016
+ metadata=metadata,
1017
+ )
1018
+ for filename in [
1019
+ "config.json",
1020
+ "flax_model.msgpack",
1021
+ "merges.txt",
1022
+ "special_tokens_map.json",
1023
+ "tokenizer.json",
1024
+ "tokenizer_config.json",
1025
+ "vocab.json",
1026
+ ]:
1027
+ artifact.add_file(f"{Path(training_args.output_dir) / filename}")
1028
+ wandb.run.log_artifact(artifact)
1029
+
1030
+ # create state artifact
1031
+ artifact_state = wandb.Artifact(
1032
+ name=f"state-{wandb.run.id}",
1033
+ type="DalleBart_state",
1034
+ metadata=metadata,
1035
+ )
1036
+ for filename in ["opt_state.msgpack", "training_state.json"]:
1037
+ artifact_state.add_file(
1038
+ f"{Path(training_args.output_dir) / filename}"
 
 
 
 
1039
  )
1040
+ wandb.run.log_artifact(artifact_state)
 
1041
 
1042
  # init variables
1043
  last_time = time.perf_counter()