Spaces:
Sleeping
Sleeping
Aryan Wadhawan
commited on
Commit
•
7b134d4
1
Parent(s):
365f352
Add model
Browse files- __pycache__/data_transforms.cpython-312.pyc +0 -0
- __pycache__/u2net.cpython-312.pyc +0 -0
- app.py +31 -11
__pycache__/data_transforms.cpython-312.pyc
ADDED
Binary file (17.5 kB). View file
|
|
__pycache__/u2net.cpython-312.pyc
ADDED
Binary file (27.6 kB). View file
|
|
app.py
CHANGED
@@ -37,7 +37,7 @@ def preprocess(image):
|
|
37 |
|
38 |
return sample
|
39 |
|
40 |
-
#
|
41 |
def generate_mask(image):
|
42 |
# Preprocess the image
|
43 |
image = np.array(image.convert("RGB"))
|
@@ -46,13 +46,15 @@ def generate_mask(image):
|
|
46 |
input_size = [1024, 1024]
|
47 |
im_shp = image.shape[0:2]
|
48 |
im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
|
49 |
-
|
|
|
|
|
50 |
image = torch.divide(im_tensor, 255.0)
|
51 |
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
52 |
|
53 |
with torch.no_grad():
|
54 |
result = model(image)
|
55 |
-
result = torch.squeeze(F.
|
56 |
ma = torch.max(result)
|
57 |
mi = torch.min(result)
|
58 |
result = (result - mi) / (ma - mi)
|
@@ -61,17 +63,35 @@ def generate_mask(image):
|
|
61 |
output_mask = result[0]
|
62 |
output_mask = (output_mask - output_mask.min()) / (output_mask.max() - output_mask.min()) * 255
|
63 |
output_mask = output_mask.astype(np.uint8)
|
64 |
-
output_image = Image.fromarray(output_mask)
|
65 |
|
66 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
# Create the Gradio interface
|
69 |
iface = gr.Interface(
|
70 |
-
fn=
|
71 |
-
inputs=gr.
|
72 |
-
outputs=gr.
|
73 |
-
title="
|
74 |
-
description="Upload an image and
|
75 |
)
|
76 |
|
77 |
if __name__ == "__main__":
|
|
|
37 |
|
38 |
return sample
|
39 |
|
40 |
+
# Generate the mask
|
41 |
def generate_mask(image):
|
42 |
# Preprocess the image
|
43 |
image = np.array(image.convert("RGB"))
|
|
|
46 |
input_size = [1024, 1024]
|
47 |
im_shp = image.shape[0:2]
|
48 |
im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
|
49 |
+
|
50 |
+
# Replace F.upsample with F.interpolate
|
51 |
+
im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
|
52 |
image = torch.divide(im_tensor, 255.0)
|
53 |
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
54 |
|
55 |
with torch.no_grad():
|
56 |
result = model(image)
|
57 |
+
result = torch.squeeze(F.interpolate(result[0][0], im_shp, mode='bilinear'), 0)
|
58 |
ma = torch.max(result)
|
59 |
mi = torch.min(result)
|
60 |
result = (result - mi) / (ma - mi)
|
|
|
63 |
output_mask = result[0]
|
64 |
output_mask = (output_mask - output_mask.min()) / (output_mask.max() - output_mask.min()) * 255
|
65 |
output_mask = output_mask.astype(np.uint8)
|
|
|
66 |
|
67 |
+
return output_mask
|
68 |
+
|
69 |
+
# Define the final predict method to overlay the mask
|
70 |
+
def predict(image):
|
71 |
+
# Generate the mask
|
72 |
+
mask = generate_mask(image)
|
73 |
+
|
74 |
+
# Convert the image to RGBA (to support transparency)
|
75 |
+
image = image.convert("RGBA")
|
76 |
+
|
77 |
+
# Convert the mask into a binary mask where 255 is kept and 0 is transparent
|
78 |
+
mask = Image.fromarray(mask).resize(image.size).convert("L") # Convert to grayscale (L mode)
|
79 |
+
|
80 |
+
# Create a new image with transparency (RGBA)
|
81 |
+
transparent_image = Image.new("RGBA", image.size)
|
82 |
+
|
83 |
+
# Use the mask as transparency: paste the original image where the mask is white
|
84 |
+
transparent_image.paste(image, mask=mask)
|
85 |
+
|
86 |
+
return transparent_image
|
87 |
|
88 |
+
# Create the Gradio interface with custom output size for the display only (not affecting the saved image)
|
89 |
iface = gr.Interface(
|
90 |
+
fn=predict,
|
91 |
+
inputs=gr.Image(type="pil"),
|
92 |
+
outputs=gr.Image(type="pil", tool="editor", label="Edited Image"), # Adjust the box size
|
93 |
+
title="Background Removal with U2NET",
|
94 |
+
description="Upload an image and remove the background"
|
95 |
)
|
96 |
|
97 |
if __name__ == "__main__":
|