Xenova HF staff commited on
Commit
099bf4d
1 Parent(s): 1e0706b

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. index.css +119 -119
  3. index.html +65 -42
  4. index.js +234 -263
README.md CHANGED
@@ -6,9 +6,9 @@ colorTo: yellow
6
  sdk: static
7
  pinned: false
8
  models:
9
- - Xenova/slimsam-77-uniform
10
  license: apache-2.0
11
  short_description: In-browser image segmentation w/ 🤗 Transformers.js
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
6
  sdk: static
7
  pinned: false
8
  models:
9
+ - Xenova/slimsam-77-uniform
10
  license: apache-2.0
11
  short_description: In-browser image segmentation w/ 🤗 Transformers.js
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
index.css CHANGED
@@ -1,119 +1,119 @@
1
- * {
2
- box-sizing: border-box;
3
- padding: 0;
4
- margin: 0;
5
- font-family: sans-serif;
6
- }
7
-
8
- html,
9
- body {
10
- height: 100%;
11
- }
12
-
13
- body {
14
- padding: 16px 32px;
15
- }
16
-
17
- body,
18
- #container,
19
- #upload-button {
20
- display: flex;
21
- flex-direction: column;
22
- justify-content: center;
23
- align-items: center;
24
- }
25
-
26
- h1,
27
- h3 {
28
- text-align: center;
29
- }
30
-
31
- #container {
32
- position: relative;
33
- width: 640px;
34
- height: 420px;
35
- max-width: 100%;
36
- max-height: 100%;
37
- border: 2px dashed #D1D5DB;
38
- border-radius: 0.75rem;
39
- overflow: hidden;
40
- cursor: pointer;
41
- margin-top: 1rem;
42
- background-size: 100% 100%;
43
- background-position: center;
44
- background-repeat: no-repeat;
45
- }
46
-
47
- #mask-output {
48
- position: absolute;
49
- width: 100%;
50
- height: 100%;
51
- pointer-events: none;
52
- }
53
-
54
- #upload-button {
55
- gap: 0.4rem;
56
- font-size: 18px;
57
- cursor: pointer;
58
- opacity: 0.2;
59
- }
60
-
61
- #upload {
62
- display: none;
63
- }
64
-
65
- svg {
66
- pointer-events: none;
67
- }
68
-
69
- #example {
70
- font-size: 14px;
71
- text-decoration: underline;
72
- cursor: pointer;
73
- pointer-events: none;
74
- }
75
-
76
- #example:hover {
77
- color: #2563EB;
78
- }
79
-
80
- canvas {
81
- position: absolute;
82
- width: 100%;
83
- height: 100%;
84
- opacity: 0.6;
85
- }
86
-
87
- #status {
88
- min-height: 16px;
89
- margin: 8px 0;
90
- }
91
-
92
- .icon {
93
- height: 16px;
94
- width: 16px;
95
- position: absolute;
96
- transform: translate(-50%, -50%);
97
- }
98
-
99
- #controls>button {
100
- padding: 6px 12px;
101
- background-color: #3498db;
102
- color: white;
103
- border: 1px solid #2980b9;
104
- border-radius: 5px;
105
- cursor: pointer;
106
- font-size: 16px;
107
- }
108
-
109
- #controls>button:disabled {
110
- background-color: #d1d5db;
111
- color: #6b7280;
112
- border: 1px solid #9ca3af;
113
- cursor: not-allowed;
114
- }
115
-
116
- #information {
117
- margin-top: 0.25rem;
118
- font-size: 15px;
119
- }
 
