File size: 2,287 Bytes
7945f87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# /// script
# requires-python = ">=3.11"
# dependencies = [
#     "coremltools",
#     "torch",
#     "transformers",
# ]
#
# [tool.uv.sources]
# transformers = { git = "https://github.com/huggingface/transformers.git", branch = "main" }
# ///

from transformers import AutoModelForMaskedLM
import torch
import coremltools as ct
import numpy as np
import argparse


def log(text):
    print(f"\033[92m\033[1m{text}\033[0m")


parser = argparse.ArgumentParser(
    prog="convert.py", description="Convert ModernBERT to CoreML"
)
parser.add_argument("--model", type=str, default="ModernBERT-base", help="Model name")
parser.add_argument("--quantize", action="store_true", help="Linear quantize model")
args = parser.parse_args()


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained(f"answerdotai/{args.model}")

    def forward(self, input_ids):
        attention_mask = torch.ones_like(input_ids)
        return self.model(input_ids=input_ids, attention_mask=attention_mask).logits


log("Loading model…")
model = Model().eval()

log("Tracing model…")
example_input = (torch.zeros((1, 1), dtype=torch.int32),)
traced_model = torch.jit.trace(model, example_input)

log("Converting model…")
input_shape = (
    1,
    ct.RangeDim(
        lower_bound=1, upper_bound=model.model.config.max_position_embeddings, default=1
    ),
)
mlmodel = ct.convert(
    traced_model,
    inputs=[ct.TensorType(name="input_ids", shape=input_shape, dtype=np.int32)],
    outputs=[ct.TensorType(name="logits")],
    minimum_deployment_target=ct.target.macOS15,
)

if args.quantize:
    log("Quantizing model…")
    op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
        mode="linear_symmetric",
        dtype="int4",
        granularity="per_block",
        block_size=32,
    )
    config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
    mlmodel = ct.optimize.coreml.linear_quantize_weights(mlmodel, config=config)

mlmodel.author = "Finn Voorhees"
mlmodel.short_description = "https://hf.co/finnvoorhees/ModernBERT-CoreML"

log("Saving mlmodel…")
if args.quantize:
    mlmodel.save(f"{args.model}-4bit.mlpackage")
else:
    mlmodel.save(f"{args.model}.mlpackage")

log("Done!")