Spaces:
Runtime error
Runtime error
import os | |
import os.path as osp | |
import cv2 | |
import numpy as np | |
import torch | |
from basicsr.utils import img2tensor, tensor2img | |
from pytorch_lightning import seed_everything | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.modules.encoders.adapter import Adapter | |
from ldm.util import instantiate_from_config | |
from model_edge import pidinet | |
import gradio as gr | |
from omegaconf import OmegaConf | |
import pathlib | |
import random | |
import shlex | |
import subprocess | |
import sys | |
sys.path.append('T2I-Adapter') | |
config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/' | |
model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/' | |
def load_model_from_config(config, ckpt, verbose=False): | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
if "state_dict" in pl_sd: | |
sd = pl_sd["state_dict"] | |
else: | |
sd = pl_sd | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
# if len(m) > 0 and verbose: | |
# print("missing keys:") | |
# print(m) | |
# if len(u) > 0 and verbose: | |
# print("unexpected keys:") | |
# print(u) | |
model.cuda() | |
model.eval() | |
return model | |
class Model: | |
def __init__(self, | |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml', | |
model_dir: str = 'models', | |
use_lightweight: bool = True): | |
self.device = torch.device( | |
'cuda:0' if torch.cuda.is_available() else 'cpu') | |
self.model_dir = pathlib.Path(model_dir) | |
self.download_models() | |
def download_models(self) -> None: | |
self.model_dir.mkdir(exist_ok=True, parents=True) | |
device = 'cuda' | |
config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml) | |
config.model.params.cond_stage_config.params.device = device | |
base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt" | |
sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth" | |
pidinet_file = model_path+"table5_pidinet.pth" | |
clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*" | |
subprocess.run(shlex.split(f'wget {base_model_file} -O models/sd-v1-4.ckpt')) | |
subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth')) | |
subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth')) | |
model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device) | |
current_base = 'sd-v1-4.ckpt' | |
model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) | |
model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth")) | |
net_G = pidinet() | |
ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict'] | |
net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) | |
net_G.to(device) | |
sampler = PLMSSampler(model) | |
save_memory=True | |
@torch.inference_mode() | |
def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model): | |
global current_base | |
if current_base != base_model: | |
ckpt = os.path.join("models", base_model) | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "state_dict" in pl_sd: | |
sd = pl_sd["state_dict"] | |
else: | |
sd = pl_sd | |
model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device) | |
current_base = base_model | |
con_strength = int((1-con_strength)*50) | |
if fix_sample == 'True': | |
seed_everything(42) | |
im = cv2.resize(input_img,(512,512)) | |
if type_in == 'Sketch': | |
# net_G = net_G.cpu() | |
if color_back == 'White': | |
im = 255-im | |
im_edge = im.copy() | |
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255. | |
# edge = 1-edge # for white background | |
im = im>0.5 | |
im = im.float() | |
elif type_in == 'Image': | |
im = img2tensor(im).unsqueeze(0)/255. | |
im = net_G(im.to(device))[-1] | |
im = im>0.5 | |
im = im.float() | |
im_edge = tensor2img(im) | |
c = model.get_learned_conditioning([prompt]) | |
nc = model.get_learned_conditioning([neg_prompt]) | |
with torch.no_grad(): | |
# extract condition features | |
features_adapter = model_ad(im.to(device)) | |
shape = [4, 64, 64] | |
# sampling | |
samples_ddim, _ = sampler.sample(S=50, | |
conditioning=c, | |
batch_size=1, | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=nc, | |
eta=0.0, | |
x_T=None, | |
features_adapter1=features_adapter, | |
mode = 'sketch', | |
con_strength = con_strength) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0] | |
x_samples_ddim = 255.*x_samples_ddim | |
x_samples_ddim = x_samples_ddim.astype(np.uint8) | |
return [im_edge, x_samples_ddim] | |