import os | |
import torch | |
from convert_nemo_ul2_checkpoint import convert_nemo_to_hf | |
from transformers import T5ForConditionalGeneration, AutoTokenizer | |
#### Step 1: Convert the original HF model which was converted to NEMO back to HF weights | |
nemo_weights = torch.load("ul2-base-nl36-finnish/nemo_state_dict.pt") | |
hf_weights = convert_nemo_to_hf(nemo_weights) | |
#### Step 2: Load original HF model and save its config/tokenizer in local folder | |
hf_model = T5ForConditionalGeneration.from_pretrained("Finnish-NLP/ul2-base-nl36-finnish") | |
tokenizer = AutoTokenizer.from_pretrained("Finnish-NLP/ul2-base-nl36-finnish") | |
# Save tokenizer in ul2-base-nl36-finnish | |
tokenizer.save_pretrained("ul2-base-nl36-finnish/hf_t5_ul2") | |
# Save config in ul2-base-nl36-finnish | |
hf_model.config.save_pretrained("ul2-base-nl36-finnish/hf_t5_ul2") | |
#### Step 3: Save our converted weights to the local folder | |
# Save converted model weights in ul2-base-nl36-finnish | |
torch.save(hf_weights, os.path.join("ul2-base-nl36-finnish/hf_t5_ul2", "pytorch_model.bin")) | |
#### Step4: Load the converted model from local folder and check whether weights are the same | |
converted_model = T5ForConditionalGeneration.from_pretrained("ul2-base-nl36-finnish/hf_t5_ul2") | |
equal = [] | |
for key in hf_model.state_dict().keys(): | |
print(key) | |
print(torch.allclose(hf_model.state_dict()[key], converted_model.state_dict()[key])) | |
equal.append(torch.allclose(hf_model.state_dict()[key], converted_model.state_dict()[key])) | |
print(f"All weights are equal: {all(equal)}") | |