Spaces:
Build error
Build error
lemonaddie
commited on
Commit
•
fbf7415
1
Parent(s):
47e2130
Upload app_recon.py
Browse files- app_recon.py +295 -0
app_recon.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import sys
|
5 |
+
import git
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import torch as torch
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from gradio_imageslider import ImageSlider
|
13 |
+
from bilateral_normal_integration.bilateral_normal_integration_cupy import bilateral_normal_integration_function
|
14 |
+
|
15 |
+
import spaces
|
16 |
+
|
17 |
+
import fire
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import os
|
21 |
+
import logging
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from PIL import Image
|
26 |
+
from tqdm.auto import tqdm
|
27 |
+
import glob
|
28 |
+
import json
|
29 |
+
import cv2
|
30 |
+
|
31 |
+
from rembg import remove
|
32 |
+
from segment_anythi ng import sam_model_registry, SamPredictor
|
33 |
+
from datetime import datetime
|
34 |
+
import time
|
35 |
+
|
36 |
+
|
37 |
+
import sys
|
38 |
+
sys.path.append("../")
|
39 |
+
from models.geowizard_pipeline import DepthNormalEstimationPipeline
|
40 |
+
from utils.seed_all import seed_all
|
41 |
+
import matplotlib.pyplot as plt
|
42 |
+
from utils.de_normalized import align_scale_shift
|
43 |
+
from utils.depth2normal import *
|
44 |
+
|
45 |
+
from diffusers import DiffusionPipeline, DDIMScheduler, AutoencoderKL
|
46 |
+
from models.unet_2d_condition import UNet2DConditionModel
|
47 |
+
|
48 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
49 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
50 |
+
import torchvision.transforms.functional as TF
|
51 |
+
from torchvision.transforms import InterpolationMode
|
52 |
+
|
53 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
+
|
55 |
+
stable_diffusion_repo_path = "stabilityai/stable-diffusion-2-1-unclip"
|
56 |
+
vae = AutoencoderKL.from_pretrained(stable_diffusion_repo_path, subfolder='vae')
|
57 |
+
scheduler = DDIMScheduler.from_pretrained(stable_diffusion_repo_path, subfolder='scheduler')
|
58 |
+
sd_image_variations_diffusers_path = 'lambdalabs/sd-image-variations-diffusers'
|
59 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_image_variations_diffusers_path, subfolder="image_encoder")
|
60 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(sd_image_variations_diffusers_path, subfolder="feature_extractor")
|
61 |
+
unet = UNet2DConditionModel.from_pretrained('.', subfolder="unet")
|
62 |
+
|
63 |
+
pipe = DepthNormalEstimationPipeline(vae=vae,
|
64 |
+
image_encoder=image_encoder,
|
65 |
+
feature_extractor=feature_extractor,
|
66 |
+
unet=unet,
|
67 |
+
scheduler=scheduler)
|
68 |
+
|
69 |
+
try:
|
70 |
+
import xformers
|
71 |
+
pipe.enable_xformers_memory_efficient_attention()
|
72 |
+
except:
|
73 |
+
pass # run without xformers
|
74 |
+
|
75 |
+
pipe = pipe.to(device)
|
76 |
+
|
77 |
+
def sam_init():
|
78 |
+
sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_l_0b3195.pth")
|
79 |
+
model_type = "vit_l"
|
80 |
+
|
81 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda")
|
82 |
+
predictor = SamPredictor(sam)
|
83 |
+
return predictor
|
84 |
+
|
85 |
+
sam_predictor = sam_init()
|
86 |
+
|
87 |
+
def sam_segment(predictor, input_image, *bbox_coords):
|
88 |
+
bbox = np.array(bbox_coords)
|
89 |
+
image = np.asarray(input_image)
|
90 |
+
|
91 |
+
start_time = time.time()
|
92 |
+
predictor.set_image(image)
|
93 |
+
|
94 |
+
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
|
95 |
+
box=bbox,
|
96 |
+
multimask_output=True
|
97 |
+
)
|
98 |
+
|
99 |
+
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
100 |
+
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
101 |
+
out_image[:, :, :3] = image
|
102 |
+
out_image_bbox = out_image.copy()
|
103 |
+
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
return Image.fromarray(out_image_bbox, mode='RGBA'), masks_bbox
|
106 |
+
|
107 |
+
@spaces.GPU
|
108 |
+
def depth_normal(img_path,
|
109 |
+
denoising_steps,
|
110 |
+
ensemble_size,
|
111 |
+
processing_res,
|
112 |
+
seed,
|
113 |
+
domain):
|
114 |
+
|
115 |
+
seed = int(seed)
|
116 |
+
if seed >= 0:
|
117 |
+
torch.manual_seed(seed)
|
118 |
+
|
119 |
+
img = Image.open(img_path)
|
120 |
+
|
121 |
+
pipe_out = pipe(
|
122 |
+
img,
|
123 |
+
denoising_steps=denoising_steps,
|
124 |
+
ensemble_size=ensemble_size,
|
125 |
+
processing_res=processing_res,
|
126 |
+
batch_size=0,
|
127 |
+
domain=domain,
|
128 |
+
show_progress_bar=True,
|
129 |
+
)
|
130 |
+
|
131 |
+
depth_colored = pipe_out.depth_colored
|
132 |
+
normal_colored = pipe_out.normal_colored
|
133 |
+
|
134 |
+
depth_np = pipe_out.depth_np
|
135 |
+
normal_np = pipe_out.normal_np
|
136 |
+
|
137 |
+
path_output_dir = os.path.splitext(os.path.basename(img_path))[0] + datetime.now().strftime('%Y%m%d-%H%M%S')
|
138 |
+
os.makedirs(path_output_dir, exist_ok=True)
|
139 |
+
|
140 |
+
name_base = os.path.splitext(os.path.basename(img_path))[0]
|
141 |
+
depth_path = os.path.join(path_output_dir, f"{name_base}_depth.npy")
|
142 |
+
normal_path = os.path.join(path_output_dir, f"{name_base}_normal.npy")
|
143 |
+
|
144 |
+
np.save(normal_path, normal_np)
|
145 |
+
np.save(depth_path, depth_np)
|
146 |
+
|
147 |
+
return depth_colored, normal_colored, [depth_path, normal_path]
|
148 |
+
|
149 |
+
def reconstruction(image, files):
|
150 |
+
|
151 |
+
torch.cuda.empty_cache()
|
152 |
+
|
153 |
+
img = Image.open(image)
|
154 |
+
|
155 |
+
image_rem = img.convert('RGBA')
|
156 |
+
image_nobg = remove(image_rem, alpha_matting=True)
|
157 |
+
arr = np.asarray(image_nobg)[:,:,-1]
|
158 |
+
x_nonzero = np.nonzero(arr.sum(axis=0))
|
159 |
+
y_nonzero = np.nonzero(arr.sum(axis=1))
|
160 |
+
x_min = int(x_nonzero[0].min())
|
161 |
+
y_min = int(y_nonzero[0].min())
|
162 |
+
x_max = int(x_nonzero[0].max())
|
163 |
+
y_max = int(y_nonzero[0].max())
|
164 |
+
masked_image, mask = sam_segment(sam_predictor, img.convert('RGB'), x_min, y_min, x_max, y_max)
|
165 |
+
|
166 |
+
depth_np = np.load(files[0])
|
167 |
+
normal_np = np.load(files[1])
|
168 |
+
|
169 |
+
dir_name = os.path.dirname(os.path.realpath(files[0]))
|
170 |
+
mask_output_temp = mask[-1]
|
171 |
+
name_base = os.path.splitext(os.path.basename(files[0]))[0][:-6]
|
172 |
+
|
173 |
+
normal_np[:, :, 0] *= -1
|
174 |
+
_, surface, _, _, _ = bilateral_normal_integration_function(normal_np, mask_output_temp, k=2, K=None, max_iter=100, tol=1e-4, cg_max_iter=5000, cg_tol=1e-3)
|
175 |
+
ply_path = os.path.join(dir_name, f"{name_base}_mask.ply")
|
176 |
+
surface.save(ply_path, binary=False)
|
177 |
+
return ply_path
|
178 |
+
|
179 |
+
def run_demo():
|
180 |
+
|
181 |
+
|
182 |
+
custom_theme = gr.themes.Soft(primary_hue="blue").set(
|
183 |
+
button_secondary_background_fill="*neutral_100",
|
184 |
+
button_secondary_background_fill_hover="*neutral_200")
|
185 |
+
custom_css = '''#disp_image {
|
186 |
+
text-align: center; /* Horizontally center the content */
|
187 |
+
}'''
|
188 |
+
|
189 |
+
_TITLE = '''GeoWizard: Unleashing the Diffusion Priors for 3D Geometry Estimation from a Single Image'''
|
190 |
+
_DESCRIPTION = '''
|
191 |
+
<div>
|
192 |
+
Generate consistent depth and normal from single image. High quality and rich details. (PS: We find the demo running on ZeroGPU output slightly inferior results compared to A100 or 3060 with everything exactly the same.)
|
193 |
+
<a style="display:inline-block; margin-left: .5em" href='https://github.com/fuxiao0719/GeoWizard/'><img src='https://img.shields.io/github/stars/fuxiao0719/GeoWizard?style=social' /></a>
|
194 |
+
</div>
|
195 |
+
'''
|
196 |
+
_GPU_ID = 0
|
197 |
+
|
198 |
+
with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
|
199 |
+
with gr.Row():
|
200 |
+
with gr.Column(scale=1):
|
201 |
+
gr.Markdown('# ' + _TITLE)
|
202 |
+
gr.Markdown(_DESCRIPTION)
|
203 |
+
with gr.Row(variant='panel'):
|
204 |
+
with gr.Column(scale=1):
|
205 |
+
input_image = gr.Image(type='filepath', height=320, label='Input image')
|
206 |
+
|
207 |
+
example_folder = os.path.join(os.path.dirname(__file__), "./files")
|
208 |
+
example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
|
209 |
+
gr.Examples(
|
210 |
+
examples=example_fns,
|
211 |
+
inputs=[input_image],
|
212 |
+
cache_examples=False,
|
213 |
+
label='Examples (click one of the images below to start)',
|
214 |
+
examples_per_page=30
|
215 |
+
)
|
216 |
+
with gr.Column(scale=1):
|
217 |
+
|
218 |
+
with gr.Accordion('Advanced options', open=True):
|
219 |
+
with gr.Column():
|
220 |
+
|
221 |
+
domain = gr.Radio(
|
222 |
+
[
|
223 |
+
("Outdoor", "outdoor"),
|
224 |
+
("Indoor", "indoor"),
|
225 |
+
("Object", "object"),
|
226 |
+
],
|
227 |
+
label="Data Type (Must Select One matches your image)",
|
228 |
+
value="indoor",
|
229 |
+
)
|
230 |
+
denoising_steps = gr.Slider(
|
231 |
+
label="Number of denoising steps (More steps, better quality)",
|
232 |
+
minimum=1,
|
233 |
+
maximum=50,
|
234 |
+
step=1,
|
235 |
+
value=10,
|
236 |
+
)
|
237 |
+
ensemble_size = gr.Slider(
|
238 |
+
label="Ensemble size (More steps, higher accuracy)",
|
239 |
+
minimum=1,
|
240 |
+
maximum=15,
|
241 |
+
step=1,
|
242 |
+
value=3,
|
243 |
+
)
|
244 |
+
seed = gr.Number(0, label='Random Seed. Negative values for not specifying')
|
245 |
+
|
246 |
+
processing_res = gr.Radio(
|
247 |
+
[
|
248 |
+
("Native", 0),
|
249 |
+
("Recommended", 768),
|
250 |
+
],
|
251 |
+
label="Processing resolution",
|
252 |
+
value=768,
|
253 |
+
)
|
254 |
+
|
255 |
+
|
256 |
+
run_btn = gr.Button('Generate', variant='primary', interactive=True)
|
257 |
+
with gr.Row():
|
258 |
+
with gr.Column():
|
259 |
+
depth = gr.Image(interactive=False, show_label=False)
|
260 |
+
with gr.Column():
|
261 |
+
normal = gr.Image(interactive=False, show_label=False)
|
262 |
+
|
263 |
+
with gr.Row():
|
264 |
+
files = gr.Files(
|
265 |
+
label = "Depth and Normal (numpy)",
|
266 |
+
elem_id = "download",
|
267 |
+
interactive=False,
|
268 |
+
)
|
269 |
+
|
270 |
+
with gr.Row():
|
271 |
+
recon_btn = gr.Button('Is there a salient foreground object? If yes, Click here to Reconstruct its 3D model.', variant='primary', interactive=True)
|
272 |
+
|
273 |
+
with gr.Row():
|
274 |
+
reconstructed_3d = gr.Model3D(
|
275 |
+
label = 'Bini post-processed 3D model', height=320, interactive=False,
|
276 |
+
)
|
277 |
+
|
278 |
+
|
279 |
+
run_btn.click(fn=depth_normal,
|
280 |
+
inputs=[input_image, denoising_steps,
|
281 |
+
ensemble_size,
|
282 |
+
processing_res,
|
283 |
+
seed,
|
284 |
+
domain],
|
285 |
+
outputs=[depth, normal, files]
|
286 |
+
)
|
287 |
+
recon_btn.click(fn=reconstruction,
|
288 |
+
inputs=[input_image, files],
|
289 |
+
outputs=[reconstructed_3d]
|
290 |
+
)
|
291 |
+
demo.queue().launch(share=True, max_threads=80)
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == '__main__':
|
295 |
+
fire.Fire(run_demo)
|