Ahmadzei's picture
update 1
57bdca5
raw
history blame contribute delete
643 Bytes
A checkpoint like this can be fully reloaded using the [~PreTrainedModel.from_pretrained] method:
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="200MB")
new_model = AutoModel.from_pretrained(tmp_dir)
The main advantage of doing this for big models is that during step 2 of the workflow shown above, each shard of the checkpoint is loaded after the previous one, capping the memory usage in RAM to the model size plus the size of the biggest shard.
Behind the scenes, the index file is used to determine which keys are in the checkpoint, and where the corresponding weights are stored.