We can load that index like any json and get a dictionary: | |
import json | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
model.save_pretrained(tmp_dir, max_shard_size="200MB") | |
with open(os.path.join(tmp_dir, "pytorch_model.bin.index.json"), "r") as f: | |
index = json.load(f) | |
print(index.keys()) | |
dict_keys(['metadata', 'weight_map']) | |
The metadata just consists of the total size of the model for now. |