Spaces:
Running
on
Zero
Running
on
Zero
from typing import TypedDict | |
import diffusers.image_processor | |
import gradio as gr | |
import pillow_heif # pyright: ignore[reportMissingTypeStubs] | |
import spaces # pyright: ignore[reportMissingTypeStubs] | |
import torch | |
from PIL import Image | |
from pipeline import TryOffAnyone | |
pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType] | |
pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType] | |
torch.set_float32_matmul_precision("high") | |
torch.backends.cuda.matmul.allow_tf32 = True | |
TITLE = """ | |
# Try Off Anyone | |
## ⚠️ Important | |
1. Choose an example image or upload your own | |
2. Use the Pen tool to draw a mask over the clothing area you want to extract | |
[![arxiv badge](https://img.shields.io/badge/arXiv-Paper-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/2412.08573) | |
""" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
pipeline_tryoff = TryOffAnyone( | |
device=DEVICE, | |
dtype=DTYPE, | |
) | |
mask_processor = diffusers.image_processor.VaeImageProcessor( | |
vae_scale_factor=8, | |
do_normalize=False, | |
do_binarize=True, | |
do_convert_grayscale=True, | |
) | |
vae_processor = diffusers.image_processor.VaeImageProcessor( | |
vae_scale_factor=8, | |
) | |
class ImageData(TypedDict): | |
background: Image.Image | |
composite: Image.Image | |
layers: list[Image.Image] | |
def process( | |
image_data: ImageData, | |
image_width: int, | |
image_height: int, | |
num_inference_steps: int, | |
condition_scale: float, | |
seed: int, | |
) -> Image.Image: | |
assert image_width > 0 | |
assert image_height > 0 | |
assert num_inference_steps > 0 | |
assert condition_scale > 0 | |
assert seed >= 0 | |
# extract image and mask from image_data | |
image = image_data["background"] | |
mask = image_data["layers"][0] | |
# preprocess image | |
image = image.convert("RGB").resize((image_width, image_height)) | |
image_preprocessed = vae_processor.preprocess( # pyright: ignore[reportUnknownMemberType,reportAssignmentType] | |
image=image, | |
width=image_width, | |
height=image_height, | |
)[0] | |
# preprocess mask | |
mask = mask.getchannel("A").resize((image_width, image_height)) | |
mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType] | |
image=mask, | |
width=image_width, | |
height=image_height, | |
)[0] | |
# generate the TryOff image | |
gen = torch.Generator(device=DEVICE).manual_seed(seed) | |
tryoff_image = pipeline_tryoff( | |
image_preprocessed, | |
mask_preprocessed, | |
inference_steps=num_inference_steps, | |
scale=condition_scale, | |
generator=gen, | |
)[0] | |
return tryoff_image | |
with gr.Blocks() as demo: | |
gr.Markdown(TITLE) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.ImageMask( | |
label="Input Image", | |
height=1024, # https://github.com/gradio-app/gradio/issues/10236 | |
type="pil", | |
interactive=True, | |
) | |
run_button = gr.Button( | |
value="Extract Clothing", | |
) | |
gr.Examples( | |
examples=[ | |
["examples/model_1.jpg"], | |
["examples/model_2.jpg"], | |
["examples/model_3.jpg"], | |
["examples/model_4.jpg"], | |
["examples/model_5.jpg"], | |
["examples/model_6.jpg"], | |
["examples/model_7.jpg"], | |
["examples/model_8.jpg"], | |
["examples/model_9.jpg"], | |
], | |
inputs=[input_image], | |
) | |
with gr.Column(): | |
output_image = gr.Image( | |
label="TryOff result", | |
height=1024, | |
image_mode="RGB", | |
type="pil", | |
) | |
with gr.Accordion("Advanced Settings", open=True): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=100_000, | |
value=69_420, | |
step=1, | |
) | |
scale = gr.Slider( | |
label="Scale", | |
minimum=0.5, | |
maximum=5, | |
value=2.5, | |
step=0.05, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
value=25, | |
step=1, | |
) | |
with gr.Row(): | |
image_width = gr.Slider( | |
label="Image Width", | |
minimum=64, | |
maximum=1024, | |
value=384, | |
step=8, | |
) | |
image_height = gr.Slider( | |
label="Image Height", | |
minimum=64, | |
maximum=1024, | |
value=512, | |
step=8, | |
) | |
run_button.click( | |
fn=process, | |
inputs=[ | |
input_image, | |
image_width, | |
image_height, | |
num_inference_steps, | |
scale, | |
seed, | |
], | |
outputs=output_image, | |
) | |
demo.launch() | |