Ahmadzei's picture
added 3 more tables for large emb model
5fa1a76
raw
history blame
506 Bytes
A checkpoint like this can be fully reloaded using the [~PreTrainedModel.from_pretrained] method:
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="200MB")
new_model = AutoModel.from_pretrained(tmp_dir)
The main advantage of doing this for big models is that during step 2 of the workflow shown above, each shard of the checkpoint is loaded after the previous one, capping the memory usage in RAM to the model size plus the size of the biggest shard.