Geonmo commited on
Commit
14e037d
1 Parent(s): 62ab537

support CPU

Browse files
Files changed (1) hide show
  1. app.py +37 -24
app.py CHANGED
@@ -162,7 +162,7 @@ class GraphitPipeline(StableDiffusionInstructPix2PixPipeline):
162
 
163
  # 2. Encode input prompt
164
  cond_embeds = torch.cat([image_cond_embeds, negative_image_cond_embeds])
165
- cond_embeds = einops.repeat(cond_embeds, 'b n d -> (b num) n d', num=num_images_per_prompt).to(torch.float16)
166
  prompt_embeds = cond_embeds
167
 
168
  # 3. Preprocess image
@@ -312,38 +312,43 @@ class CustomRealESRGAN(RealESRGAN):
312
 
313
  def build_models(args):
314
  # Load scheduler, tokenizer and models.
 
 
315
 
316
  model_path = 'navervision/Graphit-SD'
317
  unet = UNet2DConditionModel.from_pretrained(
318
- model_path, torch_dtype=torch.float16,
319
  )
320
 
321
  vae_name = 'stabilityai/sd-vae-ft-ema'
322
- vae = AutoencoderKL.from_pretrained(vae_name, torch_dtype=torch.float16)
323
 
324
  model_name = 'timbrooks/instruct-pix2pix'
325
- pipe = GraphitPipeline.from_pretrained(model_name, torch_dtype=torch.float16, safety_checker=None,
326
  unet = unet,
327
  vae = vae,
328
  )
329
- pipe = pipe.to('cuda:0')
330
 
331
  ## load CompoDiff
332
  compodiff_model, clip_model, clip_preprocess, clip_tokenizer = compodiff.build_model()
333
- compodiff_model, clip_model = compodiff_model.to('cuda:0'), clip_model.to('cuda:0')
 
 
 
334
 
335
  ## load third-party models
336
  model_name = 'Intel/dpt-large'
337
  depth_preprocess = DPTFeatureExtractor.from_pretrained(model_name)
338
- depth_predictor = DPTForDepthEstimation.from_pretrained(model_name, torch_dtype=torch.float16)
339
- depth_predictor = depth_predictor.to('cuda:0')
340
 
341
  if not os.path.exists('./third_party/remover_fast.pth'):
342
  model_file_url = hf_hub_url(repo_id='Geonmo/remover_fast', filename='remover_fast.pth')
343
  cached_download(model_file_url, cache_dir='./third_party', force_filename='remover_fast.pth')
344
- remover = Remover(fast=True, jit=False, device='cuda:0', ckpt='./third_party/remover_fast.pth')
345
 
346
- sr_model = CustomRealESRGAN('cuda:0', scale=2)
347
  sr_model.load_weights('./third_party/RealESRGAN_x2.pth', download=True)
348
 
349
  dataset = datasets.load_dataset("FredZhang7/stable-diffusion-prompts-2.47M")
@@ -361,28 +366,31 @@ def build_models(args):
361
  'remover': remover,
362
  'sr_model': sr_model,
363
  'prompt_candidates': prompts,
 
 
364
  }
365
  return model_dict
366
 
367
 
368
  def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_text_scale, mask, random_seed):
 
369
  text_token_dict = model_dict['clip_tokenizer'](text=text_input, return_tensors='pt', padding='max_length', truncation=True)
370
- text_tokens, text_attention_mask = text_token_dict['input_ids'].to('cuda:0'), text_token_dict['attention_mask'].to('cuda:0')
371
 
