|
import argparse |
|
import json |
|
import os |
|
import re |
|
import zipfile |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def recursive_print(name, val, spaces=0): |
|
|
|
if name is None: |
|
msg = None |
|
else: |
|
fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" |
|
msg = fmt.format(name) |
|
|
|
|
|
if isinstance(val, dict): |
|
if msg is not None: |
|
print(msg) |
|
for k in val.keys(): |
|
recursive_print(k, val[k], spaces + 2) |
|
elif isinstance(val, torch.Tensor): |
|
print(msg, ":", val.size()) |
|
else: |
|
print(msg, ":", val) |
|
|
|
|
|
def convert_megatron_checkpoint(input_state_dict, head_model=True): |
|
|
|
output_state_dict = {} |
|
|
|
|
|
model = input_state_dict["model"] |
|
|
|
lm = model["language_model"] |
|
|
|
embeddings = lm["embedding"] |
|
|
|
|
|
word_embeddings = embeddings["word_embeddings"]["weight"] |
|
|
|
output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings |
|
|
|
|
|
pos_embeddings = embeddings["position_embeddings"]["weight"] |
|
|
|
assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024 |
|
|
|
output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings |
|
|
|
|
|
tokentype_embeddings = embeddings["tokentype_embeddings"]["weight"] |
|
|
|
output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings |
|
|
|
|
|
transformer = lm["transformer"] |
|
|
|
|
|
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") |
|
|
|
|
|
megatron_to_transformers = { |
|
"attention.dense": ".attention.output.dense.", |
|
"mlp.dense_h_to_4h": ".intermediate.dense.", |
|
"mlp.dense_4h_to_h": ".output.dense.", |
|
} |
|
|
|
|
|
attention_qkv_weight = None |
|
|
|
|
|
for key, val in transformer.items(): |
|
|
|
m = layer_re.match(key) |
|
|
|
|
|
if m is None: |
|
break |
|
|
|
|
|
layer_idx = int(m.group(1)) |
|
|
|
op_name = m.group(2) |
|
|
|
weight_or_bias = m.group(3) |
|
|
|
|
|
layer_name = f"bert.encoder.layer.{layer_idx}" |
|
|
|
|
|
if op_name.endswith("layernorm"): |
|
|
|
ln_name = "attention.ln" if op_name.startswith("input") else "ln" |
|
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val |
|
|
|
|
|
elif op_name == "attention.query_key_value" and weight_or_bias == "weight": |
|
|
|
|
|
assert attention_qkv_weight is None, "" |
|
|
|
|
|
attention_qkv_weight = val |
|
|
|
|
|
elif op_name == "attention.query_key_value" and weight_or_bias == "bias": |
|
|
|
|
|
assert attention_qkv_weight is not None, "" |
|
|
|
|
|
q = attention_qkv_weight[0 * 1024 : 1 * 1024, :] |
|
k = attention_qkv_weight[1 * 1024 : 2 * 1024, :] |
|
v = attention_qkv_weight[2 * 1024 : 3 * 1024, :] |
|
|
|
|
|
q_bias = val[0 * 1024 : 1 * 1024] |
|
k_bias = val[1 * 1024 : 2 * 1024] |
|
v_bias = val[2 * 1024 : 3 * 1024] |
|
|
|
|
|
output_state_dict[f"{layer_name}.attention.self.query.weight"] = q |
|
output_state_dict[f"{layer_name}.attention.self.query.bias"] = q_bias |
|
output_state_dict[f"{layer_name}.attention.self.key.weight"] = k |
|
output_state_dict[f"{layer_name}.attention.self.key.bias"] = k_bias |
|
output_state_dict[f"{layer_name}.attention.self.value.weight"] = v |
|
output_state_dict[f"{layer_name}.attention.self.value.bias"] = v_bias |
|
|
|
|
|
attention_qkv_weight = None |
|
|
|
|
|
elif weight_or_bias in ["weight", "bias"]: |
|
|
|
out_name = megatron_to_transformers[op_name] |
|
output_state_dict[layer_name + out_name + weight_or_bias] = val |
|
|
|
|
|
output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"] |
|
output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"] |
|
|
|
|
|
output_config = { |
|
"vocab_size": word_embeddings.size(0), |
|
"hidden_size": 1024, |
|
"num_hidden_layers": 24, |
|
"num_attention_heads": 16, |
|
"hidden_act": "gelu_new", |
|
"intermediate_size": 4096, |
|
"hidden_dropout_prob": 0.1, |
|
"attention_probs_dropout_prob": 0.1, |
|
"max_position_embeddings": 512, |
|
"type_vocab_size": 2, |
|
"initializer_range": 0.2, |
|
"layer_norm_eps": 1e-12, |
|
"position_embedding_type": "absolute", |
|
"use_cache": False, |
|
"model_type": "megatron-bert", |
|
} |
|
|
|
if head_model: |
|
|
|
pooler = lm["pooler"] |
|
|
|
|
|
output_state_dict["bert.pooler.dense.weight"] = pooler["dense.weight"] |
|
output_state_dict["bert.pooler.dense.bias"] = pooler["dense.bias"] |
|
|
|
|
|
lm_head = model["lm_head"] |
|
|
|
|
|
output_state_dict["cls.predictions.transform.dense.weight"] = lm_head["dense.weight"] |
|
output_state_dict["cls.predictions.transform.dense.bias"] = lm_head["dense.bias"] |
|
|
|
|
|
output_state_dict["cls.predictions.transform.LayerNorm.weight"] = lm_head["layernorm.weight"] |
|
output_state_dict["cls.predictions.transform.LayerNorm.bias"] = lm_head["layernorm.bias"] |
|
|
|
|
|
output_state_dict["cls.predictions.decoder.weight"] = word_embeddings |
|
output_state_dict["cls.predictions.bias"] = lm_head["bias"] |
|
|
|
|
|
binary_head = model["binary_head"] |
|
|
|
|
|
output_state_dict["cls.seq_relationship.weight"] = binary_head["weight"] |
|
output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"] |
|
|
|
|
|
return output_state_dict, output_config |
|
|