bababababooey commited on
Commit
eab6e7f
1 Parent(s): 3be96ad

Upload hotswap.py

Browse files
Files changed (1) hide show
  1. swapper/hotswap.py +158 -0
swapper/hotswap.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
3
+ import os
4
+ import json
5
+ from safetensors import safe_open
6
+ import re
7
+
8
+ # apologies in advance for shitty gpt-assisted code
9
+
10
+ # this script should also work with 70b/90b if you change `cross_attention_layers` and `total_layers` accordingly
11
+ # but i dont have enough deditated wam to test it and i dont feel like spinning up runpod so
12
+
13
+ cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38]
14
+
15
+ #b8 = './models/mlabonne_Meta-Llama-3.1-8B-Instruct-abliterated'
16
+ b8 = './models/v000000_L3-8B-Stheno-v3.2-abliterated'
17
+ #b8 = './models/arcee-ai_Llama-3.1-SuperNova-Lite'
18
+ print(b8)
19
+
20
+ model_id = "./models/meta-llama_Llama-3.2-11B-Vision-Instruct"
21
+
22
+ def create_layer_mapping(total_layers=32, cross_attn_layers=cross_attention_layers):
23
+ """
24
+ Creates a mapping from llama-3.1-8b layer indices to llama-3.2-11b layer indices.
25
+ """
26
+ mapping = {}
27
+ shift = 0
28
+ next_cross_attn_idx = 0
29
+ for X in range(total_layers):
30
+ # Check if a cross-attention layer is inserted before this layer
31
+ if next_cross_attn_idx < len(cross_attn_layers) and (X + shift) == cross_attn_layers[next_cross_attn_idx]:
32
+ shift += 1
33
+ next_cross_attn_idx += 1
34
+ Y = X + shift
35
+ mapping[X] = Y
36
+ return mapping
37
+
38
+ def load_sharded_state_dict(model_dir):
39
+ index_file = os.path.join(model_dir, 'model.safetensors.index.json')
40
+ with open(index_file, 'r') as f:
41
+ index_data = json.load(f)
42
+ weight_map = index_data['weight_map']
43
+ state_dict = {}
44
+ shard_to_params = {}
45
+ for param_name, shard_file in weight_map.items():
46
+ if shard_file not in shard_to_params:
47
+ shard_to_params[shard_file] = []
48
+ shard_to_params[shard_file].append(param_name)
49
+ for shard_file, params_in_shard in shard_to_params.items():
50
+ shard_path = os.path.join(model_dir, shard_file)
51
+ with safe_open(shard_path, framework="pt", device="cpu") as f:
52
+ for name in params_in_shard:
53
+ state_dict[name] = f.get_tensor(name)
54
+ return state_dict
55
+
56
+ def compare_model_states(model, new_state_dict):
57
+ current_state = model.state_dict()
58
+ unchanged_params = []
59
+ changed_params = []
60
+ missing_params = []
61
+
62
+ for name, param in current_state.items():
63
+ if name not in new_state_dict:
64
+ missing_params.append(name)
65
+ elif torch.equal(param.cpu(), new_state_dict[name].cpu()):
66
+ unchanged_params.append(name)
67
+ else:
68
+ changed_params.append(name)
69
+
70
+ return {
71
+ 'unchanged': unchanged_params,
72
+ 'changed': changed_params,
73
+ 'missing': missing_params
74
+ }
75
+
76
+
77
+ layer_mapping = create_layer_mapping()
78
+
79
+ # Load Llama 3.2 state dict
80
+ llama_3_2_state_dict = load_sharded_state_dict(model_id)
81
+
82
+ # Extract the embedding matrix from Llama 3.2
83
+ llama_3_2_embeddings = llama_3_2_state_dict['language_model.model.embed_tokens.weight'] # Shape: [128264, 4096]
84
+
85
+ llama_3_2_state_dict.clear()
86
+
87
+ b8dict = load_sharded_state_dict(b8)
88
+
89
+ embed_tokens_weight = b8dict['model.embed_tokens.weight'] # Shape: [128256, 4096]
90
+ new_vocab_size = 128264 # From Llama 3.2
91
+ new_embed_tokens_weight = torch.zeros((new_vocab_size, 4096), dtype=embed_tokens_weight.dtype)
92
+
93
+ # Copy the existing embeddings
94
+ new_embed_tokens_weight[:128256, :] = embed_tokens_weight
95
+ # Copy the additional embeddings from Llama 3.2
96
+ new_embed_tokens_weight[128256:, :] = llama_3_2_embeddings[128256:, :]
97
+
98
+ b8dict['model.embed_tokens.weight'] = new_embed_tokens_weight
99
+
100
+
101
+ llama_3_2_embeddings = None
102
+
103
+ # Adjust Llama 3.1 parameter names to match Llama 3.2 language model
104
+ st8dict = {}
105
+ for name, param in b8dict.items():
106
+ # Prefix non-layer parameters with 'language_model.'
107
+ if not re.match(r'model\.layers\.\d+\.', name):
108
+ new_name = 'language_model.' + name
109
+ else:
110
+ # Extract the layer index X from 'model.layers.X.'
111
+ match = re.match(r'model\.layers\.(\d+)\.(.+)', name)
112
+ if match:
113
+ X = int(match.group(1))
114
+ suffix = match.group(2)
115
+ # Get the corresponding Y in llama-3.2-11b
116
+ Y = layer_mapping.get(X, X + len(cross_attention_layers))
117
+ new_name = f'language_model.model.layers.{Y}.{suffix}'
118
+ else:
119
+ # If the pattern doesn't match, just prefix with 'language_model.'
120
+ new_name = 'language_model.' + name
121
+ st8dict[new_name] = param
122
+
123
+ #write st8dict keys to file for verification
124
+ with open('st8dict.txt', 'w') as f:
125
+ f.write('\n'.join(st8dict.keys()))
126
+
127
+
128
+ model = MllamaForConditionalGeneration.from_pretrained(
129
+ model_id,
130
+ torch_dtype=torch.bfloat16,
131
+ device_map="cpu",
132
+ )
133
+
134
+ #original_state = {k: v.clone() for k, v in model.state_dict().items()}
135
+
136
+ model.load_state_dict(st8dict, strict=False)
137
+
138
+ b8dict.clear()
139
+ st8dict.clear()
140
+
141
+
142
+ '''
143
+ result = compare_model_states(model, original_state)
144
+
145
+ print("Unchanged parameters:", len(result['unchanged']))
146
+ print("Changed parameters:", len(result['changed']))
147
+ print("Missing parameters:", len(result['missing']))
148
+
149
+ #write result to file
150
+ with open('result.txt', 'w') as f:
151
+ f.write(json.dumps(result, indent=2))
152
+ '''
153
+
154
+
155
+ processor = AutoProcessor.from_pretrained(model_id)
156
+
157
+
158
+ model.save_pretrained("llama-3.2-11b-vision-stheno-abliterated")