multimodalart HF staff commited on
Commit
34b628f
1 Parent(s): 69c26b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers.utils import load_image
4
+ from controlnet_flux import FluxControlNetModel
5
+ from transformer_flux import FluxTransformer2DModel
6
+ from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
7
+ from PIL import Image, ImageDraw
8
+
9
+ # Load models
10
+ controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
11
+ transformer = FluxTransformer2DModel.from_pretrained(
12
+ "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16
13
+ )
14
+ pipe = FluxControlNetInpaintingPipeline.from_pretrained(
15
+ "black-forest-labs/FLUX.1-dev",
16
+ controlnet=controlnet,
17
+ transformer=transformer,
18
+ torch_dtype=torch.bfloat16
19
+ ).to("cuda")
20
+ pipe.transformer.to(torch.bfloat16)
21
+ pipe.controlnet.to(torch.bfloat16)
22
+
23
+ def prepare_image_and_mask(image, width, height, overlap_percentage):
24
+ # Resize the input image to fit within the target size
25
+ image.thumbnail((width, height), Image.LANCZOS)
26
+
27
+ # Create a new white background image of the target size
28
+ background = Image.new('RGB', (width, height), (255, 255, 255))
29
+
30
+ # Paste the resized image onto the background
31
+ offset = ((width - image.width) // 2, (height - image.height) // 2)
32
+ background.paste(image, offset)
33
+
34
+ # Create a mask
35
+ mask = Image.new('L', (width, height), 255)
36
+ draw = ImageDraw.Draw(mask)
37
+
38
+ # Calculate the overlap area
39
+ overlap_x = int(image.width * overlap_percentage / 100)
40
+ overlap_y = int(image.height * overlap_percentage / 100)
41
+
42
+ # Draw the mask (black area is where we want to inpaint)
43
+ draw.rectangle([
44
+ (offset[0] + overlap_x, offset[1] + overlap_y),
45
+ (offset[0] + image.width - overlap_x, offset[1] + image.height - overlap_y)
46
+ ], fill=0)
47
+
48
+ return background, mask
49
+
50
+ def inpaint(image, prompt, width, height, overlap_percentage, num_inference_steps, guidance_scale):
51
+ # Prepare image and mask
52
+ image, mask = prepare_image_and_mask(image, width, height, overlap_percentage)
53
+
54
+ # Set up generator for reproducibility
55
+ generator = torch.Generator(device="cuda").manual_seed(42)
56
+
57
+ # Run inpainting
58
+ result = pipe(
59
+ prompt=prompt,
60
+ height=height,
61
+ width=width,
62
+ control_image=image,
63
+ control_mask=mask,
64
+ num_inference_steps=num_inference_steps,
65
+ generator=generator,
66
+ controlnet_conditioning_scale=0.9,
67
+ guidance_scale=guidance_scale,
68
+ negative_prompt="",
69
+ true_guidance_scale=guidance_scale
70
+ ).images[0]
71
+
72
+ return result
73
+
74
+ # Gradio interface
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown("# FLUX Outpainting Demo")
77
+ with gr.Row():
78
+ with gr.Column():
79
+ input_image = gr.Image(type="pil", label="Input Image")
80
+ prompt_input = gr.Textbox(label="Prompt")
81
+ width_slider = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=768)
82
+ height_slider = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=768)
83
+ overlap_slider = gr.Slider(label="Overlap Percentage", minimum=0, maximum=50, step=1, value=10)
84
+ steps_slider = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28)
85
+ guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=3.5)
86
+ run_button = gr.Button("Generate")
87
+ with gr.Column():
88
+ output_image = gr.Image(label="Output Image")
89
+
90
+ run_button.click(
91
+ fn=inpaint,
92
+ inputs=[input_image, prompt_input, width_slider, height_slider, overlap_slider, steps_slider, guidance_slider],
93
+ outputs=output_image
94
+ )
95
+
96
+ demo.launch()