Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import os | |
import gradio as gr | |
import numpy as np | |
import PIL.Image | |
import spaces | |
import torch | |
from transformers import VitMatteForImageMatting, VitMatteImageProcessor | |
DESCRIPTION = "# [ViTMatte](https://github.com/hustvl/ViTMatte)" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500")) | |
MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646") | |
processor = VitMatteImageProcessor.from_pretrained(MODEL_ID) | |
model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device) | |
def check_image_size(image: PIL.Image.Image) -> None: | |
if max(image.size) > MAX_IMAGE_SIZE: | |
raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.") | |
def binarize_mask(mask: np.ndarray) -> np.ndarray: | |
mask[mask < 128] = 0 | |
mask[mask > 0] = 1 | |
return mask | |
def update_trimap(foreground_mask: dict[str, np.ndarray], unknown_mask: dict[str, np.ndarray]) -> np.ndarray: | |
foreground = foreground_mask["mask"] | |
foreground = binarize_mask(foreground) | |
unknown = unknown_mask["mask"] | |
unknown = binarize_mask(unknown) | |
trimap = np.zeros_like(foreground) | |
trimap[unknown > 0] = 128 | |
trimap[foreground > 0] = 255 | |
return trimap | |
def run(image: PIL.Image.Image, trimap: PIL.Image.Image) -> tuple[PIL.Image.Image, PIL.Image.Image]: | |
if image.size != trimap.size: | |
raise gr.Error("Image and trimap must have the same size.") | |
if max(image.size) > MAX_IMAGE_SIZE: | |
raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.") | |
if image.mode != "RGB": | |
raise gr.Error("Image must be RGB.") | |
if trimap.mode != "L": | |
raise gr.Error("Trimap must be grayscale.") | |
pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values | |
out = model(pixel_values=pixel_values) | |
alpha = out.alphas[0, 0].to("cpu").numpy() | |
w, h = image.size | |
alpha = alpha[:h, :w] | |
foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None]) | |
foreground = (foreground * 255).astype(np.uint8) | |
foreground = PIL.Image.fromarray(foreground) | |
return alpha, foreground | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton( | |
value="Duplicate Space for private use", | |
elem_id="duplicate-button", | |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Box(): | |
image = gr.Image(label="Input image", type="pil", height=500) | |
with gr.Tabs(): | |
with gr.Tab(label="Trimap"): | |
trimap = gr.Image(label="Trimap", type="pil", image_mode="L", height=500) | |
with gr.Tab(label="Draw trimap"): | |
load_image_button = gr.Button("Load image") | |
foreground_mask = gr.Image( | |
label="Foreground", | |
tool="sketch", | |
type="numpy", | |
brush_color="green", | |
mask_opacity=0.7, | |
height=500, | |
) | |
unknown_mask = gr.Image( | |
label="Unkown", | |
tool="sketch", | |
type="numpy", | |
brush_color="green", | |
mask_opacity=0.7, | |
height=500, | |
) | |
set_trimap_button = gr.Button("Set trimap") | |
run_button = gr.Button("Run") | |
with gr.Column(): | |
with gr.Box(): | |
out_alpha = gr.Image(label="Alpha", height=500) | |
out_foreground = gr.Image(label="Foreground", height=500) | |
gr.Examples( | |
examples=[ | |
["assets/bulb_rgb.png", "assets/bulb_trimap.png"], | |
["assets/retriever_rgb.png", "assets/retriever_trimap.png"], | |
], | |
inputs=[image, trimap], | |
outputs=[out_alpha, out_foreground], | |
fn=run, | |
cache_examples=os.getenv("CACHE_EXAMPLES") == "1", | |
) | |
image.change( | |
fn=check_image_size, | |
inputs=image, | |
queue=False, | |
api_name=False, | |
) | |
load_image_button.click( | |
fn=lambda image: (image, image), | |
inputs=image, | |
outputs=[foreground_mask, unknown_mask], | |
queue=False, | |
api_name=False, | |
) | |
set_trimap_button.click( | |
fn=update_trimap, | |
inputs=[foreground_mask, unknown_mask], | |
outputs=trimap, | |
queue=False, | |
api_name=False, | |
) | |
run_button.click( | |
fn=run, | |
inputs=[image, trimap], | |
outputs=[out_alpha, out_foreground], | |
api_name="run", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |