import os import torch import fire import gradio as gr from PIL import Image from functools import partial from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler import cv2 import time import numpy as np from rembg import remove from segment_anything import sam_model_registry, SamPredictor _TITLE = '''Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model''' _DESCRIPTION = '''
''' _GPU_ID = 0 if not hasattr(Image, 'Resampling'): Image.Resampling = Image def sam_init(): sam_checkpoint = os.path.join(os.path.dirname(__file__), "tmp", "sam_vit_h_4b8939.pth") model_type = "vit_h" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}") predictor = SamPredictor(sam) return predictor def sam_segment(predictor, input_image, *bbox_coords): bbox = np.array(bbox_coords) image = np.asarray(input_image) start_time = time.time() predictor.set_image(image) masks_bbox, scores_bbox, logits_bbox = predictor.predict( box=bbox, multimask_output=True ) print(f"SAM Time: {time.time() - start_time:.3f}s") out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) out_image[:, :, :3] = image out_image_bbox = out_image.copy() out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 torch.cuda.empty_cache() return Image.fromarray(out_image_bbox, mode='RGBA') def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def preprocess(predictor, input_image, chk_group=None, segment=True, rescale=False): RES = 1024 input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS) if chk_group is not None: segment = "Background Removal" in chk_group rescale = "Rescale" in chk_group if segment: image_rem = input_image.convert('RGBA') image_nobg = remove(image_rem, alpha_matting=True) arr = np.asarray(image_nobg)[:,:,-1] x_nonzero = np.nonzero(arr.sum(axis=0)) y_nonzero = np.nonzero(arr.sum(axis=1)) x_min = int(x_nonzero[0].min()) y_min = int(y_nonzero[0].min()) x_max = int(x_nonzero[0].max()) y_max = int(y_nonzero[0].max()) input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max) # Rescale and recenter if rescale: image_arr = np.array(input_image) in_w, in_h = image_arr.shape[:2] out_res = min(RES, max(in_w, in_h)) ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY) x, y, w, h = cv2.boundingRect(mask) max_size = max(w, h) ratio = 0.75 side_len = int(max_size / ratio) padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) center = side_len//2 padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w] rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS) rgba_arr = np.array(rgba) / 255.0 rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:]) input_image = Image.fromarray((rgb * 255).astype(np.uint8)) else: input_image = expand2square(input_image, (127, 127, 127, 0)) return input_image, input_image.resize((320, 320), Image.Resampling.LANCZOS) def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider, seed, output_processing=False): seed = int(seed) torch.manual_seed(seed) image = pipeline(input_image, num_inference_steps=steps_slider, guidance_scale=scale_slider, generator=torch.Generator(pipeline.device).manual_seed(seed)).images[0] side_len = image.width//2 subimages = [image.crop((x, y, x + side_len, y+side_len)) for y in range(0, image.height, side_len) for x in range(0, image.width, side_len)] if "Background Removal" in output_processing: out_images = [] for sub_image in subimages: sub_image, _ = preprocess(predictor, sub_image.convert('RGB'), segment=True, rescale=False) out_images.append(sub_image) return out_images return subimages def run_demo(): # Load the pipeline pipeline = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline", torch_dtype=torch.float16 ) # Feel free to tune the scheduler pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( pipeline.scheduler.config, timestep_spacing='trailing' ) pipeline.to(f'cuda:{_GPU_ID}') predictor = sam_init() custom_theme = gr.themes.Soft(primary_hue="blue").set( button_secondary_background_fill="*neutral_100", button_secondary_background_fill_hover="*neutral_200") custom_css = '''#disp_image { text-align: center; /* Horizontally center the content */ }''' with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ' + _TITLE) gr.Markdown(_DESCRIPTION) with gr.Row(variant='panel'): with gr.Column(scale=1): input_image = gr.Image(type='pil', image_mode='RGBA', height=320, label='Input image', tool=None) example_folder = os.path.join(os.path.dirname(__file__), "./resources/examples") example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)] gr.Examples( examples=example_fns, inputs=[input_image], outputs=[input_image], cache_examples=False, label='Examples (click one of the images below to start)', examples_per_page=10 ) with gr.Accordion('Advanced options', open=False): with gr.Row(): with gr.Column(): input_processing = gr.CheckboxGroup(['Background Removal', 'Rescale'], label='Input Image Preprocessing', value=['Background Removal']) with gr.Column(): output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[]) scale_slider = gr.Slider(1, 10, value=4, step=1, label='Classifier Free Guidance Scale') steps_slider = gr.Slider(15, 100, value=75, step=1, label='Number of Diffusion Inference Steps', info="For general real or synthetic objects, around 28 is enough. For objects with delicate details such as faces (either realistic or illustration), you may need 75 or more steps.") seed = gr.Number(42, label='Seed') run_btn = gr.Button('Generate', variant='primary', interactive=True) with gr.Column(scale=1): processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=320, tool=None, image_mode='RGBA', elem_id="disp_image") processed_image_highres = gr.Image(type='pil', image_mode='RGBA', visible=False, tool=None) with gr.Row(): view_1 = gr.Image(interactive=False, height=240, show_label=False) view_2 = gr.Image(interactive=False, height=240, show_label=False) view_3 = gr.Image(interactive=False, height=240, show_label=False) with gr.Row(): view_4 = gr.Image(interactive=False, height=240, show_label=False) view_5 = gr.Image(interactive=False, height=240, show_label=False) view_6 = gr.Image(interactive=False, height=240, show_label=False) run_btn.click(fn=partial(preprocess, predictor), inputs=[input_image, input_processing], outputs=[processed_image_highres, processed_image], queue=True ).success(fn=partial(gen_multiview, pipeline, predictor), inputs=[processed_image_highres, scale_slider, steps_slider, seed, output_processing], outputs=[view_1, view_2, view_3, view_4, view_5, view_6]) demo.queue().launch(share=False, max_threads=80, server_name="0.0.0.0", server_port=7860) if __name__ == '__main__': fire.Fire(run_demo)