Ahmadzei's picture
update 1
57bdca5
raw
history blame
593 Bytes
You can resume training with the sharded state dicts with the [~accelerate.Accelerator.load_state]` method.
directory containing checkpoints
accelerator.load_state("ckpt")
However, when training ends, you want to save the full state dict because sharded state dict is only compatible with FSDP.
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(script_args.output_dir)
TPU
PyTorch XLA supports FSDP training for TPUs and it can be enabled by modifying the FSDP configuration file generated by accelerate config.