Aryan Wadhawan commited on
Commit
7b134d4
1 Parent(s): 365f352
__pycache__/data_transforms.cpython-312.pyc ADDED
Binary file (17.5 kB). View file
 
__pycache__/u2net.cpython-312.pyc ADDED
Binary file (27.6 kB). View file
 
app.py CHANGED
@@ -37,7 +37,7 @@ def preprocess(image):
37
 
38
  return sample
39
 
40
- # Define the function to generate the mask
41
  def generate_mask(image):
42
  # Preprocess the image
43
  image = np.array(image.convert("RGB"))
@@ -46,13 +46,15 @@ def generate_mask(image):
46
  input_size = [1024, 1024]
47
  im_shp = image.shape[0:2]
48
  im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
49
- im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
 
 
50
  image = torch.divide(im_tensor, 255.0)
51
  image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
52
 
53
  with torch.no_grad():
54
  result = model(image)
55
- result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0)
56
  ma = torch.max(result)
57
  mi = torch.min(result)
58
  result = (result - mi) / (ma - mi)
@@ -61,17 +63,35 @@ def generate_mask(image):
61
  output_mask = result[0]
62
  output_mask = (output_mask - output_mask.min()) / (output_mask.max() - output_mask.min()) * 255
63
  output_mask = output_mask.astype(np.uint8)
64
- output_image = Image.fromarray(output_mask)
65
 
66
- return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # Create the Gradio interface
69
  iface = gr.Interface(
70
- fn=generate_mask,
71
- inputs=gr.inputs.Image(type="pil"),
72
- outputs=gr.outputs.Image(type="pil"),
73
- title="U2NET Background Removal",
74
- description="Upload an image and get the background mask"
75
  )
76
 
77
  if __name__ == "__main__":
 
37
 
38
  return sample
39
 
40
+ # Generate the mask
41
  def generate_mask(image):
42
  # Preprocess the image
43
  image = np.array(image.convert("RGB"))
 
46
  input_size = [1024, 1024]
47
  im_shp = image.shape[0:2]
48
  im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
49
+
50
+ # Replace F.upsample with F.interpolate
51
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
52
  image = torch.divide(im_tensor, 255.0)
53
  image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
54
 
55
  with torch.no_grad():
56
  result = model(image)
57
+ result = torch.squeeze(F.interpolate(result[0][0], im_shp, mode='bilinear'), 0)
58
  ma = torch.max(result)
59
  mi = torch.min(result)
60
  result = (result - mi) / (ma - mi)
 
63
  output_mask = result[0]
64
  output_mask = (output_mask - output_mask.min()) / (output_mask.max() - output_mask.min()) * 255
65
  output_mask = output_mask.astype(np.uint8)
 
66
 
67
+ return output_mask
68
+
69
+ # Define the final predict method to overlay the mask
70
+ def predict(image):
71
+ # Generate the mask
72
+ mask = generate_mask(image)
73
+
74
+ # Convert the image to RGBA (to support transparency)
75
+ image = image.convert("RGBA")
76
+
77
+ # Convert the mask into a binary mask where 255 is kept and 0 is transparent
78
+ mask = Image.fromarray(mask).resize(image.size).convert("L") # Convert to grayscale (L mode)
79
+
80
+ # Create a new image with transparency (RGBA)
81
+ transparent_image = Image.new("RGBA", image.size)
82
+
83
+ # Use the mask as transparency: paste the original image where the mask is white
84
+ transparent_image.paste(image, mask=mask)
85
+
86
+ return transparent_image
87
 
88
+ # Create the Gradio interface with custom output size for the display only (not affecting the saved image)
89
  iface = gr.Interface(
90
+ fn=predict,
91
+ inputs=gr.Image(type="pil"),
92
+ outputs=gr.Image(type="pil", tool="editor", label="Edited Image"), # Adjust the box size
93
+ title="Background Removal with U2NET",
94
+ description="Upload an image and remove the background"
95
  )
96
 
97
  if __name__ == "__main__":