Spaces:
Runtime error
Runtime error
support CPU
Browse files
app.py
CHANGED
@@ -162,7 +162,7 @@ class GraphitPipeline(StableDiffusionInstructPix2PixPipeline):
|
|
162 |
|
163 |
# 2. Encode input prompt
|
164 |
cond_embeds = torch.cat([image_cond_embeds, negative_image_cond_embeds])
|
165 |
-
cond_embeds = einops.repeat(cond_embeds, 'b n d -> (b num) n d', num=num_images_per_prompt)
|
166 |
prompt_embeds = cond_embeds
|
167 |
|
168 |
# 3. Preprocess image
|
@@ -312,38 +312,43 @@ class CustomRealESRGAN(RealESRGAN):
|
|
312 |
|
313 |
def build_models(args):
|
314 |
# Load scheduler, tokenizer and models.
|
|
|
|
|
315 |
|
316 |
model_path = 'navervision/Graphit-SD'
|
317 |
unet = UNet2DConditionModel.from_pretrained(
|
318 |
-
model_path, torch_dtype=
|
319 |
)
|
320 |
|
321 |
vae_name = 'stabilityai/sd-vae-ft-ema'
|
322 |
-
vae = AutoencoderKL.from_pretrained(vae_name, torch_dtype=
|
323 |
|
324 |
model_name = 'timbrooks/instruct-pix2pix'
|
325 |
-
pipe = GraphitPipeline.from_pretrained(model_name, torch_dtype=
|
326 |
unet = unet,
|
327 |
vae = vae,
|
328 |
)
|
329 |
-
pipe = pipe.to(
|
330 |
|
331 |
## load CompoDiff
|
332 |
compodiff_model, clip_model, clip_preprocess, clip_tokenizer = compodiff.build_model()
|
333 |
-
compodiff_model, clip_model = compodiff_model.to(
|
|
|
|
|
|
|
334 |
|
335 |
## load third-party models
|
336 |
model_name = 'Intel/dpt-large'
|
337 |
depth_preprocess = DPTFeatureExtractor.from_pretrained(model_name)
|
338 |
-
depth_predictor = DPTForDepthEstimation.from_pretrained(model_name, torch_dtype=
|
339 |
-
depth_predictor = depth_predictor.to(
|
340 |
|
341 |
if not os.path.exists('./third_party/remover_fast.pth'):
|
342 |
model_file_url = hf_hub_url(repo_id='Geonmo/remover_fast', filename='remover_fast.pth')
|
343 |
cached_download(model_file_url, cache_dir='./third_party', force_filename='remover_fast.pth')
|
344 |
-
remover = Remover(fast=True, jit=False, device=
|
345 |
|
346 |
-
sr_model = CustomRealESRGAN(
|
347 |
sr_model.load_weights('./third_party/RealESRGAN_x2.pth', download=True)
|
348 |
|
349 |
dataset = datasets.load_dataset("FredZhang7/stable-diffusion-prompts-2.47M")
|
@@ -361,28 +366,31 @@ def build_models(args):
|
|
361 |
'remover': remover,
|
362 |
'sr_model': sr_model,
|
363 |
'prompt_candidates': prompts,
|
|
|
|
|
364 |
}
|
365 |
return model_dict
|
366 |
|
367 |
|
368 |
def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_text_scale, mask, random_seed):
|
|
|
369 |
text_token_dict = model_dict['clip_tokenizer'](text=text_input, return_tensors='pt', padding='max_length', truncation=True)
|
370 |
-
text_tokens, text_attention_mask = text_token_dict['input_ids'].to(
|
371 |
|
372 |
negative_text_token_dict = model_dict['clip_tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
|
373 |
-
negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(
|
374 |
|
375 |
with torch.no_grad():
|
376 |
if image is None:
|
377 |
-
image_cond = torch.zeros([1,1,768]).to(
|
378 |
-
mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(
|
379 |
else:
|
380 |
image_source = image.resize((512, 512))
|
381 |
-
image_source = model_dict['clip_preprocess'](image_source, return_tensors='pt')['pixel_values'].to(
|
382 |
mask = mask.resize((512, 512))
|
383 |
mask = model_dict['clip_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values']
|
384 |
mask = mask[:,:1,:,:]
|
385 |
-
mask = (mask > 0.5).float().to(
|
386 |
image_source = image_source * (1 - mask)
|
387 |
image_cond = model_dict['clip_model'].encode_images(image_source)
|
388 |
mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
|
@@ -396,7 +404,9 @@ def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_tex
|
|
396 |
|
397 |
|
398 |
def generate_depth_map(image, height, width):
|
399 |
-
|
|
|
|
|
400 |
depth_map = model_dict['depth_predictor'](**depth_inputs).predicted_depth.unsqueeze(1)
|
401 |
depth_min = torch.amin(depth_map, dim=[1,2,3], keepdim=True)
|
402 |
depth_max = torch.amax(depth_map, dim=[1,2,3], keepdim=True)
|
@@ -421,6 +431,9 @@ def generate_color(image, compactness=30, n_segments=100, thresh=35, blur_kernel
|
|
421 |
|
422 |
@torch.no_grad()
|
423 |
def generate(image_source, image_reference, text_input, negative_prompt, steps, random_seed, cfg_image_scale, cfg_text_scale, cfg_image_space_scale, cfg_image_reference_mix_weight, cfg_image_source_mix_weight, mask_scale, use_edge, t2i_height, t2i_width, do_sr, mode):
|
|
|
|
|
|
|
424 |
text_input = text_input.lower()
|
425 |
if negative_prompt == '':
|
426 |
print('running without a negative prompt')
|
@@ -513,10 +526,10 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
|
|
513 |
# do reference first
|
514 |
if image_reference is not None:
|
515 |
image_cond_reference = ImageOps.exif_transpose(image_reference)
|
516 |
-
image_cond_reference = model_dict['clip_preprocess'](image_cond_reference, return_tensors='pt')['pixel_values'].to(
|
517 |
image_cond_reference = model_dict['clip_model'].encode_images(image_cond_reference)
|
518 |
else:
|
519 |
-
image_cond_reference = torch.zeros([1, 1, 768]).to(
|
520 |
|
521 |
# do source or knn
|
522 |
image_cond_source = None
|
@@ -530,14 +543,14 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
|
|
530 |
image_cond, image_cond_source = predict_compodiff(None, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
531 |
else:
|
532 |
image_cond, image_cond_source = predict_compodiff(image_source, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
533 |
-
image_cond = image_cond.to(
|
534 |
-
image_cond_source = image_cond_source.to(
|
535 |
else:
|
536 |
-
image_cond = torch.zeros([1, 1, 768]).to(
|
537 |
|
538 |
if image_cond_source is None and mode != 't2i':
|
539 |
image_cond_source = image_source.resize((512, 512))
|
540 |
-
image_cond_source = model_dict['clip_preprocess'](image_cond_source, return_tensors='pt')['pixel_values'].to(
|
541 |
image_cond_source = model_dict['clip_model'].encode_images(image_cond_source)
|
542 |
|
543 |
if cfg_image_reference_mix_weight > 0.0 and torch.sum(image_cond_reference).item() != 0.0:
|
@@ -551,7 +564,7 @@ def generate(image_source, image_reference, text_input, negative_prompt, steps,
|
|
551 |
|
552 |
if negative_prompt != '':
|
553 |
negative_image_cond, _ = predict_compodiff(None, negative_prompt, '', cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
554 |
-
negative_image_cond = negative_image_cond.to(
|
555 |
else:
|
556 |
negative_image_cond = torch.zeros_like(image_cond)
|
557 |
|
|
|
162 |
|
163 |
# 2. Encode input prompt
|
164 |
cond_embeds = torch.cat([image_cond_embeds, negative_image_cond_embeds])
|
165 |
+
cond_embeds = einops.repeat(cond_embeds, 'b n d -> (b num) n d', num=num_images_per_prompt) #.to(torch_dtype)
|
166 |
prompt_embeds = cond_embeds
|
167 |
|
168 |
# 3. Preprocess image
|
|
|
312 |
|
313 |
def build_models(args):
|
314 |
# Load scheduler, tokenizer and models.
|
315 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
316 |
+
torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
317 |
|
318 |
model_path = 'navervision/Graphit-SD'
|
319 |
unet = UNet2DConditionModel.from_pretrained(
|
320 |
+
model_path, torch_dtype=torch_dtype,
|
321 |
)
|
322 |
|
323 |
vae_name = 'stabilityai/sd-vae-ft-ema'
|
324 |
+
vae = AutoencoderKL.from_pretrained(vae_name, torch_dtype=torch_dtype)
|
325 |
|
326 |
model_name = 'timbrooks/instruct-pix2pix'
|
327 |
+
pipe = GraphitPipeline.from_pretrained(model_name, torch_dtype=torch_dtype, safety_checker=None,
|
328 |
unet = unet,
|
329 |
vae = vae,
|
330 |
)
|
331 |
+
pipe = pipe.to(device)
|
332 |
|
333 |
## load CompoDiff
|
334 |
compodiff_model, clip_model, clip_preprocess, clip_tokenizer = compodiff.build_model()
|
335 |
+
compodiff_model, clip_model = compodiff_model.to(device), clip_model.to(device)
|
336 |
+
|
337 |
+
if device != 'cpu':
|
338 |
+
clip_model = clip_model.half()
|
339 |
|
340 |
## load third-party models
|
341 |
model_name = 'Intel/dpt-large'
|
342 |
depth_preprocess = DPTFeatureExtractor.from_pretrained(model_name)
|
343 |
+
depth_predictor = DPTForDepthEstimation.from_pretrained(model_name, torch_dtype=torch_dtype)
|
344 |
+
depth_predictor = depth_predictor.to(device)
|
345 |
|
346 |
if not os.path.exists('./third_party/remover_fast.pth'):
|
347 |
model_file_url = hf_hub_url(repo_id='Geonmo/remover_fast', filename='remover_fast.pth')
|
348 |
cached_download(model_file_url, cache_dir='./third_party', force_filename='remover_fast.pth')
|
349 |
+
remover = Remover(fast=True, jit=False, device=device, ckpt='./third_party/remover_fast.pth')
|
350 |
|
351 |
+
sr_model = CustomRealESRGAN(device, scale=2)
|
352 |
sr_model.load_weights('./third_party/RealESRGAN_x2.pth', download=True)
|
353 |
|
354 |
dataset = datasets.load_dataset("FredZhang7/stable-diffusion-prompts-2.47M")
|
|
|
366 |
'remover': remover,
|
367 |
'sr_model': sr_model,
|
368 |
'prompt_candidates': prompts,
|
369 |
+
'device': device,
|
370 |
+
'torch_dtype': torch_dtype,
|
371 |
}
|
372 |
return model_dict
|
373 |
|
374 |
|
375 |
def predict_compodiff(image, text_input, negative_text, cfg_image_scale, cfg_text_scale, mask, random_seed):
|
376 |
+
device = model_dict['device']
|
377 |
text_token_dict = model_dict['clip_tokenizer'](text=text_input, return_tensors='pt', padding='max_length', truncation=True)
|
378 |
+
text_tokens, text_attention_mask = text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
|
379 |
|
380 |
negative_text_token_dict = model_dict['clip_tokenizer'](text=negative_text, return_tensors='pt', padding='max_length', truncation=True)
|
381 |
+
negative_text_tokens, negative_text_attention_mask = negative_text_token_dict['input_ids'].to(device), text_token_dict['attention_mask'].to(device)
|
382 |
|
383 |
with torch.no_grad():
|
384 |
if image is None:
|
385 |
+
image_cond = torch.zeros([1,1,768]).to(device)
|
386 |
+
mask = torch.tensor(np.zeros([64, 64], dtype='float32')).to(device).unsqueeze(0)
|
387 |
else:
|
388 |
image_source = image.resize((512, 512))
|
389 |
+
image_source = model_dict['clip_preprocess'](image_source, return_tensors='pt')['pixel_values'].to(device)
|
390 |
mask = mask.resize((512, 512))
|
391 |
mask = model_dict['clip_preprocess'](mask, do_normalize=False, return_tensors='pt')['pixel_values']
|
392 |
mask = mask[:,:1,:,:]
|
393 |
+
mask = (mask > 0.5).float().to(device)
|
394 |
image_source = image_source * (1 - mask)
|
395 |
image_cond = model_dict['clip_model'].encode_images(image_source)
|
396 |
mask = transforms.Resize([64, 64])(mask)[:,0,:,:]
|
|
|
404 |
|
405 |
|
406 |
def generate_depth_map(image, height, width):
|
407 |
+
device = model_dict['device']
|
408 |
+
torch_dtype = model_dict['torch_dtype']
|
409 |
+
depth_inputs = {k: v.to(device, dtype=torch_dtype) for k, v in model_dict['depth_preprocess'](images=image, return_tensors='pt').items()}
|
410 |
depth_map = model_dict['depth_predictor'](**depth_inputs).predicted_depth.unsqueeze(1)
|
411 |
depth_min = torch.amin(depth_map, dim=[1,2,3], keepdim=True)
|
412 |
depth_max = torch.amax(depth_map, dim=[1,2,3], keepdim=True)
|
|
|
431 |
|
432 |
@torch.no_grad()
|
433 |
def generate(image_source, image_reference, text_input, negative_prompt, steps, random_seed, cfg_image_scale, cfg_text_scale, cfg_image_space_scale, cfg_image_reference_mix_weight, cfg_image_source_mix_weight, mask_scale, use_edge, t2i_height, t2i_width, do_sr, mode):
|
434 |
+
device = model_dict['device']
|
435 |
+
torch_dtype = model_dict['torch_dtype']
|
436 |
+
|
437 |
text_input = text_input.lower()
|
438 |
if negative_prompt == '':
|
439 |
print('running without a negative prompt')
|
|
|
526 |
# do reference first
|
527 |
if image_reference is not None:
|
528 |
image_cond_reference = ImageOps.exif_transpose(image_reference)
|
529 |
+
image_cond_reference = model_dict['clip_preprocess'](image_cond_reference, return_tensors='pt')['pixel_values'].to(device)
|
530 |
image_cond_reference = model_dict['clip_model'].encode_images(image_cond_reference)
|
531 |
else:
|
532 |
+
image_cond_reference = torch.zeros([1, 1, 768]).to(torch_dtype).to(device)
|
533 |
|
534 |
# do source or knn
|
535 |
image_cond_source = None
|
|
|
543 |
image_cond, image_cond_source = predict_compodiff(None, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
544 |
else:
|
545 |
image_cond, image_cond_source = predict_compodiff(image_source, text_input, negative_prompt, cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
546 |
+
image_cond = image_cond.to(torch_dtype).to(device)
|
547 |
+
image_cond_source = image_cond_source.to(torch_dtype).to(device)
|
548 |
else:
|
549 |
+
image_cond = torch.zeros([1, 1, 768]).to(torch_dtype).to(device)
|
550 |
|
551 |
if image_cond_source is None and mode != 't2i':
|
552 |
image_cond_source = image_source.resize((512, 512))
|
553 |
+
image_cond_source = model_dict['clip_preprocess'](image_cond_source, return_tensors='pt')['pixel_values'].to(device)
|
554 |
image_cond_source = model_dict['clip_model'].encode_images(image_cond_source)
|
555 |
|
556 |
if cfg_image_reference_mix_weight > 0.0 and torch.sum(image_cond_reference).item() != 0.0:
|
|
|
564 |
|
565 |
if negative_prompt != '':
|
566 |
negative_image_cond, _ = predict_compodiff(None, negative_prompt, '', cfg_image_scale, cfg_text_scale, mask=mask_pil, random_seed=random_seed)
|
567 |
+
negative_image_cond = negative_image_cond.to(torch_dtype).to(device)
|
568 |
else:
|
569 |
negative_image_cond = torch.zeros_like(image_cond)
|
570 |
|