|
import os |
|
import torch |
|
import json |
|
|
|
|
|
model = SomeLargeModel('/mnt/e/ai_cache/output/wizardcoder_mmlu_2/merged') |
|
model.load_state_dict(torch.load('pytorch_model.bin')) |
|
|
|
|
|
state_dict = model.state_dict() |
|
index = {"metadata": {"total_size": 0}, "weight_map": {}} |
|
i = 1 |
|
total_files = len(state_dict.keys()) |
|
|
|
for key, tensor in state_dict.items(): |
|
chunk_file = f'pytorch_model-{str(i).zfill(5)}-of-{str(total_files).zfill(5)}.bin' |
|
torch.save({key: tensor}, chunk_file) |
|
index["weight_map"][key] = chunk_file |
|
index["metadata"]["total_size"] += tensor.nelement() * tensor.element_size() |
|
i += 1 |
|
|
|
|
|
with open('pytorch_model.bin.index', 'w') as f: |
|
json.dump(index, f) |
|
|