RamAnanth1 commited on
Commit
cb8d2af
1 Parent(s): 600ccab

Add seg adapter

Browse files
Files changed (1) hide show
  1. model.py +69 -2
model.py CHANGED
@@ -177,7 +177,8 @@ class Model:
177
  base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt"
178
  base_model_file_anything = "https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.ckpt"
179
  sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth"
180
- pose_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth"
 
181
  pidinet_file = model_path+"table5_pidinet.pth"
182
  clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*"
183
 
@@ -185,6 +186,7 @@ class Model:
185
  subprocess.run(shlex.split(f'wget {base_model_file_anything} -O models/anything-v4.0-pruned.ckpt'))
186
  subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth'))
187
  subprocess.run(shlex.split(f'wget {pose_adapter_file} -O models/t2iadapter_keypose_sd14v1.pth'))
 
188
  subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth'))
189
 
190
 
@@ -204,6 +206,9 @@ class Model:
204
  self.model_ad_pose = Adapter(cin=int(3*64),channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
205
  self.model_ad_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth"))
206
 
 
 
 
207
 
208
  @torch.inference_mode()
209
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
@@ -370,4 +375,66 @@ class Model:
370
  x_samples_ddim = 255.*x_samples_ddim
371
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
372
 
373
- return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  base_model_file = "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt"
178
  base_model_file_anything = "https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned.ckpt"
179
  sketch_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_sketch_sd14v1.pth"
180
+ pose_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_keypose_sd14v1.pth"
181
+ seg_adapter_file = "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_seg_sd14v1.pth"
182
  pidinet_file = model_path+"table5_pidinet.pth"
183
  clip_file = "https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/*"
184
 
 
186
  subprocess.run(shlex.split(f'wget {base_model_file_anything} -O models/anything-v4.0-pruned.ckpt'))
187
  subprocess.run(shlex.split(f'wget {sketch_adapter_file} -O models/t2iadapter_sketch_sd14v1.pth'))
188
  subprocess.run(shlex.split(f'wget {pose_adapter_file} -O models/t2iadapter_keypose_sd14v1.pth'))
189
+ subprocess.run(shlex.split(f'wget {seg_adapter_file} -O models/t2iadapter_seg_sd14v1.pth'))
190
  subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth'))
191
 
192
 
 
206
  self.model_ad_pose = Adapter(cin=int(3*64),channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
207
  self.model_ad_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth"))
208
 
209
+ self.model_ad_seg = Adapter(cin=int(3*64),channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
210
+ self.model_ad_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth""))
211
+
212
 
213
  @torch.inference_mode()
214
  def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
 
375
  x_samples_ddim = 255.*x_samples_ddim
376
  x_samples_ddim = x_samples_ddim.astype(np.uint8)
377
 
378
+ return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim]
379
+
380
+
381
+ @torch.inference_mode()
382
+ def process_seg(self, input_img, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
383
+ global current_base
384
+ device = 'cuda'
385
+ if base_model == 'sd-v1-4.ckpt':
386
+ model = self.model
387
+ sampler = self.sampler
388
+ else:
389
+ model = self.model_anything
390
+ sampler = self.sampler_anything
391
+ # if current_base != base_model:
392
+ # ckpt = os.path.join("models", base_model)
393
+ # pl_sd = torch.load(ckpt, map_location="cpu")
394
+ # if "state_dict" in pl_sd:
395
+ # sd = pl_sd["state_dict"]
396
+ # else:
397
+ # sd = pl_sd
398
+ # model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
399
+ # current_base = base_model
400
+ con_strength = int((1-con_strength)*50)
401
+ if fix_sample == 'True':
402
+ seed_everything(42)
403
+
404
+ im = cv2.resize(input_img,(512,512))
405
+ mask = im.copy()
406
+ mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.
407
+ mask = mask.unsqueeze(0)
408
+
409
+ im_mask = tensor2img(mask)
410
+
411
+ c = model.get_learned_conditioning([prompt])
412
+ nc = model.get_learned_conditioning([neg_prompt])
413
+
414
+ with torch.no_grad():
415
+ # extract condition features
416
+ features_adapter = self.model_ad_seg(mask.to(device))
417
+
418
+ shape = [4, 64, 64]
419
+
420
+ # sampling
421
+ samples_ddim, _ = sampler.sample(S=50,
422
+ conditioning=c,
423
+ batch_size=1,
424
+ shape=shape,
425
+ verbose=False,
426
+ unconditional_guidance_scale=scale,
427
+ unconditional_conditioning=nc,
428
+ eta=0.0,
429
+ x_T=None,
430
+ features_adapter1=features_adapter,
431
+ mode = 'mask',
432
+ con_strength = con_strength)
433
+
434
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
435
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
436
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).cpu().numpy()[0]
437
+ x_samples_ddim = 255.*x_samples_ddim
438
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
439
+
440
+ return [im_edge, x_samples_ddim]