Spaces:
Running
on
Zero
Running
on
Zero
Update inference_i2mv_sdxl.py
Browse files- inference_i2mv_sdxl.py +98 -22
inference_i2mv_sdxl.py
CHANGED
@@ -151,28 +151,105 @@ def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = No
|
|
151 |
# return output_image
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
|
|
|
|
|
|
|
|
|
|
155 |
|
|
|
|
|
|
|
|
|
156 |
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
H, W = alpha.shape
|
170 |
-
#
|
171 |
y, x = np.where(alpha)
|
172 |
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
|
173 |
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
|
174 |
-
image_center =
|
175 |
-
|
|
|
176 |
H, W, _ = image_center.shape
|
177 |
if H > W:
|
178 |
W = int(W * (height * 0.9) / H)
|
@@ -180,18 +257,17 @@ def preprocess_image(image: Image.Image, height, width):
|
|
180 |
else:
|
181 |
H = int(H * (width * 0.9) / W)
|
182 |
W = int(width * 0.9)
|
|
|
183 |
image_center = np.array(Image.fromarray(image_center).resize((W, H)))
|
184 |
-
|
|
|
185 |
start_h = (height - H) // 2
|
186 |
start_w = (width - W) // 2
|
187 |
-
|
188 |
-
|
189 |
-
image = image.astype(np.float32) / 255.0
|
190 |
-
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
191 |
-
image = (image * 255).clip(0, 255).astype(np.uint8)
|
192 |
-
image = Image.fromarray(image)
|
193 |
|
194 |
-
|
|
|
195 |
|
196 |
|
197 |
def run_pipeline(
|
|
|
151 |
# return output_image
|
152 |
|
153 |
|
154 |
+
def remove_bg(image: Image.Image, net, transform, device, mask: np.ndarray = None):
|
155 |
+
"""
|
156 |
+
Applies a pre-existing mask to an image to make the background transparent.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
image (PIL.Image.Image): The input image.
|
160 |
+
net: Pre-trained neural network (not used but kept for compatibility).
|
161 |
+
transform: Image transformation object (not used but kept for compatibility).
|
162 |
+
device: Device used for inference (not used but kept for compatibility).
|
163 |
+
mask (np.ndarray, optional): The mask to use. Should be the same size
|
164 |
+
as the input image, with values between 0 and 255.
|
165 |
+
If None, will return image with no changes.
|
166 |
|
167 |
+
Returns:
|
168 |
+
PIL.Image.Image: The modified image with transparent background.
|
169 |
+
"""
|
170 |
+
if mask is None:
|
171 |
+
return image
|
172 |
|
173 |
+
# Ensure the mask is in the correct format
|
174 |
+
if mask.ndim == 2: # If mask is 2D (H, W)
|
175 |
+
mask = mask.astype(np.uint8) # Ensure mask is uint8
|
176 |
+
mask = np.expand_dims(mask, axis=-1) # Add channel dimension
|
177 |
|
178 |
+
# Convert the mask to PIL Image
|
179 |
+
mask_pil = Image.fromarray(mask.squeeze(2) * 255) # Convert to binary mask
|
180 |
+
|
181 |
+
# Resize the mask to match the original image size
|
182 |
+
mask_pil = mask_pil.resize(image.size, Image.LANCZOS)
|
183 |
+
|
184 |
+
# Create a new image with the same size and mode as the original
|
185 |
+
output_image = Image.new("RGBA", image.size)
|
186 |
+
|
187 |
+
# Apply the mask to the original image
|
188 |
+
image.putalpha(mask_pil)
|
189 |
+
|
190 |
+
# Composite the original image with the mask
|
191 |
+
output_image.paste(image, (0, 0), image)
|
192 |
+
|
193 |
+
return output_image
|
194 |
+
|
195 |
+
|
196 |
+
# def preprocess_image(image: Image.Image, height, width):
|
197 |
|
198 |
+
# alpha = image[..., 3] > 0
|
199 |
+
# # alpha = image
|
200 |
+
|
201 |
+
# #if image.mode in ("RGBA", "LA"):
|
202 |
+
# # image = np.array(image)
|
203 |
+
# # alpha = image[..., 3] # Extract the alpha channel
|
204 |
+
# #elif image.mode in ("RGB"):
|
205 |
+
# # image = np.array(image)
|
206 |
+
# # Create default alpha for non-alpha images
|
207 |
+
# # alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create
|
208 |
+
# H, W = alpha.shape
|
209 |
+
# # get the bounding box of alpha
|
210 |
+
# y, x = np.where(alpha)
|
211 |
+
# y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
|
212 |
+
# x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
|
213 |
+
# image_center = image[y0:y1, x0:x1]
|
214 |
+
# # resize the longer side to H * 0.9
|
215 |
+
# H, W, _ = image_center.shape
|
216 |
+
# if H > W:
|
217 |
+
# W = int(W * (height * 0.9) / H)
|
218 |
+
# H = int(height * 0.9)
|
219 |
+
# else:
|
220 |
+
# H = int(H * (width * 0.9) / W)
|
221 |
+
# W = int(width * 0.9)
|
222 |
+
# image_center = np.array(Image.fromarray(image_center).resize((W, H)))
|
223 |
+
# # pad to H, W
|
224 |
+
# start_h = (height - H) // 2
|
225 |
+
# start_w = (width - W) // 2
|
226 |
+
# image = np.zeros((height, width, 4), dtype=np.uint8)
|
227 |
+
# image[start_h : start_h + H, start_w : start_w + W] = image_center
|
228 |
+
# image = image.astype(np.float32) / 255.0
|
229 |
+
# image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
230 |
+
# image = (image * 255).clip(0, 255).astype(np.uint8)
|
231 |
+
# image = Image.fromarray(image)
|
232 |
+
|
233 |
+
# return image
|
234 |
+
|
235 |
+
def preprocess_image(image: Image.Image, height, width):
|
236 |
+
# Convert image to numpy array
|
237 |
+
image_np = np.array(image)
|
238 |
+
|
239 |
+
# Extract the alpha channel if present
|
240 |
+
if image_np.shape[-1] == 4:
|
241 |
+
alpha = image_np[..., 3] > 0 # Create a binary mask from the alpha channel
|
242 |
+
else:
|
243 |
+
alpha = np.ones(image_np[..., 0].shape, dtype=bool) # Default to all true for RGB images
|
244 |
+
|
245 |
H, W = alpha.shape
|
246 |
+
# Get the bounding box of the alpha
|
247 |
y, x = np.where(alpha)
|
248 |
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
|
249 |
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
|
250 |
+
image_center = image_np[y0:y1, x0:x1]
|
251 |
+
|
252 |
+
# Resize the longer side to H * 0.9
|
253 |
H, W, _ = image_center.shape
|
254 |
if H > W:
|
255 |
W = int(W * (height * 0.9) / H)
|
|
|
257 |
else:
|
258 |
H = int(H * (width * 0.9) / W)
|
259 |
W = int(width * 0.9)
|
260 |
+
|
261 |
image_center = np.array(Image.fromarray(image_center).resize((W, H)))
|
262 |
+
|
263 |
+
# Pad to H, W
|
264 |
start_h = (height - H) // 2
|
265 |
start_w = (width - W) // 2
|
266 |
+
padded_image = np.zeros((height, width, 4), dtype=np.uint8)
|
267 |
+
padded_image[start_h:start_h + H, start_w:start_w + W] = image_center
|
|
|
|
|
|
|
|
|
268 |
|
269 |
+
# Convert back to PIL Image
|
270 |
+
return Image.fromarray(padded_image)
|
271 |
|
272 |
|
273 |
def run_pipeline(
|