TryOffAnyone / src /app.py
1aurent's picture
init
74a242e unverified
raw
history blame
5.16 kB
from typing import TypedDict
import diffusers.image_processor
import gradio as gr
import pillow_heif # pyright: ignore[reportMissingTypeStubs]
import spaces # pyright: ignore[reportMissingTypeStubs]
import torch
from PIL import Image
from pipeline import TryOffAnyone
pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType]
pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType]
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
TITLE = """
# Try Off Anyone
## ⚠️ Important
1. Choose an example image or upload your own
2. Use the Pen tool to draw a mask over the clothing area you want to extract
[![arxiv badge](https://img.shields.io/badge/arXiv-Paper-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/2412.08573)
"""
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
pipeline_tryoff = TryOffAnyone(
device=DEVICE,
dtype=DTYPE,
)
mask_processor = diffusers.image_processor.VaeImageProcessor(
vae_scale_factor=8,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
vae_processor = diffusers.image_processor.VaeImageProcessor(
vae_scale_factor=8,
)
class ImageData(TypedDict):
background: Image.Image
composite: Image.Image
layers: list[Image.Image]
@spaces.GPU
def process(
image_data: ImageData,
image_width: int,
image_height: int,
num_inference_steps: int,
condition_scale: float,
seed: int,
) -> Image.Image:
assert image_width > 0
assert image_height > 0
assert num_inference_steps > 0
assert condition_scale > 0
assert seed >= 0
# extract image and mask from image_data
image = image_data["background"]
mask = image_data["layers"][0]
# preprocess image
image = image.convert("RGB").resize((image_width, image_height))
image_preprocessed = vae_processor.preprocess( # pyright: ignore[reportUnknownMemberType,reportAssignmentType]
image=image,
width=image_width,
height=image_height,
)[0]
# preprocess mask
mask = mask.getchannel("A").resize((image_width, image_height))
mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType]
image=mask,
width=image_width,
height=image_height,
)[0]
# generate the TryOff image
gen = torch.Generator(device=DEVICE).manual_seed(seed)
tryoff_image = pipeline_tryoff(
image_preprocessed,
mask_preprocessed,
inference_steps=num_inference_steps,
scale=condition_scale,
generator=gen,
)[0]
return tryoff_image
with gr.Blocks() as demo:
gr.Markdown(TITLE)
with gr.Row():
with gr.Column():
input_image = gr.ImageMask(
label="Input Image",
height=1024, # https://github.com/gradio-app/gradio/issues/10236
type="pil",
interactive=True,
)
run_button = gr.Button(
value="Extract Clothing",
)
gr.Examples(
examples=[
["examples/model_1.jpg"],
["examples/model_2.jpg"],
["examples/model_3.jpg"],
["examples/model_4.jpg"],
["examples/model_5.jpg"],
["examples/model_6.jpg"],
["examples/model_7.jpg"],
["examples/model_8.jpg"],
["examples/model_9.jpg"],
],
inputs=[input_image],
)
with gr.Column():
output_image = gr.Image(
label="TryOff result",
height=1024,
image_mode="RGB",
type="pil",
)
with gr.Accordion("Advanced Settings", open=True):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100_000,
value=69_420,
step=1,
)
scale = gr.Slider(
label="Scale",
minimum=0.5,
maximum=5,
value=2.5,
step=0.05,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
value=25,
step=1,
)
with gr.Row():
image_width = gr.Slider(
label="Image Width",
minimum=64,
maximum=1024,
value=384,
step=8,
)
image_height = gr.Slider(
label="Image Height",
minimum=64,
maximum=1024,
value=512,
step=8,
)
run_button.click(
fn=process,
inputs=[
input_image,
image_width,
image_height,
num_inference_steps,
scale,
seed,
],
outputs=output_image,
)
demo.launch()