|
import torch |
|
|
|
from transformers import T5ForConditionalGeneration |
|
|
|
|
|
model_old = torch.load('models/ckpt/fidt5-base-nq/pytorch_model.bin', map_location='cpu') |
|
|
|
model_new = T5ForConditionalGeneration.from_pretrained('t5-base') |
|
|
|
|
|
model_new_keys = sorted(list(model_new.state_dict().keys())) |
|
model_old_keys = sorted(list(model_old.keys())) |
|
|
|
|
|
for k in model_old_keys: |
|
k_prime = k.replace('encoder.encoder', 'encoder') |
|
k_prime = k_prime.replace('module.layer', 'layer') |
|
model_old[k_prime] = model_old.pop(k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(model_old, '/home/jhju/models/fidt5-base-nq/pytorch_model.bin') |
|
|