Spaces:
Running
Running
Merge branch 'add-tokenizer-save' into feat-model
Browse filesFormer-commit-id: 2cfaef4a020f43332a8f33b6a9bd8221ec9fae34
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -818,13 +818,16 @@ def main():
|
|
818 |
params=params,
|
819 |
)
|
820 |
|
|
|
|
|
|
|
821 |
# save state
|
822 |
state = unreplicate(state)
|
823 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
824 |
f.write(to_bytes(state.opt_state))
|
825 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
826 |
json.dump({'step': state.step.item()}, f)
|
827 |
-
|
828 |
# save to W&B
|
829 |
if data_args.log_model:
|
830 |
metadata = {'step': step, 'epoch': epoch}
|
@@ -834,6 +837,11 @@ def main():
|
|
834 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
835 |
)
|
836 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
|
|
|
|
|
|
|
|
|
|
837 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
838 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
839 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
|
|
818 |
params=params,
|
819 |
)
|
820 |
|
821 |
+
# save tokenizer
|
822 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
823 |
+
|
824 |
# save state
|
825 |
state = unreplicate(state)
|
826 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
827 |
f.write(to_bytes(state.opt_state))
|
828 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
829 |
json.dump({'step': state.step.item()}, f)
|
830 |
+
|
831 |
# save to W&B
|
832 |
if data_args.log_model:
|
833 |
metadata = {'step': step, 'epoch': epoch}
|
|
|
837 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
838 |
)
|
839 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
840 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
|
841 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
|
842 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
|
843 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'added_tokens.json'))
|
844 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
|
845 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
846 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
847 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|