Spaces:
Runtime error
Runtime error
from typing import IO, List | |
import cv2 | |
import torch | |
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator | |
from PIL import Image | |
import numpy as np | |
import io | |
def to_file(item) -> IO[bytes]: | |
# Create a BytesIO object | |
file_obj = io.BytesIO() | |
if isinstance(item, Image.Image): | |
item.save(file_obj, format='PNG') | |
if isinstance(item, np.ndarray): | |
np.save(file_obj, item) | |
# Reset the file object's position to the beginning | |
file_obj.seek(0) | |
# Return the file object | |
return file_obj | |
def get_sam(model_type, checkpoint_path, device=None): | |
if device is None: | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
sam = sam_model_registry[model_type](checkpoint=checkpoint_path) | |
sam.to(device=device) | |
return sam | |
def draw_mask(img: Image.Image, boolean_mask: np.ndarray, color: tuple, mask_alpha: float) -> Image.Image: | |
int_alpha = int(mask_alpha*255) | |
color_mask = Image.new('RGBA', img.size, color=color) | |
color_mask.putalpha(Image.fromarray(boolean_mask.astype(np.uint8)*int_alpha, mode='L')) | |
result = Image.alpha_composite(img, color_mask) | |
return result | |
def random_color(): | |
return tuple(np.random.randint(0,255, 3)) | |
def draw_masks(img: Image.Image, boolean_masks: np.ndarray) -> Image.Image: | |
img = img.copy() | |
for boolean_mask in boolean_masks: | |
img = draw_mask(img, boolean_mask, random_color(), 0.2) | |
return img | |
def cutout(img: Image.Image, boolean_mask: np.ndarray): | |
rgba_img = img.convert('RGBA') | |
mask = Image.fromarray(boolean_mask).convert("L") | |
rgba_img.putalpha(mask) | |
return rgba_img | |
def predict_conditioned(sam, pil_img, **kwargs): | |
rgb_arr = pil_image_to_rgb_array(pil_img) | |
predictor = SamPredictor(sam) | |
predictor.set_image(rgb_arr) | |
masks, quality, _ = predictor.predict(**kwargs) | |
return masks, quality | |
def predict_all(sam, pil_img): | |
rgb_arr = pil_image_to_rgb_array(pil_img) | |
mask_generator = SamAutomaticMaskGenerator(sam) | |
results = mask_generator.generate(rgb_arr) | |
masks = [] | |
quality = [] | |
for result in results: | |
masks.append(result['segmentation']) | |
quality.append(result['stability_score']) | |
masks = np.array(masks) | |
quality = np.array(quality) | |
return masks, quality | |
def pil_image_to_rgb_array(image): | |
if image.mode == "RGBA": | |
rgb_image = Image.new("RGB", image.size, (255, 255, 255)) | |
rgb_image.paste(image, mask=image.split()[3]) # Apply alpha channel as the mask | |
rgb_array = np.array(rgb_image) | |
else: | |
rgb_array = np.array(image.convert("RGB")) | |
return rgb_array | |
def box_pts_to_xyxy(pt1, pt2): | |
"""convert box from pts format to XYXY | |
Args: | |
pt1 : (x1, y1) first corner of a box | |
pt2 : (x2, y2) second corner, diagonal to pt1 | |
Returns: | |
xyxy: (x_min, y_min, x_max, y_max) | |
""" | |
x1, y1 = pt1 | |
x2, y2 = pt2 | |
return (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)) | |
def crop_empty(image:Image.Image): | |
# Convert image to numpy array | |
np_image = np.array(image) | |
# Find non-transparent pixels | |
non_transparent_pixels = np_image[:, :, 3] > 0 | |
# Calculate bounding box coordinates | |
rows = np.any(non_transparent_pixels, axis=1) | |
cols = np.any(non_transparent_pixels, axis=0) | |
ymin, ymax = np.where(rows)[0][[0, -1]] | |
xmin, xmax = np.where(cols)[0][[0, -1]] | |
# Crop the image | |
cropped_image = np_image[ymin:ymax+1, xmin:xmax+1, :] | |
# Convert cropped image back to PIL image | |
pil_image = Image.fromarray(cropped_image) | |
return pil_image |