1
+ * {
2
+ box-sizing: border-box;
3
+ padding: 0;
4
+ margin: 0;
5
+ font-family: sans-serif;
6
+ }
7
+
8
+ html,
9
+ body {
10
+ height: 100%;
11
+ }
12
+
13
+ body {
14
+ padding: 16px 32px;
15
+ }
16
+
17
+ body,
18
+ #container,
19
+ #upload-button {
20
+ display: flex;
21
+ flex-direction: column;
22
+ justify-content: center;
23
+ align-items: center;
24
+ }
25
+
26
+ h1,
27
+ h3 {
28
+ text-align: center;
29
+ }
30
+
31
+ #container {
32
+ position: relative;
33
+ width: 640px;
34
+ height: 420px;
35
+ max-width: 100%;
36
+ max-height: 100%;
37
+ border: 2px dashed #d1d5db;
38
+ border-radius: 0.75rem;
39
+ overflow: hidden;
40
+ cursor: pointer;
41
+ margin-top: 1rem;
42
+ background-size: 100% 100%;
43
+ background-position: center;
44
+ background-repeat: no-repeat;
45
+ }
46
+
47
+ #mask-output {
48
+ position: absolute;
49
+ width: 100%;
50
+ height: 100%;
51
+ pointer-events: none;
52
+ }
53
+
54
+ #upload-button {
55
+ gap: 0.4rem;
56
+ font-size: 18px;
57
+ cursor: pointer;
58
+ opacity: 0.2;
59
+ }
60
+
61
+ #upload {
62
+ display: none;
63
+ }
64
+
65
+ svg {
66
+ pointer-events: none;
67
+ }
68
+
69
+ #example {
70
+ font-size: 14px;
71
+ text-decoration: underline;
72
+ cursor: pointer;
73
+ pointer-events: none;
74
+ }
75
+
76
+ #example:hover {
77
+ color: #2563eb;
78
+ }
79
+
80
+ canvas {
81
+ position: absolute;
82
+ width: 100%;
83
+ height: 100%;
84
+ opacity: 0.6;
85
+ }
86
+
87
+ #status {
88
+ min-height: 16px;
89
+ margin: 8px 0;
90
+ }
91
+
92
+ .icon {
93
+ height: 16px;
94
+ width: 16px;
95
+ position: absolute;
96
+ transform: translate(-50%, -50%);
97
+ }
98
+
99
+ #controls > button {
100
+ padding: 6px 12px;
101
+ background-color: #3498db;
102
+ color: white;
103
+ border: 1px solid #2980b9;
104
+ border-radius: 5px;
105
+ cursor: pointer;
106
+ font-size: 16px;
107
+ }
108
+
109
+ #controls > button:disabled {
110
+ background-color: #d1d5db;
111
+ color: #6b7280;
112
+ border: 1px solid #9ca3af;
113
+ cursor: not-allowed;
114
+ }
115
+
116
+ #information {
117
+ margin-top: 0.25rem;
118
+ font-size: 15px;
119
+ }
index.html CHANGED
@@ -1,42 +1,65 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
-
4
- <head>
5
- <meta charset="UTF-8" />
6
- <link rel="stylesheet" href="index.css" />
7
-
8
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
9
- <title>Transformers.js - Segment Anything WebGPU</title>
10
- </head>
11
-
12
- <body>
13
- <h1>Segment Anything WebGPU</h1>
14
- <h3>In-browser image segmentation w/ <a href="https://hf.co/docs/transformers.js" target="_blank">🤗
15
- Transformers.js</a></h3>
16
- <div id="container">
17
- <label id="upload-button" for="upload">
18
- <svg width="25" height="25" viewBox="0 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg">
19
- <path fill="#000"
20
- d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z">
21
- </path>
22
- </svg>
23
- Click to upload image
24
- <label id="example">(or try example)</label>
25
- </label>
26
- <canvas id="mask-output"></canvas>
27
- </div>
28
- <label id="status"></label>
29
- <div id="controls">
30
- <button id="reset-image">Reset image</button>
31
- <button id="clear-points">Clear points</button>
32
- <button id="cut-mask" disabled>Cut mask</button>
33
- </div>
34
- <p id="information">
35
- Left click = positive points, right click = negative points.
36
- </p>
37
- <input id="upload" type="file" accept="image/*" disabled />
38
-
39
- <script src="index.js" type="module"></script>
40
- </body>
41
-
42
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="stylesheet" href="index.css" />
6
+
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
8
+ <title>Segment Anything WebGPU | Transformers.js</title>
9
+ </head>
10
+
11
+ <body>
12
+ <h1>Segment Anything WebGPU</h1>
13
+ <h3>
14
+ In-browser image segmentation w/
15
+ <a href="https://hf.co/docs/transformers.js" target="_blank"
16
+ >🤗 Transformers.js</a
17
+ >
18
+ </h3>
19
+ <div id="container">
20
+ <label id="upload-button" for="upload">
21
+ <svg
22
+ width="25"
23
+ height="25"
24
+ viewBox="0 0 25 25"
25
+ fill="none"
26
+ xmlns="http://www.w3.org/2000/svg"
27
+ >
28
+ <path
29
+ fill="#000"
30
+ d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
31
+ ></path>
32
+ </svg>
33
+ Click to upload image
34
+ <label id="example">(or try example)</label>
35
+ </label>
36
+ <image id="image"></image>
37
+ <canvas id="mask-output"></canvas>
38
+ </div>
39
+ <label id="status"></label>
40
+ <div id="controls">
41
+ <button id="reset-image">Reset image</button>
42
+ <button id="clear-points">Clear points</button>
43
+ <button id="cut-mask" disabled>Cut mask</button>
44
+ </div>
45
+ <p id="information">
46
+ Left click = positive points, right click = negative points.
47
+ </p>
48
+ <input id="upload" type="file" accept="image/*" disabled />
49
+
50
+ <div style="display: none">
51
+ <!-- Preload star and cross images to avoid lag on first click -->
52
+ <img
53
+ id="star-icon"
54
+ class="icon"
55
+ src="https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/star-icon.png"
56
+ />
57
+ <img
58
+ id="cross-icon"
59
+ class="icon"
60
+ src="https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cross-icon.png"
61
+ />
62
+ </div>
63
+ <script src="index.js" type="module"></script>
64
+ </body>
65
+ </html>
index.js CHANGED
@@ -1,325 +1,296 @@
1
- import { SamModel, AutoProcessor, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.4';
 
 
 
 
 
2
 
3
  // Reference the elements we will use
4
- const statusLabel = document.getElementById('status');
5
- const fileUpload = document.getElementById('upload');
6
- const imageContainer = document.getElementById('container');
7
- const example = document.getElementById('example');
8
- const maskCanvas = document.getElementById('mask-output');
9
- const uploadButton = document.getElementById('upload-button');
10
- const resetButton = document.getElementById('reset-image');
11
- const clearButton = document.getElementById('clear-points');
12
- const cutButton = document.getElementById('cut-mask');
13
-
14
- // Constants
15
- const BASE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/';
16
- const EXAMPLE_URL = BASE_URL + 'corgi.jpg';
17
-
18
- // Preload star and cross images to avoid lag on first click
19
- const star = new Image();
20
- star.src = BASE_URL + 'star-icon.png';
21
- star.className = 'icon';
22
-
23
- const cross = new Image();
24
- cross.src = BASE_URL + 'cross-icon.png';
25
- cross.className = 'icon';
26
 
27
  // State variables
28
- let lastPoints = null;
29
  let isDecoding = false;
 
 
30
  let isMultiMaskMode = false;
31
- let imageDataURI = null;
32
- let imageInputs = null;
33
  let imageEmbeddings = null;
34
 
35
  async function decode() {
36
- if (!imageInputs || !imageEmbeddings) {
37
- return;
38
- }
39
- isDecoding = true;
40
-
41
- // Prepare inputs for decoding
42
- const reshaped = imageInputs.reshaped_input_sizes[0];
43
- const points = lastPoints.map(x => [x.point[0] * reshaped[1], x.point[1] * reshaped[0]])
44
- const labels = lastPoints.map(x => BigInt(x.label));
45
-
46
- const input_points = new Tensor(
47
- 'float32',
48
- points.flat(Infinity),
49
- [1, 1, points.length, 2],
50
- )
51
- const input_labels = new Tensor(
52
- 'int64',
53
- labels.flat(Infinity),
54
- [1, 1, labels.length],
55
- )
56
-
57
- // Generate the mask
58
- const { pred_masks, iou_scores } = await model({
59
- ...imageEmbeddings,
60
- input_points,
61
- input_labels,
62
- })
63
-
64
- // Post-process the mask
65
- const masks = await processor.post_process_masks(
66
- pred_masks,
67
- imageInputs.original_sizes,
68
- imageInputs.reshaped_input_sizes,
69
- );
70
-
71
- const data = {
72
- mask: RawImage.fromTensor(masks[0][0]),
73
- scores: iou_scores.data,
74
- };
75
- isDecoding = false;
76
-
77
- if (!isMultiMaskMode && lastPoints) {
78
- // Perform decoding with the last point
79
- decode();
80
- lastPoints = null;
81
- }
82
-
83
- const { mask, scores } = data;
84
-
85
- // Update canvas dimensions (if different)
86
- if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
87
- maskCanvas.width = mask.width;
88
- maskCanvas.height = mask.height;
89
- }
90
 
91
- // Create context and allocate buffer for pixel data
92
- const context = maskCanvas.getContext('2d');
93
- const imageData = context.createImageData(maskCanvas.width, maskCanvas.height);
94
-
95
- // Select best mask
96
- const numMasks = scores.length; // 3
97
- let bestIndex = 0;
98
- for (let i = 1; i < numMasks; ++i) {
99
- if (scores[i] > scores[bestIndex]) {
100
- bestIndex = i;
101
- }
 
 
 
 
 
 
 
 
102
  }
103
- statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
104
-
105
- // Fill mask with colour
106
- const pixelData = imageData.data;
107
- for (let i = 0; i < pixelData.length; ++i) {
108
- if (mask.data[numMasks * i + bestIndex] === 1) {
109
- const offset = 4 * i;
110
- pixelData[offset] = 0; // red
111
- pixelData[offset + 1] = 114; // green
112
- pixelData[offset + 2] = 189; // blue
113
- pixelData[offset + 3] = 255; // alpha
114
- }
115
  }
 
116
 
117
- // Draw image data to context
118
- context.putImageData(imageData, 0, 0);
119
  }
