hysts HF staff commited on
Commit
4c57c14
1 Parent(s): 206e60d
Files changed (3) hide show
  1. app_normal.py +3 -0
  2. app_seg.py +3 -0
  3. model.py +28 -16
app_normal.py CHANGED
@@ -13,6 +13,8 @@ def create_demo(process, max_images=12, default_num_images=3):
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
 
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
@@ -74,6 +76,7 @@ def create_demo(process, max_images=12, default_num_images=3):
74
  guidance_scale,
75
  seed,
76
  bg_threshold,
 
77
  ]
78
  prompt.submit(fn=process, inputs=inputs, outputs=result)
79
  run_button.click(fn=process,
 
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
16
+ is_normal_image = gr.Checkbox(label='Is normal image',
17
+ value=False)
18
  num_samples = gr.Slider(label='Images',
19
  minimum=1,
20
  maximum=max_images,
 
76
  guidance_scale,
77
  seed,
78
  bg_threshold,
79
+ is_normal_image,
80
  ]
81
  prompt.submit(fn=process, inputs=inputs, outputs=result)
82
  run_button.click(fn=process,
app_seg.py CHANGED
@@ -13,6 +13,8 @@ def create_demo(process, max_images=12, default_num_images=3):
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
 
 
16
  num_samples = gr.Slider(label='Images',
17
  minimum=1,
18
  maximum=max_images,
@@ -68,6 +70,7 @@ def create_demo(process, max_images=12, default_num_images=3):
68
  num_steps,
69
  guidance_scale,
70
  seed,
 
71
  ]
72
  prompt.submit(fn=process, inputs=inputs, outputs=result)
73
  run_button.click(fn=process,
 
13
  prompt = gr.Textbox(label='Prompt')
14
  run_button = gr.Button(label='Run')
15
  with gr.Accordion('Advanced options', open=False):
16
+ is_segmentation_map = gr.Checkbox(
17
+ label='Is segmentation map', value=False)
18
  num_samples = gr.Slider(label='Images',
19
  minimum=1,
20
  maximum=max_images,
 
70
  num_steps,
71
  guidance_scale,
72
  seed,
73
+ is_segmentation_map,
74
  ]
75
  prompt.submit(fn=process, inputs=inputs, outputs=result)
76
  run_button.click(fn=process,
model.py CHANGED
@@ -494,14 +494,18 @@ class Model:
494
  input_image: np.ndarray,
495
  image_resolution: int,
496
  detect_resolution: int,
 
497
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
498
  input_image = HWC3(input_image)
499
- control_image = apply_uniformer(
500
- resize_image(input_image, detect_resolution))
501
- image = resize_image(input_image, image_resolution)
502
- H, W = image.shape[:2]
503
- control_image = cv2.resize(control_image, (W, H),
504
- interpolation=cv2.INTER_NEAREST)
 
 
 
505
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
506
  control_image)
507
 
@@ -518,11 +522,13 @@ class Model:
518
  num_steps: int,
519
  guidance_scale: float,
520
  seed: int,
 
521
  ) -> list[PIL.Image.Image]:
522
  control_image, vis_control_image = self.preprocess_seg(
523
  input_image=input_image,
524
  image_resolution=image_resolution,
525
  detect_resolution=detect_resolution,
 
526
  )
527
  return self.process(
528
  task_name='seg',
@@ -597,17 +603,21 @@ class Model:
597
  input_image: np.ndarray,
598
  image_resolution: int,
599
  detect_resolution: int,
600
- bg_threshold,
 
601
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
602
  input_image = HWC3(input_image)
603
- _, control_image = apply_midas(resize_image(input_image,
604
- detect_resolution),
605
- bg_th=bg_threshold)
606
- control_image = HWC3(control_image)
607
- image = resize_image(input_image, image_resolution)
608
- H, W = image.shape[:2]
609
- control_image = cv2.resize(control_image, (W, H),
610
- interpolation=cv2.INTER_LINEAR)
 
 
 
611
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
612
  control_image)
613
 
@@ -624,13 +634,15 @@ class Model:
624
  num_steps: int,
625
  guidance_scale: float,
626
  seed: int,
627
- bg_threshold,
 
628
  ) -> list[PIL.Image.Image]:
629
  control_image, vis_control_image = self.preprocess_normal(
630
  input_image=input_image,
631
  image_resolution=image_resolution,
632
  detect_resolution=detect_resolution,
633
  bg_threshold=bg_threshold,
 
634
  )
635
  return self.process(
636
  task_name='normal',
 
494
  input_image: np.ndarray,
495
  image_resolution: int,
496
  detect_resolution: int,
497
+ is_segmentation_map: bool,
498
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
499
  input_image = HWC3(input_image)
500
+ if not is_segmentation_map:
501
+ control_image = apply_uniformer(
502
+ resize_image(input_image, detect_resolution))
503
+ image = resize_image(input_image, image_resolution)
504
+ H, W = image.shape[:2]
505
+ control_image = cv2.resize(control_image, (W, H),
506
+ interpolation=cv2.INTER_NEAREST)
507
+ else:
508
+ control_image = input_image
509
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
510
  control_image)
511
 
 
522
  num_steps: int,
523
  guidance_scale: float,
524
  seed: int,
525
+ is_segmentation_map: bool,
526
  ) -> list[PIL.Image.Image]:
527
  control_image, vis_control_image = self.preprocess_seg(
528
  input_image=input_image,
529
  image_resolution=image_resolution,
530
  detect_resolution=detect_resolution,
531
+ is_segmentation_map=is_segmentation_map,
532
  )
533
  return self.process(
534
  task_name='seg',
 
603
  input_image: np.ndarray,
604
  image_resolution: int,
605
  detect_resolution: int,
606
+ bg_threshold: float,
607
+ is_normal_image: bool,
608
  ) -> tuple[PIL.Image.Image, PIL.Image.Image]:
609
  input_image = HWC3(input_image)
610
+ if not is_normal_image:
611
+ _, control_image = apply_midas(resize_image(
612
+ input_image, detect_resolution),
613
+ bg_th=bg_threshold)
614
+ control_image = HWC3(control_image)
615
+ image = resize_image(input_image, image_resolution)
616
+ H, W = image.shape[:2]
617
+ control_image = cv2.resize(control_image, (W, H),
618
+ interpolation=cv2.INTER_LINEAR)
619
+ else:
620
+ control_image = input_image
621
  return PIL.Image.fromarray(control_image), PIL.Image.fromarray(
622
  control_image)
623
 
 
634
  num_steps: int,
635
  guidance_scale: float,
636
  seed: int,
637
+ bg_threshold: float,
638
+ is_normal_image: bool,
639
  ) -> list[PIL.Image.Image]:
640
  control_image, vis_control_image = self.preprocess_normal(
641
  input_image=input_image,
642
  image_resolution=image_resolution,
643
  detect_resolution=detect_resolution,
644
  bg_threshold=bg_threshold,
645
+ is_normal_image=is_normal_image,
646
  )
647
  return self.process(
648
  task_name='normal',