File size: 3,996 Bytes
9aa8e63
 
 
 
 
 
973996a
9aa8e63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b134d4
9aa8e63
 
 
 
6030ffa
9aa8e63
 
 
6030ffa
7b134d4
 
9aa8e63
 
 
 
 
7b134d4
9aa8e63
 
 
 
 
 
 
 
 
7b134d4
 
 
 
 
6030ffa
7b134d4
6030ffa
 
 
 
 
 
7b134d4
6030ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aa8e63
6030ffa
9aa8e63
6030ffa
 
 
 
 
 
7b134d4
6030ffa
9aa8e63
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from u2net import U2NET
import data_transforms
import torch.nn.functional as F
from skimage import io
from torchvision.transforms.functional import normalize

# Load the model
model = U2NET(3, 1)
model_path = "u2net.pth"
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()

# Preprocess the image
def preprocess(image):
    label_3 = np.zeros(image.shape)
    label = np.zeros(label_3.shape[0:2])

    if 3 == len(label_3.shape):
        label = label_3[:, :, 0]
    elif 2 == len(label_3.shape):
        label = label_3

    if 3 == len(image.shape) and 2 == len(label.shape):
        label = label[:, :, np.newaxis]
    elif 2 == len(image.shape) and 2 == len(label.shape):
        image = image[:, :, np.newaxis]
        label = label[:, :, np.newaxis]

    transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)])
    sample = transform({"imidx": np.array([0]), "image": image, "label": label})

    return sample

# Generate the mask
def generate_mask(image):
    # Preprocess the image
    image = np.array(image.convert("RGB"))
    img = preprocess(image)

    input_size = [1024, 1024]
    im_shp = image.shape[0:2]
    im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)

    # Replace F.upsample with F.interpolate
    im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
    image = torch.divide(im_tensor, 255.0)
    image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])

    with torch.no_grad():
        result = model(image)
        result = torch.squeeze(F.interpolate(result[0][0], im_shp, mode='bilinear'), 0)
        ma = torch.max(result)
        mi = torch.min(result)
        result = (result - mi) / (ma - mi)
        result = result.numpy()

    output_mask = result[0]
    output_mask = (output_mask - output_mask.min()) / (output_mask.max() - output_mask.min()) * 255
    output_mask = output_mask.astype(np.uint8)

    return output_mask

def predict(image):
    # Generate the mask
    mask = generate_mask(image)

    # Convert the image to RGBA (to support transparency)
    image_rgba = image.convert("RGBA")

    # Create a binary mask from the generated mask and resize it to the image size
    mask = Image.fromarray(mask).resize(image.size).convert("L")  # Convert to grayscale

    # Create a new image with transparency (RGBA) for the output with transparent background
    transparent_image = Image.new("RGBA", image.size)
    transparent_image.paste(image_rgba, mask=mask)

    # Create foreground and background masks
    red_foreground = Image.new("RGBA", image.size, (255, 0, 0, 128))  # Red foreground with 50% opacity
    blue_background = Image.new("RGBA", image.size, (0, 0, 255, 128))  # Blue background with 50% opacity

    # Create an empty overlay image
    overlay_image = Image.new("RGBA", image.size)

    # Overlay the red and blue masks based on the mask
    overlay_image.paste(blue_background, (0, 0))  # Fill the entire overlay with blue
    overlay_image.paste(red_foreground, (0, 0), mask=mask)  # Paste red where mask is white

    # Combine the original image with the overlay at 50% opacity
    combined_image = Image.blend(image_rgba, overlay_image, alpha=0.5)

    return transparent_image, combined_image

# Create the Gradio interface with two outputs
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Transparent Background", image_mode="RGBA", format="png"),  # Transparent output
        gr.Image(type="pil", label="Overlay with Colors", image_mode="RGBA", format="png"),  # Colored overlay output
    ],
    title="Background Removal with U2NET",
    description="Upload an image to remove the background and visualize it with an overlay."
)

if __name__ == "__main__":
    iface.launch()