Peng Shiya commited on
Commit
9d85e44
2 Parent(s): c2aab2b 0dd537b

Merge branch 'master'

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. app.py +98 -0
  3. app_configs.py +4 -0
  4. examples/cat-256.png +0 -0
  5. requirements.txt +5 -0
  6. service.py +90 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ model/
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import app_configs as configs
3
+ import service
4
+ import gradio as gr
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ import logging
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger()
13
+
14
+ sam = None #service.get_sam(configs.model_type, configs.model_ckpt_path, configs.device)
15
+ red = (255,0,0)
16
+ blue = (0,0,255)
17
+
18
+ def load_sam_instance():
19
+ global sam
20
+ if sam is None:
21
+ gr.Info('Initialising SAM, hang in there...')
22
+ if not os.path.exists(configs.model_ckpt_path):
23
+ gr.Info('Downloading weights from hugging face hub')
24
+ chkpt_path = hf_hub_download("ybelkada/segment-anything", configs.model_ckpt_path)
25
+ else:
26
+ chkpt_path = configs.model_ckpt_path
27
+ sam = service.get_sam(configs.model_type, chkpt_path, configs.device)
28
+ return sam
29
+
30
+ block = gr.Blocks()
31
+ with block:
32
+ # states
33
+ def point_coords_empty():
34
+ return []
35
+ def point_labels_empty():
36
+ return []
37
+ point_coords = gr.State(point_coords_empty)
38
+ point_labels = gr.State(point_labels_empty)
39
+ raw_image = gr.Image(type='pil', visible=False)
40
+
41
+ # UI
42
+ with gr.Row():
43
+ with gr.Column():
44
+ input_image = gr.Image(label='Input', height=512, type='pil')
45
+ with gr.Row():
46
+ point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
47
+ reset_btn = gr.Button('Reset')
48
+ run_btn = gr.Button('Run', variant = 'primary')
49
+ with gr.Column():
50
+ with gr.Tab('Cutout'):
51
+ cutout_gallery = gr.Gallery()
52
+ with gr.Tab('Annotation'):
53
+ masks_annotated_image = gr.AnnotatedImage(label='Segments')
54
+ gr.Examples(examples=[['examples/cat-256.png','examples/cat-256.png']],inputs=[input_image, raw_image])
55
+
56
+ # components
57
+ components = {point_coords, point_labels, raw_image, input_image, point_label_radio, reset_btn, run_btn, cutout_gallery, masks_annotated_image}
58
+
59
+ # event - init coords
60
+ def on_reset_btn_click(raw_image):
61
+ return raw_image, point_coords_empty(), point_labels_empty(), None
62
+ reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)
63
+
64
+ def on_input_image_upload(input_image):
65
+ return input_image, point_coords_empty(), point_labels_empty(), None
66
+ input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
67
+
68
+ # event - set coords
69
+ def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
70
+ x, y = evt.index
71
+ color = red if point_label_radio == 0 else blue
72
+ img = np.array(input_image)
73
+ cv2.circle(img, (x, y), 10, color, -1)
74
+ img = Image.fromarray(img)
75
+ point_coords.append([x,y])
76
+ point_labels.append(point_label_radio)
77
+ return img, point_coords, point_labels
78
+ input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)
79
+
80
+ # event - inference
81
+ def on_run_btn_click(data):
82
+ sam = load_sam_instance()
83
+ image = data[raw_image]
84
+ if len(data[point_coords]) == 0:
85
+ masks, _ = service.predict_all(sam, image)
86
+ else:
87
+ masks, _ = service.predict_conditioned(sam,
88
+ image,
89
+ point_coords=np.array(data[point_coords]),
90
+ point_labels=np.array(data[point_labels]))
91
+ annotated = (image, [(masks[i], f'Mask {i}') for i in range(len(masks))])
92
+ cutouts = [service.cutout(image, mask) for mask in masks]
93
+ return cutouts, annotated, masks
94
+ run_btn.click(on_run_btn_click, components, [cutout_gallery, masks_annotated_image], queue=True)
95
+
96
+ if __name__ == '__main__':
97
+ block.queue()
98
+ block.launch()
app_configs.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ model_type = r'vit_b'
2
+ # model_ckpt_path = None
3
+ model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
+ device = 'cpu'
examples/cat-256.png ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface_hub
2
+ gradio
3
+ torch
4
+ torchvision
5
+ git+https://github.com/facebookresearch/segment-anything.git
service.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import IO, List
2
+ import torch
3
+ from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
4
+ from PIL import Image
5
+ import numpy as np
6
+ import io
7
+
8
+ def to_file(item) -> IO[bytes]:
9
+ # Create a BytesIO object
10
+ file_obj = io.BytesIO()
11
+ if isinstance(item, Image.Image):
12
+ item.save(file_obj, format='PNG')
13
+ if isinstance(item, np.ndarray):
14
+ np.save(file_obj, item)
15
+ # Reset the file object's position to the beginning
16
+ file_obj.seek(0)
17
+ # Return the file object
18
+ return file_obj
19
+
20
+ def get_sam(model_type, checkpoint_path, device=None):
21
+ if device is None:
22
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
23
+ sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
24
+ sam.to(device=device)
25
+ return sam
26
+
27
+ def draw_mask(img: Image.Image, boolean_mask: np.ndarray, color: tuple, mask_alpha: float) -> Image.Image:
28
+ int_alpha = int(mask_alpha*255)
29
+ color_mask = Image.new('RGBA', img.size, color=color)
30
+ color_mask.putalpha(Image.fromarray(boolean_mask.astype(np.uint8)*int_alpha, mode='L'))
31
+ result = Image.alpha_composite(img, color_mask)
32
+
33
+ return result
34
+ def random_color():
35
+ return tuple(np.random.randint(0,255, 3))
36
+
37
+ def draw_masks(img: Image.Image, boolean_masks: np.ndarray) -> Image.Image:
38
+ img = img.copy()
39
+ for boolean_mask in boolean_masks:
40
+ img = draw_mask(img, boolean_mask, random_color(), 0.2)
41
+ return img
42
+
43
+ def cutout(img: Image.Image, boolean_mask: np.ndarray):
44
+ rgba_img = img.convert('RGBA')
45
+ mask = Image.fromarray(boolean_mask).convert("L")
46
+ rgba_img.putalpha(mask)
47
+ return rgba_img
48
+
49
+
50
+ def predict_conditioned(sam, pil_img, **kwargs):
51
+ rgb_arr = pil_image_to_rgb_array(pil_img)
52
+ predictor = SamPredictor(sam)
53
+ predictor.set_image(rgb_arr)
54
+ masks, quality, _ = predictor.predict(**kwargs)
55
+ return masks, quality
56
+
57
+ def predict_all(sam, pil_img):
58
+ rgb_arr = pil_image_to_rgb_array(pil_img)
59
+ mask_generator = SamAutomaticMaskGenerator(sam)
60
+ results = mask_generator.generate(rgb_arr)
61
+ masks = []
62
+ quality = []
63
+ for result in results:
64
+ masks.append(result['segmentation'])
65
+ quality.append(result['stability_score'])
66
+ masks = np.array(masks)
67
+ quality = np.array(quality)
68
+ return masks, quality
69
+
70
+ def pil_image_to_rgb_array(image):
71
+ if image.mode == "RGBA":
72
+ rgb_image = Image.new("RGB", image.size, (255, 255, 255))
73
+ rgb_image.paste(image, mask=image.split()[3]) # Apply alpha channel as the mask
74
+ rgb_array = np.array(rgb_image)
75
+ else:
76
+ rgb_array = np.array(image.convert("RGB"))
77
+ return rgb_array
78
+
79
+ def box_pts_to_xyxy(pt1, pt2):
80
+ """convert box from pts format to XYXY
81
+ Args:
82
+ pt1 : (x1, y1) first corner of a box
83
+ pt2 : (x2, y2) second corner, diagonal to pt1
84
+
85
+ Returns:
86
+ xyxy: (x_min, y_min, x_max, y_max)
87
+ """
88
+ x1, y1 = pt1
89
+ x2, y2 = pt2
90
+ return (min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2))