120
 
121
  function clearPointsAndMask() {
122
- // Reset state
123
- isMultiMaskMode = false;
124
- lastPoints = null;
125
 
126
- // Remove points from previous mask (if any)
127
- document.querySelectorAll('.icon').forEach(e => e.remove());
128
 
129
- // Disable cut button
130
- cutButton.disabled = true;
131
 
132
- // Reset mask canvas
133
- maskCanvas.getContext('2d').clearRect(0, 0, maskCanvas.width, maskCanvas.height);
134
  }
135
- clearButton.addEventListener('click', clearPointsAndMask);
136
-
137
- resetButton.addEventListener('click', () => {
138
- // Update state
139
- imageEmbeddings = null;
140
- imageDataURI = null;
141
-
142
- // Reset the state
143
- imageInputs = null;
144
- imageEmbeddings = null;
145
- isDecoding = false;
146
-
147
- // Clear points and mask (if present)
148
- clearPointsAndMask();
149
-
150
- // Update UI
151
- cutButton.disabled = true;
152
- imageContainer.style.backgroundImage = 'none';
153
- uploadButton.style.display = 'flex';
154
- statusLabel.textContent = 'Ready';
155
  });
156
 
157
- async function segment(data) {
158
- statusLabel.textContent = 'Extracting image embedding...';
159
 
160
- // Update state
161
- imageEmbeddings = null;
162
- imageDataURI = data;
163
 
164
- // Update UI
165
- imageContainer.style.backgroundImage = `url(${data})`;
166
- uploadButton.style.display = 'none';
167
- cutButton.disabled = true;
168
 
169
- // Read the image and recompute image embeddings
170
- const image = await RawImage.read(data);
171
- imageInputs = await processor(image);
172
- imageEmbeddings = await model.get_image_embeddings(imageInputs)
173
 
174
- statusLabel.textContent = 'Embedding extracted!';
175
  }
