|
You can resume training with the sharded state dicts with the [~accelerate.Accelerator.load_state]` method. |
|
|
|
directory containing checkpoints |
|
accelerator.load_state("ckpt") |
|
|
|
However, when training ends, you want to save the full state dict because sharded state dict is only compatible with FSDP. |
|
|
|
if trainer.is_fsdp_enabled: |
|
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") |
|
trainer.save_model(script_args.output_dir) |
|
|
|
TPU |
|
PyTorch XLA supports FSDP training for TPUs and it can be enabled by modifying the FSDP configuration file generated by accelerate config. |