Spaces:
Runtime error
Runtime error
File size: 1,313 Bytes
34fb220 62745e5 34fb220 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
import yaml
import logging
import torch
def parse_configs(config: str):
""" Parse the config file and return a dictionary of configs
:param config: path to the config file
:returns:
"""
if not os.path.exists(config):
logging.error('Cannot find the config file: {}'.format(config))
exit()
with open(config, 'r') as stream:
try:
configs=yaml.safe_load(stream)
return configs
except yaml.YAMLError as exc:
logging.error(exc)
return {}
def load_model(config: str, weight: str, model_def, device):
""" Load the model from the config file and the weight file
:param config: path to the config file
:param weight: path to the weight file
:param model_def: model class definition
:param device: pytorch device
:returns:
"""
assert os.path.exists(weight), 'Cannot find the weight file: {}'.format(weight)
assert os.path.exists(config), 'Cannot find the config file: {}'.format(config)
opt = parse_configs(config)
model = model_def(opt)
cp = torch.load(weight, map_location=device)
models = model.get_models()
for k, m in models.items():
m.load_state_dict(cp[k])
m.to(device)
model.set_models(models)
return model
|