T2I-Adapter / model.py
RamAnanth1's picture
Create model.py
7da7768
raw
history blame
1.05 kB
import os
import os.path as osp
import cv2
import numpy as np
import torch
from basicsr.utils import img2tensor, tensor2img
from pytorch_lightning import seed_everything
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.adapter import Adapter
from ldm.util import instantiate_from_config
from model_edge import pidinet
import gradio as gr
from omegaconf import OmegaConf
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
# if len(m) > 0 and verbose:
# print("missing keys:")
# print(m)
# if len(u) > 0 and verbose:
# print("unexpected keys:")
# print(u)
model.cuda()
model.eval()
return model