|
import { |
|
SamModel, |
|
AutoProcessor, |
|
RawImage, |
|
Tensor, |
|
} from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.5"; |
|
|
|
|
|
const statusLabel = document.getElementById("status"); |
|
const fileUpload = document.getElementById("upload"); |
|
const imageContainer = document.getElementById("container"); |
|
const example = document.getElementById("example"); |
|
const uploadButton = document.getElementById("upload-button"); |
|
const resetButton = document.getElementById("reset-image"); |
|
const clearButton = document.getElementById("clear-points"); |
|
const cutButton = document.getElementById("cut-mask"); |
|
const starIcon = document.getElementById("star-icon"); |
|
const crossIcon = document.getElementById("cross-icon"); |
|
const maskCanvas = document.getElementById("mask-output"); |
|
const maskContext = maskCanvas.getContext("2d"); |
|
|
|
const EXAMPLE_URL = |
|
"https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg"; |
|
|
|
|
|
let isEncoding = false; |
|
let isDecoding = false; |
|
let decodePending = false; |
|
let lastPoints = null; |
|
let isMultiMaskMode = false; |
|
let imageInput = null; |
|
let imageProcessed = null; |
|
let imageEmbeddings = null; |
|
|
|
async function decode() { |
|
|
|
if (isDecoding) { |
|
decodePending = true; |
|
return; |
|
} |
|
isDecoding = true; |
|
|
|
|
|
const reshaped = imageProcessed.reshaped_input_sizes[0]; |
|
const points = lastPoints |
|
.map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]]) |
|
.flat(Infinity); |
|
const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity); |
|
|
|
const num_points = lastPoints.length; |
|
const input_points = new Tensor("float32", points, [1, 1, num_points, 2]); |
|
const input_labels = new Tensor("int64", labels, [1, 1, num_points]); |
|
|
|
|
|
const { pred_masks, iou_scores } = await model({ |
|
...imageEmbeddings, |
|
input_points, |
|
input_labels, |
|
}); |
|
|
|
|
|
const masks = await processor.post_process_masks( |
|
pred_masks, |
|
imageProcessed.original_sizes, |
|
imageProcessed.reshaped_input_sizes, |
|
); |
|
|
|
isDecoding = false; |
|
|
|
updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data); |
|
|
|
|
|
if (decodePending) { |
|
decodePending = false; |
|
decode(); |
|
} |
|
} |
|
|
|
function updateMaskOverlay(mask, scores) { |
|
|
|
if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { |
|
maskCanvas.width = mask.width; |
|
maskCanvas.height = mask.height; |
|
} |
|
|
|
|
|
const imageData = maskContext.createImageData( |
|
maskCanvas.width, |
|
maskCanvas.height, |
|
); |
|
|
|
|
|
const numMasks = scores.length; |
|
let bestIndex = 0; |
|
for (let i = 1; i < numMasks; ++i) { |
|
if (scores[i] > scores[bestIndex]) { |
|
bestIndex = i; |
|
} |
|
} |
|
statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; |
|
|
|
|
|
const pixelData = imageData.data; |
|
for (let i = 0; i < pixelData.length; ++i) { |
|
if (mask.data[numMasks * i + bestIndex] === 1) { |
|
const offset = 4 * i; |
|
pixelData[offset] = 0; |
|
pixelData[offset + 1] = 114; |
|
pixelData[offset + 2] = 189; |
|
pixelData[offset + 3] = 255; |
|
} |
|
} |
|
|
|
|
|
maskContext.putImageData(imageData, 0, 0); |
|
} |
|
|
|
function clearPointsAndMask() { |
|
|
|
isMultiMaskMode = false; |
|
lastPoints = null; |
|
|
|
|
|
document.querySelectorAll(".icon").forEach((e) => e.remove()); |
|
|
|
|
|
cutButton.disabled = true; |
|
|
|
|
|
maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height); |
|
} |
|
clearButton.addEventListener("click", clearPointsAndMask); |
|
|
|
resetButton.addEventListener("click", () => { |
|
|
|
imageInput = null; |
|
imageProcessed = null; |
|
imageEmbeddings = null; |
|
isEncoding = false; |
|
isDecoding = false; |
|
|
|
|
|
clearPointsAndMask(); |
|
|
|
|
|
cutButton.disabled = true; |
|
imageContainer.style.backgroundImage = "none"; |
|
uploadButton.style.display = "flex"; |
|
statusLabel.textContent = "Ready"; |
|
}); |
|
|
|
async function encode(url) { |
|
if (isEncoding) return; |
|
isEncoding = true; |
|
statusLabel.textContent = "Extracting image embedding..."; |
|
|
|
imageInput = await RawImage.fromURL(url); |
|
|
|
|
|
imageContainer.style.backgroundImage = `url(${url})`; |
|
uploadButton.style.display = "none"; |
|
cutButton.disabled = true; |
|
|
|
|
|
imageProcessed = await processor(imageInput); |
|
imageEmbeddings = await model.get_image_embeddings(imageProcessed); |
|
|
|
statusLabel.textContent = "Embedding extracted!"; |
|
isEncoding = false; |
|
} |
|
|
|
|
|
fileUpload.addEventListener("change", function (e) { |
|
const file = e.target.files[0]; |
|
if (!file) return; |
|
|
|
const reader = new FileReader(); |
|
|
|
|
|
reader.onload = (e2) => encode(e2.target.result); |
|
|
|
reader.readAsDataURL(file); |
|
}); |
|
|
|
example.addEventListener("click", (e) => { |
|
e.preventDefault(); |
|
encode(EXAMPLE_URL); |
|
}); |
|
|
|
|
|
imageContainer.addEventListener("mousedown", (e) => { |
|
if (e.button !== 0 && e.button !== 2) { |
|
return; |
|
} |
|
if (!imageEmbeddings) { |
|
return; |
|
} |
|
if (!isMultiMaskMode) { |
|
lastPoints = []; |
|
isMultiMaskMode = true; |
|
cutButton.disabled = false; |
|
} |
|
|
|
const point = getPoint(e); |
|
lastPoints.push(point); |
|
|
|
|
|
const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode(); |
|
icon.style.left = `${point.position[0] * 100}%`; |
|
icon.style.top = `${point.position[1] * 100}%`; |
|
imageContainer.appendChild(icon); |
|
|
|
|
|
decode(); |
|
}); |
|
|
|
|
|
function clamp(x, min = 0, max = 1) { |
|
return Math.max(Math.min(x, max), min); |
|
} |
|
|
|
function getPoint(e) { |
|
|
|
const bb = imageContainer.getBoundingClientRect(); |
|
|
|
|
|
const mouseX = clamp((e.clientX - bb.left) / bb.width); |
|
const mouseY = clamp((e.clientY - bb.top) / bb.height); |
|
|
|
return { |
|
position: [mouseX, mouseY], |
|
label: |
|
e.button === 2 |
|
? 0 |
|
: 1, |
|
}; |
|
} |
|
|
|
|
|
imageContainer.addEventListener("contextmenu", (e) => e.preventDefault()); |
|
|
|
|
|
imageContainer.addEventListener("mousemove", (e) => { |
|
if (!imageEmbeddings || isMultiMaskMode) { |
|
|
|
|
|
return; |
|
} |
|
lastPoints = [getPoint(e)]; |
|
|
|
decode(); |
|
}); |
|
|
|
|
|
cutButton.addEventListener("click", async () => { |
|
const [w, h] = [maskCanvas.width, maskCanvas.height]; |
|
|
|
|
|
const maskImageData = maskContext.getImageData(0, 0, w, h); |
|
|
|
|
|
const cutCanvas = new OffscreenCanvas(w, h); |
|
const cutContext = cutCanvas.getContext("2d"); |
|
|
|
|
|
const maskPixelData = maskImageData.data; |
|
const imagePixelData = imageInput.data; |
|
for (let i = 0; i < w * h; ++i) { |
|
const sourceOffset = 3 * i; |
|
const targetOffset = 4 * i; |
|
|
|
if (maskPixelData[targetOffset + 3] > 0) { |
|
|
|
for (let j = 0; j < 3; ++j) { |
|
maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j]; |
|
} |
|
} |
|
} |
|
cutContext.putImageData(maskImageData, 0, 0); |
|
|
|
|
|
const link = document.createElement("a"); |
|
link.download = "image.png"; |
|
link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); |
|
link.click(); |
|
link.remove(); |
|
}); |
|
|
|
const model_id = "Xenova/slimsam-77-uniform"; |
|
statusLabel.textContent = "Loading model..."; |
|
const model = await SamModel.from_pretrained(model_id, { |
|
dtype: "fp16", |
|
device: "webgpu", |
|
}); |
|
const processor = await AutoProcessor.from_pretrained(model_id); |
|
statusLabel.textContent = "Ready"; |
|
|
|
|
|
fileUpload.disabled = false; |
|
uploadButton.style.opacity = 1; |
|
example.style.pointerEvents = "auto"; |
|
|