import { SamModel, AutoProcessor, RawImage, Tensor, } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.5"; // Reference the elements we will use const statusLabel = document.getElementById("status"); const fileUpload = document.getElementById("upload"); const imageContainer = document.getElementById("container"); const example = document.getElementById("example"); const uploadButton = document.getElementById("upload-button"); const resetButton = document.getElementById("reset-image"); const clearButton = document.getElementById("clear-points"); const cutButton = document.getElementById("cut-mask"); const starIcon = document.getElementById("star-icon"); const crossIcon = document.getElementById("cross-icon"); const maskCanvas = document.getElementById("mask-output"); const maskContext = maskCanvas.getContext("2d"); const EXAMPLE_URL = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg"; // State variables let isEncoding = false; let isDecoding = false; let decodePending = false; let lastPoints = null; let isMultiMaskMode = false; let imageInput = null; let imageProcessed = null; let imageEmbeddings = null; async function decode() { // Only proceed if we are not already decoding if (isDecoding) { decodePending = true; return; } isDecoding = true; // Prepare inputs for decoding const reshaped = imageProcessed.reshaped_input_sizes[0]; const points = lastPoints .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]]) .flat(Infinity); const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity); const num_points = lastPoints.length; const input_points = new Tensor("float32", points, [1, 1, num_points, 2]); const input_labels = new Tensor("int64", labels, [1, 1, num_points]); // Generate the mask const { pred_masks, iou_scores } = await model({ ...imageEmbeddings, input_points, input_labels, }); // Post-process the mask const masks = await processor.post_process_masks( pred_masks, imageProcessed.original_sizes, imageProcessed.reshaped_input_sizes, ); isDecoding = false; updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data); // Check if another decode is pending if (decodePending) { decodePending = false; decode(); } } function updateMaskOverlay(mask, scores) { // Update canvas dimensions (if different) if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { maskCanvas.width = mask.width; maskCanvas.height = mask.height; } // Allocate buffer for pixel data const imageData = maskContext.createImageData( maskCanvas.width, maskCanvas.height, ); // Select best mask const numMasks = scores.length; // 3 let bestIndex = 0; for (let i = 1; i < numMasks; ++i) { if (scores[i] > scores[bestIndex]) { bestIndex = i; } } statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; // Fill mask with colour const pixelData = imageData.data; for (let i = 0; i < pixelData.length; ++i) { if (mask.data[numMasks * i + bestIndex] === 1) { const offset = 4 * i; pixelData[offset] = 0; // red pixelData[offset + 1] = 114; // green pixelData[offset + 2] = 189; // blue pixelData[offset + 3] = 255; // alpha } } // Draw image data to context maskContext.putImageData(imageData, 0, 0); } function clearPointsAndMask() { // Reset state isMultiMaskMode = false; lastPoints = null; // Remove points from previous mask (if any) document.querySelectorAll(".icon").forEach((e) => e.remove()); // Disable cut button cutButton.disabled = true; // Reset mask canvas maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height); } clearButton.addEventListener("click", clearPointsAndMask); resetButton.addEventListener("click", () => { // Reset the state imageInput = null; imageProcessed = null; imageEmbeddings = null; isEncoding = false; isDecoding = false; // Clear points and mask (if present) clearPointsAndMask(); // Update UI cutButton.disabled = true; imageContainer.style.backgroundImage = "none"; uploadButton.style.display = "flex"; statusLabel.textContent = "Ready"; }); async function encode(url) { if (isEncoding) return; isEncoding = true; statusLabel.textContent = "Extracting image embedding..."; imageInput = await RawImage.fromURL(url); // Update UI imageContainer.style.backgroundImage = `url(${url})`; uploadButton.style.display = "none"; cutButton.disabled = true; // Recompute image embeddings imageProcessed = await processor(imageInput); imageEmbeddings = await model.get_image_embeddings(imageProcessed); statusLabel.textContent = "Embedding extracted!"; isEncoding = false; } // Handle file selection fileUpload.addEventListener("change", function (e) { const file = e.target.files[0]; if (!file) return; const reader = new FileReader(); // Set up a callback when the file is loaded reader.onload = (e2) => encode(e2.target.result); reader.readAsDataURL(file); }); example.addEventListener("click", (e) => { e.preventDefault(); encode(EXAMPLE_URL); }); // Attach hover event to image container imageContainer.addEventListener("mousedown", (e) => { if (e.button !== 0 && e.button !== 2) { return; // Ignore other buttons } if (!imageEmbeddings) { return; // Ignore if not encoded yet } if (!isMultiMaskMode) { lastPoints = []; isMultiMaskMode = true; cutButton.disabled = false; } const point = getPoint(e); lastPoints.push(point); // add icon const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode(); icon.style.left = `${point.position[0] * 100}%`; icon.style.top = `${point.position[1] * 100}%`; imageContainer.appendChild(icon); // Run decode decode(); }); // Clamp a value inside a range [min, max] function clamp(x, min = 0, max = 1) { return Math.max(Math.min(x, max), min); } function getPoint(e) { // Get bounding box const bb = imageContainer.getBoundingClientRect(); // Get the mouse coordinates relative to the container const mouseX = clamp((e.clientX - bb.left) / bb.width); const mouseY = clamp((e.clientY - bb.top) / bb.height); return { position: [mouseX, mouseY], label: e.button === 2 // right click ? 0 // negative prompt : 1, // positive prompt }; } // Do not show context menu on right click imageContainer.addEventListener("contextmenu", (e) => e.preventDefault()); // Attach hover event to image container imageContainer.addEventListener("mousemove", (e) => { if (!imageEmbeddings || isMultiMaskMode) { // Ignore mousemove events if the image is not encoded yet, // or we are in multi-mask mode return; } lastPoints = [getPoint(e)]; decode(); }); // Handle cut button click cutButton.addEventListener("click", async () => { const [w, h] = [maskCanvas.width, maskCanvas.height]; // Get the mask pixel data (and use this as a buffer) const maskImageData = maskContext.getImageData(0, 0, w, h); // Create a new canvas to hold the cut-out const cutCanvas = new OffscreenCanvas(w, h); const cutContext = cutCanvas.getContext("2d"); // Copy the image pixel data to the cut canvas const maskPixelData = maskImageData.data; const imagePixelData = imageInput.data; for (let i = 0; i < w * h; ++i) { const sourceOffset = 3 * i; // RGB const targetOffset = 4 * i; // RGBA if (maskPixelData[targetOffset + 3] > 0) { // Only copy opaque pixels for (let j = 0; j < 3; ++j) { maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j]; } } } cutContext.putImageData(maskImageData, 0, 0); // Download image const link = document.createElement("a"); link.download = "image.png"; link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); link.click(); link.remove(); }); const model_id = "Xenova/slimsam-77-uniform"; statusLabel.textContent = "Loading model..."; const model = await SamModel.from_pretrained(model_id, { dtype: "fp16", // or "fp32" device: "webgpu", }); const processor = await AutoProcessor.from_pretrained(model_id); statusLabel.textContent = "Ready"; // Enable the user interface fileUpload.disabled = false; uploadButton.style.opacity = 1; example.style.pointerEvents = "auto";