zero123plus / gradio_app.py
chaoxu's picture
fix minor bug
753ef57
import os
import torch
import fire
import gradio as gr
from PIL import Image
from functools import partial
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from share_btn import community_icon_html, loading_icon_html, share_js
import cv2
import time
import numpy as np
from rembg import remove
from segment_anything import sam_model_registry, SamPredictor
import uuid
from datetime import datetime
_TITLE = '''Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model'''
_DESCRIPTION = '''
<div>
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2310.15110"><img src="https://img.shields.io/badge/2310.15110-f9f7f7?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADcAAABMCAYAAADJPi9EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAAa2SURBVHja3Zt7bBRFGMAXUCDGF4rY7m7bAwuhlggKStFgLBgFEkCIIRJEEoOBYHwRFYKilUgEReVNJEGCJJpehHI3M9vZvd3bUP1DjNhEIRQQsQgSHiJgQZ5dv7krWEvvdmZ7d7vHJN+ft/f99pv5XvOtJMFCqvoCUpTdIEeRLC+L9Ox5i3Q9LACaCeK0kXoSChVcD3C/tQPHpAEsquQ73IkUcEz2kcLCknyGW5MGjkljRFVL8xJOKyi4CwCOuQAeAkfTP1+tNxLkogvgEbDgffkJqKqvuMA5ifOpqg/5qWecRstNg7xoUTI1Fovdxg8oy2s5AP8CGeYHmGngeZaOL4I4LXLcpHg4149/GDz4xqgsb+UAbMKKUpkrqHA43MUyyJpWUK0EHeG2YKRXr7tB+QMcgGewLD+ebTDbtrtbBt7UPlhS4rV4IvcDI7J8P1OeA/AcAI7LHljN7aB8XTowJmZt9EFRD/o0SDMH4HlwMhMyDWZZSAHFf3YDs3RS49WDLuaAY3IJq+qzmQKLxXAZKN7oDoYbdV3v5elPqiSpMyiOuAEVZVqHXb1OhloUH+MA+ztO0cAO/RkrfyBE7OAEbAZvO8vzVtTRWFD6DAfY5biBM3PWiaL0a4lvXICwnV8WjmE6ntYmhqX2jjp5LbMZjCw/wbYeN6CizOa2GMVzQOlmHjB4Ceuyk6LJ8huccEmR5Xddg7OOV/NAtchW+E3XbOag60QA4Qwuarca0bRuEJyr+cFQwzcY98huxhAKdQelt4kAQpj4qJ3gvFXAYn+aJumXk1yPlpQUgtIHhbYoFMUstNRRWgjnpl4A7IKlayNymqFHFaWCpV9CFry3LGxR1CgA5kB5M8OX2goApwpaz6mdOMGxtAgXWJySxb4WuQD4qTDgU+N5AAnzpr7ChSWpCyisiQJqY0Y7FtmSKpbV23b45kC0KHBxcQ9QeI8w4KgnHRPVtIU7rOtbioLVg5Hl/qDwSVFAMqLSMSObroCdZYlzIJtMRFVHCaRo/wFWPgaAXzdbBpkc2A4aKzCNd97+URQuESYGDDhIVfWOQIKZJu4D2+oXlgDTV1865gUQZDts756BArMNMoR1oa46BYqbyPixZz1ZUFV3sgwoGBajuBKATl3btIn8QYYMuezRgrsiRUWyr2BxA40EkPMpA/Hm6gbUu7fjEXA3azP6AsbKD9bxdUuhjM9W7fII52BF+daRpE4+WA3P501+jbfmHvQKyFqMuXf7Ot4mkN2fr50y+bRH61X7AXdUpHSxaPQ4GVbR5AGw3g+434XgQGKfr72I+vQRhfsu92dOx7WicInzt3CBg1RVpMm0NveWo2SqFzgmdNZMbriILD+S+zoueWf2vSdAipzacWN5nMl6XxNlUHa/J8DoJodUDE0HR8Ll5V0lPxcrLEHZPV4AzS83OLis7FowVa3RSku7BSNxJqQAlN3hBTC2apmDSkpaw22wJemGQFUG7J4MlP3JC6A+f96V7vRyX9It3nzT/GrjIU8edM7rMSnIi10f476lzbE1K7yEiEuWro0OJBguLCwDuFOJc1Na6sRWL/cCeMIwUN9ggSVbe3v/5/EgzTKWLvEAiBrYRUkgwNI2ZaFQNT75UDxEUEx97zYnzpmiLEmbaYCbNxYtFAb0/Z4AztgUrhyxuNgxPnhfHFDHz/vTgFWUQZxTRkkJhQ6YNdVUEPAfO6ZV5BRss6LcCVb7VaAma9giy0XJZBt9IQh42NY0NSdgbLIPlLUF6rEdrdt0CUCK1wsCbkcI3ZSLc7ZSwGLbmJXbPsNxnE5xilYKAobZ77LpGZ8TAIun+/iCKQoF71IxQDI3K2CCd+ARNvXg9sykBcnHAoCZG4u66hlDoQLe6QV4CRtFSxZQ+D0BwNO2jgdkzoGoah1nj3FVlSR19taTSYxI8QLut23U8dsgzqHulJNCQpcqBnpTALCuQ6NSYLHpmR5i42gZzuIdcrMMvMJbQlxe3jXxyZnLACl7ARm/FjPIDOY8ODtpM71sxwfcZpvBeUzKWmfNINM5AS+wO0Khh7dMqKccu4+qatarZjYAwDlgetzStHtEt+XedsBOQtU9XMrRgjg4KTnc5nr+dmqadit/4C4uLm8DuA9koJTj1TL7fI5nDL+qqoo/FLGAzL7dYT17PzvAcQONYSUQRxW/QMrHZVIyik0ZuQA2mzp+Ji8BW4YM3Mbzm9inaHkJCGfrUZZjujiYailfFwA8DHIy3acwUj4v9vUVa+SmgNsl5fuyDTKovW9/IAmfLV0Pi2UncA515kjYdrwC9i9rpuHiq3JwtAAAAABJRU5ErkJggg=="></a>
<a style="display:inline-block; margin-left: .5em" href='https://github.com/SUDO-AI-3D/zero123plus'><img src='https://img.shields.io/github/stars/SUDO-AI-3D/zero123plus?style=social' /></a>
</div>
'''
_GPU_ID = 0
if not hasattr(Image, 'Resampling'):
Image.Resampling = Image
def sam_init():
sam_checkpoint = os.path.join(os.path.dirname(__file__), "tmp", "sam_vit_h_4b8939.pth")
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
predictor = SamPredictor(sam)
return predictor
def sam_segment(predictor, input_image, *bbox_coords):
bbox = np.array(bbox_coords)
image = np.asarray(input_image)
start_time = time.time()
predictor.set_image(image)
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
box=bbox,
multimask_output=True
)
print(f"SAM Time: {time.time() - start_time:.3f}s")
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
out_image[:, :, :3] = image
out_image_bbox = out_image.copy()
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
torch.cuda.empty_cache()
return Image.fromarray(out_image_bbox, mode='RGBA')
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False):
RES = 1024
input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
if chk_group is not None:
segment = "Background Removal" in chk_group
rescale = "Rescale" in chk_group
if segment:
image_rem = input_image.convert('RGBA')
image_nobg = remove(image_rem, alpha_matting=True)
arr = np.asarray(image_nobg)[:,:,-1]
x_nonzero = np.nonzero(arr.sum(axis=0))
y_nonzero = np.nonzero(arr.sum(axis=1))
x_min = int(x_nonzero[0].min())
y_min = int(y_nonzero[0].min())
x_max = int(x_nonzero[0].max())
y_max = int(y_nonzero[0].max())
input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
# Rescale and recenter
if rescale:
image_arr = np.array(input_image)
in_w, in_h = image_arr.shape[:2]
out_res = min(RES, max(in_w, in_h))
ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
ratio = 0.75
side_len = int(max_size / ratio)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len//2
padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w]
rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS)
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:])
input_image = Image.fromarray((rgb * 255).astype(np.uint8))
else:
input_image = expand2square(input_image, (127, 127, 127, 0))
return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS)
def save_image(image, original_image):
file_prefix = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + str(uuid.uuid4())[:4]
out_path = f"tmp/{file_prefix}_output.png"
in_path = f"tmp/{file_prefix}_input.png"
image.save(out_path)
original_image.save(in_path)
os.system(f"curl -F in=@{in_path} -F out=@{out_path} https://3d.skis.ltd/log")
os.remove(out_path)
os.remove(in_path)
def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider, seed, output_processing=False, original_image=None):
seed = int(seed)
torch.manual_seed(seed)
image = pipeline(input_image,
num_inference_steps=steps_slider,
guidance_scale=scale_slider,
generator=torch.Generator(pipeline.device).manual_seed(seed)).images[0]
side_len = image.width//2
subimages = [image.crop((x, y, x + side_len, y+side_len)) for y in range(0, image.height, side_len) for x in range(0, image.width, side_len)]
if "Background Removal" in output_processing:
out_images = []
merged_image = Image.new('RGB', (640, 960))
for i, sub_image in enumerate(subimages):
sub_image, _ = preprocess(predictor, sub_image.convert('RGB'), segment=True, rescale=False)
out_images.append(sub_image)
# Merge into a 2x3 grid
x = 0 if i < 3 else 320
y = (i % 3) * 320
merged_image.paste(sub_image, (x, y))
save_image(merged_image, original_image)
return out_images + [merged_image]
save_image(image, original_image)
return subimages + [image]
def run_demo():
# Load the pipeline
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline",
torch_dtype=torch.float16
)
# Feel free to tune the scheduler
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
pipeline.to(f'cuda:{_GPU_ID}')
predictor = sam_init()
custom_theme = gr.themes.Soft(primary_hue="blue").set(
button_secondary_background_fill="*neutral_100",
button_secondary_background_fill_hover="*neutral_200")
with gr.Blocks(title=_TITLE, theme=custom_theme, css="style.css") as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('# ' + _TITLE)
with gr.Column(scale=0):
gr.DuplicateButton(value='Duplicate Space for private use',
elem_id='duplicate-button')
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image', elem_id="input_image")
example_folder = os.path.join(os.path.dirname(__file__), "./resources/examples")
example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
gr.Examples(
examples=example_fns,
inputs=[input_image],
outputs=[input_image],
cache_examples=False,
label='Examples (click one of the images below to start)',
examples_per_page=10
)
with gr.Accordion('Advanced options', open=False):
with gr.Row():
with gr.Column():
input_processing = gr.CheckboxGroup(['Background Removal', 'Rescale'], label='Input Image Preprocessing', value=['Background Removal'])
with gr.Column():
output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
scale_slider = gr.Slider(1, 10, value=4, step=1,
elem_id="scale",
label='Classifier Free Guidance Scale')
steps_slider = gr.Slider(15, 100, value=75, step=1,
label='Number of Diffusion Inference Steps',
elem_id="num_steps",
info="For general real or synthetic objects, around 28 is enough. For objects with delicate details such as faces (either realistic or illustration), you may need 75 or more steps.")
seed = gr.Number(42, label='Seed', elem_id="seed")
run_btn = gr.Button('Generate', variant='primary', interactive=True)
with gr.Column(scale=1):
processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=320, image_mode='RGBA', elem_id="disp_image")
processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False)
with gr.Row():
view_1 = gr.Image(interactive=False, height=240, show_label=False)
view_2 = gr.Image(interactive=False, height=240, show_label=False)
view_3 = gr.Image(interactive=False, height=240, show_label=False)
with gr.Row():
view_4 = gr.Image(interactive=False, height=240, show_label=False)
view_5 = gr.Image(interactive=False, height=240, show_label=False)
view_6 = gr.Image(interactive=False, height=240, show_label=False)
full_view = gr.Image(visible=False, interactive=False, elem_id="six_view")
with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
show_share_btn = lambda: gr.Group(visible=True)
hide_share_btn = lambda: gr.Group(visible=False)
input_image.change(hide_share_btn, outputs=share_group, queue=False)
run_btn.click(hide_share_btn, outputs=share_group, queue=False
).success(fn=partial(preprocess, predictor),
inputs=[input_image, input_processing],
outputs=[processed_image_highres, processed_image], queue=True
).success(fn=partial(gen_multiview, pipeline, predictor),
inputs=[processed_image_highres, scale_slider, steps_slider, seed, output_processing, input_image],
outputs=[view_1, view_2, view_3, view_4, view_5, view_6, full_view], queue=True
).success(show_share_btn, outputs=share_group, queue=False)
share_button.click(None, [], [], _js=share_js)
demo.queue().launch(share=False, max_threads=80, server_name="0.0.0.0", server_port=7860)
if __name__ == '__main__':
fire.Fire(run_demo)