176
 
177
  // Handle file selection
178
- fileUpload.addEventListener('change', function (e) {
179
- const file = e.target.files[0];
180
- if (!file) {
181
- return;
182
- }
183
 
184
- const reader = new FileReader();
185
 
186
- // Set up a callback when the file is loaded
187
- reader.onload = e2 => segment(e2.target.result);
188
 
189
- reader.readAsDataURL(file);
190
  });
191
 
192
- example.addEventListener('click', (e) => {
193
- e.preventDefault();
194
- segment(EXAMPLE_URL);
195
  });
196
 
197
- function addIcon({ point, label }) {
198
- const icon = (label === 1 ? star : cross).cloneNode();
199
- icon.style.left = `${point[0] * 100}%`;
200
- icon.style.top = `${point[1] * 100}%`;
201
- imageContainer.appendChild(icon);
202
- }
203
-
204
  // Attach hover event to image container
205
- imageContainer.addEventListener('mousedown', e => {
206
- if (e.button !== 0 && e.button !== 2) {
207
- return; // Ignore other buttons
208
- }
209
- if (!imageEmbeddings) {
210
- return; // Ignore if not encoded yet
211
- }
212
- if (!isMultiMaskMode) {
213
- lastPoints = [];
214
- isMultiMaskMode = true;
215
- cutButton.disabled = false;
216
- }
217
-
218
- const point = getPoint(e);
219
- lastPoints.push(point);
220
-
221
- // add icon
222
- addIcon(point);
223
-
224
- decode();
 
 
 
 
225
  });
