from typing import IO, List import cv2 import torch from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator from PIL import Image import numpy as np import io def to_file(item) -> IO[bytes]: # Create a BytesIO object file_obj = io.BytesIO() if isinstance(item, Image.Image): item.save(file_obj, format='PNG') if isinstance(item, np.ndarray): np.save(file_obj, item) # Reset the file object's position to the beginning file_obj.seek(0) # Return the file object return file_obj def get_sam(model_type, checkpoint_path, device=None): if device is None: device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') sam = sam_model_registry[model_type](checkpoint=checkpoint_path) sam.to(device=device) return sam def draw_mask(img: Image.Image, boolean_mask: np.ndarray, color: tuple, mask_alpha: float) -> Image.Image: int_alpha = int(mask_alpha*255) color_mask = Image.new('RGBA', img.size, color=color) color_mask.putalpha(Image.fromarray(boolean_mask.astype(np.uint8)*int_alpha, mode='L')) result = Image.alpha_composite(img, color_mask) return result def random_color(): return tuple(np.random.randint(0,255, 3)) def draw_masks(img: Image.Image, boolean_masks: np.ndarray) -> Image.Image: img = img.copy() for boolean_mask in boolean_masks: img = draw_mask(img, boolean_mask, random_color(), 0.2) return img def cutout(img: Image.Image, boolean_mask: np.ndarray): rgba_img = img.convert('RGBA') mask = Image.fromarray(boolean_mask).convert("L") rgba_img.putalpha(mask) return rgba_img def predict_conditioned(sam, pil_img, **kwargs): rgb_arr = pil_image_to_rgb_array(pil_img) predictor = SamPredictor(sam) predictor.set_image(rgb_arr) masks, quality, _ = predictor.predict(**kwargs) return masks, quality def predict_all(sam, pil_img): rgb_arr = pil_image_to_rgb_array(pil_img) mask_generator = SamAutomaticMaskGenerator(sam) results = mask_generator.generate(rgb_arr) masks = [] quality = [] for result in results: masks.append(result['segmentation']) quality.append(result['stability_score']) masks = np.array(masks) quality = np.array(quality) return masks, quality def pil_image_to_rgb_array(image): if image.mode == "RGBA": rgb_image = Image.new("RGB", image.size, (255, 255, 255)) rgb_image.paste(image, mask=image.split()[3]) # Apply alpha channel as the mask rgb_array = np.array(rgb_image) else: rgb_array = np.array(image.convert("RGB")) return rgb_array def box_pts_to_xyxy(pt1, pt2): """convert box from pts format to XYXY Args: pt1 : (x1, y1) first corner of a box pt2 : (x2, y2) second corner, diagonal to pt1 Returns: xyxy: (x_min, y_min, x_max, y_max) """ x1, y1 = pt1 x2, y2 = pt2 return (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)) def crop_empty(image:Image.Image): # Convert image to numpy array np_image = np.array(image) # Find non-transparent pixels non_transparent_pixels = np_image[:, :, 3] > 0 # Calculate bounding box coordinates rows = np.any(non_transparent_pixels, axis=1) cols = np.any(non_transparent_pixels, axis=0) ymin, ymax = np.where(rows)[0][[0, -1]] xmin, xmax = np.where(cols)[0][[0, -1]] # Crop the image cropped_image = np_image[ymin:ymax+1, xmin:xmax+1, :] # Convert cropped image back to PIL image pil_image = Image.fromarray(cropped_image) return pil_image