RamAnanth1 commited on
Commit
c6e2b5c
1 Parent(s): 77e73c7

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +64 -3
model.py CHANGED
@@ -68,25 +68,31 @@ class Model:
68
 
69
  base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt"
70
  sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth"
 
71
  pidinet_file = model_path+"table5_pidinet.pth"
72
  clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*"
73
 
74
  subprocess.run(shlex.split(f'wget {base_model_file} -O models/sd-v1-4.ckpt'))
75
  subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth'))
 
76
  subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth'))
77
 
78
 
79
  self.model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
80
  current_base = 'sd-v1-4.ckpt'
81
- self.model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
82
- self.model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
83
  net_G = pidinet()
84
  ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
85
  net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
86
  net_G.to(device)
87
- self.sampler = PLMSSampler(self.model)
88
  save_memory=True
89
 
 
 
 
 
90
  @torch.inference_mode()
91
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
92
  global current_base
@@ -152,3 +158,58 @@ class Model:
152
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
153
 
154
  return [im_edge, x_samples_ddim]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt"
70
  sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth"
71
+ pose_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth"
72
  pidinet_file = model_path+"table5_pidinet.pth"
73
  clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*"
74
 
75
  subprocess.run(shlex.split(f'wget {base_model_file} -O models/sd-v1-4.ckpt'))
76
  subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth'))
77
+ subprocess.run(shlex.split(f'wget {pose_adapter_file} -O models/t2iadapter_keypose_sd14v1.pth'))
78
  subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth'))
79
 
80
 
81
  self.model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
82
  current_base = 'sd-v1-4.ckpt'
83
+ self.model_ad_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
84
+ self.model_ad_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
85
  net_G = pidinet()
86
  ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
87
  net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
88
  net_G.to(device)
89
+ self.sampler= PLMSSampler(self.model)
90
  save_memory=True
91
 
92
+ self.model_ad_pose = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
93
+ self.model_ad_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth"))
94
+
95
+
96
  @torch.inference_mode()
97
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
98
  global current_base
 
158
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
159
 
160
  return [im_edge, x_samples_ddim]
161
+
162
+ @torch.inference_mode()
163
+ def process_pose(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
164
+ global current_base
165
+ device = 'cuda'
166
+ # if current_base != base_model:
167
+ # ckpt = os.path.join("models", base_model)
168
+ # pl_sd = torch.load(ckpt, map_location="cpu")
169
+ # if "state_dict" in pl_sd:
170
+ # sd = pl_sd["state_dict"]
171
+ # else:
172
+ # sd = pl_sd
173
+ # model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
174
+ # current_base = base_model
175
+ con_strength = int((1-con_strength)*50)
176
+ if fix_sample == 'True':
177
+ seed_everything(42)
178
+
179
+ im = cv2.resize(input_img,(512,512))
180
+ pose = img2tensor(im, bgr2rgb=True, float32=True)/255.
181
+ pose = pose.unsqueeze(0)
182
+
183
+ im_pose = tensor2img(pose)
184
+
185
+ c = self.model.get_learned_conditioning([prompt])
186
+ nc = self.model.get_learned_conditioning([neg_prompt])
187
+
188
+ with torch.no_grad():
189
+ # extract condition features
190
+ features_adapter = self.model_ad_pose(im.to(device))
191
+
192
+ shape = [4, 64, 64]
193
+
194
+ # sampling
195
+ samples_ddim, _ = self.sampler.sample(S=50,
196
+ conditioning=c,
197
+ batch_size=1,
198
+ shape=shape,
199
+ verbose=False,
200
+ unconditional_guidance_scale=scale,
201
+ unconditional_conditioning=nc,
202
+ eta=0.0,
203
+ x_T=None,
204
+ features_adapter1=features_adapter,
205
+ mode = 'pose',
206
+ con_strength = con_strength)
207
+
208
+ x_samples_ddim = self.model.decode_first_stage(samples_ddim)
209
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
210
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
211
+ x_samples_ddim = 255.*x_samples_ddim
212
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
213
+
214
+ return [im_pose, x_samples_ddim]
215
+