eliphatfs commited on
Commit
a22ab8b
1 Parent(s): 1059e8f

Use gradio UI.

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. gradio_app.py +204 -0
  3. requirements.txt +2 -0
Dockerfile CHANGED
@@ -37,4 +37,4 @@ COPY --chown=user . $HOME/app
37
 
38
  RUN python3 download_checkpoints.py
39
 
40
- CMD ["streamlit", "run", "--server.enableXsrfProtection", "false", "app.py"]
 
37
 
38
  RUN python3 download_checkpoints.py
39
 
40
+ CMD ["python", "gradio_app.py"]
gradio_app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import fire
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from functools import partial
7
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
8
+
9
+ import cv2
10
+ import time
11
+ import numpy as np
12
+ from rembg import remove
13
+ from segment_anything import sam_model_registry, SamPredictor
14
+
15
+ _TITLE = '''Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model'''
16
+ _DESCRIPTION = '''
17
+ <div>
18
+ <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2310.15110"><img src="https://img.shields.io/badge/2310.15110-f9f7f7?logo="></a>
19
+ <a style="display:inline-block; margin-left: .5em" href='https://github.com/SUDO-AI-3D/zero123plus'><img src='https://img.shields.io/github/stars/SUDO-AI-3D/zero123plus?style=social' /></a>
20
+ </div>
21
+ '''
22
+ _GPU_ID = 0
23
+
24
+
25
+ if not hasattr(Image, 'Resampling'):
26
+ Image.Resampling = Image
27
+
28
+
29
+ def sam_init():
30
+ sam_checkpoint = os.path.join(os.path.dirname(__file__), "tmp", "sam_vit_h_4b8939.pth")
31
+ model_type = "vit_h"
32
+
33
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
34
+ predictor = SamPredictor(sam)
35
+ return predictor
36
+
37
+ def sam_segment(predictor, input_image, *bbox_coords):
38
+ bbox = np.array(bbox_coords)
39
+ image = np.asarray(input_image)
40
+
41
+ start_time = time.time()
42
+ predictor.set_image(image)
43
+
44
+ masks_bbox, scores_bbox, logits_bbox = predictor.predict(
45
+ box=bbox,
46
+ multimask_output=True
47
+ )
48
+
49
+ print(f"SAM Time: {time.time() - start_time:.3f}s")
50
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
51
+ out_image[:, :, :3] = image
52
+ out_image_bbox = out_image.copy()
53
+ out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
54
+ torch.cuda.empty_cache()
55
+ return Image.fromarray(out_image_bbox, mode='RGBA')
56
+
57
+ def expand2square(pil_img, background_color):
58
+ width, height = pil_img.size
59
+ if width == height:
60
+ return pil_img
61
+ elif width > height:
62
+ result = Image.new(pil_img.mode, (width, width), background_color)
63
+ result.paste(pil_img, (0, (width - height) // 2))
64
+ return result
65
+ else:
66
+ result = Image.new(pil_img.mode, (height, height), background_color)
67
+ result.paste(pil_img, ((height - width) // 2, 0))
68
+ return result
69
+
70
+ def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False):
71
+ RES = 1024
72
+ input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
73
+ if chk_group is not None:
74
+ segment = "Background Removal" in chk_group
75
+ rescale = "Rescale" in chk_group
76
+ if segment:
77
+ image_rem = input_image.convert('RGBA')
78
+ image_nobg = remove(image_rem, alpha_matting=True)
79
+ arr = np.asarray(image_nobg)[:,:,-1]
80
+ x_nonzero = np.nonzero(arr.sum(axis=0))
81
+ y_nonzero = np.nonzero(arr.sum(axis=1))
82
+ x_min = int(x_nonzero[0].min())
83
+ y_min = int(y_nonzero[0].min())
84
+ x_max = int(x_nonzero[0].max())
85
+ y_max = int(y_nonzero[0].max())
86
+ input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
87
+ # Rescale and recenter
88
+ if rescale:
89
+ image_arr = np.array(input_image)
90
+ in_w, in_h = image_arr.shape[:2]
91
+ out_res = min(RES, max(in_w, in_h))
92
+ ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
93
+ x, y, w, h = cv2.boundingRect(mask)
94
+ max_size = max(w, h)
95
+ ratio = 0.75
96
+ side_len = int(max_size / ratio)
97
+ padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
98
+ center = side_len//2
99
+ padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w]
100
+ rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS)
101
+
102
+ rgba_arr = np.array(rgba) / 255.0
103
+ rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:])
104
+ input_image = Image.fromarray((rgb * 255).astype(np.uint8))
105
+ else:
106
+ input_image = expand2square(input_image, (127, 127, 127, 0))
107
+ return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
108
+
109
+ def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider, seed, output_processing=False):
110
+ seed = int(seed)
111
+ torch.manual_seed(seed)
112
+ image = pipeline(input_image,
113
+ num_inference_steps=steps_slider,
114
+ guidance_scale=scale_slider,
115
+ generator=torch.Generator(pipeline.device).manual_seed(seed)).images[0]
116
+ side_len = image.width//2
117
+ subimages = [image.crop((x, y, x + side_len, y+side_len)) for y in range(0, image.height, side_len) for x in range(0, image.width, side_len)]
118
+ if "Background Removal" in output_processing:
119
+ out_images = []
120
+ for sub_image in subimages:
121
+ sub_image, _ = preprocess(predictor, sub_image.convert('RGB'), segment=True, rescale=False)
122
+ out_images.append(sub_image)
123
+ return out_images
124
+ return subimages
125
+
126
+
127
+ def run_demo():
128
+ # Load the pipeline
129
+ pipeline = DiffusionPipeline.from_pretrained(
130
+ "sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline",
131
+ torch_dtype=torch.float16
132
+ )
133
+ # Feel free to tune the scheduler
134
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
135
+ pipeline.scheduler.config, timestep_spacing='trailing'
136
+ )
137
+ pipeline.to(f'cuda:{_GPU_ID}')
138
+
139
+ predictor = sam_init()
140
+
141
+ custom_theme = gr.themes.Soft(primary_hue="blue").set(
142
+ button_secondary_background_fill="*neutral_100",
143
+ button_secondary_background_fill_hover="*neutral_200")
144
+ custom_css = '''#disp_image {
145
+ text-align: center; /* Horizontally center the content */
146
+ }'''
147
+
148
+ with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
149
+ with gr.Row():
150
+ with gr.Column(scale=1):
151
+ gr.Markdown('# ' + _TITLE)
152
+ gr.Markdown(_DESCRIPTION)
153
+ with gr.Row(variant='panel'):
154
+ with gr.Column(scale=1):
155
+ input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image', tool=None)
156
+
157
+ example_folder = os.path.join(os.path.dirname(__file__), "./resources/examples")
158
+ example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
159
+ gr.Examples(
160
+ examples=example_fns,
161
+ inputs=[input_image],
162
+ outputs=[input_image],
163
+ cache_examples=False,
164
+ label='Examples (click one of the images below to start)',
165
+ examples_per_page=10
166
+ )
167
+ with gr.Accordion('Advanced options', open=False):
168
+ with gr.Row():
169
+ with gr.Column():
170
+ input_processing = gr.CheckboxGroup(['Background Removal', 'Rescale'], label='Input Image Preprocessing', value=['Background Removal'])
171
+ with gr.Column():
172
+ output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
173
+ scale_slider = gr.Slider(1, 10, value=4, step=1,
174
+ label='Classifier Free Guidance Scale')
175
+ steps_slider = gr.Slider(15, 100, value=75, step=1,
176
+ label='Number of Diffusion Inference Steps',
177
+ info="For general real or synthetic objects, around 28 is enough. For objects with delicate details such as faces (either realistic or illustration), you may need 75 or more steps.")
178
+ seed = gr.Number(42, label='Seed')
179
+ run_btn = gr.Button('Generate', variant='primary', interactive=True)
180
+ with gr.Column(scale=1):
181
+ processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=320, tool=None, image_mode='RGBA', elem_id="disp_image")
182
+ processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False, tool=None)
183
+ with gr.Row():
184
+ view_1 = gr.Image(interactive=False, height=240, show_label=False)
185
+ view_2 = gr.Image(interactive=False, height=240, show_label=False)
186
+ view_3 = gr.Image(interactive=False, height=240, show_label=False)
187
+ with gr.Row():
188
+ view_4 = gr.Image(interactive=False, height=240, show_label=False)
189
+ view_5 = gr.Image(interactive=False, height=240, show_label=False)
190
+ view_6 = gr.Image(interactive=False, height=240, show_label=False)
191
+
192
+
193
+ run_btn.click(fn=partial(preprocess, predictor),
194
+ inputs=[input_image, input_processing],
195
+ outputs=[processed_image_highres, processed_image], queue=True
196
+ ).success(fn=partial(gen_multiview, pipeline, predictor),
197
+ inputs=[processed_image_highres, scale_slider, steps_slider, seed, output_processing],
198
+ outputs=[view_1, view_2, view_3, view_4, view_5, view_6])
199
+
200
+ demo.queue().launch(share=True, max_threads=80)
201
+
202
+
203
+ if __name__ == '__main__':
204
+ fire.Fire(run_demo)
requirements.txt CHANGED
@@ -9,3 +9,5 @@ streamlit==1.22.0
9
  altair<5
10
  huggingface_hub
11
  git+https://github.com/facebookresearch/segment-anything.git
 
 
 
9
  altair<5
10
  huggingface_hub
11
  git+https://github.com/facebookresearch/segment-anything.git
12
+ gradio>=3.50
13
+ fire