File size: 4,694 Bytes
f21d996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import argparse
from safetensors import safe_open
from safetensors.torch import save_file
import json
from tqdm import tqdm

def get_tensor_locations(input_dir):
    tensor_locations = {}
    for i in tqdm(range(1, 52), desc="Scanning input files"):  # 51 splits
        file_path = os.path.join(input_dir, f"model-{i:05d}-of-00051.safetensors")
        with safe_open(file_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                tensor_locations[key] = i
    return tensor_locations

def create_merge_plan(tensor_locations, layer_config):
    merge_plan = []
    new_layer_idx = 0
    new_file_idx = 1

    # Special handling for specific weights
    special_weights = {
        "model.embed_tokens.weight": 1,
        "lm_head.weight": 48,
        "model.norm.weight": 48
    }

    for slice_config in layer_config:
        start, end = slice_config['layer_range']
        for i in range(start, end):
            layer_tensors = []
            for key in tensor_locations.keys():
                if key.startswith(f"model.layers.{i}."):
                    new_key = key.replace(f"model.layers.{i}", f"model.layers.{new_layer_idx}")
                    layer_tensors.append({
                        'old_key': key,
                        'new_key': new_key,
                        'original_file_index': tensor_locations[key],
                        'new_file_index': new_file_idx
                    })
            if layer_tensors:
                merge_plan.extend(layer_tensors)
                new_file_idx += 1
            new_layer_idx += 1
    
    # Add special weights to their original locations
    for key, file_index in special_weights.items():
        merge_plan.append({
            'old_key': key,
            'new_key': key,
            'original_file_index': file_index,
            'new_file_index': file_index
        })
    
    # Add any remaining non-layer tensors to the first file
    for key, file_index in tensor_locations.items():
        if not key.startswith("model.layers.") and key not in special_weights:
            merge_plan.append({
                'old_key': key,
                'new_key': key,
                'original_file_index': file_index,
                'new_file_index': 1
            })
    
    return merge_plan

def merge_layers(input_dir, output_dir, merge_plan):
    output_tensors = {}
    current_new_file_index = 1
    max_file_index = max(item['new_file_index'] for item in merge_plan)

    with tqdm(total=len(merge_plan), desc="Merging layers") as pbar:
        for file_index in range(1, max_file_index + 1):
            for item in merge_plan:
                if item['new_file_index'] == file_index:
                    input_file = os.path.join(input_dir, f"model-{item['original_file_index']:05d}-of-00051.safetensors")
                    with safe_open(input_file, framework="pt", device="cpu") as f:
                        tensor = f.get_tensor(item['old_key'])
                        output_tensors[item['new_key']] = tensor
                    pbar.update(1)

            if output_tensors:
                output_file = os.path.join(output_dir, f"model-{file_index:05d}-of-{max_file_index:05d}.safetensors")
                save_file(output_tensors, output_file)
                output_tensors = {}

    print(f"Merged model saved to {output_dir}")

def main():
    parser = argparse.ArgumentParser(description="Merge and split Mistral model")
    parser.add_argument("input_dir", help="Directory containing input safetensors files")
    parser.add_argument("output_dir", help="Directory for output safetensors files")
    parser.add_argument("--dry-run", action="store_true", help="Perform a dry run and output merge plan")
    args = parser.parse_args()

    layer_config = [
        {'layer_range': [0, 20]},
        {'layer_range': [10, 30]},
        {'layer_range': [20, 40]},
        {'layer_range': [30, 50]},
        {'layer_range': [40, 60]},
        {'layer_range': [50, 70]},
        {'layer_range': [60, 80]},
        {'layer_range': [70, 87]}
    ]

    tensor_locations = get_tensor_locations(args.input_dir)
    merge_plan = create_merge_plan(tensor_locations, layer_config)

    if args.dry_run:
        print("Merge plan:")
        print(json.dumps(merge_plan, indent=2))
        with open("merge_plan_large.json", "w") as f:
            json.dump(merge_plan, f, indent=2)
        print("Merge plan saved to merge_plan.json")
    else:
        os.makedirs(args.output_dir, exist_ok=True)
        merge_layers(args.input_dir, args.output_dir, merge_plan)
        print(f"Merged model saved to {args.output_dir}")

if __name__ == "__main__":
    main()