RamAnanth1 commited on
Commit
47163f5
1 Parent(s): 98d37ce

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +111 -1
model.py CHANGED
@@ -13,6 +13,16 @@ from model_edge import pidinet
13
  import gradio as gr
14
  from omegaconf import OmegaConf
15
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def load_model_from_config(config, ckpt, verbose=False):
18
  print(f"Loading model from {ckpt}")
@@ -34,4 +44,104 @@ def load_model_from_config(config, ckpt, verbose=False):
34
 
35
  model.cuda()
36
  model.eval()
37
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import gradio as gr
14
  from omegaconf import OmegaConf
15
 
16
+ import pathlib
17
+ import random
18
+ import shlex
19
+ import subprocess
20
+ import sys
21
+
22
+ sys.path.append('T2I-Adapter')
23
+
24
+ config_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/configs/stable-diffusion/'
25
+ model_path = 'https://github.com/TencentARC/T2I-Adapter/raw/main/models/'
26
 
27
  def load_model_from_config(config, ckpt, verbose=False):
28
  print(f"Loading model from {ckpt}")
 
44
 
45
  model.cuda()
46
  model.eval()
47
+ return model
48
+
49
+ class Model:
50
+ def __init__(self,
51
+ model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
52
+ model_dir: str = 'models',
53
+ use_lightweight: bool = True):
54
+ self.device = torch.device(
55
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
56
+ self.model_dir = model_dir
57
+
58
+ self.download_models()
59
+
60
+
61
+
62
+ def download_models(self) -> None:
63
+ self.model_dir.mkdir(exist_ok=True, parents=True)
64
+ device = 'cuda'
65
+ subprocess.run(shlex.split(f'wget {config_path+'test_sketch.yaml'} -O config_sketch.yaml'))
66
+ config = OmegaConf.load("config_sketch.yaml")
67
+ config.model.params.cond_stage_config.params.device = device
68
+
69
+ subprocess.run(shlex.split(f'wget {model_path+"sd-v1-4.ckpt"} -O models/sd-v1-4.ckpt'))
70
+ subprocess.run(shlex.split(f'wget {model_path+"t2iadapter_sketch_sd14v1.pth"} -O models/t2iadapter_sketch_sd14v1.pth'))
71
+ subprocess.run(shlex.split(f'wget {model_path+"table5_pidinet.pth"} -O models/table5_pidinet.pth'))
72
+
73
+ model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
74
+ current_base = 'sd-v1-4.ckpt'
75
+ model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
76
+ model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
77
+ net_G = pidinet()
78
+ ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
79
+ net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
80
+ net_G.to(device)
81
+ sampler = PLMSSampler(model)
82
+ save_memory=True
83
+
84
+ @torch.inference_mode()
85
+ def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
86
+ global current_base
87
+ if current_base != base_model:
88
+ ckpt = os.path.join("models", base_model)
89
+ pl_sd = torch.load(ckpt, map_location="cpu")
90
+ if "state_dict" in pl_sd:
91
+ sd = pl_sd["state_dict"]
92
+ else:
93
+ sd = pl_sd
94
+ model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
95
+ current_base = base_model
96
+ con_strength = int((1-con_strength)*50)
97
+ if fix_sample == 'True':
98
+ seed_everything(42)
99
+
100
+ im = cv2.resize(input_img,(512,512))
101
+
102
+ if type_in == 'Sketch':
103
+ # net_G = net_G.cpu()
104
+ if color_back == 'White':
105
+ im = 255-im
106
+ im_edge = im.copy()
107
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
108
+ # edge = 1-edge # for white background
109
+ im = im>0.5
110
+ im = im.float()
111
+ elif type_in == 'Image':
112
+ im = img2tensor(im).unsqueeze(0)/255.
113
+ im = net_G(im.to(device))[-1]
114
+ im = im>0.5
115
+ im = im.float()
116
+ im_edge = tensor2img(im)
117
+
118
+ c = model.get_learned_conditioning([prompt])
119
+ nc = model.get_learned_conditioning([neg_prompt])
120
+
121
+ with torch.no_grad():
122
+ # extract condition features
123
+ features_adapter = model_ad(im.to(device))
124
+
125
+ shape = [4, 64, 64]
126
+
127
+ # sampling
128
+ samples_ddim, _ = sampler.sample(S=50,
129
+ conditioning=c,
130
+ batch_size=1,
131
+ shape=shape,
132
+ verbose=False,
133
+ unconditional_guidance_scale=scale,
134
+ unconditional_conditioning=nc,
135
+ eta=0.0,
136
+ x_T=None,
137
+ features_adapter1=features_adapter,
138
+ mode = 'sketch',
139
+ con_strength = con_strength)
140
+
141
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
142
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
143
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
144
+ x_samples_ddim = 255.*x_samples_ddim
145
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
146
+
147
+ return [im_edge, x_samples_ddim]