| | <!doctype html> |
| | <html lang="en"> |
| |
|
| | <head> |
| | <meta charset="UTF-8" /> |
| | <meta name="viewport" content="width=device-width, initial-scale=1.0" /> |
| | <title>SAM3 WebGPU | Transformers.js</title> |
| |
|
| | <script src="https://cdn.tailwindcss.com"></script> |
| |
|
| | <style> |
| | @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); |
| | |
| | body { |
| | font-family: 'Inter', sans-serif; |
| | } |
| | |
| | |
| | canvas { |
| | position: absolute; |
| | top: 0; |
| | left: 0; |
| | opacity: 0.6; |
| | pointer-events: none; |
| | } |
| | |
| | |
| | .icon { |
| | position: absolute; |
| | transform: translate(-50%, -50%); |
| | font-size: 24px; |
| | user-select: none; |
| | pointer-events: none; |
| | text-shadow: 0 0 4px white; |
| | } |
| | |
| | .aspect-w-3 { |
| | position: relative; |
| | width: 100%; |
| | } |
| | |
| | .aspect-w-3::before { |
| | content: ''; |
| | display: block; |
| | padding-bottom: calc(var(--aspect-h) / var(--aspect-w) * 100%); |
| | } |
| | |
| | .aspect-w-3> :first-child { |
| | position: absolute; |
| | top: 0; |
| | left: 0; |
| | width: 100%; |
| | height: 100%; |
| | } |
| | |
| | .aspect-w-3 { |
| | --aspect-w: 3; |
| | } |
| | |
| | .aspect-h-2 { |
| | --aspect-h: 2; |
| | } |
| | |
| | .aspect-w-4 { |
| | --aspect-w: 4; |
| | } |
| | |
| | .aspect-h-3 { |
| | --aspect-h: 3; |
| | } |
| | </style> |
| | </head> |
| |
|
| | <body class="bg-gray-100 text-gray-800 min-h-screen flex flex-col items-center justify-center p-4 sm:p-8"> |
| |
|
| | <div class="w-full max-w-3xl bg-white rounded-xl shadow-2xl overflow-hidden"> |
| |
|
| | <div class="p-6 sm:p-10"> |
| | <h1 class="text-3xl sm:text-4xl font-bold text-center text-gray-900">SAM3 WebGPU</h1> |
| | <h3 class="text-lg sm:text-xl text-gray-500 text-center mb-6"> |
| | In-browser image segmentation w/ |
| | <a href="https://hf.co/docs/transformers.js" target="_blank" class="text-blue-600 hover:underline">🤗 |
| | Transformers.js</a> |
| | </h3> |
| |
|
| | <div id="container" |
| | class="relative w-full max-w-2xl mx-auto border border-gray-200 rounded-lg overflow-hidden cursor-pointer bg-gray-50 shadow-sm transition-all aspect-w-3 aspect-h-2"> |
| |
|
| | <label id="upload-area" for="upload" |
| | class="absolute inset-0 z-10 flex flex-col justify-center items-center p-10 transition-all cursor-pointer border-2 border-dashed border-gray-300 rounded-lg hover:bg-gray-50/50 hover:border-blue-500"> |
| | <div class="flex flex-col items-center justify-center p-6 transition-colors w-full max-w-sm"> |
| | <svg class="w-12 h-12 text-gray-400" fill="currentColor" viewBox="0 0 25 25" |
| | xmlns="http://www.w3.org/2000/svg"> |
| | <path |
| | d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"> |
| | </path> |
| | </svg> |
| | <span class="text-lg font-medium text-gray-500 mt-2">Click to upload image</span> |
| | <span class="text-sm text-gray-400">or drag and drop</span> |
| | </div> |
| |
|
| | <p class="text-gray-500 text-sm my-4">...or try an example:</p> |
| |
|
| | <div id="example-gallery" class="flex gap-4"> |
| | <img src="https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg" |
| | class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity" |
| | alt="Example of a truck"> |
| | <img src="https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg" |
| | class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity" |
| | alt="Example of a corgi"> |
| | <img src="https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg" |
| | class="example-image w-20 h-20 sm:w-24 sm:h-24 object-cover rounded-lg shadow-md cursor-pointer hover:opacity-80 transition-opacity" |
| | alt="Example of groceries"> |
| | </div> |
| | </label> |
| |
|
| | <img id="image-display" class="absolute inset-0 w-full h-full object-contain block hidden z-0" /> |
| |
|
| | <canvas id="mask-output"></canvas> |
| | </div> |
| |
|
| | <label id="status" class="text-base text-center text-gray-600 min-h-[1.5rem] mt-6 mb-4 block w-full">Loading |
| | model...</label> |
| |
|
| | <div id="controls" class="flex flex-col sm:flex-row justify-center gap-3"> |
| | <button id="reset-image" disabled |
| | class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-gray-200 text-gray-800 font-medium rounded-lg shadow-sm hover:bg-gray-300 transition-colors disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed"> |
| | <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" |
| | stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" |
| | class="w-4 h-4 mr-2"> |
| | <path d="M3 12a9 9 0 1 0 9-9 9.75 9.75 0 0 0-6.74 2.74L3 8" /> |
| | <path d="M3 3v5h5" /> |
| | </svg> |
| | Reset image |
| | </button> |
| | <button id="clear-points" disabled |
| | class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-gray-200 text-gray-800 font-medium rounded-lg shadow-sm hover:bg-gray-300 transition-colors disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed"> |
| | <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" |
| | stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" |
| | class="w-4 h-4 mr-2"> |
| | <line x1="18" y1="6" x2="6" y2="18"></line> |
| | <line x1="6" y1="6" x2="18" y2="18"></line> |
| | </svg> |
| | Clear points |
| | </button> |
| | <button id="cut-mask" disabled |
| | class="w-full sm:w-auto inline-flex items-center justify-center px-4 py-2 bg-blue-600 text-white font-medium rounded-lg shadow-sm hover:bg-blue-700 transition-colors disabled:bg-gray-300 disabled:text-gray-500 disabled:cursor-not-allowed"> |
| | <svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" |
| | stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" |
| | class="w-4 h-4 mr-2"> |
| | <path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4" /> |
| | <polyline points="7 10 12 15 17 10" /> |
| | <line x1="12" y1="15" x2="12" y2="3" /> |
| | </svg> |
| | Cut & Download |
| | </button> |
| | </div> |
| |
|
| | <p id="information" class="text-sm text-gray-500 mt-4 text-center"> |
| | Left click = positive (⭐), Right click = negative (❌). |
| | </p> |
| | </div> |
| | </div> |
| |
|
| | <input id="upload" type="file" accept="image/*" disabled class="hidden" /> |
| |
|
| |
|
| | <script type="module"> |
| | import { |
| | Sam3TrackerModel, |
| | AutoProcessor, |
| | RawImage, |
| | Tensor, |
| | } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@next"; |
| | |
| | const statusLabel = document.getElementById("status"); |
| | const fileUpload = document.getElementById("upload"); |
| | const imageContainer = document.getElementById("container"); |
| | const uploadArea = document.getElementById("upload-area"); |
| | const exampleImages = document.querySelectorAll(".example-image"); |
| | const resetButton = document.getElementById("reset-image"); |
| | const clearButton = document.getElementById("clear-points"); |
| | const cutButton = document.getElementById("cut-mask"); |
| | const imageDisplay = document.getElementById("image-display"); |
| | const maskCanvas = document.getElementById("mask-output"); |
| | const maskContext = maskCanvas.getContext("2d"); |
| | |
| | let isEncoding = false; |
| | let isDecoding = false; |
| | let decodePending = false; |
| | let lastPoints = null; |
| | let isMultiMaskMode = false; |
| | let imageInput = null; |
| | let imageProcessed = null; |
| | let imageEmbeddings = null; |
| | |
| | let model, processor; |
| | |
| | |
| | |
| | |
| | async function encode(url) { |
| | if (isEncoding || !processor || !model) return; |
| | isEncoding = true; |
| | statusLabel.textContent = "Extracting image embedding..."; |
| | |
| | try { |
| | imageInput = await RawImage.fromURL(url); |
| | |
| | imageDisplay.onload = updateCanvasGeometry; |
| | imageDisplay.src = url; |
| | imageDisplay.classList.remove('hidden'); |
| | uploadArea.classList.add("hidden"); |
| | cutButton.disabled = true; |
| | |
| | imageProcessed = await processor(imageInput); |
| | imageEmbeddings = await model.get_image_embeddings(imageProcessed); |
| | console.log({ imageEmbeddings }) |
| | |
| | statusLabel.textContent = "Embedding extracted! Click on the image."; |
| | resetButton.disabled = false; |
| | clearButton.disabled = false; |
| | } catch (error) { |
| | console.error("Error during encoding:", error); |
| | statusLabel.textContent = "Error loading image. Please try again."; |
| | resetUI(); |
| | } finally { |
| | isEncoding = false; |
| | } |
| | } |
| | |
| | |
| | |
| | |
| | async function decode() { |
| | if (isDecoding || !imageEmbeddings || !lastPoints || lastPoints.length === 0) { |
| | if (isDecoding) { |
| | decodePending = true; |
| | } |
| | return; |
| | } |
| | isDecoding = true; |
| | |
| | try { |
| | const reshaped = imageProcessed.reshaped_input_sizes[0]; |
| | const points = lastPoints |
| | .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]]) |
| | .flat(Infinity); |
| | const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity); |
| | |
| | const num_points = lastPoints.length; |
| | const input_points = new Tensor("float32", points, [1, 1, num_points, 2]); |
| | const input_labels = new Tensor("int64", labels, [1, 1, num_points]); |
| | |
| | const { pred_masks, iou_scores } = await model({ |
| | ...imageEmbeddings, |
| | input_points, |
| | input_labels, |
| | }); |
| | |
| | const masks = await processor.post_process_masks( |
| | pred_masks, |
| | imageProcessed.original_sizes, |
| | imageProcessed.reshaped_input_sizes, |
| | ); |
| | |
| | updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data); |
| | |
| | } catch (error) { |
| | console.error("Error during decoding:", error); |
| | statusLabel.textContent = "Error generating mask."; |
| | } finally { |
| | isDecoding = false; |
| | } |
| | |
| | if (decodePending) { |
| | decodePending = false; |
| | setTimeout(decode, 0); |
| | } |
| | } |
| | |
| | |
| | |
| | |
| | function updateCanvasGeometry() { |
| | if (!imageDisplay.src || imageDisplay.classList.contains('hidden')) return; |
| | |
| | const { naturalWidth, naturalHeight } = imageDisplay; |
| | const { width: containerWidth, height: containerHeight } = imageContainer.getBoundingClientRect(); |
| | |
| | const imageAspectRatio = naturalWidth / naturalHeight; |
| | const containerAspectRatio = containerWidth / containerHeight; |
| | |
| | let newWidth, newHeight, newTop, newLeft; |
| | |
| | if (imageAspectRatio > containerAspectRatio) { |
| | newWidth = containerWidth; |
| | newHeight = newWidth / imageAspectRatio; |
| | newTop = (containerHeight - newHeight) / 2; |
| | newLeft = 0; |
| | } else { |
| | newHeight = containerHeight; |
| | newWidth = newHeight * imageAspectRatio; |
| | newLeft = (containerWidth - newWidth) / 2; |
| | newTop = 0; |
| | } |
| | |
| | maskCanvas.style.width = `${newWidth}px`; |
| | maskCanvas.style.height = `${newHeight}px`; |
| | maskCanvas.style.top = `${newTop}px`; |
| | maskCanvas.style.left = `${newLeft}px`; |
| | } |
| | |
| | |
| | |
| | |
| | function updateMaskOverlay(mask, scores) { |
| | if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { |
| | maskCanvas.width = mask.width; |
| | maskCanvas.height = mask.height; |
| | } |
| | |
| | const imageData = maskContext.createImageData( |
| | maskCanvas.width, |
| | maskCanvas.height, |
| | ); |
| | |
| | const numMasks = scores.length; |
| | let bestIndex = 0; |
| | for (let i = 1; i < numMasks; ++i) { |
| | if (scores[i] > scores[bestIndex]) { |
| | bestIndex = i; |
| | } |
| | } |
| | statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; |
| | |
| | const pixelData = imageData.data; |
| | for (let i = 0; i < pixelData.length; ++i) { |
| | if (mask.data[numMasks * i + bestIndex] === 1) { |
| | const offset = 4 * i; |
| | pixelData[offset] = 0; |
| | pixelData[offset + 1] = 114; |
| | pixelData[offset + 2] = 189; |
| | pixelData[offset + 3] = 255; |
| | } |
| | } |
| | |
| | maskContext.putImageData(imageData, 0, 0); |
| | } |
| | |
| | function clearPointsAndMask() { |
| | isMultiMaskMode = false; |
| | lastPoints = null; |
| | |
| | document.querySelectorAll(".icon").forEach((e) => e.remove()); |
| | |
| | cutButton.disabled = true; |
| | |
| | maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height); |
| | statusLabel.textContent = "Points cleared. Click to add new points."; |
| | } |
| | |
| | function resetUI() { |
| | imageInput = null; |
| | imageProcessed = null; |
| | imageEmbeddings = null; |
| | isEncoding = false; |
| | isDecoding = false; |
| | decodePending = false; |
| | |
| | clearPointsAndMask(); |
| | |
| | cutButton.disabled = true; |
| | resetButton.disabled = true; |
| | clearButton.disabled = true; |
| | imageDisplay.src = ''; |
| | imageDisplay.classList.add('hidden'); |
| | uploadArea.classList.remove("hidden"); |
| | |
| | |
| | maskCanvas.style.width = '0px'; |
| | maskCanvas.style.height = '0px'; |
| | |
| | exampleImages.forEach(img => img.style.pointerEvents = "auto"); |
| | |
| | statusLabel.textContent = "Ready"; |
| | } |
| | function clamp(x, min = 0, max = 1) { |
| | return Math.max(Math.min(x, max), min); |
| | } |
| | function getPoint(e) { |
| | const imgBB = imageDisplay.getBoundingClientRect(); |
| | const canvasBB = maskCanvas.getBoundingClientRect(); |
| | |
| | |
| | const mouseX = clamp((e.clientX - canvasBB.left) / canvasBB.width); |
| | const mouseY = clamp((e.clientY - canvasBB.top) / canvasBB.height); |
| | |
| | return { |
| | position: [mouseX, mouseY], |
| | label: e.button === 2 ? 0 : 1, |
| | }; |
| | } |
| | |
| | fileUpload.addEventListener("change", function (e) { |
| | const file = e.target.files[0]; |
| | if (!file) return; |
| | |
| | const reader = new FileReader(); |
| | reader.onload = (e2) => encode(e2.target.result); |
| | reader.readAsDataURL(file); |
| | }); |
| | |
| | exampleImages.forEach((img) => { |
| | img.addEventListener("click", (e) => { |
| | e.preventDefault(); |
| | e.stopPropagation(); |
| | exampleImages.forEach(i => i.style.pointerEvents = "none"); |
| | encode(img.src); |
| | }); |
| | }); |
| | |
| | window.addEventListener("resize", updateCanvasGeometry); |
| | |
| | resetButton.addEventListener("click", resetUI); |
| | |
| | clearButton.addEventListener("click", clearPointsAndMask); |
| | |
| | imageContainer.addEventListener("mousedown", (e) => { |
| | if (!imageEmbeddings || uploadArea.classList.contains('hidden') === false) { |
| | return; |
| | } |
| | |
| | if (e.button !== 0 && e.button !== 2) return; |
| | |
| | if (!isMultiMaskMode) { |
| | lastPoints = []; |
| | isMultiMaskMode = true; |
| | cutButton.disabled = false; |
| | } |
| | |
| | const point = getPoint(e); |
| | lastPoints.push(point); |
| | |
| | const icon = document.createElement('span'); |
| | icon.className = 'icon'; |
| | icon.textContent = point.label === 1 ? '⭐' : '❌'; |
| | |
| | |
| | const canvasRect = maskCanvas.getBoundingClientRect(); |
| | const containerRect = imageContainer.getBoundingClientRect(); |
| | const left = canvasRect.left - containerRect.left + point.position[0] * canvasRect.width; |
| | const top = canvasRect.top - containerRect.top + point.position[1] * canvasRect.height; |
| | |
| | icon.style.left = `${left}px`; |
| | icon.style.top = `${top}px`; |
| | imageContainer.appendChild(icon); |
| | |
| | decode(); |
| | }); |
| | |
| | imageContainer.addEventListener("contextmenu", (e) => e.preventDefault()); |
| | |
| | imageContainer.addEventListener("mousemove", (e) => { |
| | if (!imageEmbeddings || isMultiMaskMode || uploadArea.classList.contains('hidden') === false) { |
| | return; |
| | } |
| | lastPoints = [getPoint(e)]; |
| | decode(); |
| | }); |
| | |
| | cutButton.addEventListener("click", async () => { |
| | if (!imageInput || !maskCanvas) return; |
| | |
| | const [w, h] = [maskCanvas.width, maskCanvas.height]; |
| | |
| | const maskImageData = maskContext.getImageData(0, 0, w, h); |
| | const maskPixelData = maskImageData.data; |
| | |
| | const cutCanvas = new OffscreenCanvas(w, h); |
| | const cutContext = cutCanvas.getContext("2d"); |
| | |
| | const cutImageData = cutContext.createImageData(w, h); |
| | const cutPixelData = cutImageData.data; |
| | |
| | const imagePixelData = imageInput.data; |
| | |
| | for (let i = 0; i < w * h; ++i) { |
| | const maskOffset = 4 * i; |
| | const imageOffset = 3 * i; |
| | |
| | if (maskPixelData[maskOffset + 3] > 0) { |
| | cutPixelData[maskOffset] = imagePixelData[imageOffset]; |
| | cutPixelData[maskOffset + 1] = imagePixelData[imageOffset + 1]; |
| | cutPixelData[maskOffset + 2] = imagePixelData[imageOffset + 2]; |
| | cutPixelData[maskOffset + 3] = 255; |
| | } |
| | } |
| | cutContext.putImageData(cutImageData, 0, 0); |
| | |
| | const link = document.createElement("a"); |
| | link.download = "mask-cutout.png"; |
| | link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); |
| | link.click(); |
| | link.remove(); |
| | }); |
| | |
| | async function loadModel() { |
| | try { |
| | const model_id = "onnx-community/sam3-tracker-ONNX"; |
| | |
| | model = await Sam3TrackerModel.from_pretrained(model_id, { |
| | dtype: { |
| | vision_encoder: "q4", |
| | prompt_encoder_mask_decoder: "fp32", |
| | }, |
| | device: "webgpu", |
| | }); |
| | |
| | processor = await AutoProcessor.from_pretrained(model_id); |
| | } catch (error) { |
| | console.error("Error loading model:", error); |
| | statusLabel.textContent = "Error loading model. Please refresh the page."; |
| | } |
| | } |
| | |
| | await loadModel(); |
| | fileUpload.disabled = false; |
| | statusLabel.textContent = "Ready"; |
| | |
| | </script> |
| | </body> |
| |
|
| | </html> |