226
 
227
-
228
  // Clamp a value inside a range [min, max]
229
  function clamp(x, min = 0, max = 1) {
230
- return Math.max(Math.min(x, max), min)
231
  }
232
 
233
  function getPoint(e) {
234
- // Get bounding box
235
- const bb = imageContainer.getBoundingClientRect();
236
-
237
- // Get the mouse coordinates relative to the container
238
- const mouseX = clamp((e.clientX - bb.left) / bb.width);
239
- const mouseY = clamp((e.clientY - bb.top) / bb.height);
240
-
241
- return {
242
- point: [mouseX, mouseY],
243
- label: e.button === 2 // right click
244
- ? 0 // negative prompt
245
- : 1, // positive prompt
246
- }
 
247
  }
248
 
249
  // Do not show context menu on right click
250
- imageContainer.addEventListener('contextmenu', e => {
251
- e.preventDefault();
252
- });
253
 
254
  // Attach hover event to image container
255
- imageContainer.addEventListener('mousemove', e => {
256
- if (!imageEmbeddings || isMultiMaskMode) {
257
- // Ignore mousemove events if the image is not encoded yet,
258
- // or we are in multi-mask mode
259
- return;
260
- }
261
- lastPoints = [getPoint(e)];
262
-
263
- if (!isDecoding) {
264
- decode(); // Only decode if we are not already decoding
265
- }
266
  });
267
 
268
  // Handle cut button click
269
- cutButton.addEventListener('click', () => {
270
- const [w, h] = [maskCanvas.width, maskCanvas.height];
271
-
272
- // Get the mask pixel data
273
- const maskContext = maskCanvas.getContext('2d');
274
- const maskPixelData = maskContext.getImageData(0, 0, w, h);
275
-
276
- // Load the image
277
- const image = new Image();
278
- image.crossOrigin = 'anonymous';
279
- image.onload = async () => {
280
- // Create a new canvas to hold the image
281
- const imageCanvas = new OffscreenCanvas(w, h);
282
- const imageContext = imageCanvas.getContext('2d');
283
- imageContext.drawImage(image, 0, 0, w, h);
284
- const imagePixelData = imageContext.getImageData(0, 0, w, h);
285
-
286
- // Create a new canvas to hold the cut-out
287
- const cutCanvas = new OffscreenCanvas(w, h);
288
- const cutContext = cutCanvas.getContext('2d');
289
- const cutPixelData = cutContext.getImageData(0, 0, w, h);
290
-
291
- // Copy the image pixel data to the cut canvas
292
- for (let i = 3; i < maskPixelData.data.length; i += 4) {
293
- if (maskPixelData.data[i] > 0) {
294
- for (let j = 0; j < 4; ++j) {
295
- const offset = i - j;
296
- cutPixelData.data[offset] = imagePixelData.data[offset];
297
- }
298
- }
299
- }
300
- cutContext.putImageData(cutPixelData, 0, 0);
301
-
302
- // Download image
303
- const link = document.createElement('a');
304
- link.download = 'image.png';
305
- link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
306
- link.click();
307
- link.remove();
308
  }
309
- image.src = imageDataURI;
 
 
 
 
 
 
 
 
310
  });