372
  negative_text_token_dict = model_dict['clip_tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
373
- negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to('cuda:0'), text_token_dict['attention_mask'].to('cuda:0')
374
 
375
  with torch.no_grad():
376
  if image is None:
377
- image_cond = torch.zeros([1,1,768]).to('cuda:0')
378
- mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to('cuda:0').unsqueeze(0)
379
  else:
380
  image_source = image.resize((512, 512))
381
- image_source = model_dict['clip_preprocess'](image_source, return_tensors='pt')['pixel_values'].to('cuda:0')
382
  mask = mask.resize((512, 512))
383
  mask = model_dict['clip_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values']
384
  mask = mask[:,:1,:,:]
385
- mask = (mask > 0.5).float().to('cuda:0')
386
  image_source = image_source * (1 - mask)
387
  image_cond = model_dict['clip_model'].encode_images(image_source)
388
  mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
@@ -396,7 +404,9 @@ def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_tex
396
 
397
 
398
  def generate_depth_map(image, height, width):
399
- depth_inputs = {k: v.to('cuda:0', dtype=torch.float16) for k, v in model_dict['depth_preprocess'](images=image, return_tensors='pt').items()}
 
 
400
  depth_map = model_dict['depth_predictor'](**depth_inputs).predicted_depth.unsqueeze(1)
401
  depth_min = torch.amin(depth_map, dim=[1,2,3], keepdim=True)
402
  depth_max = torch.amax(depth_map, dim=[1,2,3], keepdim=True)
@@ -421,6 +431,9 @@ def generate_color(image, compactness=30, n_segments=100, thresh=35, blur_kernel
421
 
422
  @torch.no_grad()
423
  def generate(image_source, image_reference, text_input, negative_prompt, steps, random_seed, cfg_image_scale, cfg_text_scale, cfg_image_space_scale, cfg_image_reference_mix_weight, cfg_image_source_mix_weight, mask_scale, use_edge, t2i_height, t2i_width, do_sr, mode):
 
 
 
424
  text_input = text_input.lower()
425
  if negative_prompt == '':
426
  print('running without a negative prompt')
@@ -513,10 +526,10 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
513
  # do reference first
514
  if image_reference is not None:
515
  image_cond_reference = ImageOps.exif_transpose(image_reference)
516
- image_cond_reference = model_dict['clip_preprocess'](image_cond_reference, return_tensors='pt')['pixel_values'].to('cuda:0')
517
  image_cond_reference = model_dict['clip_model'].encode_images(image_cond_reference)
518
  else:
519
- image_cond_reference = torch.zeros([1, 1, 768]).to(torch.float16).to('cuda:0')
520
 
521
  # do source or knn
522
  image_cond_source = None
@@ -530,14 +543,14 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
530
  image_cond, image_cond_source = predict_compodiff(None, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
531
  else:
532
  image_cond, image_cond_source = predict_compodiff(image_source, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
533
- image_cond = image_cond.to(torch.float16).to('cuda:0')
534
- image_cond_source = image_cond_source.to(torch.float16).to('cuda:0')
535
  else:
536
- image_cond = torch.zeros([1, 1, 768]).to(torch.float16).to('cuda:0')
537
 
538
  if image_cond_source is None and mode != 't2i':
539
  image_cond_source = image_source.resize((512, 512))
540
- image_cond_source = model_dict['clip_preprocess'](image_cond_source, return_tensors='pt')['pixel_values'].to('cuda:0')
541
  image_cond_source = model_dict['clip_model'].encode_images(image_cond_source)
542
 
543
  if cfg_image_reference_mix_weight > 0.0 and torch.sum(image_cond_reference).item() != 0.0:
@@ -551,7 +564,7 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
551
 
552
  if negative_prompt != '':
553
  negative_image_cond, _ = predict_compodiff(None, negative_prompt, '', cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
554
- negative_image_cond = negative_image_cond.to(torch.float16).to('cuda:0')
555
  else:
556
  negative_image_cond = torch.zeros_like(image_cond)
557
 
 
162
 
163
  # 2. Encode input prompt
164
  cond_embeds = torch.cat([image_cond_embeds, negative_image_cond_embeds])
165
+ cond_embeds = einops.repeat(cond_embeds, 'b n d -> (b num) n d', num=num_images_per_prompt) #.to(torch_dtype)
166
  prompt_embeds = cond_embeds
167
 
168
  # 3. Preprocess image
 
312
 
313
  def build_models(args):
314
  # Load scheduler, tokenizer and models.
315
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
316
+ torch_dtype = torch.float16 if 'cuda' in device else torch.float32
317
 
318
  model_path = 'navervision/Graphit-SD'
319
  unet = UNet2DConditionModel.from_pretrained(
320
+ model_path, torch_dtype=torch_dtype,
321
  )
322
 
323
  vae_name = 'stabilityai/sd-vae-ft-ema'
324
+ vae = AutoencoderKL.from_pretrained(vae_name, torch_dtype=torch_dtype)
325
 
326
  model_name = 'timbrooks/instruct-pix2pix'
327
+ pipe = GraphitPipeline.from_pretrained(model_name, torch_dtype=torch_dtype, safety_checker=None,
328
  unet = unet,
329
  vae = vae,
330
  )
331
+ pipe = pipe.to(device)
332
 
333
  ## load CompoDiff
334
  compodiff_model, clip_model, clip_preprocess, clip_tokenizer = compodiff.build_model()
335
+ compodiff_model, clip_model = compodiff_model.to(device), clip_model.to(device)
336
+
337
+ if device != 'cpu':
338
+ clip_model = clip_model.half()
339
 
340
  ## load third-party models
341
  model_name = 'Intel/dpt-large'
342
  depth_preprocess = DPTFeatureExtractor.from_pretrained(model_name)
343
+ depth_predictor = DPTForDepthEstimation.from_pretrained(model_name, torch_dtype=torch_dtype)
344
+ depth_predictor = depth_predictor.to(device)
345
 
346
  if not os.path.exists('./third_party/remover_fast.pth'):
347
  model_file_url = hf_hub_url(repo_id='Geonmo/remover_fast', filename='remover_fast.pth')
348
  cached_download(model_file_url, cache_dir='./third_party', force_filename='remover_fast.pth')
349
+ remover = Remover(fast=True, jit=False, device=device, ckpt='./third_party/remover_fast.pth')
350
 
351
+ sr_model = CustomRealESRGAN(device, scale=2)
352
  sr_model.load_weights('./third_party/RealESRGAN_x2.pth', download=True)
353
 
354
  dataset = datasets.load_dataset("FredZhang7/stable-diffusion-prompts-2.47M")
 
366
  'remover': remover,
367
  'sr_model': sr_model,
368
  'prompt_candidates': prompts,
369
+ 'device': device,
370
+ 'torch_dtype': torch_dtype,
371
  }
372
  return model_dict
373
 
374
 
375
  def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_text_scale, mask, random_seed):
376
+ device = model_dict['device']
377
  text_token_dict = model_dict['clip_tokenizer'](text=text_input, return_tensors='pt', padding='max_length', truncation=True)
378
+ text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
379
 
380
  negative_text_token_dict = model_dict['clip_tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
381
+ negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
382
 
383
  with torch.no_grad():
384
  if image is None:
385
+ image_cond = torch.zeros([1,1,768]).to(device)
386
+ mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(device).unsqueeze(0)
387
  else:
388
  image_source = image.resize((512, 512))
389
+ image_source = model_dict['clip_preprocess'](image_source, return_tensors='pt')['pixel_values'].to(device)
390
  mask = mask.resize((512, 512))
391
  mask = model_dict['clip_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values']
392
  mask = mask[:,:1,:,:]
393
+ mask = (mask > 0.5).float().to(device)
394
  image_source = image_source * (1 - mask)
395
  image_cond = model_dict['clip_model'].encode_images(image_source)
396
  mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
 
404
 
405
 
406
  def generate_depth_map(image, height, width):
407
+ device = model_dict['device']
408
+ torch_dtype = model_dict['torch_dtype']
409
+ depth_inputs = {k: v.to(device, dtype=torch_dtype) for k, v in model_dict['depth_preprocess'](images=image, return_tensors='pt').items()}
410
  depth_map = model_dict['depth_predictor'](**depth_inputs).predicted_depth.unsqueeze(1)
411
  depth_min = torch.amin(depth_map, dim=[1,2,3], keepdim=True)
412
  depth_max = torch.amax(depth_map, dim=[1,2,3], keepdim=True)
 
431
 
432
  @torch.no_grad()
433
  def generate(image_source, image_reference, text_input, negative_prompt, steps, random_seed, cfg_image_scale, cfg_text_scale, cfg_image_space_scale, cfg_image_reference_mix_weight, cfg_image_source_mix_weight, mask_scale, use_edge, t2i_height, t2i_width, do_sr, mode):
434
+ device = model_dict['device']
435
+ torch_dtype = model_dict['torch_dtype']
436
+
437
  text_input = text_input.lower()
438
  if negative_prompt == '':
439
  print('running without a negative prompt')
 
526
  # do reference first
527
  if image_reference is not None:
528
  image_cond_reference = ImageOps.exif_transpose(image_reference)
529
+ image_cond_reference = model_dict['clip_preprocess'](image_cond_reference, return_tensors='pt')['pixel_values'].to(device)
530
  image_cond_reference = model_dict['clip_model'].encode_images(image_cond_reference)
531
  else:
532
+ image_cond_reference = torch.zeros([1, 1, 768]).to(torch_dtype).to(device)
533
 
534
  # do source or knn
535
  image_cond_source = None
 
543
  image_cond, image_cond_source = predict_compodiff(None, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
544
  else:
545
  image_cond, image_cond_source = predict_compodiff(image_source, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
546
+ image_cond = image_cond.to(torch_dtype).to(device)
547
+ image_cond_source = image_cond_source.to(torch_dtype).to(device)
548
  else:
549
+ image_cond = torch.zeros([1, 1, 768]).to(torch_dtype).to(device)
550
 
551
  if image_cond_source is None and mode != 't2i':
552
  image_cond_source = image_source.resize((512, 512))
553
+ image_cond_source = model_dict['clip_preprocess'](image_cond_source, return_tensors='pt')['pixel_values'].to(device)
554
  image_cond_source = model_dict['clip_model'].encode_images(image_cond_source)
555
 
556
  if cfg_image_reference_mix_weight > 0.0 and torch.sum(image_cond_reference).item() != 0.0:
 
564
 
565
  if negative_prompt != '':
566
  negative_image_cond, _ = predict_compodiff(None, negative_prompt, '', cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
567
+ negative_image_cond = negative_image_cond.to(torch_dtype).to(device)
568
  else:
569
  negative_image_cond = torch.zeros_like(image_cond)
570