clearbg_space / app.py
Aryan Wadhawan
Add model
6030ffa
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from u2net import U2NET
import data_transforms
import torch.nn.functional as F
from skimage import io
from torchvision.transforms.functional import normalize
# Load the model
model = U2NET(3, 1)
model_path = "u2net.pth"
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Preprocess the image
def preprocess(image):
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)])
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
return sample
# Generate the mask
def generate_mask(image):
# Preprocess the image
image = np.array(image.convert("RGB"))
img = preprocess(image)
input_size = [1024, 1024]
im_shp = image.shape[0:2]
im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
# Replace F.upsample with F.interpolate
im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
with torch.no_grad():
result = model(image)
result = torch.squeeze(F.interpolate(result[0][0], im_shp, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
result = result.numpy()
output_mask = result[0]
output_mask = (output_mask - output_mask.min()) / (output_mask.max() - output_mask.min()) * 255
output_mask = output_mask.astype(np.uint8)
return output_mask
def predict(image):
# Generate the mask
mask = generate_mask(image)
# Convert the image to RGBA (to support transparency)
image_rgba = image.convert("RGBA")
# Create a binary mask from the generated mask and resize it to the image size
mask = Image.fromarray(mask).resize(image.size).convert("L") # Convert to grayscale
# Create a new image with transparency (RGBA) for the output with transparent background
transparent_image = Image.new("RGBA", image.size)
transparent_image.paste(image_rgba, mask=mask)
# Create foreground and background masks
red_foreground = Image.new("RGBA", image.size, (255, 0, 0, 128)) # Red foreground with 50% opacity
blue_background = Image.new("RGBA", image.size, (0, 0, 255, 128)) # Blue background with 50% opacity
# Create an empty overlay image
overlay_image = Image.new("RGBA", image.size)
# Overlay the red and blue masks based on the mask
overlay_image.paste(blue_background, (0, 0)) # Fill the entire overlay with blue
overlay_image.paste(red_foreground, (0, 0), mask=mask) # Paste red where mask is white
# Combine the original image with the overlay at 50% opacity
combined_image = Image.blend(image_rgba, overlay_image, alpha=0.5)
return transparent_image, combined_image
# Create the Gradio interface with two outputs
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Transparent Background", image_mode="RGBA", format="png"), # Transparent output
gr.Image(type="pil", label="Overlay with Colors", image_mode="RGBA", format="png"), # Colored overlay output
],
title="Background Removal with U2NET",
description="Upload an image to remove the background and visualize it with an overlay."
)
if __name__ == "__main__":
iface.launch()