T2I-Adapter / model.py
RamAnanth1's picture
Update model.py
e102afc
raw
history blame
No virus
10.7 kB
import os
import os.path as osp
import cv2
import numpy as np
import torch
from basicsr.utils import img2tensor, tensor2img
from pytorch_lightning import seed_everything
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.adapter import Adapter
from ldm.util import instantiate_from_config
from model_edge import pidinet
import gradio as gr
from omegaconf import OmegaConf
import pathlib
import random
import shlex
import subprocess
import sys
import mmcv
from mmdet.apis import inference_detector, init_detector
from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10],
[1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
[255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0],
[0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0],
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
[51, 153, 255], [51, 153, 255], [51, 153, 255]]
sys.path.append('T2I-Adapter')
config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
# if len(m) > 0 and verbose:
# print("missing keys:")
# print(m)
# if len(u) > 0 and verbose:
# print("unexpected keys:")
# print(u)
model.cuda()
model.eval()
return model
class Model:
def __init__(self,
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
model_dir: str = 'models',
use_lightweight: bool = True):
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
self.model_dir = pathlib.Path(model_dir)
self.download_models()
def download_models(self) -> None:
self.model_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda'
config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")
config.model.params.cond_stage_config.params.device = device
base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt"
base_model_file_anything = "https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.ckpt"
sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth"
pose_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth"
pidinet_file = model_path+"table5_pidinet.pth"
clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*"
subprocess.run(shlex.split(f'wget {base_model_file} -O models/sd-v1-4.ckpt'))
subprocess.run(shlex.split(f'wget {base_model_file_anything} -O models/anything-v4.0-pruned.ckpt'))
subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth'))
subprocess.run(shlex.split(f'wget {pose_adapter_file} -O models/t2iadapter_keypose_sd14v1.pth'))
subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth'))
self.model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
self.model_anything = load_model_from_config(config, "models/anything-v4.0-pruned.ckpt").to(device)
current_base = 'sd-v1-4.ckpt'
self.model_ad_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
self.model_ad_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
net_G = pidinet()
ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
net_G.to(device)
self.sampler= PLMSSampler(self.model)
self.sampler_anything= PLMSSampler(self.model_anything)
save_memory=True
self.model_ad_pose = Adapter(cin=int(3*64),channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
self.model_ad_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth"))
@torch.inference_mode()
def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
global current_base
device = 'cuda'
if base_model == 'sd-v1-4.ckpt':
model = self.model
else:
model = self.model_anything
# if current_base != base_model:
# ckpt = os.path.join("models", base_model)
# pl_sd = torch.load(ckpt, map_location="cpu")
# if "state_dict" in pl_sd:
# sd = pl_sd["state_dict"]
# else:
# sd = pl_sd
# model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
# current_base = base_model
con_strength = int((1-con_strength)*50)
if fix_sample == 'True':
seed_everything(42)
im = cv2.resize(input_img,(512,512))
if type_in == 'Sketch':
# net_G = net_G.cpu()
if color_back == 'White':
im = 255-im
im_edge = im.copy()
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
# edge = 1-edge # for white background
im = im>0.5
im = im.float()
elif type_in == 'Image':
im = img2tensor(im).unsqueeze(0)/255.
im = net_G(im.to(device))[-1]
im = im>0.5
im = im.float()
im_edge = tensor2img(im)
c = model.get_learned_conditioning([prompt])
nc = model.get_learned_conditioning([neg_prompt])
with torch.no_grad():
# extract condition features
features_adapter = self.model_ad_sketch(im.to(device))
shape = [4, 64, 64]
# sampling
samples_ddim, _ = self.sampler.sample(S=50,
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=nc,
eta=0.0,
x_T=None,
features_adapter1=features_adapter,
mode = 'sketch',
con_strength = con_strength)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
x_samples_ddim = 255.*x_samples_ddim
x_samples_ddim = x_samples_ddim.astype(np.uint8)
return [im_edge, x_samples_ddim]
@torch.inference_mode()
def process_pose(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
global current_base
device = 'cuda'
if base_model == 'sd-v1-4.ckpt':
model = self.model
else:
model = self.model_anything
# if current_base != base_model:
# ckpt = os.path.join("models", base_model)
# pl_sd = torch.load(ckpt, map_location="cpu")
# if "state_dict" in pl_sd:
# sd = pl_sd["state_dict"]
# else:
# sd = pl_sd
# model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
# current_base = base_model
con_strength = int((1-con_strength)*50)
if fix_sample == 'True':
seed_everything(42)
im = cv2.resize(input_img,(512,512))
pose = img2tensor(im, bgr2rgb=True, float32=True)/255.
pose = pose.unsqueeze(0)
im_pose = tensor2img(pose)
c = model.get_learned_conditioning([prompt])
nc = model.get_learned_conditioning([neg_prompt])
with torch.no_grad():
# extract condition features
features_adapter = self.model_ad_pose(pose.to(device))
shape = [4, 64, 64]
# sampling
samples_ddim, _ = self.sampler.sample(S=50,
conditioning=c,
batch_size=1,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=nc,
eta=0.0,
x_T=None,
features_adapter1=features_adapter,
mode = 'sketch',
con_strength = con_strength)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
x_samples_ddim = 255.*x_samples_ddim
x_samples_ddim = x_samples_ddim.astype(np.uint8)
return [im_pose, x_samples_ddim]