Spaces:
Runtime error
Runtime error
import os | |
import yaml | |
import logging | |
import torch | |
def parse_configs(config: str): | |
""" Parse the config file and return a dictionary of configs | |
:param config: path to the config file | |
:returns: | |
""" | |
if not os.path.exists(config): | |
logging.error('Cannot find the config file: {}'.format(config)) | |
exit() | |
with open(config, 'r') as stream: | |
try: | |
configs=yaml.safe_load(stream) | |
return configs | |
except yaml.YAMLError as exc: | |
logging.error(exc) | |
return {} | |
def load_model(config: str, weight: str, model_def, device): | |
""" Load the model from the config file and the weight file | |
:param config: path to the config file | |
:param weight: path to the weight file | |
:param model_def: model class definition | |
:param device: pytorch device | |
:returns: | |
""" | |
assert os.path.exists(weight), 'Cannot find the weight file: {}'.format(weight) | |
assert os.path.exists(config), 'Cannot find the config file: {}'.format(config) | |
opt = parse_configs(config) | |
model = model_def(opt) | |
cp = torch.load(weight) | |
models = model.get_models() | |
for k, m in models.items(): | |
m.load_state_dict(cp[k]) | |
m.to(device) | |
model.set_models(models) | |
return model | |