Spaces:
Runtime error
Runtime error
import importlib | |
import torch | |
import yaml | |
from omegaconf import OmegaConf | |
from taming.models.vqgan import VQModel | |
def load_config(config_path, display=False): | |
config = OmegaConf.load(config_path) | |
if display: | |
print(yaml.dump(OmegaConf.to_container(config))) | |
return config | |
def load_vqgan(device, conf_path=None, ckpt_path=None): | |
if conf_path is None: | |
conf_path = "./model_checkpoints/vqgan_only.yaml" | |
config = load_config(conf_path, display=False) | |
model = VQModel(**config.model.params) | |
if ckpt_path is None: | |
ckpt_path = "./model_checkpoints/vqgan_only.pt" | |
sd = torch.load(ckpt_path, map_location=device) | |
if ".ckpt" in ckpt_path: | |
sd = sd["state_dict"] | |
model.load_state_dict(sd, strict=True) | |
model.to(device) | |
del sd | |
return model | |
def reconstruct_with_vqgan(x, model): | |
z, _, [_, _, indices] = model.encode(x) | |
print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}") | |
xrec = model.decode(z) | |
return xrec | |
def get_obj_from_str(string, reload=False): | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def instantiate_from_config(config): | |
if "target" not in config: | |
raise KeyError("Expected key `target` to instantiate.") | |
return get_obj_from_str(config["target"])(**config.get("params", {})) | |
def load_model_from_config(config, sd, gpu=True, eval_mode=True): | |
model = instantiate_from_config(config) | |
if sd is not None: | |
model.load_state_dict(sd) | |
if gpu: | |
model.cuda() | |
if eval_mode: | |
model.eval() | |
return {"model": model} | |
def load_model(config, ckpt, gpu, eval_mode): | |
# load the specified checkpoint | |
if ckpt: | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
global_step = pl_sd["global_step"] | |
print(f"loaded model from global step {global_step}.") | |
else: | |
pl_sd = {"state_dict": None} | |
global_step = None | |
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] | |
return model, global_step | |