Ahmadzei's picture
update 1
57bdca5
raw
history blame
415 Bytes
We can load that index like any json and get a dictionary:
import json
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, max_shard_size="200MB")
with open(os.path.join(tmp_dir, "pytorch_model.bin.index.json"), "r") as f:
index = json.load(f)
print(index.keys())
dict_keys(['metadata', 'weight_map'])
The metadata just consists of the total size of the model for now.