Xenova's picture
Xenova HF staff
Upload 4 files
91efc35 verified
import {
SamModel,
AutoProcessor,
RawImage,
Tensor,
} from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.5";
// Reference the elements we will use
const statusLabel = document.getElementById("status");
const fileUpload = document.getElementById("upload");
const imageContainer = document.getElementById("container");
const example = document.getElementById("example");
const uploadButton = document.getElementById("upload-button");
const resetButton = document.getElementById("reset-image");
const clearButton = document.getElementById("clear-points");
const cutButton = document.getElementById("cut-mask");
const starIcon = document.getElementById("star-icon");
const crossIcon = document.getElementById("cross-icon");
const maskCanvas = document.getElementById("mask-output");
const maskContext = maskCanvas.getContext("2d");
const EXAMPLE_URL =
"https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg";
// State variables
let isEncoding = false;
let isDecoding = false;
let decodePending = false;
let lastPoints = null;
let isMultiMaskMode = false;
let imageInput = null;
let imageProcessed = null;
let imageEmbeddings = null;
async function decode() {
// Only proceed if we are not already decoding
if (isDecoding) {
decodePending = true;
return;
}
isDecoding = true;
// Prepare inputs for decoding
const reshaped = imageProcessed.reshaped_input_sizes[0];
const points = lastPoints
.map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
.flat(Infinity);
const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity);
const num_points = lastPoints.length;
const input_points = new Tensor("float32", points, [1, 1, num_points, 2]);
const input_labels = new Tensor("int64", labels, [1, 1, num_points]);
// Generate the mask
const { pred_masks, iou_scores } = await model({
...imageEmbeddings,
input_points,
input_labels,
});
// Post-process the mask
const masks = await processor.post_process_masks(
pred_masks,
imageProcessed.original_sizes,
imageProcessed.reshaped_input_sizes,
);
isDecoding = false;
updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data);
// Check if another decode is pending
if (decodePending) {
decodePending = false;
decode();
}
}
function updateMaskOverlay(mask, scores) {
// Update canvas dimensions (if different)
if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
maskCanvas.width = mask.width;
maskCanvas.height = mask.height;
}
// Allocate buffer for pixel data
const imageData = maskContext.createImageData(
maskCanvas.width,
maskCanvas.height,
);
// Select best mask
const numMasks = scores.length; // 3
let bestIndex = 0;
for (let i = 1; i < numMasks; ++i) {
if (scores[i] > scores[bestIndex]) {
bestIndex = i;
}
}
statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
// Fill mask with colour
const pixelData = imageData.data;
for (let i = 0; i < pixelData.length; ++i) {
if (mask.data[numMasks * i + bestIndex] === 1) {
const offset = 4 * i;
pixelData[offset] = 0; // red
pixelData[offset + 1] = 114; // green
pixelData[offset + 2] = 189; // blue
pixelData[offset + 3] = 255; // alpha
}
}
// Draw image data to context
maskContext.putImageData(imageData, 0, 0);
}
function clearPointsAndMask() {
// Reset state
isMultiMaskMode = false;
lastPoints = null;
// Remove points from previous mask (if any)
document.querySelectorAll(".icon").forEach((e) => e.remove());
// Disable cut button
cutButton.disabled = true;
// Reset mask canvas
maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
}
clearButton.addEventListener("click", clearPointsAndMask);
resetButton.addEventListener("click", () => {
// Reset the state
imageInput = null;
imageProcessed = null;
imageEmbeddings = null;
isEncoding = false;
isDecoding = false;
// Clear points and mask (if present)
clearPointsAndMask();
// Update UI
cutButton.disabled = true;
imageContainer.style.backgroundImage = "none";
uploadButton.style.display = "flex";
statusLabel.textContent = "Ready";
});
async function encode(url) {
if (isEncoding) return;
isEncoding = true;
statusLabel.textContent = "Extracting image embedding...";
imageInput = await RawImage.fromURL(url);
// Update UI
imageContainer.style.backgroundImage = `url(${url})`;
uploadButton.style.display = "none";
cutButton.disabled = true;
// Recompute image embeddings
imageProcessed = await processor(imageInput);
imageEmbeddings = await model.get_image_embeddings(imageProcessed);
statusLabel.textContent = "Embedding extracted!";
isEncoding = false;
}
// Handle file selection
fileUpload.addEventListener("change", function (e) {
const file = e.target.files[0];
if (!file) return;
const reader = new FileReader();
// Set up a callback when the file is loaded
reader.onload = (e2) => encode(e2.target.result);
reader.readAsDataURL(file);
});
example.addEventListener("click", (e) => {
e.preventDefault();
encode(EXAMPLE_URL);
});
// Attach hover event to image container
imageContainer.addEventListener("mousedown", (e) => {
if (e.button !== 0 && e.button !== 2) {
return; // Ignore other buttons
}
if (!imageEmbeddings) {
return; // Ignore if not encoded yet
}
if (!isMultiMaskMode) {
lastPoints = [];
isMultiMaskMode = true;
cutButton.disabled = false;
}
const point = getPoint(e);
lastPoints.push(point);
// add icon
const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode();
icon.style.left = `${point.position[0] * 100}%`;
icon.style.top = `${point.position[1] * 100}%`;
imageContainer.appendChild(icon);
// Run decode
decode();
});
// Clamp a value inside a range [min, max]
function clamp(x, min = 0, max = 1) {
return Math.max(Math.min(x, max), min);
}
function getPoint(e) {
// Get bounding box
const bb = imageContainer.getBoundingClientRect();
// Get the mouse coordinates relative to the container
const mouseX = clamp((e.clientX - bb.left) / bb.width);
const mouseY = clamp((e.clientY - bb.top) / bb.height);
return {
position: [mouseX, mouseY],
label:
e.button === 2 // right click
? 0 // negative prompt
: 1, // positive prompt
};
}
// Do not show context menu on right click
imageContainer.addEventListener("contextmenu", (e) => e.preventDefault());
// Attach hover event to image container
imageContainer.addEventListener("mousemove", (e) => {
if (!imageEmbeddings || isMultiMaskMode) {
// Ignore mousemove events if the image is not encoded yet,
// or we are in multi-mask mode
return;
}
lastPoints = [getPoint(e)];
decode();
});
// Handle cut button click
cutButton.addEventListener("click", async () => {
const [w, h] = [maskCanvas.width, maskCanvas.height];
// Get the mask pixel data (and use this as a buffer)
const maskImageData = maskContext.getImageData(0, 0, w, h);
// Create a new canvas to hold the cut-out
const cutCanvas = new OffscreenCanvas(w, h);
const cutContext = cutCanvas.getContext("2d");
// Copy the image pixel data to the cut canvas
const maskPixelData = maskImageData.data;
const imagePixelData = imageInput.data;
for (let i = 0; i < w * h; ++i) {
const sourceOffset = 3 * i; // RGB
const targetOffset = 4 * i; // RGBA
if (maskPixelData[targetOffset + 3] > 0) {
// Only copy opaque pixels
for (let j = 0; j < 3; ++j) {
maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j];
}
}
}
cutContext.putImageData(maskImageData, 0, 0);
// Download image
const link = document.createElement("a");
link.download = "image.png";
link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
link.click();
link.remove();
});
const model_id = "Xenova/slimsam-77-uniform";
statusLabel.textContent = "Loading model...";
const model = await SamModel.from_pretrained(model_id, {
dtype: "fp16", // or "fp32"
device: "webgpu",
});
const processor = await AutoProcessor.from_pretrained(model_id);
statusLabel.textContent = "Ready";
// Enable the user interface
fileUpload.disabled = false;
uploadButton.style.opacity = 1;
example.style.pointerEvents = "auto";