File size: 941 Bytes
0b289a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
# from models import FiDT5_meta
from transformers import T5ForConditionalGeneration

# load fid model checkppints
model_old = torch.load('models/ckpt/fidt5-base-nq/pytorch_model.bin', map_location='cpu')

model_new = T5ForConditionalGeneration.from_pretrained('t5-base')

# compare state dict
model_new_keys = sorted(list(model_new.state_dict().keys()))
model_old_keys = sorted(list(model_old.keys()))

# change key map
for k in model_old_keys:
    k_prime = k.replace('encoder.encoder', 'encoder')
    k_prime = k_prime.replace('module.layer', 'layer')
    model_old[k_prime] = model_old.pop(k)

# validate if the old keys align the new one
# model_old_keys = sorted(list(model_old.keys()))
#
# for i, k in enumerate(model_new_keys):
#     if k not in model_old_keys:
#         print(model_old_keys[i])
#         print(k)

# save as the new checkpoint
torch.save(model_old, '/home/jhju/models/fidt5-base-nq/pytorch_model.bin')