QuietImpostor commited on
Commit
1b52fc3
1 Parent(s): 5ad024a

Conversion script!

Browse files

Should work the same, but much speedier and fixes the tensor issue.

Files changed (1) hide show
  1. convert.py +327 -0
convert.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # converter.py
2
+
3
+ import sys
4
+ import torch
5
+ import safetensors.torch as st
6
+ import logging
7
+ import math
8
+ import tflite.Model
9
+ import tflite.SubGraph
10
+ from tflite.TensorType import TensorType
11
+
12
+ # Set up logging
13
+ logger = logging.getLogger(__name__)
14
+ logging.basicConfig(
15
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
16
+ level=logging.INFO
17
+ )
18
+
19
+ # Define scale and size mappings
20
+ name_of_tensor_type = {
21
+ 0: "FLOAT32",
22
+ 9: "INT8 ",
23
+ 17: "INT4 ",
24
+ }
25
+
26
+ dtype_for_tensor_type = {
27
+ 0: torch.float32,
28
+ 9: torch.int8,
29
+ 17: torch.uint8, # Because torch.int4 doesn't exist
30
+ }
31
+
32
+ size_for_tensor_type = {
33
+ 0: 4,
34
+ 9: 1,
35
+ 17: 0.5,
36
+ }
37
+
38
+ # Function to update target tensor names
39
+ def update_target_name(target_name: str) -> str:
40
+ """Updates the target name to match the tensor name convention."""
41
+ def reverse_replace(theStr: str, a, b):
42
+ return theStr.replace(b, a)
43
+
44
+ target_name = reverse_replace(target_name, ".weight", ".w")
45
+ target_name = reverse_replace(target_name,
46
+ "model.layers.", "params.lm.transformer.x_layers_"
47
+ )
48
+
49
+ target_name = reverse_replace(target_name,
50
+ "mlp.gate_proj", "ff_layer.ffn_layer1_gate"
51
+ )
52
+ target_name = reverse_replace(target_name, "mlp.up_proj", "ff_layer.ffn_layer1")
53
+ target_name = reverse_replace(target_name, "mlp.down_proj", "ff_layer.ffn_layer2")
54
+
55
+ target_name = reverse_replace(target_name,
56
+ "post_layer_norm.weight", "post_layer_norm.scale"
57
+ )
58
+ target_name = reverse_replace(target_name,
59
+ "post_attention_layernorm", "post_layer_norm"
60
+ )
61
+
62
+ target_name = reverse_replace(target_name,
63
+ "pre_layer_norm.weight", "pre_layer_norm.scale"
64
+ )
65
+ target_name = reverse_replace(target_name, "input_layernorm", "pre_layer_norm")
66
+
67
+ target_name = reverse_replace(target_name, "self_attn.q_proj", "self_attention.q")
68
+ target_name = reverse_replace(target_name, "self_attn.k_proj", "self_attention.k")
69
+ target_name = reverse_replace(target_name, "self_attn.v_proj", "self_attention.v")
70
+ target_name = reverse_replace(target_name, "self_attn.o_proj", "self_attention.post")
71
+ target_name = reverse_replace(target_name,
72
+ "model.embed_tokens", "params.lm.softmax.logits_ffn"
73
+ )
74
+ target_name = reverse_replace(target_name, "final_ln.weight", "final_ln.scale")
75
+ target_name = reverse_replace(target_name, "model.norm", "params.lm.final_ln")
76
+
77
+ return target_name
78
+
79
+ # Optimized dequantization for INT4
80
+ def convert_quantized_int4_to_fp(quantized_data, scale_data, dims, dim_scale, dtype):
81
+ zero_point = 8
82
+
83
+ # Reshape quantized data to 1D tensor
84
+ quantized_data = quantized_data.view(-1)
85
+
86
+ # Extract low and high 4 bits
87
+ low_bits = (quantized_data & 0x0F).type(torch.int8)
88
+ high_bits = (quantized_data >> 4).type(torch.int8)
89
+
90
+ # Concatenate low and high bits
91
+ int4_values = torch.stack((low_bits, high_bits), dim=1).view(-1)
92
+ int4_values = int4_values - zero_point # Adjust zero point
93
+
94
+ # Apply scaling
95
+ scaled_data = int4_values.type(dtype) * scale_data
96
+
97
+ # Reshape to original dimensions
98
+ scaled_data = scaled_data.view(dims[0], dims[1])
99
+
100
+ return scaled_data
101
+
102
+ # Function to dequantize INT8
103
+ def convert_quantized_int8_to_fp(quantized_data, scale_data, dims, dim_scale, dtype):
104
+ zero_point = 0 # Assuming zero_point=0 for int8
105
+
106
+ # Reshape quantized data to 1D tensor
107
+ quantized_data = quantized_data.view(-1).type(torch.int8)
108
+
109
+ # Handle scale_data based on dim_scale
110
+ if dim_scale:
111
+ # Per-column scaling
112
+ scale_data = scale_data.repeat_interleave(2)
113
+ else:
114
+ # Per-row scaling
115
+ scale_data = scale_data.repeat_interleave(2)
116
+
117
+ # Convert scale_data to the same dtype
118
+ scale_data = scale_data.to(dtype=dtype)
119
+
120
+ # Apply scaling
121
+ scaled_data = (quantized_data - zero_point).type(dtype) * scale_data
122
+
123
+ # Reshape to original dimensions
124
+ scaled_data = scaled_data.view(dims[0], dims[1])
125
+
126
+ return scaled_data
127
+
128
+ def main():
129
+ # Check command-line arguments
130
+ if len(sys.argv) < 3:
131
+ print("Usage: python converter.py <path_to_tflite_model> <output_safetensors_file> [fp32|fp16|bf16]")
132
+ sys.exit(1)
133
+
134
+ tflite_model_path = sys.argv[1]
135
+ output_safetensors_path = sys.argv[2]
136
+ dtype_arg = sys.argv[3] if len(sys.argv) >= 4 else "fp32"
137
+
138
+ if dtype_arg == "fp32":
139
+ TARGET_DTYPE = torch.float32
140
+ elif dtype_arg == "fp16":
141
+ TARGET_DTYPE = torch.float16
142
+ elif dtype_arg == "bf16":
143
+ TARGET_DTYPE = torch.bfloat16
144
+ else:
145
+ print("Unsupported dtype. Choose from fp32, fp16, bf16.")
146
+ sys.exit(1)
147
+
148
+ logger.info(f"Starting conversion with TARGET_DTYPE={TARGET_DTYPE}")
149
+
150
+ # Read the TFLite model
151
+ with open(tflite_model_path, "rb") as input_file:
152
+ buf = bytearray(input_file.read())
153
+
154
+ model: tflite.Model.Model = tflite.Model.Model.GetRootAs(buf)
155
+ graph: tflite.SubGraph.SubGraph = model.Subgraphs(0)
156
+
157
+ # Initialize dictionaries to hold tensors
158
+ i4_tensors = {}
159
+ i8_tensors = {}
160
+ fp32_tensors = {}
161
+ scale_tensors = {}
162
+ tensor_dims = {}
163
+
164
+ # Read and sort tensors
165
+ for i in range(graph.TensorsLength()):
166
+ tensor = graph.Tensors(i)
167
+ tensor_name = tensor.Name().decode("utf-8")
168
+ tensor_type: TensorType = tensor.Type()
169
+
170
+ if tensor_name.endswith(".w_quantized_scale"):
171
+ scale_tensors[tensor_name] = tensor
172
+ elif tensor_type == TensorType.INT4:
173
+ i4_tensors[tensor_name] = tensor
174
+ elif tensor_type == TensorType.INT8:
175
+ i8_tensors[tensor_name] = tensor
176
+ elif tensor_type == TensorType.FLOAT32:
177
+ fp32_tensors[tensor_name] = tensor
178
+
179
+ tensor_buf_size = tensor.Shape(0)
180
+ tensor_size = tensor_buf_size // size_for_tensor_type[tensor_type]
181
+
182
+ shape = None
183
+ if (".self_attention.q." in tensor_name
184
+ or ".self_attention.post." in tensor_name) and tensor_size == 4_194_304:
185
+ shape = (2048, 2048)
186
+ elif (".self_attention.k." in tensor_name
187
+ or ".self_attention.v." in tensor_name) and tensor_size == 524_288:
188
+ shape = (256, 2048)
189
+ elif (".ff_layer.ffn_layer1_gate." in tensor_name
190
+ or ".ff_layer.ffn_layer1." in tensor_name) and tensor_size == 25_165_824:
191
+ shape = (12_288, 2048)
192
+ elif ".ff_layer.ffn_layer2." in tensor_name and tensor_size == 25_165_824:
193
+ shape = (2048, 12_288)
194
+ elif "params.lm.softmax.logits_ffn.w" == tensor_name and tensor_size == 524_550_144:
195
+ shape = (256_128, 2048)
196
+ # LayerNorm weights are of shape {1, 1, 2048}
197
+ elif "layer_norm" in tensor_name and tensor_size == 2048:
198
+ shape = (1, 1, 2048)
199
+ else:
200
+ # Default to 1D if shape is unknown
201
+ pass
202
+
203
+ tensor_dims[tensor_name] = shape
204
+
205
+ # Dictionary to hold dequantized tensors
206
+ tensor_dict = {}
207
+
208
+ # Dequantize FP32 tensors
209
+ for tensor_name, tensor in fp32_tensors.items():
210
+ logger.info(f"Saving fp32 {tensor_name}...")
211
+ buffer_meta = model.Buffers(tensor.Buffer())
212
+ dims = tensor_dims.get(tensor_name)
213
+
214
+ target_name = update_target_name(tensor_name)
215
+
216
+ tensor_data = torch.frombuffer(buffer=buf,
217
+ dtype=torch.float32,
218
+ offset=buffer_meta.Offset(),
219
+ count=buffer_meta.Size() // 4)
220
+
221
+ # Assign reshaped tensor back
222
+ if dims is not None:
223
+ tensor_data = tensor_data.reshape(dims)
224
+
225
+ if TARGET_DTYPE != torch.float32:
226
+ tensor_data = tensor_data.to(dtype=TARGET_DTYPE)
227
+
228
+ tensor_dict[target_name] = tensor_data
229
+
230
+ del fp32_tensors
231
+
232
+ # Dequantize INT8 tensors
233
+ for tensor_name, quantized_tensor in i8_tensors.items():
234
+ buffer_meta = model.Buffers(quantized_tensor.Buffer())
235
+ scale_tensor_name = tensor_name + "_quantized_scale"
236
+ scale_buf_meta = model.Buffers(scale_tensors[scale_tensor_name].Buffer())
237
+ dims = tensor_dims.get(tensor_name)
238
+
239
+ logger.info(f"Dequantizing int8 {dims} {tensor_name}...")
240
+
241
+ target_name = update_target_name(tensor_name)
242
+
243
+ quantized_buf = torch.frombuffer(buffer=buf,
244
+ dtype=torch.int8,
245
+ offset=buffer_meta.Offset(),
246
+ count=buffer_meta.Size())
247
+
248
+ scale_buf = torch.frombuffer(buffer=buf,
249
+ dtype=torch.float32,
250
+ offset=scale_buf_meta.Offset(),
251
+ count=scale_buf_meta.Size() // 4)
252
+
253
+ # MediaPipe TfLiteWeightAccessor::BuildWeightsMapFromTfliteModel sets
254
+ # dim_scale=0, so we do the same.
255
+ tensor_data = convert_quantized_int8_to_fp(
256
+ quantized_data=quantized_buf,
257
+ scale_data=scale_buf,
258
+ dims=dims,
259
+ dim_scale=0,
260
+ dtype=TARGET_DTYPE
261
+ )
262
+
263
+ tensor_dict[target_name] = tensor_data
264
+
265
+ del quantized_buf, scale_buf
266
+
267
+ del i8_tensors
268
+
269
+ # Dequantize INT4 tensors
270
+ for tensor_name, quantized_tensor in i4_tensors.items():
271
+ buffer_meta = model.Buffers(quantized_tensor.Buffer())
272
+ scale_tensor_name = tensor_name + "_quantized_scale"
273
+ scale_buf_meta = model.Buffers(scale_tensors[scale_tensor_name].Buffer())
274
+ dims = tensor_dims.get(tensor_name)
275
+
276
+ logger.info(f"Dequantizing int4 {dims} {tensor_name}...")
277
+
278
+ target_name = update_target_name(tensor_name)
279
+
280
+ quantized_buf = torch.frombuffer(buffer=buf,
281
+ dtype=torch.uint8,
282
+ offset=buffer_meta.Offset(),
283
+ count=buffer_meta.Size())
284
+
285
+ scale_buf = torch.frombuffer(buffer=buf,
286
+ dtype=torch.float32,
287
+ offset=scale_buf_meta.Offset(),
288
+ count=scale_buf_meta.Size() // 4)
289
+
290
+ # Special handling for 'logits_ffn.w_quantized_scale'
291
+ if 'logits_ffn.w_quantized_scale' in tensor_name:
292
+ # Assuming two scale factors per row, average them
293
+ if scale_buf.numel() % 2 != 0:
294
+ logger.error(f"Scale data size for {tensor_name} is not even. Cannot average.")
295
+ sys.exit(1)
296
+ scale_data = scale_buf.view(-1, 2).mean(dim=1) # Average every two scale factors
297
+ # Repeat each scale factor twice to match the two int4 values
298
+ scale_data = scale_data.repeat_interleave(2)
299
+ else:
300
+ # General handling: per-row scaling, repeat each scale factor twice
301
+ scale_data = scale_buf.repeat_interleave(2)
302
+
303
+ # Convert and reshape quantized_data
304
+ tensor_data = convert_quantized_int4_to_fp(
305
+ quantized_data=quantized_buf,
306
+ scale_data=scale_data,
307
+ dims=dims,
308
+ dim_scale=0,
309
+ dtype=TARGET_DTYPE
310
+ )
311
+
312
+ tensor_dict[target_name] = tensor_data
313
+
314
+ del quantized_buf, scale_buf
315
+
316
+ del i4_tensors
317
+ del scale_tensors
318
+
319
+ del buf, model, graph
320
+
321
+ # Save all tensors to the safetensors file
322
+ logger.info(f"Saving to {output_safetensors_path}...")
323
+ st.save_file(tensor_dict, output_safetensors_path)
324
+ logger.info(f"Success! Saved to {output_safetensors_path}")
325
+
326
+ if __name__ == "__main__":
327
+ main()