A checkpoint like this can be fully reloaded using the [~PreTrainedModel.from_pretrained] method: | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
model.save_pretrained(tmp_dir, max_shard_size="200MB") | |
new_model = AutoModel.from_pretrained(tmp_dir) | |
The main advantage of doing this for big models is that during step 2 of the workflow shown above, each shard of the checkpoint is loaded after the previous one, capping the memory usage in RAM to the model size plus the size of the biggest shard. |