T2I-Adapter / model.py
RamAnanth1's picture
Update model.py
aff3d3b
raw
history blame
No virus
9 kB
import os
import os.path as osp
import cv2
import numpy as np
import torch
from basicsr.utils import img2tensor, tensor2img
from pytorch_lightning import seed_everything
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.adapter import Adapter
from ldm.util import instantiate_from_config
from model_edge import pidinet
import gradio as gr
from omegaconf import OmegaConf
import pathlib
import random
import shlex
import subprocess
import sys
sys.path.append('T2I-Adapter')
config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
# if len(m) > 0 and verbose:
# print("missing keys:")
# print(m)
# if len(u) > 0 and verbose:
# print("unexpected keys:")
# print(u)
model.cuda()
model.eval()
return model
class Model:
def __init__(self,
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
model_dir: str = 'models',
use_lightweight: bool = True):
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
self.model_dir = pathlib.Path(model_dir)
self.download_models()
def download_models(self) -> None:
self.model_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda'
config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")
config.model.params.cond_stage_config.params.device = device
base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt"
sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth"
pose_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth"
pidinet_file = model_path+"table5_pidinet.pth"
clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*"
subprocess.run(shlex.split(f'wget {base_model_file} -O models/sd-v1-4.ckpt'))
subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth'))
subprocess.run(shlex.split(f'wget {pose_adapter_file} -O models/t2iadapter_keypose_sd14v1.pth'))
subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth'))
self.model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
current_base = 'sd-v1-4.ckpt'
self.model_ad_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
self.model_ad_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
net_G = pidinet()
ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
net_G.to(device)
self.sampler= PLMSSampler(self.model)
save_memory=True
self.model_ad_pose = Adapter(cin=int(3*64),channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
self.model_ad_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth"))
@torch.inference_mode()
def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
global current_base
device = 'cuda'
# if current_base != base_model:
# ckpt = os.path.join("models", base_model)
# pl_sd = torch.load(ckpt, map_location="cpu")
# if "state_dict" in pl_sd:
# sd = pl_sd["state_dict"]
# else:
# sd = pl_sd
# model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
# current_base = base_model
con_strength = int((1-con_strength)*50)
if fix_sample == 'True':
seed_everything(42)
im = cv2.resize(input_img,(512,512))
if type_in == 'Sketch':
# net_G = net_G.cpu()
if color_back == 'White':
im = 255-im
im_edge = im.copy()
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
# edge = 1-edge # for white background
im = im>0.5
im = im.float()
elif type_in == 'Image':
im = img2tensor(im).unsqueeze(0)/255.
im = net_G(im.to(device))[-1]
im = im>0.5
im = im.float()
im_edge = tensor2img(im)
c = self.model.get_learned_conditioning([prompt])
nc = self.model.get_learned_conditioning([neg_prompt])
with torch.no_grad():
# extract condition features
features_adapter = self.model_ad_sketch(im.to(device))
shape = [4, 64, 64]
# sampling
samples_ddim, _ = self.sampler.sample(S=50,
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=nc,
eta=0.0,
x_T=None,
features_adapter1=features_adapter,
mode = 'sketch',
con_strength = con_strength)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
x_samples_ddim = 255.*x_samples_ddim
x_samples_ddim = x_samples_ddim.astype(np.uint8)
return [im_edge, x_samples_ddim]
@torch.inference_mode()
def process_pose(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
global current_base
device = 'cuda'
# if current_base != base_model:
# ckpt = os.path.join("models", base_model)
# pl_sd = torch.load(ckpt, map_location="cpu")
# if "state_dict" in pl_sd:
# sd = pl_sd["state_dict"]
# else:
# sd = pl_sd
# model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
# current_base = base_model
con_strength = int((1-con_strength)*50)
if fix_sample == 'True':
seed_everything(42)
im = cv2.resize(input_img,(512,512))
pose = img2tensor(im, bgr2rgb=True, float32=True)/255.
pose = pose.unsqueeze(0)
im_pose = tensor2img(pose)
c = self.model.get_learned_conditioning([prompt])
nc = self.model.get_learned_conditioning([neg_prompt])
with torch.no_grad():
# extract condition features
features_adapter = self.model_ad_pose(im.to(device))
shape = [4, 64, 64]
# sampling
samples_ddim, _ = self.sampler.sample(S=50,
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=nc,
eta=0.0,
x_T=None,
features_adapter1=features_adapter,
mode = 'pose',
con_strength = con_strength)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
x_samples_ddim = 255.*x_samples_ddim
x_samples_ddim = x_samples_ddim.astype(np.uint8)
return [im_pose, x_samples_ddim]