Spaces:
Running
on
L40S
Running
on
L40S
import React, { useState, useEffect, useRef, useCallback } from 'react'; | |
import * as vision from '@mediapipe/tasks-vision'; | |
import { facePoke } from '@/lib/facePoke'; | |
import { useMainStore } from './useMainStore'; | |
import useThrottledCallback from 'beautiful-react-hooks/useThrottledCallback'; | |
import { landmarkGroups, FACEMESH_LIPS, FACEMESH_LEFT_EYE, FACEMESH_LEFT_EYEBROW, FACEMESH_RIGHT_EYE, FACEMESH_RIGHT_EYEBROW, FACEMESH_FACE_OVAL } from './landmarks'; | |
import type { ActionMode, ClosestLandmark, LandmarkCenter, LandmarkGroup, MediaPipeResources } from '@/types'; | |
export function useFaceLandmarkDetection() { | |
const setError = useMainStore(s => s.setError); | |
const previewImage = useMainStore(s => s.previewImage); | |
const handleServerResponse = useMainStore(s => s.handleServerResponse); | |
const faceLandmarks = useMainStore(s => s.faceLandmarks); | |
//////////////////////////////////////////////////////////////////////// | |
// if we only send the face/square then we can use 138ms | |
// unfortunately it doesn't work well yet | |
// const throttleInMs = 138ms | |
const throttleInMs = 220 | |
//////////////////////////////////////////////////////////////////////// | |
// State for face detection | |
const [isMediaPipeReady, setIsMediaPipeReady] = useState(false); | |
const [isDrawingUtilsReady, setIsDrawingUtilsReady] = useState(false); | |
// State for mouse interaction | |
const [dragStart, setDragStart] = useState<{ x: number; y: number } | null>(null); | |
const [isDragging, setIsDragging] = useState(false); | |
const dragStartRef = useRef<{ x: number; y: number } | null>(null); | |
const [currentLandmark, setCurrentLandmark] = useState<ClosestLandmark | null>(null); | |
const [previousLandmark, setPreviousLandmark] = useState<ClosestLandmark | null>(null); | |
const [currentOpacity, setCurrentOpacity] = useState(0); | |
const [previousOpacity, setPreviousOpacity] = useState(0); | |
// Refs | |
const canvasRef = useRef<HTMLCanvasElement>(null); | |
const mediaPipeRef = useRef<MediaPipeResources>({ | |
faceLandmarker: null, | |
drawingUtils: null, | |
}); | |
const setActiveLandmark = useCallback((newLandmark: ClosestLandmark | undefined) => { | |
//if (newLandmark && (!currentLandmark || newLandmark.group !== currentLandmark.group)) { | |
setPreviousLandmark(currentLandmark || null); | |
setCurrentLandmark(newLandmark || null); | |
setCurrentOpacity(0); | |
setPreviousOpacity(1); | |
//} | |
}, [currentLandmark, setPreviousLandmark, setCurrentLandmark, setCurrentOpacity, setPreviousOpacity]); | |
// Initialize MediaPipe | |
useEffect(() => { | |
console.log('Initializing MediaPipe...'); | |
let isMounted = true; | |
const initializeMediaPipe = async () => { | |
const { FaceLandmarker, FilesetResolver, DrawingUtils } = vision; | |
try { | |
console.log('Initializing FilesetResolver...'); | |
const filesetResolver = await FilesetResolver.forVisionTasks( | |
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.3/wasm" | |
); | |
console.log('Creating FaceLandmarker...'); | |
const faceLandmarker = await FaceLandmarker.createFromOptions(filesetResolver, { | |
baseOptions: { | |
modelAssetPath: `https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task`, | |
delegate: "GPU" | |
}, | |
outputFaceBlendshapes: true, | |
runningMode: "IMAGE", | |
numFaces: 1 | |
}); | |
if (isMounted) { | |
console.log('FaceLandmarker created successfully.'); | |
mediaPipeRef.current.faceLandmarker = faceLandmarker; | |
setIsMediaPipeReady(true); | |
} else { | |
faceLandmarker.close(); | |
} | |
} catch (error) { | |
console.error('Error during MediaPipe initialization:', error); | |
setError('Failed to initialize face detection. Please try refreshing the page.'); | |
} | |
}; | |
initializeMediaPipe(); | |
return () => { | |
isMounted = false; | |
if (mediaPipeRef.current.faceLandmarker) { | |
mediaPipeRef.current.faceLandmarker.close(); | |
} | |
}; | |
}, []); | |
// New state for storing landmark centers | |
const [landmarkCenters, setLandmarkCenters] = useState<Record<LandmarkGroup, LandmarkCenter>>({} as Record<LandmarkGroup, LandmarkCenter>); | |
// Function to compute the center of each landmark group | |
const computeLandmarkCenters = useCallback((landmarks: vision.NormalizedLandmark[]) => { | |
const centers: Record<LandmarkGroup, LandmarkCenter> = {} as Record<LandmarkGroup, LandmarkCenter>; | |
const computeGroupCenter = (group: Readonly<Set<number[]>>): LandmarkCenter => { | |
let sumX = 0, sumY = 0, sumZ = 0, count = 0; | |
group.forEach(([index]) => { | |
if (landmarks[index]) { | |
sumX += landmarks[index].x; | |
sumY += landmarks[index].y; | |
sumZ += landmarks[index].z || 0; | |
count++; | |
} | |
}); | |
return { x: sumX / count, y: sumY / count, z: sumZ / count }; | |
}; | |
centers.lips = computeGroupCenter(FACEMESH_LIPS); | |
centers.leftEye = computeGroupCenter(FACEMESH_LEFT_EYE); | |
centers.leftEyebrow = computeGroupCenter(FACEMESH_LEFT_EYEBROW); | |
centers.rightEye = computeGroupCenter(FACEMESH_RIGHT_EYE); | |
centers.rightEyebrow = computeGroupCenter(FACEMESH_RIGHT_EYEBROW); | |
centers.faceOval = computeGroupCenter(FACEMESH_FACE_OVAL); | |
centers.background = { x: 0.5, y: 0.5, z: 0 }; | |
setLandmarkCenters(centers); | |
// console.log('Landmark centers computed:', centers); | |
}, []); | |
// Function to find the closest landmark to the mouse position | |
const findClosestLandmark = useCallback((mouseX: number, mouseY: number, isGroup?: LandmarkGroup): ClosestLandmark => { | |
const defaultLandmark: ClosestLandmark = { | |
group: 'background', | |
distance: 0, | |
vector: { | |
x: mouseX, | |
y: mouseY, | |
z: 0 | |
} | |
} | |
if (Object.keys(landmarkCenters).length === 0) { | |
console.warn('Landmark centers not computed yet'); | |
return defaultLandmark; | |
} | |
let closestGroup: LandmarkGroup | null = null; | |
let minDistance = Infinity; | |
let closestVector = { x: 0, y: 0, z: 0 }; | |
let faceOvalDistance = Infinity; | |
let faceOvalVector = { x: 0, y: 0, z: 0 }; | |
Object.entries(landmarkCenters).forEach(([group, center]) => { | |
const dx = mouseX - center.x; | |
const dy = mouseY - center.y; | |
const distance = Math.sqrt(dx * dx + dy * dy); | |
if (group === 'faceOval') { | |
faceOvalDistance = distance; | |
faceOvalVector = { x: dx, y: dy, z: 0 }; | |
} | |
// filter to keep the group if it is belonging to `ofGroup` | |
if (isGroup) { | |
if (group !== isGroup) { | |
return | |
} | |
} | |
if (distance < minDistance) { | |
minDistance = distance; | |
closestGroup = group as LandmarkGroup; | |
closestVector = { x: dx, y: dy, z: 0 }; // Z is 0 as mouse interaction is 2D | |
} | |
}); | |
// Fallback to faceOval if no group found or distance is too large | |
if (minDistance > 0.05) { | |
// console.log('Distance is too high, so we use the faceOval group'); | |
closestGroup = 'background'; | |
minDistance = faceOvalDistance; | |
closestVector = faceOvalVector; | |
} | |
if (closestGroup) { | |
// console.log(`Closest landmark: ${closestGroup}, distance: ${minDistance.toFixed(4)}`); | |
return { group: closestGroup, distance: minDistance, vector: closestVector }; | |
} else { | |
// console.log('No group found, returning fallback'); | |
return defaultLandmark | |
} | |
}, [landmarkCenters]); | |
// Detect face landmarks | |
const detectFaceLandmarks = useCallback(async (imageDataUrl: string) => { | |
const { setFaceLandmarks,setBlendShapes } = useMainStore.getState(); | |
// console.log('Attempting to detect face landmarks...'); | |
if (!isMediaPipeReady) { | |
console.log('MediaPipe not ready. Skipping detection.'); | |
return; | |
} | |
const faceLandmarker = mediaPipeRef.current.faceLandmarker; | |
if (!faceLandmarker) { | |
console.error('FaceLandmarker is not initialized.'); | |
return; | |
} | |
const drawingUtils = mediaPipeRef.current.drawingUtils; | |
const image = new Image(); | |
image.src = imageDataUrl; | |
await new Promise((resolve) => { image.onload = resolve; }); | |
const faceLandmarkerResult = faceLandmarker.detect(image); | |
// console.log("Face landmarks detected:", faceLandmarkerResult); | |
setFaceLandmarks(faceLandmarkerResult.faceLandmarks); | |
setBlendShapes(faceLandmarkerResult.faceBlendshapes || []); | |
if (faceLandmarkerResult.faceLandmarks && faceLandmarkerResult.faceLandmarks[0]) { | |
computeLandmarkCenters(faceLandmarkerResult.faceLandmarks[0]); | |
} | |
if (canvasRef.current && drawingUtils) { | |
drawLandmarks(faceLandmarkerResult.faceLandmarks[0], canvasRef.current, drawingUtils); | |
} | |
}, [isMediaPipeReady, isDrawingUtilsReady, computeLandmarkCenters]); | |
const drawLandmarks = useCallback(( | |
landmarks: vision.NormalizedLandmark[], | |
canvas: HTMLCanvasElement, | |
drawingUtils: vision.DrawingUtils | |
) => { | |
const ctx = canvas.getContext('2d'); | |
if (!ctx) return; | |
ctx.clearRect(0, 0, canvas.width, canvas.height); | |
if (canvasRef.current && previewImage) { | |
const img = new Image(); | |
img.onload = () => { | |
canvas.width = img.width; | |
canvas.height = img.height; | |
const drawLandmarkGroup = (landmark: ClosestLandmark | null, opacity: number) => { | |
if (!landmark) return; | |
const connections = landmarkGroups[landmark.group]; | |
if (connections) { | |
ctx.globalAlpha = opacity; | |
drawingUtils.drawConnectors( | |
landmarks, | |
connections, | |
{ color: 'orange', lineWidth: 4 } | |
); | |
} | |
}; | |
drawLandmarkGroup(previousLandmark, previousOpacity); | |
drawLandmarkGroup(currentLandmark, currentOpacity); | |
ctx.globalAlpha = 1; | |
}; | |
img.src = previewImage; | |
} | |
}, [previewImage, currentLandmark, previousLandmark, currentOpacity, previousOpacity]); | |
useEffect(() => { | |
if (isMediaPipeReady && isDrawingUtilsReady && faceLandmarks.length > 0 && canvasRef.current && mediaPipeRef.current.drawingUtils) { | |
drawLandmarks(faceLandmarks[0], canvasRef.current, mediaPipeRef.current.drawingUtils); | |
} | |
}, [isMediaPipeReady, isDrawingUtilsReady, faceLandmarks, currentLandmark, previousLandmark, currentOpacity, previousOpacity, drawLandmarks]); | |
useEffect(() => { | |
let animationFrame: number; | |
const animate = () => { | |
setCurrentOpacity((prev) => Math.min(prev + 0.2, 1)); | |
setPreviousOpacity((prev) => Math.max(prev - 0.2, 0)); | |
if (currentOpacity < 1 || previousOpacity > 0) { | |
animationFrame = requestAnimationFrame(animate); | |
} | |
}; | |
animationFrame = requestAnimationFrame(animate); | |
return () => cancelAnimationFrame(animationFrame); | |
}, [currentLandmark]); | |
// Canvas ref callback | |
const canvasRefCallback = useCallback((node: HTMLCanvasElement | null) => { | |
if (node !== null) { | |
const ctx = node.getContext('2d'); | |
if (ctx) { | |
// Get device pixel ratio | |
const pixelRatio = window.devicePixelRatio || 1; | |
// Scale canvas based on the pixel ratio | |
node.width = node.clientWidth * pixelRatio; | |
node.height = node.clientHeight * pixelRatio; | |
ctx.scale(pixelRatio, pixelRatio); | |
mediaPipeRef.current.drawingUtils = new vision.DrawingUtils(ctx); | |
setIsDrawingUtilsReady(true); | |
} else { | |
console.error('Failed to get 2D context from canvas.'); | |
} | |
canvasRef.current = node; | |
} | |
}, []); | |
useEffect(() => { | |
if (!isMediaPipeReady) { | |
console.log('MediaPipe not ready. Skipping landmark detection.'); | |
return | |
} | |
if (!previewImage) { | |
console.log('Preview image not ready. Skipping landmark detection.'); | |
return | |
} | |
if (!isDrawingUtilsReady) { | |
console.log('DrawingUtils not ready. Skipping landmark detection.'); | |
return | |
} | |
detectFaceLandmarks(previewImage); | |
}, [isMediaPipeReady, isDrawingUtilsReady, previewImage]) | |
const modifyImageWithRateLimit = useThrottledCallback((params: { | |
landmark: ClosestLandmark | |
vector: { x: number; y: number; z: number } | |
mode: ActionMode | |
}) => { | |
useMainStore.getState().modifyImage(params); | |
}, [], throttleInMs); | |
useEffect(() => { | |
facePoke.setOnServerResponse(handleServerResponse); | |
}, [handleServerResponse]); | |
const handleStart = useCallback((x: number, y: number, mode: ActionMode) => { | |
if (!canvasRef.current) return; | |
const rect = canvasRef.current.getBoundingClientRect(); | |
const normalizedX = (x - rect.left) / rect.width; | |
const normalizedY = (y - rect.top) / rect.height; | |
const landmark = findClosestLandmark(normalizedX, normalizedY); | |
// console.log(`Interaction start on ${landmark.group}`); | |
setActiveLandmark(landmark); | |
setDragStart({ x: normalizedX, y: normalizedY }); | |
dragStartRef.current = { x: normalizedX, y: normalizedY }; | |
}, [findClosestLandmark, setActiveLandmark, setDragStart]); | |
const handleMove = useCallback((x: number, y: number, mode: ActionMode) => { | |
if (!canvasRef.current) return; | |
const rect = canvasRef.current.getBoundingClientRect(); | |
const normalizedX = (x - rect.left) / rect.width; | |
const normalizedY = (y - rect.top) / rect.height; | |
const landmark = findClosestLandmark( | |
normalizedX, | |
normalizedY, | |
dragStart && dragStartRef.current ? currentLandmark?.group : undefined | |
); | |
const landmarkData = landmarkCenters[landmark?.group] | |
const vector = landmarkData ? { | |
x: normalizedX - landmarkData.x, | |
y: normalizedY - landmarkData.y, | |
z: 0 | |
} : { | |
x: 0.5, | |
y: 0.5, | |
z: 0 | |
} | |
if (dragStart && dragStartRef.current) { | |
setIsDragging(true); | |
modifyImageWithRateLimit({ | |
landmark: currentLandmark || landmark, | |
vector, | |
mode | |
}); | |
} else { | |
if (!currentLandmark || (currentLandmark?.group !== landmark?.group)) { | |
setActiveLandmark(landmark); | |
} | |
/* | |
modifyImageWithRateLimit({ | |
landmark, | |
vector, | |
mode: 'HOVERING' | |
}); | |
*/ | |
} | |
}, [currentLandmark, dragStart, setActiveLandmark, setIsDragging, modifyImageWithRateLimit, landmarkCenters]); | |
const handleEnd = useCallback((x: number, y: number, mode: ActionMode) => { | |
if (!canvasRef.current) return; | |
const rect = canvasRef.current.getBoundingClientRect(); | |
const normalizedX = (x - rect.left) / rect.width; | |
const normalizedY = (y - rect.top) / rect.height; | |
if (dragStart && dragStartRef.current) { | |
const landmark = findClosestLandmark(normalizedX, normalizedY, currentLandmark?.group); | |
modifyImageWithRateLimit({ | |
landmark: currentLandmark || landmark, | |
vector: { | |
x: normalizedX - landmarkCenters[landmark.group].x, | |
y: normalizedY - landmarkCenters[landmark.group].y, | |
z: 0 | |
}, | |
mode | |
}); | |
} | |
setIsDragging(false); | |
dragStartRef.current = null; | |
setActiveLandmark(undefined); | |
}, [currentLandmark, isDragging, modifyImageWithRateLimit, findClosestLandmark, setActiveLandmark, landmarkCenters, setIsDragging]); | |
const handleMouseDown = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => { | |
const mode: ActionMode = event.button === 0 ? 'PRIMARY' : 'SECONDARY'; | |
handleStart(event.clientX, event.clientY, mode); | |
}, [handleStart]); | |
const handleMouseMove = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => { | |
const mode: ActionMode = event.buttons === 1 ? 'PRIMARY' : 'SECONDARY'; | |
handleMove(event.clientX, event.clientY, mode); | |
}, [handleMove]); | |
const handleMouseUp = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => { | |
const mode: ActionMode = event.buttons === 1 ? 'PRIMARY' : 'SECONDARY'; | |
handleEnd(event.clientX, event.clientY, mode); | |
}, [handleEnd]); | |
const handleTouchStart = useCallback((event: React.TouchEvent<HTMLCanvasElement>) => { | |
const mode: ActionMode = event.touches.length === 1 ? 'PRIMARY' : 'SECONDARY'; | |
const touch = event.touches[0]; | |
handleStart(touch.clientX, touch.clientY, mode); | |
}, [handleStart]); | |
const handleTouchMove = useCallback((event: React.TouchEvent<HTMLCanvasElement>) => { | |
const mode: ActionMode = event.touches.length === 1 ? 'PRIMARY' : 'SECONDARY'; | |
const touch = event.touches[0]; | |
handleMove(touch.clientX, touch.clientY, mode); | |
}, [handleMove]); | |
const handleTouchEnd = useCallback((event: React.TouchEvent<HTMLCanvasElement>) => { | |
const mode: ActionMode = event.changedTouches.length === 1 ? 'PRIMARY' : 'SECONDARY'; | |
const touch = event.changedTouches[0]; | |
handleEnd(touch.clientX, touch.clientY, mode); | |
}, [handleEnd]); | |
return { | |
canvasRef, | |
canvasRefCallback, | |
mediaPipeRef, | |
isMediaPipeReady, | |
isDrawingUtilsReady, | |
handleMouseDown, | |
handleMouseUp, | |
handleMouseMove, | |
handleTouchStart, | |
handleTouchMove, | |
handleTouchEnd, | |
currentLandmark, | |
currentOpacity, | |
} | |
} | |