Spaces:
Running
Running
fix: define function before it is used
Browse files- seq2seq/run_seq2seq_flax.py +32 -31
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -779,6 +779,38 @@ def main():
|
|
779 |
|
780 |
return eval_metrics
|
781 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
782 |
for epoch in epochs:
|
783 |
# ======================== Training ================================
|
784 |
train_start = time.time()
|
@@ -820,37 +852,6 @@ def main():
|
|
820 |
# save checkpoint after each epoch and push checkpoint to the hub
|
821 |
run_save_model(global_step, epoch, eval_metrics)
|
822 |
|
823 |
-
def run_save_model(step, epoch, eval_metrics=None):
|
824 |
-
if jax.process_index() == 0:
|
825 |
-
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
826 |
-
|
827 |
-
# save model locally
|
828 |
-
model.save_pretrained(
|
829 |
-
training_args.output_dir,
|
830 |
-
params=params,
|
831 |
-
)
|
832 |
-
|
833 |
-
# save to W&B
|
834 |
-
if data_args.log_model:
|
835 |
-
metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
|
836 |
-
if eval_metrics is not None:
|
837 |
-
metadata['eval/loss'] = eval_metrics['loss']
|
838 |
-
artifact = wandb.Artifact(
|
839 |
-
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
840 |
-
)
|
841 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
842 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
843 |
-
wandb.run.log_artifact(artifact)
|
844 |
-
|
845 |
-
# save to the hub
|
846 |
-
if training_args.push_to_hub:
|
847 |
-
model.save_pretrained(
|
848 |
-
training_args.output_dir,
|
849 |
-
params=params,
|
850 |
-
push_to_hub=training_args.push_to_hub,
|
851 |
-
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
852 |
-
temp_dir=True # avoid issues with being in a repository
|
853 |
-
)
|
854 |
|
855 |
# ======================== Prediction loop ==============================
|
856 |
if training_args.do_predict:
|
|
|
779 |
|
780 |
return eval_metrics
|
781 |
|
782 |
+
def run_save_model(step, epoch, eval_metrics=None):
|
783 |
+
if jax.process_index() == 0:
|
784 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
785 |
+
|
786 |
+
# save model locally
|
787 |
+
model.save_pretrained(
|
788 |
+
training_args.output_dir,
|
789 |
+
params=params,
|
790 |
+
)
|
791 |
+
|
792 |
+
# save to W&B
|
793 |
+
if data_args.log_model:
|
794 |
+
metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
|
795 |
+
if eval_metrics is not None:
|
796 |
+
metadata['eval/loss'] = eval_metrics['loss']
|
797 |
+
artifact = wandb.Artifact(
|
798 |
+
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
799 |
+
)
|
800 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
801 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
802 |
+
wandb.run.log_artifact(artifact)
|
803 |
+
|
804 |
+
# save to the hub
|
805 |
+
if training_args.push_to_hub:
|
806 |
+
model.save_pretrained(
|
807 |
+
training_args.output_dir,
|
808 |
+
params=params,
|
809 |
+
push_to_hub=training_args.push_to_hub,
|
810 |
+
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
811 |
+
temp_dir=True # avoid issues with being in a repository
|
812 |
+
)
|
813 |
+
|
814 |
for epoch in epochs:
|
815 |
# ======================== Training ================================
|
816 |
train_start = time.time()
|
|
|
852 |
# save checkpoint after each epoch and push checkpoint to the hub
|
853 |
run_save_model(global_step, epoch, eval_metrics)
|
854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
855 |
|
856 |
# ======================== Prediction loop ==============================
|
857 |
if training_args.do_predict:
|