import gradio as gr from PIL import Image import numpy as np from transformers import SamModel, SamProcessor from diffusers import AutoPipelineForInpainting from diffusers.models.autoencoders.vq_model import VQEncoderOutput, VQModel import torch # Check if GPU is available, otherwise use CPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Model and Processor setup model_name = "facebook/sam-vit-huge" model = SamModel.from_pretrained(model_name).to(device) processor = SamProcessor.from_pretrained(model_name) def mask_to_rgb(mask): bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8) bg_transparent[mask == 1] = [0, 255, 0, 127] return bg_transparent def get_processed_inputs(image, points_str): # Parse the input string into a list of points points = list(map(int, points_str.split(','))) # Reshape the points into pairs input_points = [[[x, y] for x, y in zip(points[::2], points[1::2])]] inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) best_mask = masks[0][0][outputs.iou_scores.argmax()] return ~best_mask.cpu().numpy() def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7): mask_image = Image.fromarray(input_mask) rand_gen = torch.manual_seed(seed) pipeline = AutoPipelineForInpainting.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) if device == "cpu": pipeline.enable_model_cpu_offload() image = pipeline( prompt=prompt, image=raw_image, mask_image=mask_image, guidance_scale=cfgs, negative_prompt=negative_prompt, generator=rand_gen ).images[0] return image def gradio_interface(image, points, positive_prompt, negative_prompt): raw_image = Image.fromarray(image).convert("RGB").resize((512, 512)) mask = get_processed_inputs(raw_image, points) processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt) return processed_image, mask_to_rgb(mask) iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Image(type="numpy", label="Input Image"), gr.Textbox(label="Points (format: x1,y1,x2,y2,...)", placeholder="e.g., 100,100,200,200"), gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"), gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here") ], outputs=[ gr.Image(label="Inpainted Image"), gr.Image(label="Segmentation Mask") ], title="Interactive Image Inpainting", description="Enter points as 'x1,y1,x2,y2,...' for segmentation, provide prompts, and see the inpainted result." ) iface.launch(share=True)