Spaces:
Runtime error
Runtime error
Merge branch 'master'
Browse files- .gitignore +2 -0
- app.py +98 -0
- app_configs.py +4 -0
- examples/cat-256.png +0 -0
- requirements.txt +5 -0
- service.py +90 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
model/
|
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import app_configs as configs
|
3 |
+
import service
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
from PIL import Image
|
8 |
+
import logging
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger()
|
13 |
+
|
14 |
+
sam = None #service.get_sam(configs.model_type, configs.model_ckpt_path, configs.device)
|
15 |
+
red = (255,0,0)
|
16 |
+
blue = (0,0,255)
|
17 |
+
|
18 |
+
def load_sam_instance():
|
19 |
+
global sam
|
20 |
+
if sam is None:
|
21 |
+
gr.Info('Initialising SAM, hang in there...')
|
22 |
+
if not os.path.exists(configs.model_ckpt_path):
|
23 |
+
gr.Info('Downloading weights from hugging face hub')
|
24 |
+
chkpt_path = hf_hub_download("ybelkada/segment-anything", configs.model_ckpt_path)
|
25 |
+
else:
|
26 |
+
chkpt_path = configs.model_ckpt_path
|
27 |
+
sam = service.get_sam(configs.model_type, chkpt_path, configs.device)
|
28 |
+
return sam
|
29 |
+
|
30 |
+
block = gr.Blocks()
|
31 |
+
with block:
|
32 |
+
# states
|
33 |
+
def point_coords_empty():
|
34 |
+
return []
|
35 |
+
def point_labels_empty():
|
36 |
+
return []
|
37 |
+
point_coords = gr.State(point_coords_empty)
|
38 |
+
point_labels = gr.State(point_labels_empty)
|
39 |
+
raw_image = gr.Image(type='pil', visible=False)
|
40 |
+
|
41 |
+
# UI
|
42 |
+
with gr.Row():
|
43 |
+
with gr.Column():
|
44 |
+
input_image = gr.Image(label='Input', height=512, type='pil')
|
45 |
+
with gr.Row():
|
46 |
+
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
|
47 |
+
reset_btn = gr.Button('Reset')
|
48 |
+
run_btn = gr.Button('Run', variant = 'primary')
|
49 |
+
with gr.Column():
|
50 |
+
with gr.Tab('Cutout'):
|
51 |
+
cutout_gallery = gr.Gallery()
|
52 |
+
with gr.Tab('Annotation'):
|
53 |
+
masks_annotated_image = gr.AnnotatedImage(label='Segments')
|
54 |
+
gr.Examples(examples=[['examples/cat-256.png','examples/cat-256.png']],inputs=[input_image, raw_image])
|
55 |
+
|
56 |
+
# components
|
57 |
+
components = {point_coords, point_labels, raw_image, input_image, point_label_radio, reset_btn, run_btn, cutout_gallery, masks_annotated_image}
|
58 |
+
|
59 |
+
# event - init coords
|
60 |
+
def on_reset_btn_click(raw_image):
|
61 |
+
return raw_image, point_coords_empty(), point_labels_empty(), None
|
62 |
+
reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)
|
63 |
+
|
64 |
+
def on_input_image_upload(input_image):
|
65 |
+
return input_image, point_coords_empty(), point_labels_empty(), None
|
66 |
+
input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
|
67 |
+
|
68 |
+
# event - set coords
|
69 |
+
def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
|
70 |
+
x, y = evt.index
|
71 |
+
color = red if point_label_radio == 0 else blue
|
72 |
+
img = np.array(input_image)
|
73 |
+
cv2.circle(img, (x, y), 10, color, -1)
|
74 |
+
img = Image.fromarray(img)
|
75 |
+
point_coords.append([x,y])
|
76 |
+
point_labels.append(point_label_radio)
|
77 |
+
return img, point_coords, point_labels
|
78 |
+
input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)
|
79 |
+
|
80 |
+
# event - inference
|
81 |
+
def on_run_btn_click(data):
|
82 |
+
sam = load_sam_instance()
|
83 |
+
image = data[raw_image]
|
84 |
+
if len(data[point_coords]) == 0:
|
85 |
+
masks, _ = service.predict_all(sam, image)
|
86 |
+
else:
|
87 |
+
masks, _ = service.predict_conditioned(sam,
|
88 |
+
image,
|
89 |
+
point_coords=np.array(data[point_coords]),
|
90 |
+
point_labels=np.array(data[point_labels]))
|
91 |
+
annotated = (image, [(masks[i], f'Mask {i}') for i in range(len(masks))])
|
92 |
+
cutouts = [service.cutout(image, mask) for mask in masks]
|
93 |
+
return cutouts, annotated, masks
|
94 |
+
run_btn.click(on_run_btn_click, components, [cutout_gallery, masks_annotated_image], queue=True)
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
block.queue()
|
98 |
+
block.launch()
|
app_configs.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_type = r'vit_b'
|
2 |
+
# model_ckpt_path = None
|
3 |
+
model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
|
4 |
+
device = 'cpu'
|
examples/cat-256.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
gradio
|
3 |
+
torch
|
4 |
+
torchvision
|
5 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
service.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import IO, List
|
2 |
+
import torch
|
3 |
+
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import io
|
7 |
+
|
8 |
+
def to_file(item) -> IO[bytes]:
|
9 |
+
# Create a BytesIO object
|
10 |
+
file_obj = io.BytesIO()
|
11 |
+
if isinstance(item, Image.Image):
|
12 |
+
item.save(file_obj, format='PNG')
|
13 |
+
if isinstance(item, np.ndarray):
|
14 |
+
np.save(file_obj, item)
|
15 |
+
# Reset the file object's position to the beginning
|
16 |
+
file_obj.seek(0)
|
17 |
+
# Return the file object
|
18 |
+
return file_obj
|
19 |
+
|
20 |
+
def get_sam(model_type, checkpoint_path, device=None):
|
21 |
+
if device is None:
|
22 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
23 |
+
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
|
24 |
+
sam.to(device=device)
|
25 |
+
return sam
|
26 |
+
|
27 |
+
def draw_mask(img: Image.Image, boolean_mask: np.ndarray, color: tuple, mask_alpha: float) -> Image.Image:
|
28 |
+
int_alpha = int(mask_alpha*255)
|
29 |
+
color_mask = Image.new('RGBA', img.size, color=color)
|
30 |
+
color_mask.putalpha(Image.fromarray(boolean_mask.astype(np.uint8)*int_alpha, mode='L'))
|
31 |
+
result = Image.alpha_composite(img, color_mask)
|
32 |
+
|
33 |
+
return result
|
34 |
+
def random_color():
|
35 |
+
return tuple(np.random.randint(0,255, 3))
|
36 |
+
|
37 |
+
def draw_masks(img: Image.Image, boolean_masks: np.ndarray) -> Image.Image:
|
38 |
+
img = img.copy()
|
39 |
+
for boolean_mask in boolean_masks:
|
40 |
+
img = draw_mask(img, boolean_mask, random_color(), 0.2)
|
41 |
+
return img
|
42 |
+
|
43 |
+
def cutout(img: Image.Image, boolean_mask: np.ndarray):
|
44 |
+
rgba_img = img.convert('RGBA')
|
45 |
+
mask = Image.fromarray(boolean_mask).convert("L")
|
46 |
+
rgba_img.putalpha(mask)
|
47 |
+
return rgba_img
|
48 |
+
|
49 |
+
|
50 |
+
def predict_conditioned(sam, pil_img, **kwargs):
|
51 |
+
rgb_arr = pil_image_to_rgb_array(pil_img)
|
52 |
+
predictor = SamPredictor(sam)
|
53 |
+
predictor.set_image(rgb_arr)
|
54 |
+
masks, quality, _ = predictor.predict(**kwargs)
|
55 |
+
return masks, quality
|
56 |
+
|
57 |
+
def predict_all(sam, pil_img):
|
58 |
+
rgb_arr = pil_image_to_rgb_array(pil_img)
|
59 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
60 |
+
results = mask_generator.generate(rgb_arr)
|
61 |
+
masks = []
|
62 |
+
quality = []
|
63 |
+
for result in results:
|
64 |
+
masks.append(result['segmentation'])
|
65 |
+
quality.append(result['stability_score'])
|
66 |
+
masks = np.array(masks)
|
67 |
+
quality = np.array(quality)
|
68 |
+
return masks, quality
|
69 |
+
|
70 |
+
def pil_image_to_rgb_array(image):
|
71 |
+
if image.mode == "RGBA":
|
72 |
+
rgb_image = Image.new("RGB", image.size, (255, 255, 255))
|
73 |
+
rgb_image.paste(image, mask=image.split()[3]) # Apply alpha channel as the mask
|
74 |
+
rgb_array = np.array(rgb_image)
|
75 |
+
else:
|
76 |
+
rgb_array = np.array(image.convert("RGB"))
|
77 |
+
return rgb_array
|
78 |
+
|
79 |
+
def box_pts_to_xyxy(pt1, pt2):
|
80 |
+
"""convert box from pts format to XYXY
|
81 |
+
Args:
|
82 |
+
pt1 : (x1, y1) first corner of a box
|
83 |
+
pt2 : (x2, y2) second corner, diagonal to pt1
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
xyxy: (x_min, y_min, x_max, y_max)
|
87 |
+
"""
|
88 |
+
x1, y1 = pt1
|
89 |
+
x2, y2 = pt2
|
90 |
+
return (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2))
|