import os import os.path as osp import cv2 import numpy as np import torch from basicsr.utils import img2tensor, tensor2img from pytorch_lightning import seed_everything from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.adapter import Adapter from ldm.util import instantiate_from_config from model_edge import pidinet import gradio as gr from omegaconf import OmegaConf def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: sd = pl_sd model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) # if len(m) > 0 and verbose: # print("missing keys:") # print(m) # if len(u) > 0 and verbose: # print("unexpected keys:") # print(u) model.cuda() model.eval() return model