|
import torch |
|
from torch import nn |
|
from transformers import HubertConfig, HubertModel |
|
import logging |
|
|
|
|
|
logging.getLogger("fairseq").setLevel(logging.WARNING) |
|
logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING) |
|
|
|
from fairseq import checkpoint_utils |
|
|
|
models, _, _ = checkpoint_utils.load_model_ensemble_and_task( |
|
["content-vec-best-legacy-500.pt"], suffix="" |
|
) |
|
model = models[0] |
|
model.eval() |
|
model.eval() |
|
|
|
|
|
class HubertModelWithFinalProj(HubertModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) |
|
|
|
|
|
|
|
hubert = HubertModelWithFinalProj(HubertConfig()) |
|
|
|
|
|
mapping = { |
|
"masked_spec_embed": "mask_emb", |
|
"encoder.layer_norm.bias": "encoder.layer_norm.bias", |
|
"encoder.layer_norm.weight": "encoder.layer_norm.weight", |
|
"encoder.pos_conv_embed.conv.bias": "encoder.pos_conv.0.bias", |
|
"encoder.pos_conv_embed.conv.weight_g": "encoder.pos_conv.0.weight_g", |
|
"encoder.pos_conv_embed.conv.weight_v": "encoder.pos_conv.0.weight_v", |
|
"feature_projection.layer_norm.bias": "layer_norm.bias", |
|
"feature_projection.layer_norm.weight": "layer_norm.weight", |
|
"feature_projection.projection.bias": "post_extract_proj.bias", |
|
"feature_projection.projection.weight": "post_extract_proj.weight", |
|
"final_proj.bias": "final_proj.bias", |
|
"final_proj.weight": "final_proj.weight", |
|
} |
|
|
|
|
|
for layer in range(12): |
|
for j in ["q", "k", "v"]: |
|
mapping[ |
|
f"encoder.layers.{layer}.attention.{j}_proj.weight" |
|
] = f"encoder.layers.{layer}.self_attn.{j}_proj.weight" |
|
mapping[ |
|
f"encoder.layers.{layer}.attention.{j}_proj.bias" |
|
] = f"encoder.layers.{layer}.self_attn.{j}_proj.bias" |
|
|
|
mapping[ |
|
f"encoder.layers.{layer}.final_layer_norm.bias" |
|
] = f"encoder.layers.{layer}.final_layer_norm.bias" |
|
mapping[ |
|
f"encoder.layers.{layer}.final_layer_norm.weight" |
|
] = f"encoder.layers.{layer}.final_layer_norm.weight" |
|
|
|
mapping[ |
|
f"encoder.layers.{layer}.layer_norm.bias" |
|
] = f"encoder.layers.{layer}.self_attn_layer_norm.bias" |
|
mapping[ |
|
f"encoder.layers.{layer}.layer_norm.weight" |
|
] = f"encoder.layers.{layer}.self_attn_layer_norm.weight" |
|
|
|
mapping[ |
|
f"encoder.layers.{layer}.attention.out_proj.bias" |
|
] = f"encoder.layers.{layer}.self_attn.out_proj.bias" |
|
mapping[ |
|
f"encoder.layers.{layer}.attention.out_proj.weight" |
|
] = f"encoder.layers.{layer}.self_attn.out_proj.weight" |
|
|
|
mapping[ |
|
f"encoder.layers.{layer}.feed_forward.intermediate_dense.bias" |
|
] = f"encoder.layers.{layer}.fc1.bias" |
|
mapping[ |
|
f"encoder.layers.{layer}.feed_forward.intermediate_dense.weight" |
|
] = f"encoder.layers.{layer}.fc1.weight" |
|
|
|
mapping[ |
|
f"encoder.layers.{layer}.feed_forward.output_dense.bias" |
|
] = f"encoder.layers.{layer}.fc2.bias" |
|
mapping[ |
|
f"encoder.layers.{layer}.feed_forward.output_dense.weight" |
|
] = f"encoder.layers.{layer}.fc2.weight" |
|
|
|
|
|
for layer in range(7): |
|
mapping[ |
|
f"feature_extractor.conv_layers.{layer}.conv.weight" |
|
] = f"feature_extractor.conv_layers.{layer}.0.weight" |
|
|
|
if layer != 0: |
|
continue |
|
|
|
mapping[ |
|
f"feature_extractor.conv_layers.{layer}.layer_norm.weight" |
|
] = f"feature_extractor.conv_layers.{layer}.2.weight" |
|
mapping[ |
|
f"feature_extractor.conv_layers.{layer}.layer_norm.bias" |
|
] = f"feature_extractor.conv_layers.{layer}.2.bias" |
|
|
|
hf_keys = set(hubert.state_dict().keys()) |
|
fair_keys = set(model.state_dict().keys()) |
|
|
|
hf_keys -= set(mapping.keys()) |
|
fair_keys -= set(mapping.values()) |
|
|
|
for i, j in zip(sorted(hf_keys), sorted(fair_keys)): |
|
print(i, j) |
|
|
|
print(hf_keys, fair_keys) |
|
print(len(hf_keys), len(fair_keys)) |
|
|
|
|
|
new_state_dict = {} |
|
for k, v in mapping.items(): |
|
new_state_dict[k] = model.state_dict()[v] |
|
|
|
x = hubert.load_state_dict(new_state_dict, strict=False) |
|
print(x) |
|
hubert.eval() |
|
|
|
with torch.no_grad(): |
|
new_input = torch.randn(1, 16384) |
|
|
|
result1 = hubert(new_input, output_hidden_states=True)["hidden_states"][9] |
|
result1 = hubert.final_proj(result1) |
|
|
|
result2 = model.extract_features( |
|
**{ |
|
"source": new_input, |
|
"padding_mask": torch.zeros(1, 16384, dtype=torch.bool), |
|
|
|
"output_layer": 9, |
|
} |
|
)[0] |
|
result2 = model.final_proj(result2) |
|
|
|
assert torch.allclose(result1, result2, atol=1e-3) |
|
|
|
print("Sanity check passed") |
|
|
|
|
|
hubert.save_pretrained(".") |
|
print("Saved model") |
|
|