311
 
312
-
313
- const model_id = 'Xenova/slimsam-77-uniform';
314
- statusLabel.textContent = 'Loading model...';
315
  const model = await SamModel.from_pretrained(model_id, {
316
- dtype: 'fp16',
317
- device: 'webgpu',
318
  });
319
  const processor = await AutoProcessor.from_pretrained(model_id);
320
- statusLabel.textContent = 'Ready';
321
 
322
  // Enable the user interface
323
  fileUpload.disabled = false;
324
  uploadButton.style.opacity = 1;
325
- example.style.pointerEvents = 'auto';
 
1
+ import {
2
+ SamModel,
3
+ AutoProcessor,
4
+ RawImage,
5
+ Tensor,
6
+ } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.5";
7
 
8
  // Reference the elements we will use
9
+ const statusLabel = document.getElementById("status");
10
+ const fileUpload = document.getElementById("upload");
11
+ const imageContainer = document.getElementById("container");
12
+ const example = document.getElementById("example");
13
+ const uploadButton = document.getElementById("upload-button");
14
+ const resetButton = document.getElementById("reset-image");
15
+ const clearButton = document.getElementById("clear-points");
16
+ const cutButton = document.getElementById("cut-mask");
17
+ const starIcon = document.getElementById("star-icon");
18
+ const crossIcon = document.getElementById("cross-icon");
19
+ const maskCanvas = document.getElementById("mask-output");
20
+ const maskContext = maskCanvas.getContext("2d");
21
+
22
+ const EXAMPLE_URL =
23
+ "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg";
 
 
 
 
 
 
 
24
 
25
  // State variables
 
26
  let isDecoding = false;
27
+ let decodePending = false;
28
+ let lastPoints = null;
29
  let isMultiMaskMode = false;
30
+ let imageInput = null;
31
+ let imageProcessed = null;
32
  let imageEmbeddings = null;
33
 
