File size: 4,694 Bytes
f21d996 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import argparse
from safetensors import safe_open
from safetensors.torch import save_file
import json
from tqdm import tqdm
def get_tensor_locations(input_dir):
tensor_locations = {}
for i in tqdm(range(1, 52), desc="Scanning input files"): # 51 splits
file_path = os.path.join(input_dir, f"model-{i:05d}-of-00051.safetensors")
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensor_locations[key] = i
return tensor_locations
def create_merge_plan(tensor_locations, layer_config):
merge_plan = []
new_layer_idx = 0
new_file_idx = 1
# Special handling for specific weights
special_weights = {
"model.embed_tokens.weight": 1,
"lm_head.weight": 48,
"model.norm.weight": 48
}
for slice_config in layer_config:
start, end = slice_config['layer_range']
for i in range(start, end):
layer_tensors = []
for key in tensor_locations.keys():
if key.startswith(f"model.layers.{i}."):
new_key = key.replace(f"model.layers.{i}", f"model.layers.{new_layer_idx}")
layer_tensors.append({
'old_key': key,
'new_key': new_key,
'original_file_index': tensor_locations[key],
'new_file_index': new_file_idx
})
if layer_tensors:
merge_plan.extend(layer_tensors)
new_file_idx += 1
new_layer_idx += 1
# Add special weights to their original locations
for key, file_index in special_weights.items():
merge_plan.append({
'old_key': key,
'new_key': key,
'original_file_index': file_index,
'new_file_index': file_index
})
# Add any remaining non-layer tensors to the first file
for key, file_index in tensor_locations.items():
if not key.startswith("model.layers.") and key not in special_weights:
merge_plan.append({
'old_key': key,
'new_key': key,
'original_file_index': file_index,
'new_file_index': 1
})
return merge_plan
def merge_layers(input_dir, output_dir, merge_plan):
output_tensors = {}
current_new_file_index = 1
max_file_index = max(item['new_file_index'] for item in merge_plan)
with tqdm(total=len(merge_plan), desc="Merging layers") as pbar:
for file_index in range(1, max_file_index + 1):
for item in merge_plan:
if item['new_file_index'] == file_index:
input_file = os.path.join(input_dir, f"model-{item['original_file_index']:05d}-of-00051.safetensors")
with safe_open(input_file, framework="pt", device="cpu") as f:
tensor = f.get_tensor(item['old_key'])
output_tensors[item['new_key']] = tensor
pbar.update(1)
if output_tensors:
output_file = os.path.join(output_dir, f"model-{file_index:05d}-of-{max_file_index:05d}.safetensors")
save_file(output_tensors, output_file)
output_tensors = {}
print(f"Merged model saved to {output_dir}")
def main():
parser = argparse.ArgumentParser(description="Merge and split Mistral model")
parser.add_argument("input_dir", help="Directory containing input safetensors files")
parser.add_argument("output_dir", help="Directory for output safetensors files")
parser.add_argument("--dry-run", action="store_true", help="Perform a dry run and output merge plan")
args = parser.parse_args()
layer_config = [
{'layer_range': [0, 20]},
{'layer_range': [10, 30]},
{'layer_range': [20, 40]},
{'layer_range': [30, 50]},
{'layer_range': [40, 60]},
{'layer_range': [50, 70]},
{'layer_range': [60, 80]},
{'layer_range': [70, 87]}
]
tensor_locations = get_tensor_locations(args.input_dir)
merge_plan = create_merge_plan(tensor_locations, layer_config)
if args.dry_run:
print("Merge plan:")
print(json.dumps(merge_plan, indent=2))
with open("merge_plan_large.json", "w") as f:
json.dump(merge_plan, f, indent=2)
print("Merge plan saved to merge_plan.json")
else:
os.makedirs(args.output_dir, exist_ok=True)
merge_layers(args.input_dir, args.output_dir, merge_plan)
print(f"Merged model saved to {args.output_dir}")
if __name__ == "__main__":
main()
|