ViTMatte / app.py
hysts's picture
hysts HF staff
Add files
d578b5a
raw
history blame
5.03 kB
#!/usr/bin/env python
import os
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from transformers import VitMatteForImageMatting, VitMatteImageProcessor
DESCRIPTION = "# [ViTMatte](https://github.com/hustvl/ViTMatte)"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1500"))
MODEL_ID = os.getenv("MODEL_ID", "hustvl/vitmatte-small-distinctions-646")
processor = VitMatteImageProcessor.from_pretrained(MODEL_ID)
model = VitMatteForImageMatting.from_pretrained(MODEL_ID).to(device)
def check_image_size(image: PIL.Image.Image) -> None:
if max(image.size) > MAX_IMAGE_SIZE:
raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.")
def binarize_mask(mask: np.ndarray) -> np.ndarray:
mask[mask < 128] = 0
mask[mask > 0] = 1
return mask
def update_trimap(foreground_mask: dict[str, np.ndarray], unknown_mask: dict[str, np.ndarray]) -> np.ndarray:
foreground = foreground_mask["mask"]
foreground = binarize_mask(foreground)
unknown = unknown_mask["mask"]
unknown = binarize_mask(unknown)
trimap = np.zeros_like(foreground)
trimap[unknown > 0] = 128
trimap[foreground > 0] = 255
return trimap
@spaces.GPU
@torch.inference_mode()
def run(image: PIL.Image.Image, trimap: PIL.Image.Image) -> tuple[PIL.Image.Image, PIL.Image.Image]:
if image.size != trimap.size:
raise gr.Error("Image and trimap must have the same size.")
if max(image.size) > MAX_IMAGE_SIZE:
raise gr.Error(f"Image size is too large. Max image size is {MAX_IMAGE_SIZE} pixels.")
if image.mode != "RGB":
raise gr.Error("Image must be RGB.")
if trimap.mode != "L":
raise gr.Error("Trimap must be grayscale.")
pixel_values = processor(images=image, trimaps=trimap, return_tensors="pt").to(device).pixel_values
out = model(pixel_values=pixel_values)
alpha = out.alphas[0, 0].to("cpu").numpy()
w, h = image.size
alpha = alpha[:h, :w]
foreground = np.array(image).astype(float) / 255 * alpha[:, :, None] + (1 - alpha[:, :, None])
foreground = (foreground * 255).astype(np.uint8)
foreground = PIL.Image.fromarray(foreground)
return alpha, foreground
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Row():
with gr.Column():
with gr.Box():
image = gr.Image(label="Input image", type="pil", height=500)
with gr.Tabs():
with gr.Tab(label="Trimap"):
trimap = gr.Image(label="Trimap", type="pil", image_mode="L", height=500)
with gr.Tab(label="Draw trimap"):
load_image_button = gr.Button("Load image")
foreground_mask = gr.Image(
label="Foreground",
tool="sketch",
type="numpy",
brush_color="green",
mask_opacity=0.7,
height=500,
)
unknown_mask = gr.Image(
label="Unkown",
tool="sketch",
type="numpy",
brush_color="green",
mask_opacity=0.7,
height=500,
)
set_trimap_button = gr.Button("Set trimap")
run_button = gr.Button("Run")
with gr.Column():
with gr.Box():
out_alpha = gr.Image(label="Alpha", height=500)
out_foreground = gr.Image(label="Foreground", height=500)
gr.Examples(
examples=[
["assets/bulb_rgb.png", "assets/bulb_trimap.png"],
["assets/retriever_rgb.png", "assets/retriever_trimap.png"],
],
inputs=[image, trimap],
outputs=[out_alpha, out_foreground],
fn=run,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
)
image.change(
fn=check_image_size,
inputs=image,
queue=False,
api_name=False,
)
load_image_button.click(
fn=lambda image: (image, image),
inputs=image,
outputs=[foreground_mask, unknown_mask],
queue=False,
api_name=False,
)
set_trimap_button.click(
fn=update_trimap,
inputs=[foreground_mask, unknown_mask],
outputs=trimap,
queue=False,
api_name=False,
)
run_button.click(
fn=run,
inputs=[image, trimap],
outputs=[out_alpha, out_foreground],
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()