34
  async function decode() {
35
+ // Only proceed if we are not already decoding
36
+ if (isDecoding) {
37
+ decodePending = true;
38
+ return;
39
+ }
40
+ isDecoding = true;
41
+
42
+ // Prepare inputs for decoding
43
+ const reshaped = imageProcessed.reshaped_input_sizes[0];
44
+ const points = lastPoints
45
+ .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
46
+ .flat(Infinity);
47
+ const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity);
48
+
49
+ const num_points = lastPoints.length;
50
+ const input_points = new Tensor("float32", points, [1, 1, num_points, 2]);
51
+ const input_labels = new Tensor("int64", labels, [1, 1, num_points]);
52
+
53
+ // Generate the mask
54
+ const { pred_masks, iou_scores } = await model({
55
+ ...imageEmbeddings,
56
+ input_points,
57
+ input_labels,
58
+ });
59
+
60
+ // Post-process the mask
61
+ const masks = await processor.post_process_masks(
62
+ pred_masks,
63
+ imageProcessed.original_sizes,
64
+ imageProcessed.reshaped_input_sizes,
65
+ );
66
+
67
+ isDecoding = false;
68
+
69
+ updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data);
70
+
71
+ // Check if another decode is pending
72
+ if (decodePending) {
73
+ decodePending = false;
74
+ decode();
75
+ }
76
+ }
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ function updateMaskOverlay(mask, scores) {
79
+ // Update canvas dimensions (if different)
80
+ if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
81
+ maskCanvas.width = mask.width;
82
+ maskCanvas.height = mask.height;
83
+ }
84
+
85
+ // Allocate buffer for pixel data
86
+ const imageData = maskContext.createImageData(
87
+ maskCanvas.width,
88
+ maskCanvas.height,
89
+ );
90
+
91
+ // Select best mask
92
+ const numMasks = scores.length; // 3
93
+ let bestIndex = 0;
94
+ for (let i = 1; i < numMasks; ++i) {
95
+ if (scores[i] > scores[bestIndex]) {
96
+ bestIndex = i;
97
  }
98
+ }
99
+ statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
100
+
101
+ // Fill mask with colour
102
+ const pixelData = imageData.data;
103
+ for (let i = 0; i < pixelData.length; ++i) {
104
+ if (mask.data[numMasks * i + bestIndex] === 1) {
105
+ const offset = 4 * i;
106
+ pixelData[offset] = 0; // red
107
+ pixelData[offset + 1] = 114; // green
108
+ pixelData[offset + 2] = 189; // blue
109
+ pixelData[offset + 3] = 255; // alpha
110
  }
111
+ }
112
 
113
+ // Draw image data to context
114
+ maskContext.putImageData(imageData, 0, 0);
115
  }
116
 
117
  function clearPointsAndMask() {
118
+ // Reset state
119
+ isMultiMaskMode = false;
120
+ lastPoints = null;
121
 
122
+ // Remove points from previous mask (if any)
123
+ document.querySelectorAll(".icon").forEach((e) => e.remove());
124
 
125
+ // Disable cut button
126
+ cutButton.disabled = true;
127
 
128
+ // Reset mask canvas
129
+ maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
130
  }
131
+ clearButton.addEventListener("click", clearPointsAndMask);
132
+
133
+ resetButton.addEventListener("click", () => {
134
+ // Reset the state
135
+ imageInput = null;
136
+ imageProcessed = null;
137
+ imageEmbeddings = null;
138
+ isDecoding = false;
139
+
140
+ // Clear points and mask (if present)
141
+ clearPointsAndMask();
142
+
143
+ // Update UI
144
+ cutButton.disabled = true;
145
+ imageContainer.style.backgroundImage = "none";
146
+ uploadButton.style.display = "flex";
147
+ statusLabel.textContent = "Ready";
 
 
 
148
  });
149
 
150
+ async function segment(url) {
151
+ imageInput = await RawImage.fromURL(url);
152
 
153
+ statusLabel.textContent = "Extracting image embedding...";
 
 
154
 
155
+ // Update UI
156
+ imageContainer.style.backgroundImage = `url(${url})`;
157
+ uploadButton.style.display = "none";
158
+ cutButton.disabled = true;
159
 
160
+ // Recompute image embeddings
161
+ imageProcessed = await processor(imageInput);
162
+ imageEmbeddings = await model.get_image_embeddings(imageProcessed);
 
163
 
164
+ statusLabel.textContent = "Embedding extracted!";
165
  }
166
 
167
  // Handle file selection
168
+ fileUpload.addEventListener("change", function (e) {
169
+ const file = e.target.files[0];
170
+ if (!file) return;
 
 
171
 
172
+ const reader = new FileReader();
173
 
174
+ // Set up a callback when the file is loaded
175
+ reader.onload = (e2) => segment(e2.target.result);
176
 
177
+ reader.readAsDataURL(file);
178
  });
179
 
180
+ example.addEventListener("click", (e) => {
181
+ e.preventDefault();
182
+ segment(EXAMPLE_URL);
183
  });
184
 
 
 
 
 
 
 
 
185
  // Attach hover event to image container
