Artrajz's picture
update
b0f5083
raw
history blame
3.09 kB
import os
import sys
import logging
import torch
MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, version=None):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
learning_rate = checkpoint_dict['learning_rate']
if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
elif optimizer is None and not skip_optimizer:
# else: #Disable this line if Infer ,and enable the line upper
new_opt_dict = optimizer.state_dict()
new_opt_dict_params = new_opt_dict['param_groups'][0]['params']
new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups']
new_opt_dict['param_groups'][0]['params'] = new_opt_dict_params
optimizer.load_state_dict(new_opt_dict)
saved_state_dict = checkpoint_dict['model']
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
try:
# assert "emb_g" not in k
# print("load", k)
new_state_dict[k] = saved_state_dict[k]
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
except:
# Handle legacy model versions and provide appropriate warnings
if "ja_bert_proj" in k:
v = torch.zeros_like(v)
if version is None:
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"")
elif "flow.flows.0.enc.attn_layers.3" in k:
logger.error(f"{k} is not in the checkpoint")
logger.warning(
f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.")
else:
logger.error(f"{k} is not in the checkpoint")
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict, strict=False)
else:
model.load_state_dict(new_state_dict, strict=False)
# print("load ")
logger.info("Loaded checkpoint '{}' (iteration {})".format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
def process_legacy_versions(hps):
version = getattr(hps, "version", getattr(hps.data, "version", None))
if version:
prefix = version[0].lower()
if prefix == "v":
version = version[1:]
return version