clearbg_space / app.py
Aryan Wadhawan
Add model
973996a
raw
history blame
2.46 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from u2net import U2NET
import data_transforms
import torch.nn.functional as F
from skimage import io
from torchvision.transforms.functional import normalize
# Load the model
model = U2NET(3, 1)
model_path = "u2net.pth"
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Preprocess the image
def preprocess(image):
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)])
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
return sample
# Define the function to generate the mask
def generate_mask(image):
# Preprocess the image
image = np.array(image.convert("RGB"))
img = preprocess(image)
input_size = [1024, 1024]
im_shp = image.shape[0:2]
im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
with torch.no_grad():
result = model(image)
result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
result = result.numpy()
output_mask = result[0]
output_mask = (output_mask - output_mask.min()) / (output_mask.max() - output_mask.min()) * 255
output_mask = output_mask.astype(np.uint8)
output_image = Image.fromarray(output_mask)
return output_image
# Create the Gradio interface
iface = gr.Interface(
fn=generate_mask,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Image(type="pil"),
title="U2NET Background Removal",
description="Upload an image and get the background mask"
)
if __name__ == "__main__":
iface.launch()