186
+ imageContainer.addEventListener("mousedown", (e) => {
187
+ if (e.button !== 0 && e.button !== 2) {
188
+ return; // Ignore other buttons
189
+ }
190
+ if (!imageEmbeddings) {
191
+ return; // Ignore if not encoded yet
192
+ }
193
+ if (!isMultiMaskMode) {
194
+ lastPoints = [];
195
+ isMultiMaskMode = true;
196
+ cutButton.disabled = false;
197
+ }
198
+
199
+ const point = getPoint(e);
200
+ lastPoints.push(point);
201
+
202
+ // add icon
203
+ const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode();
204
+ icon.style.left = `${point.position[0] * 100}%`;
205
+ icon.style.top = `${point.position[1] * 100}%`;
206
+ imageContainer.appendChild(icon);
207
+
208
+ // Run decode
209
+ decode();
210
  });
211
 
 
212
  // Clamp a value inside a range [min, max]
213
  function clamp(x, min = 0, max = 1) {
214
+ return Math.max(Math.min(x, max), min);
215
  }
216
 
217
  function getPoint(e) {
218
+ // Get bounding box
219
+ const bb = imageContainer.getBoundingClientRect();
220
+
221
+ // Get the mouse coordinates relative to the container
222
+ const mouseX = clamp((e.clientX - bb.left) / bb.width);
223
+ const mouseY = clamp((e.clientY - bb.top) / bb.height);
224
+
225
+ return {
226
+ position: [mouseX, mouseY],
227
+ label:
228
+ e.button === 2 // right click
229
+ ? 0 // negative prompt
230
+ : 1, // positive prompt
231
+ };
232
  }
233
 
234
  // Do not show context menu on right click
235
+ imageContainer.addEventListener("contextmenu", (e) => e.preventDefault());
 
 
236
 
237
  // Attach hover event to image container
238
+ imageContainer.addEventListener("mousemove", (e) => {
239
+ if (!imageEmbeddings || isMultiMaskMode) {
240
+ // Ignore mousemove events if the image is not encoded yet,
241
+ // or we are in multi-mask mode
242
+ return;
243
+ }
244
+ lastPoints = [getPoint(e)];
245
+
246
+ decode();
 
 
247
  });
248
 
249
  // Handle cut button click
250
+ cutButton.addEventListener("click", async () => {
251
+ const [w, h] = [maskCanvas.width, maskCanvas.height];
252
+
253
+ // Get the mask pixel data (and use this as a buffer)
254
+ const maskImageData = maskContext.getImageData(0, 0, w, h);
255
+
256
+ // Create a new canvas to hold the cut-out
257
+ const cutCanvas = new OffscreenCanvas(w, h);
258
+ const cutContext = cutCanvas.getContext("2d");
259
+
260
+ // Copy the image pixel data to the cut canvas
261
+ const maskPixelData = maskImageData.data;
262
+ const imagePixelData = imageInput.data;
263
+ for (let i = 0; i < w * h; ++i) {
264
+ const sourceOffset = 3 * i; // RGB
265
+ const targetOffset = 4 * i; // RGBA
266
+
267
+ if (maskPixelData[targetOffset + 3] > 0) {
268
+ // Only copy opaque pixels
269
+ for (let j = 0; j < 3; ++j) {
270
+ maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j];
271
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  }
273
+ }
274
+ cutContext.putImageData(maskImageData, 0, 0);
275
+
276
+ // Download image
277
+ const link = document.createElement("a");
278
+ link.download = "image.png";
279
+ link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
280
+ link.click();
281
+ link.remove();
282
  });
283
 
284
+ const model_id = "Xenova/slimsam-77-uniform";
285
+ statusLabel.textContent = "Loading model...";
 
286
  const model = await SamModel.from_pretrained(model_id, {
287
+ dtype: "fp16", // or "fp32"
288
+ device: "webgpu",
289
  });
290
  const processor = await AutoProcessor.from_pretrained(model_id);
291
+ statusLabel.textContent = "Ready";
292
 
293
  // Enable the user interface
294
  fileUpload.disabled = false;
295
  uploadButton.style.opacity = 1;
296
+ example.style.pointerEvents = "auto";