WizardCoder-Guanaco-15B-V1.1 / split_pytorch.py
LoupGarou's picture
Upload 10 files
3e32822
raw
history blame
No virus
791 Bytes
import os
import torch
import json
# load your large model
model = SomeLargeModel('/mnt/e/ai_cache/output/wizardcoder_mmlu_2/merged')
model.load_state_dict(torch.load('pytorch_model.bin'))
# save each tensor to a separate file and record the mapping in the index
state_dict = model.state_dict()
index = {"metadata": {"total_size": 0}, "weight_map": {}}
i = 1
total_files = len(state_dict.keys())
for key, tensor in state_dict.items():
chunk_file = f'pytorch_model-{str(i).zfill(5)}-of-{str(total_files).zfill(5)}.bin'
torch.save({key: tensor}, chunk_file)
index["weight_map"][key] = chunk_file
index["metadata"]["total_size"] += tensor.nelement() * tensor.element_size()
i += 1
# save the index
with open('pytorch_model.bin.index', 'w') as f:
json.dump(index, f)