gvij's picture
added repo code
387d3ff
import os
import random
import numpy as np
import gradio as gr
import spaces
import torch
import supervision as sv
from PIL import Image
from typing import Optional, Tuple
from diffusers import FluxInpaintPipeline
from utils.florence import load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
# Set up device and environment
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Load models
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
FLUX_PIPE = FluxInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
# Set up CUDA optimizations
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def resize_image_dimensions(
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = 2048
) -> Tuple[int, int]:
width, height = original_resolution_wh
if width <= maximum_dimension and height <= maximum_dimension:
width = width - (width % 32)
height = height - (height % 32)
return width, height
if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)
return new_width, new_height
@spaces.GPU(duration=150)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(
image_input,
segmentation_text,
inpaint_text,
seed_slicer: int,
randomize_seed: bool,
strength: float,
num_inference_steps: int,
progress=gr.Progress(track_tqdm=True)
) -> Optional[Image.Image]:
if not image_input:
gr.Info("Please upload an image.")
return None, None
if not segmentation_text:
gr.Info("Please enter a text prompt for segmentation.")
return None, None
if not inpaint_text:
gr.Info("Please enter a text prompt for inpainting.")
return None, None
# Florence-SAM segmentation
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=segmentation_text
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
if len(detections) == 0:
gr.Info("No objects detected.")
return None, None
mask = Image.fromarray(detections.mask[0].astype("uint8") * 255)
# Resize images for FLUX
width, height = resize_image_dimensions(original_resolution_wh=image_input.size)
resized_image = image_input.resize((width, height), Image.LANCZOS)
resized_mask = mask.resize((width, height), Image.NEAREST)
# FLUX inpainting
if randomize_seed:
seed_slicer = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed_slicer)
result = FLUX_PIPE(
prompt=inpaint_text,
image=resized_image,
mask_image=resized_mask,
width=width,
height=height,
strength=strength,
generator=generator,
num_inference_steps=num_inference_steps
).images[0]
return result, resized_mask
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# MonsterAPI Prompt Guided Inpainting")
with gr.Row():
with gr.Column():
image_input = gr.Image(
label='Upload image',
type='pil',
image_mode='RGB',
)
segmentation_text = gr.Textbox(
label='Segmentation text prompt',
placeholder='Enter text for segmentation'
)
inpaint_text = gr.Textbox(
label='Inpainting text prompt',
placeholder='Enter text for inpainting'
)
with gr.Accordion("Advanced Settings", open=False):
seed_slicer = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
strength = gr.Slider(
label="Strength",
minimum=0,
maximum=1,
step=0.01,
value=0.75,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=20,
)
submit_button = gr.Button(value='Process', variant='primary')
with gr.Column():
output_image = gr.Image(label='Output image')
with gr.Accordion("Generated Mask", open=False):
output_mask = gr.Image(label='Segmentation mask')
submit_button.click(
fn=process_image,
inputs=[
image_input,
segmentation_text,
inpaint_text,
seed_slicer,
randomize_seed,
strength,
num_inference_steps
],
outputs=[output_image, output_mask]
)
demo.launch(debug=True, show_error=True, server_name="0.0.0.0",share=True)