Spaces:
Runtime error
Runtime error
import torch | |
from collections import OrderedDict | |
def extract(ckpt): | |
model = ckpt["model"] | |
opt = OrderedDict() | |
opt["weight"] = {key: value for key, value in model.items() if "enc_q" not in key} | |
return opt | |
def model_fusion(model_name, pth_path_1, pth_path_2): | |
ckpt1 = torch.load(pth_path_1, map_location="cpu") | |
ckpt2 = torch.load(pth_path_2, map_location="cpu") | |
if "model" in ckpt1: | |
ckpt1 = extract(ckpt1) | |
else: | |
ckpt1 = ckpt1["weight"] | |
if "model" in ckpt2: | |
ckpt2 = extract(ckpt2) | |
else: | |
ckpt2 = ckpt2["weight"] | |
if sorted(ckpt1.keys()) != sorted(ckpt2.keys()): | |
return "Fail to merge the models. The model architectures are not the same." | |
opt = OrderedDict( | |
weight={ | |
key: 1 * value.float() + (1 - 1) * ckpt2[key].float() | |
for key, value in ckpt1.items() | |
} | |
) | |
opt["info"] = f"Model fusion of {pth_path_1} and {pth_path_2}" | |
torch.save(opt, f"logs/{model_name}.pth") | |
print(f"Model fusion of {pth_path_1} and {pth_path_2} is done.") | |