Spaces:
Running
Running
import os | |
import sys | |
import logging | |
import torch | |
MATPLOTLIB_FLAG = False | |
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) | |
logger = logging | |
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): | |
assert os.path.isfile(checkpoint_path) | |
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') | |
iteration = checkpoint_dict['iteration'] | |
learning_rate = checkpoint_dict['learning_rate'] | |
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: | |
optimizer.load_state_dict(checkpoint_dict['optimizer']) | |
elif optimizer is None and not skip_optimizer: | |
# else: #Disable this line if Infer ,and enable the line upper | |
new_opt_dict = optimizer.state_dict() | |
new_opt_dict_params = new_opt_dict['param_groups'][0]['params'] | |
new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups'] | |
new_opt_dict['param_groups'][0]['params'] = new_opt_dict_params | |
optimizer.load_state_dict(new_opt_dict) | |
saved_state_dict = checkpoint_dict['model'] | |
if hasattr(model, 'module'): | |
state_dict = model.module.state_dict() | |
else: | |
state_dict = model.state_dict() | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
try: | |
# assert "emb_g" not in k | |
# print("load", k) | |
new_state_dict[k] = saved_state_dict[k] | |
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) | |
except: | |
# For upgrading from the old version | |
if "ja_bert_proj" in k: | |
v = torch.zeros_like(v) | |
logger.warning( | |
f"If you are using an older version of the model, you should add the parameter \"legacy\":true to the data of the model's config.json") | |
logger.error(f"{k} is not in the checkpoint") | |
new_state_dict[k] = v | |
if hasattr(model, 'module'): | |
model.module.load_state_dict(new_state_dict, strict=False) | |
else: | |
model.load_state_dict(new_state_dict, strict=False) | |
# print("load ") | |
logger.info("Loaded checkpoint '{}' (iteration {})".format( | |
checkpoint_path, iteration)) | |
return model, optimizer, learning_rate, iteration | |