jiuface's picture
add lora configs as gradio parameters
2008ad3
raw
history blame
8.09 kB
from typing import Tuple
import requests
import random
import numpy as np
import gradio as gr
import spaces
import torch
from PIL import Image
from diffusers import FluxInpaintPipeline
from huggingface_hub import login
import os
import time
MARKDOWN = """
# FLUX.1 Inpainting with lora
"""
MAX_SEED = np.iinfo(np.int32).max
IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HF_TOKEN = os.environ.get("HF_TOKEN")
trigger_word = "a_photo_of_TOK"
# lora_path = "jiuface/boy-001"
# lora_weights = "flux_train_replicate.safetensors"
lora_scale = 0.9
login(token=HF_TOKEN)
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
print(f"Start time: {self.start_time_formatted}")
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
print(f"End time: {self.start_time_formatted}")
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
image = image.convert("RGBA")
data = image.getdata()
new_data = []
for item in data:
avg = sum(item[:3]) / 3
if avg < threshold:
new_data.append((0, 0, 0, 0))
else:
new_data.append(item)
image.putdata(new_data)
return image
pipe = FluxInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
def resize_image_dimensions(
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = IMAGE_SIZE
) -> Tuple[int, int]:
width, height = original_resolution_wh
# if width <= maximum_dimension and height <= maximum_dimension:
# width = width - (width % 32)
# height = height - (height % 32)
# return width, height
if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)
return new_width, new_height
@spaces.GPU(duration=100)
def process(
input_image_editor: dict,
lora_path: str,
lora_weights: str,
input_text: str,
seed_slicer: int,
randomize_seed_checkbox: bool,
strength_slider: float,
num_inference_steps_slider: int,
progress=gr.Progress(track_tqdm=True)
):
if not input_text:
gr.Info("Please enter a text prompt.")
return None, None
image = input_image_editor['background']
mask = input_image_editor['layers'][0]
if not image:
gr.Info("Please upload an image.")
return None, None
if not mask:
gr.Info("Please draw a mask on the image.")
return None, None
with calculateDuration("resize image"):
width, height = resize_image_dimensions(original_resolution_wh=image.size)
resized_image = image.resize((width, height), Image.LANCZOS)
resized_mask = mask.resize((width, height), Image.LANCZOS)
with calculateDuration("load lora"):
pipe.load_lora_weights(lora_path, weight_name=lora_weights)
if randomize_seed_checkbox:
seed_slicer = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed_slicer)
with calculateDuration("run pipe"):
result = pipe(
prompt=f"{input_text} {trigger_word}",
image=resized_image,
mask_image=resized_mask,
width=width,
height=height,
strength=strength_slider,
generator=generator,
num_inference_steps=num_inference_steps_slider,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
return result, resized_mask
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image_editor_component = gr.ImageEditor(
label='Image',
type='pil',
sources=["upload", "webcam"],
image_mode='RGB',
layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
with gr.Accordion("Lora Settings", open=True):
lora_path = gr.Text(
label="Lora path",
show_label=True,
max_lines=1,
placeholder="Enter your lora path",
container=False,
)
lora_weights = gr.Text(
label="Lora weights",
show_label=True,
max_lines=1,
placeholder="Enter your lora weights",
container=False,
)
with gr.Row():
input_text_component = gr.Text(
label="Prompt",
show_label=True,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
submit_button_component = gr.Button(
value='Submit', variant='primary', scale=0)
with gr.Accordion("Advanced Settings", open=True):
seed_slicer_component = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed_checkbox_component = gr.Checkbox(
label="Randomize seed", value=True)
with gr.Row():
strength_slider_component = gr.Slider(
label="Strength",
info="Indicates extent to transform the reference `image`. "
"Must be between 0 and 1. `image` is used as a starting "
"point and more noise is added the higher the `strength`.",
minimum=0,
maximum=1,
step=0.01,
value=0.85,
)
num_inference_steps_slider_component = gr.Slider(
label="Number of inference steps",
info="The number of denoising steps. More denoising steps "
"usually lead to a higher quality image at the",
minimum=1,
maximum=50,
step=1,
value=20,
)
with gr.Column():
output_image_component = gr.Image(
type='pil', image_mode='RGB', label='Generated image', format="png")
with gr.Accordion("Debug", open=False):
output_mask_component = gr.Image(
type='pil', image_mode='RGB', label='Input mask', format="png")
submit_button_component.click(
fn=process,
inputs=[
input_image_editor_component,
lora_path,
lora_weights,
input_text_component,
seed_slicer_component,
randomize_seed_checkbox_component,
strength_slider_component,
num_inference_steps_slider_component
],
outputs=[
output_image_component,
output_mask_component
]
)
demo.launch(debug=False, show_error=True)