Artrajz's picture
update
14e19a5
raw
history blame
2.29 kB
import os
import sys
import logging
import torch
MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
learning_rate = checkpoint_dict['learning_rate']
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
elif optimizer is None and not skip_optimizer:
# else: #Disable this line if Infer ,and enable the line upper
new_opt_dict = optimizer.state_dict()
new_opt_dict_params = new_opt_dict['param_groups'][0]['params']
new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups']
new_opt_dict['param_groups'][0]['params'] = new_opt_dict_params
optimizer.load_state_dict(new_opt_dict)
saved_state_dict = checkpoint_dict['model']
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
# assert "emb_g" not in k
# print("load", k)
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
except:
# For upgrading from the old version
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
logger.warning(
f"If you are using an older version of the model, you should add the parameter \"legacy\":true to the data of the model's config.json")
logger.error(f"{k} is not in the checkpoint")
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
# print("load ")
logger.info("Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration