import os import os.path as osp import cv2 import numpy as np import torch from basicsr.utils import img2tensor, tensor2img from pytorch_lightning import seed_everything from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.adapter import Adapter from ldm.util import instantiate_from_config from model_edge import pidinet import gradio as gr from omegaconf import OmegaConf import pathlib import random import shlex import subprocess import sys sys.path.append('T2I-Adapter') config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/' model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/' def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: sd = pl_sd model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) # if len(m) > 0 and verbose: # print("missing keys:") # print(m) # if len(u) > 0 and verbose: # print("unexpected keys:") # print(u) model.cuda() model.eval() return model class Model: def __init__(self, model_config_path: str = 'ControlNet/models/cldm_v15.yaml', model_dir: str = 'models', use_lightweight: bool = True): self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.model_dir = pathlib.Path(model_dir) self.download_models() def download_models(self) -> None: self.model_dir.mkdir(exist_ok=True, parents=True) device = 'cuda' config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml") config.model.params.cond_stage_config.params.device = device base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt" base_model_file_anything = "https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.ckpt" sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth" pose_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth" pidinet_file = model_path+"table5_pidinet.pth" clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*" subprocess.run(shlex.split(f'wget {base_model_file} -O models/sd-v1-4.ckpt')) subprocess.run(shlex.split(f'wget {base_model_file_anything} -O models/anything-v4.0-pruned.ckpt')) subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth')) subprocess.run(shlex.split(f'wget {pose_adapter_file} -O models/t2iadapter_keypose_sd14v1.pth')) subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth')) self.model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device) self.model_anything = load_model_from_config(config, "models/anything-v4.0-pruned.ckpt").to(device) current_base = 'sd-v1-4.ckpt' self.model_ad_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) self.model_ad_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth")) net_G = pidinet() ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict'] net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) net_G.to(device) self.sampler= PLMSSampler(self.model) self.sampler_anything= PLMSSampler(self.model_anything) save_memory=True self.model_ad_pose = Adapter(cin=int(3*64),channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) self.model_ad_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth")) @torch.inference_mode() def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model): global current_base device = 'cuda' # if current_base != base_model: # ckpt = os.path.join("models", base_model) # pl_sd = torch.load(ckpt, map_location="cpu") # if "state_dict" in pl_sd: # sd = pl_sd["state_dict"] # else: # sd = pl_sd # model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device) # current_base = base_model con_strength = int((1-con_strength)*50) if fix_sample == 'True': seed_everything(42) im = cv2.resize(input_img,(512,512)) if type_in == 'Sketch': # net_G = net_G.cpu() if color_back == 'White': im = 255-im im_edge = im.copy() im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255. # edge = 1-edge # for white background im = im>0.5 im = im.float() elif type_in == 'Image': im = img2tensor(im).unsqueeze(0)/255. im = net_G(im.to(device))[-1] im = im>0.5 im = im.float() im_edge = tensor2img(im) c = self.model.get_learned_conditioning([prompt]) nc = self.model.get_learned_conditioning([neg_prompt]) with torch.no_grad(): # extract condition features features_adapter = self.model_ad_sketch(im.to(device)) shape = [4, 64, 64] # sampling samples_ddim, _ = self.sampler.sample(S=50, conditioning=c, batch_size=1, shape=shape, verbose=False, unconditional_guidance_scale=scale, unconditional_conditioning=nc, eta=0.0, x_T=None, features_adapter1=features_adapter, mode = 'sketch', con_strength = con_strength) x_samples_ddim = self.model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0] x_samples_ddim = 255.*x_samples_ddim x_samples_ddim = x_samples_ddim.astype(np.uint8) return [im_edge, x_samples_ddim] @torch.inference_mode() def process_pose(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model): global current_base device = 'cuda' # if current_base != base_model: # ckpt = os.path.join("models", base_model) # pl_sd = torch.load(ckpt, map_location="cpu") # if "state_dict" in pl_sd: # sd = pl_sd["state_dict"] # else: # sd = pl_sd # model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device) # current_base = base_model con_strength = int((1-con_strength)*50) if fix_sample == 'True': seed_everything(42) im = cv2.resize(input_img,(512,512)) pose = img2tensor(im, bgr2rgb=True, float32=True)/255. pose = pose.unsqueeze(0) im_pose = tensor2img(pose) c = self.model.get_learned_conditioning([prompt]) nc = self.model.get_learned_conditioning([neg_prompt]) with torch.no_grad(): # extract condition features features_adapter = self.model_ad_pose(pose.to(device)) shape = [4, 64, 64] # sampling samples_ddim, _ = self.sampler.sample(S=50, conditioning=c, batch_size=1, shape=shape, verbose=False, unconditional_guidance_scale=scale, unconditional_conditioning=nc, eta=0.0, x_T=None, features_adapter1=features_adapter, mode = 'sketch', con_strength = con_strength) x_samples_ddim = self.model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0] x_samples_ddim = 255.*x_samples_ddim x_samples_ddim = x_samples_ddim.astype(np.uint8) return [im_pose, x_samples_ddim]