import gradio as gr import numpy as np import torch from diffusers import StableDiffusionInpaintPipeline from PIL import Image from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator from diffusers import ControlNetModel from diffusers import UniPCMultistepScheduler from controlnet_inpaint import StableDiffusionControlNetInpaintPipeline import colorsys import urllib.request url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" filename = "sam_vit_h_4b8939.pth" urllib.request.urlretrieve(url, filename) sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cuda" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) mask_generator = SamAutomaticMaskGenerator(sam) # pipe = StableDiffusionInpaintPipeline.from_pretrained( # "stabilityai/stable-diffusion-2-inpainting", # torch_dtype=torch.float16, # ) # pipe = pipe.to("cuda") controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16, ) pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() with gr.Blocks() as demo: selected_pixels = gr.State([]) with gr.Row(): input_img = gr.Image(label="Input") mask_img = gr.Image(label="Mask") seg_img = gr.Image(label="Segmentation") output_img = gr.Image(label="Output") with gr.Row(): prompt_text = gr.Textbox(lines=1, label="Prompt") negative_prompt_text = gr.Textbox(lines=1, label="Negative Prompt") is_background = gr.Checkbox(label="Background") with gr.Row(): submit = gr.Button("Submit") clear = gr.Button("Clear") def generate_mask(image, bg, sel_pix, evt: gr.SelectData): sel_pix.append(evt.index) predictor.set_image(image) input_point = np.array(sel_pix) input_label = np.ones(input_point.shape[0]) mask, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False, ) if bg: mask = np.logical_not(mask) mask = Image.fromarray(mask[0, :, :]) segs = mask_generator.generate(image) boolean_masks = [s["segmentation"] for s in segs] finseg = np.zeros((boolean_masks[0].shape[0], boolean_masks[0].shape[1], 3), dtype=np.uint8) # Loop over the boolean masks and assign a unique color to each class for class_id, boolean_mask in enumerate(boolean_masks): hue = class_id * 1.0 / len(boolean_masks) rgb = tuple(int(i * 255) for i in colorsys.hsv_to_rgb(hue, 1, 1)) rgb_mask = np.zeros((boolean_mask.shape[0], boolean_mask.shape[1], 3), dtype=np.uint8) rgb_mask[:, :, 0] = boolean_mask * rgb[0] rgb_mask[:, :, 1] = boolean_mask * rgb[1] rgb_mask[:, :, 2] = boolean_mask * rgb[2] finseg += rgb_mask return mask, finseg def inpaint(image, mask, seg_img, prompt, negative_prompt): image = Image.fromarray(image) mask = Image.fromarray(mask) seg_img = Image.fromarray(seg_img) image = image.resize((512, 512)) mask = mask.resize((512, 512)) seg_img = seg_img.resize((512, 512)) output = pipe(prompt, image, mask, seg_img, negative_prompt=negative_prompt).images[0] return output def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg): sel_pix = [] img = None mask = None seg = None out = None prompt = "" neg_prompt = "" bg = False return img, mask, seg, out, prompt, neg_prompt, bg input_img.select( generate_mask, [input_img, is_background, selected_pixels], [mask_img, seg_img], ) submit.click( inpaint, inputs=[input_img, mask_img, seg_img, prompt_text, negative_prompt_text], outputs=[output_img], ) clear.click( _clear, inputs=[ selected_pixels, input_img, mask_img, seg_img, output_img, prompt_text, negative_prompt_text, is_background, ], outputs=[ input_img, mask_img, seg_img, output_img, prompt_text, negative_prompt_text, is_background, ], ) if __name__ == "__main__": demo.queue(concurrency_count=50).launch()