diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..4522d57dc7d6a100de8c39a6039a4a125ac61c4d --- /dev/null +++ b/.dockerignore @@ -0,0 +1,17 @@ +# The .dockerignore file excludes files from the container build process. +# +# https://docs.docker.com/engine/reference/builder/#dockerignore-file + +# Exclude Git files +.git +.github +.gitignore + +# Exclude Python cache files +__pycache__ +.mypy_cache +.pytest_cache +.ruff_cache + +# Exclude Python virtual environment +/venv diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..2036cc5e400d33657f862b703fcdc8f11666131f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,13 @@ +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.xml filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.pdf filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +*.mpg filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text +*.webm filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ee89744a4abc2b0907096f7d7db1a1c0fbe1ed11 --- /dev/null +++ b/.gitignore @@ -0,0 +1,31 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +**/__pycache__/ +*.py[cod] +**/*.py[cod] +*$py.class + +# Model weights +**/*.pth +**/*.onnx + +# Ipython notebook +*.ipynb + +# Temporary files or benchmark resources +animations/* +tmp/* + +# more ignores +.DS_Store +*.log +.idea/ +.vscode/ +*.pyc +.ipynb_checkpoints +results/ +data/audio/*.wav +data/video/*.mp4 +ffmpeg-7.0-amd64-static +venv/ +.cog/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e37074ec26864bc0d63ae978d58bbd2c379ee7cd --- /dev/null +++ b/Dockerfile @@ -0,0 +1,67 @@ +FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 + +ARG DEBIAN_FRONTEND=noninteractive + +ENV PYTHONUNBUFFERED=1 + +RUN apt-get update && apt-get install --no-install-recommends -y \ + build-essential \ + python3.11 \ + python3-pip \ + python3-dev \ + git \ + curl \ + ffmpeg \ + libglib2.0-0 \ + libsm6 \ + libxrender1 \ + libxext6 \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +WORKDIR /code + +COPY ./requirements.txt /code/requirements.txt + +# Install pget as root +RUN echo "Installing pget" && \ + curl -o /usr/local/bin/pget -L 'https://github.com/replicate/pget/releases/download/v0.2.1/pget' && \ + chmod +x /usr/local/bin/pget + +# Set up a new user named "user" with user ID 1000 +RUN useradd -m -u 1000 user +# Switch to the "user" user +USER user +# Set home to the user's home directory +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + + +# Set home to the user's home directory +ENV PYTHONPATH=$HOME/app \ + PYTHONUNBUFFERED=1 \ + DATA_ROOT=/tmp/data + +RUN echo "Installing requirements.txt" +RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt + +# yeah.. this is manual for now +#RUN cd client +#RUN bun i +#RUN bun build ./src/index.tsx --outdir ../public/ + +RUN echo "Installing openmim and mim dependencies" +RUN pip3 install --no-cache-dir -U openmim +RUN mim install mmengine +RUN mim install "mmcv>=2.0.1" +RUN mim install "mmdet>=3.3.0" +RUN mim install "mmpose>=1.3.2" + +WORKDIR $HOME/app + +COPY --chown=user . $HOME/app + +EXPOSE 8080 + +ENV PORT 8080 + +CMD python3 app.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e251a69d97bd1e640bb375911bf771487f90da2d --- /dev/null +++ b/LICENSE @@ -0,0 +1,56 @@ +## For FacePoke (the modifications I made + the server itself) + +MIT License + +Copyright (c) 2024 Julian Bilcke + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +## For LivePortrait + +MIT License + +Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +--- + +The code of InsightFace is released under the MIT License. +The models of InsightFace are for non-commercial research purposes only. + +If you want to use the LivePortrait project for commercial purposes, you +should remove and replace InsightFace’s detection models to fully comply with +the MIT license. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e61cf4a4069b5309e8ff8bafbe10f0cbfbaa989b --- /dev/null +++ b/README.md @@ -0,0 +1,119 @@ +--- +title: FacePoke +emoji: 💬 +colorFrom: yellow +colorTo: red +sdk: docker +pinned: true +license: mit +header: mini +app_file: app.py +app_port: 8080 +--- + +# FacePoke + +![FacePoke Demo](https://your-demo-image-url-here.gif) + +## Table of Contents + +- [Introduction](#introduction) +- [Acknowledgements](#acknowledgements) +- [Installation](#installation) + - [Local Setup](#local-setup) + - [Docker Deployment](#docker-deployment) +- [Development](#development) +- [Contributing](#contributing) +- [License](#license) + +## Introduction + +A real-time head transformation app. + +For best performance please run the app from your own machine (local or in the cloud). + +**Repository**: [GitHub - jbilcke-hf/FacePoke](https://github.com/jbilcke-hf/FacePoke) + +You can try the demo but it is a shared space, latency may be high if there are multiple users or if you live far from the datacenter hosting the Hugging Face Space. + +**Live Demo**: [FacePoke on Hugging Face Spaces](https://huggingface.co/spaces/jbilcke-hf/FacePoke) + +## Acknowledgements + +This project is based on LivePortrait: https://arxiv.org/abs/2407.03168 + +It uses the face transformation routines from https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait + +## Installation + +### Local Setup + +1. Clone the repository: + ```bash + git clone https://github.com/jbilcke-hf/FacePoke.git + cd FacePoke + ``` + +2. Install Python dependencies: + ```bash + pip install -r requirements.txt + ``` + +3. Install frontend dependencies: + ```bash + cd client + bun install + ``` + +4. Build the frontend: + ```bash + bun build ./src/index.tsx --outdir ../public/ + ``` + +5. Start the backend server: + ```bash + python app.py + ``` + +6. Open `http://localhost:8080` in your web browser. + +### Docker Deployment + +1. Build the Docker image: + ```bash + docker build -t facepoke . + ``` + +2. Run the container: + ```bash + docker run -p 8080:8080 facepoke + ``` + +3. To deploy to Hugging Face Spaces: + - Fork the repository on GitHub. + - Create a new Space on Hugging Face. + - Connect your GitHub repository to the Space. + - Configure the Space to use the Docker runtime. + +## Development + +The project structure is organized as follows: + +- `app.py`: Main backend server handling WebSocket connections. +- `engine.py`: Core logic. +- `loader.py`: Initializes and loads AI models. +- `client/`: Frontend React application. + - `src/`: TypeScript source files. + - `public/`: Static assets and built files. + +## Contributing + +Contributions to FacePoke are welcome! Please read our [Contributing Guidelines](CONTRIBUTING.md) for details on how to submit pull requests, report issues, or request features. + +## License + +FacePoke is released under the MIT License. See the [LICENSE](LICENSE) file for details. + +--- + +Developed with ❤️ by Julian Bilcke at Hugging Face diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5499d985b6adf21810d1c80781b681d5ce87ccbc --- /dev/null +++ b/app.py @@ -0,0 +1,264 @@ +""" +FacePoke API + +Author: Julian Bilcke +Date: September 30, 2024 +""" + +import sys +import asyncio +import hashlib +from aiohttp import web, WSMsgType +import json +import uuid +import logging +import os +import zipfile +import signal +from typing import Dict, Any, List, Optional +import base64 +import io +from PIL import Image +import numpy as np + +# Configure logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Set asyncio logger to DEBUG level +logging.getLogger("asyncio").setLevel(logging.DEBUG) + +logger.debug(f"Python version: {sys.version}") + +# SIGSEGV handler +def SIGSEGV_signal_arises(signalNum, stack): + logger.critical(f"{signalNum} : SIGSEGV arises") + logger.critical(f"Stack trace: {stack}") + +signal.signal(signal.SIGSEGV, SIGSEGV_signal_arises) + +from loader import initialize_models +from engine import Engine, base64_data_uri_to_PIL_Image, create_engine + +# Global constants +DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') +MODELS_DIR = os.path.join(DATA_ROOT, "models") + +image_cache: Dict[str, Image.Image] = {} + +async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + """ + Handle WebSocket connections for the FacePoke application. + + Args: + request (web.Request): The incoming request object. + + Returns: + web.WebSocketResponse: The WebSocket response object. + """ + ws = web.WebSocketResponse() + await ws.prepare(request) + + session: Optional[FacePokeSession] = None + try: + logger.info("New WebSocket connection established") + + while True: + msg = await ws.receive() + + if msg.type == WSMsgType.TEXT: + data = json.loads(msg.data) + + # let's not log user requests, they are heavy + #logger.debug(f"Received message: {data}") + + if data['type'] == 'modify_image': + uuid = data.get('uuid') + if not uuid: + logger.warning("Received message without UUID") + + await handle_modify_image(request, ws, data, uuid) + + + elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): + logger.warning(f"WebSocket connection closed: {msg.type}") + break + + except Exception as e: + logger.error(f"Error in websocket_handler: {str(e)}") + logger.exception("Full traceback:") + finally: + if session: + await session.stop() + del active_sessions[session.session_id] + logger.info("WebSocket connection closed") + return ws + +async def handle_modify_image(request: web.Request, ws: web.WebSocketResponse, msg: Dict[str, Any], uuid: str): + """ + Handle the 'modify_image' request. + + Args: + request (web.Request): The incoming request object. + ws (web.WebSocketResponse): The WebSocket response object. + msg (Dict[str, Any]): The message containing the image or image_hash and modification parameters. + uuid: A unique identifier for the request. + """ + logger.info("Received modify_image request") + try: + engine = request.app['engine'] + image_hash = msg.get('image_hash') + + if image_hash: + image_or_hash = image_hash + else: + image_data = msg['image'] + image_or_hash = image_data + + modified_image_base64 = await engine.modify_image(image_or_hash, msg['params']) + + await ws.send_json({ + "type": "modified_image", + "image": modified_image_base64, + "image_hash": engine.get_image_hash(image_or_hash), + "success": True, + "uuid": uuid # Include the UUID in the response + }) + logger.info("Successfully sent modified image") + except Exception as e: + logger.error(f"Error in modify_image: {str(e)}") + await ws.send_json({ + "type": "modified_image", + "success": False, + "error": str(e), + "uuid": uuid # Include the UUID even in error responses + }) + +async def index(request: web.Request) -> web.Response: + """Serve the index.html file""" + content = open(os.path.join(os.path.dirname(__file__), "public", "index.html"), "r").read() + return web.Response(content_type="text/html", text=content) + +async def js_index(request: web.Request) -> web.Response: + """Serve the index.js file""" + content = open(os.path.join(os.path.dirname(__file__), "public", "index.js"), "r").read() + return web.Response(content_type="application/javascript", text=content) + +async def hf_logo(request: web.Request) -> web.Response: + """Serve the hf-logo.svg file""" + content = open(os.path.join(os.path.dirname(__file__), "public", "hf-logo.svg"), "r").read() + return web.Response(content_type="image/svg+xml", text=content) + +async def on_shutdown(app: web.Application): + """Cleanup function to be called on server shutdown.""" + logger.info("Server shutdown initiated, cleaning up resources...") + for session in list(active_sessions.values()): + await session.stop() + active_sessions.clear() + logger.info("All active sessions have been closed") + + if 'engine' in app: + await app['engine'].cleanup() + logger.info("Engine instance cleaned up") + + logger.info("Server shutdown complete") + +async def initialize_app() -> web.Application: + """Initialize and configure the web application.""" + try: + logger.info("Initializing application...") + models = await initialize_models() + logger.info("🚀 Creating Engine instance...") + engine = create_engine(models) + logger.info("✅ Engine instance created.") + + app = web.Application() + app['engine'] = engine + + app.on_shutdown.append(on_shutdown) + + # Configure routes + app.router.add_get("/", index) + app.router.add_get("/index.js", js_index) + app.router.add_get("/hf-logo.svg", hf_logo) + app.router.add_get("/ws", websocket_handler) + + logger.info("Application routes configured") + + return app + except Exception as e: + logger.error(f"🚨 Error during application initialization: {str(e)}") + logger.exception("Full traceback:") + raise + +async def start_background_tasks(app: web.Application): + """ + Start background tasks for the application. + + Args: + app (web.Application): The web application instance. + """ + app['cleanup_task'] = asyncio.create_task(periodic_cleanup(app)) + +async def cleanup_background_tasks(app: web.Application): + """ + Clean up background tasks when the application is shutting down. + + Args: + app (web.Application): The web application instance. + """ + app['cleanup_task'].cancel() + await app['cleanup_task'] + +async def periodic_cleanup(app: web.Application): + """ + Perform periodic cleanup tasks for the application. + + Args: + app (web.Application): The web application instance. + """ + while True: + try: + await asyncio.sleep(3600) # Run cleanup every hour + await cleanup_inactive_sessions(app) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in periodic cleanup: {str(e)}") + logger.exception("Full traceback:") + +async def cleanup_inactive_sessions(app: web.Application): + """ + Clean up inactive sessions. + + Args: + app (web.Application): The web application instance. + """ + logger.info("Starting cleanup of inactive sessions") + inactive_sessions = [ + session_id for session_id, session in active_sessions.items() + if not session.is_running.is_set() + ] + for session_id in inactive_sessions: + session = active_sessions.pop(session_id) + await session.stop() + logger.info(f"Cleaned up inactive session: {session_id}") + logger.info(f"Cleaned up {len(inactive_sessions)} inactive sessions") + +def main(): + """ + Main function to start the FacePoke application. + """ + try: + logger.info("Starting FacePoke application") + app = asyncio.run(initialize_app()) + app.on_startup.append(start_background_tasks) + app.on_cleanup.append(cleanup_background_tasks) + logger.info("Application initialized, starting web server") + web.run_app(app, host="0.0.0.0", port=8080) + except Exception as e: + logger.critical(f"🚨 FATAL: Failed to start the app: {str(e)}") + logger.exception("Full traceback:") + +if __name__ == "__main__": + main() diff --git a/build.sh b/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d88a73562f0262957177921f17c456fce8cc3de --- /dev/null +++ b/build.sh @@ -0,0 +1,3 @@ +cd client +bun i +bun build ./src/index.tsx --outdir ../public/ diff --git a/client/.gitignore b/client/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9b1ee42e8482ee945f6adc02615ef33db69a0c2f --- /dev/null +++ b/client/.gitignore @@ -0,0 +1,175 @@ +# Based on https://raw.githubusercontent.com/github/gitignore/main/Node.gitignore + +# Logs + +logs +_.log +npm-debug.log_ +yarn-debug.log* +yarn-error.log* +lerna-debug.log* +.pnpm-debug.log* + +# Caches + +.cache + +# Diagnostic reports (https://nodejs.org/api/report.html) + +report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json + +# Runtime data + +pids +_.pid +_.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover + +lib-cov + +# Coverage directory used by tools like istanbul + +coverage +*.lcov + +# nyc test coverage + +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) + +.grunt + +# Bower dependency directory (https://bower.io/) + +bower_components + +# node-waf configuration + +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) + +build/Release + +# Dependency directories + +node_modules/ +jspm_packages/ + +# Snowpack dependency directory (https://snowpack.dev/) + +web_modules/ + +# TypeScript cache + +*.tsbuildinfo + +# Optional npm cache directory + +.npm + +# Optional eslint cache + +.eslintcache + +# Optional stylelint cache + +.stylelintcache + +# Microbundle cache + +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history + +.node_repl_history + +# Output of 'npm pack' + +*.tgz + +# Yarn Integrity file + +.yarn-integrity + +# dotenv environment variable files + +.env +.env.development.local +.env.test.local +.env.production.local +.env.local + +# parcel-bundler cache (https://parceljs.org/) + +.parcel-cache + +# Next.js build output + +.next +out + +# Nuxt.js build / generate output + +.nuxt +dist + +# Gatsby files + +# Comment in the public line in if your project uses Gatsby and not Next.js + +# https://nextjs.org/blog/next-9-1#public-directory-support + +# public + +# vuepress build output + +.vuepress/dist + +# vuepress v2.x temp and cache directory + +.temp + +# Docusaurus cache and generated files + +.docusaurus + +# Serverless directories + +.serverless/ + +# FuseBox cache + +.fusebox/ + +# DynamoDB Local files + +.dynamodb/ + +# TernJS port file + +.tern-port + +# Stores VSCode versions used for testing VSCode extensions + +.vscode-test + +# yarn v2 + +.yarn/cache +.yarn/unplugged +.yarn/build-state.yml +.yarn/install-state.gz +.pnp.* + +# IntelliJ based IDEs +.idea + +# Finder (MacOS) folder config +.DS_Store diff --git a/client/README.md b/client/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b4dabd8f19d2545c23550eeb8e11595ad86f3955 --- /dev/null +++ b/client/README.md @@ -0,0 +1,13 @@ +# FacePoke.js + +To install dependencies: + +```bash +bun i +``` + +To build: + +```bash +bun build ./src/index.tsx --outdir ../public +``` diff --git a/client/bun.lockb b/client/bun.lockb new file mode 100755 index 0000000000000000000000000000000000000000..2d76d6e717ba59a69fa2fa05f97da435a3fddcbd Binary files /dev/null and b/client/bun.lockb differ diff --git a/client/package.json b/client/package.json new file mode 100644 index 0000000000000000000000000000000000000000..3d3c437d70b7ab1b1201806616b5e52e01b205cc --- /dev/null +++ b/client/package.json @@ -0,0 +1,35 @@ +{ + "name": "@aitube/facepoke", + "module": "src/index.ts", + "type": "module", + "scripts": { + "build": "bun build ./src/index.tsx --outdir ../public/" + }, + "devDependencies": { + "@types/bun": "latest" + }, + "peerDependencies": { + "typescript": "^5.0.0" + }, + "dependencies": { + "@mediapipe/tasks-vision": "^0.10.16", + "@radix-ui/react-icons": "^1.3.0", + "@types/lodash": "^4.17.10", + "@types/react": "^18.3.9", + "@types/react-dom": "^18.3.0", + "@types/uuid": "^10.0.0", + "beautiful-react-hooks": "^5.0.2", + "class-variance-authority": "^0.7.0", + "clsx": "^2.1.1", + "lodash": "^4.17.21", + "lucide-react": "^0.446.0", + "react": "^18.3.1", + "react-dom": "^18.3.1", + "tailwind-merge": "^2.5.2", + "tailwindcss": "^3.4.13", + "tailwindcss-animate": "^1.0.7", + "usehooks-ts": "^3.1.0", + "uuid": "^10.0.0", + "zustand": "^5.0.0-rc.2" + } +} diff --git a/client/src/app.tsx b/client/src/app.tsx new file mode 100644 index 0000000000000000000000000000000000000000..f7ea7b827f1f3724fed5309befa8b93c82536dd2 --- /dev/null +++ b/client/src/app.tsx @@ -0,0 +1,190 @@ +import React, { useState, useEffect, useRef, useCallback, useMemo } from 'react'; +import { RotateCcw } from 'lucide-react'; + +import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'; +import { truncateFileName } from './lib/utils'; +import { useFaceLandmarkDetection } from './hooks/useFaceLandmarkDetection'; +import { PoweredBy } from './components/PoweredBy'; +import { Spinner } from './components/Spinner'; +import { DoubleCard } from './components/DoubleCard'; +import { useFacePokeAPI } from './hooks/useFacePokeAPI'; +import { Layout } from './layout'; +import { useMainStore } from './hooks/useMainStore'; +import { convertImageToBase64 } from './lib/convertImageToBase64'; + +export function App() { + const error = useMainStore(s => s.error); + const setError = useMainStore(s => s.setError); + const imageFile = useMainStore(s => s.imageFile); + const setImageFile = useMainStore(s => s.setImageFile); + const originalImage = useMainStore(s => s.originalImage); + const setOriginalImage = useMainStore(s => s.setOriginalImage); + const previewImage = useMainStore(s => s.previewImage); + const setPreviewImage = useMainStore(s => s.setPreviewImage); + const resetImage = useMainStore(s => s.resetImage); + + const { + status, + setStatus, + isDebugMode, + setIsDebugMode, + interruptMessage, + } = useFacePokeAPI() + + // State for face detection + const { + canvasRef, + canvasRefCallback, + mediaPipeRef, + faceLandmarks, + isMediaPipeReady, + blendShapes, + + setFaceLandmarks, + setBlendShapes, + + handleMouseDown, + handleMouseUp, + handleMouseMove, + handleMouseEnter, + handleMouseLeave, + currentOpacity + } = useFaceLandmarkDetection() + + // Refs + const videoRef = useRef(null); + + // Handle file change + const handleFileChange = useCallback(async (event: React.ChangeEvent) => { + const files = event.target.files; + if (files && files[0]) { + setImageFile(files[0]); + setStatus(`File selected: ${truncateFileName(files[0].name, 16)}`); + + try { + const image = await convertImageToBase64(files[0]); + setPreviewImage(image); + setOriginalImage(image); + } catch (err) { + console.log(`failed to convert the image: `, err); + setImageFile(null); + setStatus(''); + setPreviewImage(''); + setOriginalImage(''); + setFaceLandmarks([]); + setBlendShapes([]); + } + } else { + setImageFile(null); + setStatus(''); + setPreviewImage(''); + setOriginalImage(''); + setFaceLandmarks([]); + setBlendShapes([]); + } + }, [isMediaPipeReady, setImageFile, setPreviewImage, setOriginalImage, setFaceLandmarks, setBlendShapes, setStatus]); + + const canDisplayBlendShapes = false + + // Display blend shapes + const displayBlendShapes = useMemo(() => ( +
+

Blend Shapes

+
    + {(blendShapes?.[0]?.categories || []).map((shape, index) => ( +
  • + {shape.categoryName || shape.displayName} +
    +
    +
    + {shape.score.toFixed(2)} +
  • + ))} +
+
+ ), [JSON.stringify(blendShapes)]) + + // JSX + return ( + + {error && ( + + Error + {error} + + )} + {interruptMessage && ( + + Notice + {interruptMessage} + + )} +
+
+
+ + +
+ {previewImage && } +
+ {previewImage && ( +
+ Preview + +
+ )} + {canDisplayBlendShapes && displayBlendShapes} +
+ + +
+ ); +} diff --git a/client/src/components/DoubleCard.tsx b/client/src/components/DoubleCard.tsx new file mode 100644 index 0000000000000000000000000000000000000000..247ecc959f78b2f26602bd3e437fabd7e69e7b3f --- /dev/null +++ b/client/src/components/DoubleCard.tsx @@ -0,0 +1,18 @@ +import React, { type ReactNode } from 'react'; + +export function DoubleCard({ children }: { children: ReactNode }) { + return ( + <> +
+
+
+
+
+ {children} +
+
+
+
+ + ); +} diff --git a/client/src/components/PoweredBy.tsx b/client/src/components/PoweredBy.tsx new file mode 100644 index 0000000000000000000000000000000000000000..5dddccb160a223a9851bc80eec2881ae0c9b59d7 --- /dev/null +++ b/client/src/components/PoweredBy.tsx @@ -0,0 +1,17 @@ +export function PoweredBy() { + return ( +
+ {/* + Best hosted on + */} + + Hugging Face + + + Hugging Face + +
+ ) +} diff --git a/client/src/components/Spinner.tsx b/client/src/components/Spinner.tsx new file mode 100644 index 0000000000000000000000000000000000000000..008f9a394aedcc4e7e77430279cead8f373bcfdc --- /dev/null +++ b/client/src/components/Spinner.tsx @@ -0,0 +1,7 @@ +export function Spinner() { + return ( + + + + ) +} diff --git a/client/src/components/Title.tsx b/client/src/components/Title.tsx new file mode 100644 index 0000000000000000000000000000000000000000..408a6c71b525a593407465890262b3a17635b22c --- /dev/null +++ b/client/src/components/Title.tsx @@ -0,0 +1,8 @@ +export function Title() { + return ( +

+ 💬 FacePoke +

+ ) +} diff --git a/client/src/components/ui/alert.tsx b/client/src/components/ui/alert.tsx new file mode 100644 index 0000000000000000000000000000000000000000..41fa7e0561a3fdb5f986c1213a35e563de740e96 --- /dev/null +++ b/client/src/components/ui/alert.tsx @@ -0,0 +1,59 @@ +import * as React from "react" +import { cva, type VariantProps } from "class-variance-authority" + +import { cn } from "@/lib/utils" + +const alertVariants = cva( + "relative w-full rounded-lg border p-4 [&>svg~*]:pl-7 [&>svg+div]:translate-y-[-3px] [&>svg]:absolute [&>svg]:left-4 [&>svg]:top-4 [&>svg]:text-foreground", + { + variants: { + variant: { + default: "bg-background text-foreground", + destructive: + "border-destructive/50 text-destructive dark:border-destructive [&>svg]:text-destructive", + }, + }, + defaultVariants: { + variant: "default", + }, + } +) + +const Alert = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes & VariantProps +>(({ className, variant, ...props }, ref) => ( +
+)) +Alert.displayName = "Alert" + +const AlertTitle = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +AlertTitle.displayName = "AlertTitle" + +const AlertDescription = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +AlertDescription.displayName = "AlertDescription" + +export { Alert, AlertTitle, AlertDescription } diff --git a/client/src/hooks/landmarks.ts b/client/src/hooks/landmarks.ts new file mode 100644 index 0000000000000000000000000000000000000000..e8b30c694925b4f7d8b890a1959daa5a3a6cc14b --- /dev/null +++ b/client/src/hooks/landmarks.ts @@ -0,0 +1,520 @@ +import * as vision from '@mediapipe/tasks-vision'; + +// Define unique colors for each landmark group +export const landmarkColors: { [key: string]: string } = { + lips: '#FF0000', + leftEye: '#00FF00', + leftEyebrow: '#0000FF', + leftIris: '#FFFF00', + rightEye: '#FF00FF', + rightEyebrow: '#00FFFF', + rightIris: '#FFA500', + faceOval: '#800080', + tesselation: '#C0C0C0', +}; + +// Define landmark groups with their semantic names +export const landmarkGroups: { [key: string]: any } = { + lips: vision.FaceLandmarker.FACE_LANDMARKS_LIPS, + leftEye: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_EYE, + leftEyebrow: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_EYEBROW, + leftIris: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_IRIS, + rightEye: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_EYE, + rightEyebrow: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_EYEBROW, + rightIris: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_IRIS, + faceOval: vision.FaceLandmarker.FACE_LANDMARKS_FACE_OVAL, + // tesselation: vision.FaceLandmarker.FACE_LANDMARKS_TESSELATION, +}; + +export const FACEMESH_LIPS = Object.freeze(new Set([[61, 146], [146, 91], [91, 181], [181, 84], [84, 17], + [17, 314], [314, 405], [405, 321], [321, 375], + [375, 291], [61, 185], [185, 40], [40, 39], [39, 37], + [37, 0], [0, 267], + [267, 269], [269, 270], [270, 409], [409, 291], + [78, 95], [95, 88], [88, 178], [178, 87], [87, 14], + [14, 317], [317, 402], [402, 318], [318, 324], + [324, 308], [78, 191], [191, 80], [80, 81], [81, 82], + [82, 13], [13, 312], [312, 311], [311, 310], + [310, 415], [415, 308]])) + +export const FACEMESH_LEFT_EYE = Object.freeze(new Set([[263, 249], [249, 390], [390, 373], [373, 374], + [374, 380], [380, 381], [381, 382], [382, 362], + [263, 466], [466, 388], [388, 387], [387, 386], + [386, 385], [385, 384], [384, 398], [398, 362]])) + +export const FACEMESH_LEFT_IRIS = Object.freeze(new Set([[474, 475], [475, 476], [476, 477], + [477, 474]])) + +export const FACEMESH_LEFT_EYEBROW = Object.freeze(new Set([[276, 283], [283, 282], [282, 295], + [295, 285], [300, 293], [293, 334], + [334, 296], [296, 336]])) + +export const FACEMESH_RIGHT_EYE = Object.freeze(new Set([[33, 7], [7, 163], [163, 144], [144, 145], + [145, 153], [153, 154], [154, 155], [155, 133], + [33, 246], [246, 161], [161, 160], [160, 159], + [159, 158], [158, 157], [157, 173], [173, 133]])) + +export const FACEMESH_RIGHT_EYEBROW = Object.freeze(new Set([[46, 53], [53, 52], [52, 65], [65, 55], + [70, 63], [63, 105], [105, 66], [66, 107]])) + +export const FACEMESH_RIGHT_IRIS = Object.freeze(new Set([[469, 470], [470, 471], [471, 472], + [472, 469]])) + +export const FACEMESH_FACE_OVAL = Object.freeze(new Set([[10, 338], [338, 297], [297, 332], [332, 284], + [284, 251], [251, 389], [389, 356], [356, 454], + [454, 323], [323, 361], [361, 288], [288, 397], + [397, 365], [365, 379], [379, 378], [378, 400], + [400, 377], [377, 152], [152, 148], [148, 176], + [176, 149], [149, 150], [150, 136], [136, 172], + [172, 58], [58, 132], [132, 93], [93, 234], + [234, 127], [127, 162], [162, 21], [21, 54], + [54, 103], [103, 67], [67, 109], [109, 10]])) + +export const FACEMESH_NOSE = Object.freeze(new Set([[168, 6], [6, 197], [197, 195], [195, 5], + [5, 4], [4, 1], [1, 19], [19, 94], [94, 2], [98, 97], + [97, 2], [2, 326], [326, 327], [327, 294], + [294, 278], [278, 344], [344, 440], [440, 275], + [275, 4], [4, 45], [45, 220], [220, 115], [115, 48], + [48, 64], [64, 98]])) + +export const FACEMESH_CONTOURS = Object.freeze(new Set([ + ...FACEMESH_LIPS, + ...FACEMESH_LEFT_EYE, + ...FACEMESH_LEFT_EYEBROW, + ...FACEMESH_RIGHT_EYE, + ...FACEMESH_RIGHT_EYEBROW, + ...FACEMESH_FACE_OVAL +])); + +export const FACEMESH_IRISES = Object.freeze(new Set([ + ...FACEMESH_LEFT_IRIS, + ...FACEMESH_RIGHT_IRIS +])); + +export const FACEMESH_TESSELATION = Object.freeze(new Set([ + [127, 34], [34, 139], [139, 127], [11, 0], [0, 37], [37, 11], + [232, 231], [231, 120], [120, 232], [72, 37], [37, 39], [39, 72], + [128, 121], [121, 47], [47, 128], [232, 121], [121, 128], [128, 232], + [104, 69], [69, 67], [67, 104], [175, 171], [171, 148], [148, 175], + [118, 50], [50, 101], [101, 118], [73, 39], [39, 40], [40, 73], + [9, 151], [151, 108], [108, 9], [48, 115], [115, 131], [131, 48], + [194, 204], [204, 211], [211, 194], [74, 40], [40, 185], [185, 74], + [80, 42], [42, 183], [183, 80], [40, 92], [92, 186], [186, 40], + [230, 229], [229, 118], [118, 230], [202, 212], [212, 214], [214, 202], + [83, 18], [18, 17], [17, 83], [76, 61], [61, 146], [146, 76], + [160, 29], [29, 30], [30, 160], [56, 157], [157, 173], [173, 56], + [106, 204], [204, 194], [194, 106], [135, 214], [214, 192], [192, 135], + [203, 165], [165, 98], [98, 203], [21, 71], [71, 68], [68, 21], + [51, 45], [45, 4], [4, 51], [144, 24], [24, 23], [23, 144], + [77, 146], [146, 91], [91, 77], [205, 50], [50, 187], [187, 205], + [201, 200], [200, 18], [18, 201], [91, 106], [106, 182], [182, 91], + [90, 91], [91, 181], [181, 90], [85, 84], [84, 17], [17, 85], + [206, 203], [203, 36], [36, 206], [148, 171], [171, 140], [140, 148], + [92, 40], [40, 39], [39, 92], [193, 189], [189, 244], [244, 193], + [159, 158], [158, 28], [28, 159], [247, 246], [246, 161], [161, 247], + [236, 3], [3, 196], [196, 236], [54, 68], [68, 104], [104, 54], + [193, 168], [168, 8], [8, 193], [117, 228], [228, 31], [31, 117], + [189, 193], [193, 55], [55, 189], [98, 97], [97, 99], [99, 98], + [126, 47], [47, 100], [100, 126], [166, 79], [79, 218], [218, 166], + [155, 154], [154, 26], [26, 155], [209, 49], [49, 131], [131, 209], + [135, 136], [136, 150], [150, 135], [47, 126], [126, 217], [217, 47], + [223, 52], [52, 53], [53, 223], [45, 51], [51, 134], [134, 45], + [211, 170], [170, 140], [140, 211], [67, 69], [69, 108], [108, 67], + [43, 106], [106, 91], [91, 43], [230, 119], [119, 120], [120, 230], + [226, 130], [130, 247], [247, 226], [63, 53], [53, 52], [52, 63], + [238, 20], [20, 242], [242, 238], [46, 70], [70, 156], [156, 46], + [78, 62], [62, 96], [96, 78], [46, 53], [53, 63], [63, 46], + [143, 34], [34, 227], [227, 143], [123, 117], [117, 111], [111, 123], + [44, 125], [125, 19], [19, 44], [236, 134], [134, 51], [51, 236], + [216, 206], [206, 205], [205, 216], [154, 153], [153, 22], [22, 154], + [39, 37], [37, 167], [167, 39], [200, 201], [201, 208], [208, 200], + [36, 142], [142, 100], [100, 36], [57, 212], [212, 202], [202, 57], + [20, 60], [60, 99], [99, 20], [28, 158], [158, 157], [157, 28], + [35, 226], [226, 113], [113, 35], [160, 159], [159, 27], [27, 160], + [204, 202], [202, 210], [210, 204], [113, 225], [225, 46], [46, 113], + [43, 202], [202, 204], [204, 43], [62, 76], [76, 77], [77, 62], + [137, 123], [123, 116], [116, 137], [41, 38], [38, 72], [72, 41], + [203, 129], [129, 142], [142, 203], [64, 98], [98, 240], [240, 64], + [49, 102], [102, 64], [64, 49], [41, 73], [73, 74], [74, 41], + [212, 216], [216, 207], [207, 212], [42, 74], [74, 184], [184, 42], + [169, 170], [170, 211], [211, 169], [170, 149], [149, 176], [176, 170], + [105, 66], [66, 69], [69, 105], [122, 6], [6, 168], [168, 122], + [123, 147], [147, 187], [187, 123], [96, 77], [77, 90], [90, 96], + [65, 55], [55, 107], [107, 65], [89, 90], [90, 180], [180, 89], + [101, 100], [100, 120], [120, 101], [63, 105], [105, 104], [104, 63], + [93, 137], [137, 227], [227, 93], [15, 86], [86, 85], [85, 15], + [129, 102], [102, 49], [49, 129], [14, 87], [87, 86], [86, 14], + [55, 8], [8, 9], [9, 55], [100, 47], [47, 121], [121, 100], + [145, 23], [23, 22], [22, 145], [88, 89], [89, 179], [179, 88], + [6, 122], [122, 196], [196, 6], [88, 95], [95, 96], [96, 88], + [138, 172], [172, 136], [136, 138], [215, 58], [58, 172], [172, 215], + [115, 48], [48, 219], [219, 115], [42, 80], [80, 81], [81, 42], + [195, 3], [3, 51], [51, 195], [43, 146], [146, 61], [61, 43], + [171, 175], [175, 199], [199, 171], [81, 82], [82, 38], [38, 81], + [53, 46], [46, 225], [225, 53], [144, 163], [163, 110], [110, 144], + [52, 65], [65, 66], [66, 52], [229, 228], [228, 117], [117, 229], + [34, 127], [127, 234], [234, 34], [107, 108], [108, 69], [69, 107], + [109, 108], [108, 151], [151, 109], [48, 64], [64, 235], [235, 48], + [62, 78], [78, 191], [191, 62], [129, 209], [209, 126], [126, 129], + [111, 35], [35, 143], [143, 111], [117, 123], [123, 50], [50, 117], + [222, 65], [65, 52], [52, 222], [19, 125], [125, 141], [141, 19], + [221, 55], [55, 65], [65, 221], [3, 195], [195, 197], [197, 3], + [25, 7], [7, 33], [33, 25], [220, 237], [237, 44], [44, 220], + [70, 71], [71, 139], [139, 70], [122, 193], [193, 245], [245, 122], + [247, 130], [130, 33], [33, 247], [71, 21], [21, 162], [162, 71], + [170, 169], [169, 150], [150, 170], [188, 174], [174, 196], [196, 188], + [216, 186], [186, 92], [92, 216], [2, 97], [97, 167], [167, 2], + [141, 125], [125, 241], [241, 141], [164, 167], [167, 37], [37, 164], + [72, 38], [38, 12], [12, 72], [38, 82], [82, 13], [13, 38], + [63, 68], [68, 71], [71, 63], [226, 35], [35, 111], [111, 226], + [101, 50], [50, 205], [205, 101], [206, 92], [92, 165], [165, 206], + [209, 198], [198, 217], [217, 209], [165, 167], [167, 97], [97, 165], + [220, 115], [115, 218], [218, 220], [133, 112], [112, 243], [243, 133], + [239, 238], [238, 241], [241, 239], [214, 135], [135, 169], [169, 214], + [190, 173], [173, 133], [133, 190], [171, 208], [208, 32], [32, 171], + [125, 44], [44, 237], [237, 125], [86, 87], [87, 178], [178, 86], + [85, 86], [86, 179], [179, 85], [84, 85], [85, 180], [180, 84], + [83, 84], [84, 181], [181, 83], [201, 83], [83, 182], [182, 201], + [137, 93], [93, 132], [132, 137], [76, 62], [62, 183], [183, 76], + [61, 76], [76, 184], [184, 61], [57, 61], [61, 185], [185, 57], + [212, 57], [57, 186], [186, 212], [214, 207], [207, 187], [187, 214], + [34, 143], [143, 156], [156, 34], [79, 239], [239, 237], [237, 79], + [123, 137], [137, 177], [177, 123], [44, 1], [1, 4], [4, 44], + [201, 194], [194, 32], [32, 201], [64, 102], [102, 129], [129, 64], + [213, 215], [215, 138], [138, 213], [59, 166], [166, 219], [219, 59], + [242, 99], [99, 97], [97, 242], [2, 94], [94, 141], [141, 2], + [75, 59], [59, 235], [235, 75], [24, 110], [110, 228], [228, 24], + [25, 130], [130, 226], [226, 25], [23, 24], [24, 229], [229, 23], + [22, 23], [23, 230], [230, 22], [26, 22], [22, 231], [231, 26], + [112, 26], [26, 232], [232, 112], [189, 190], [190, 243], [243, 189], + [221, 56], [56, 190], [190, 221], [28, 56], [56, 221], [221, 28], + [27, 28], [28, 222], [222, 27], [29, 27], [27, 223], [223, 29], + [30, 29], [29, 224], [224, 30], [247, 30], [30, 225], [225, 247], + [238, 79], [79, 20], [20, 238], [166, 59], [59, 75], [75, 166], + [60, 75], [75, 240], [240, 60], [147, 177], [177, 215], [215, 147], + [20, 79], [79, 166], [166, 20], [187, 147], [147, 213], [213, 187], + [112, 233], [233, 244], [244, 112], [233, 128], [128, 245], [245, 233], + [128, 114], [114, 188], [188, 128], [114, 217], [217, 174], [174, 114], + [131, 115], [115, 220], [220, 131], [217, 198], [198, 236], [236, 217], + [198, 131], [131, 134], [134, 198], [177, 132], [132, 58], [58, 177], + [143, 35], [35, 124], [124, 143], [110, 163], [163, 7], [7, 110], + [228, 110], [110, 25], [25, 228], [356, 389], [389, 368], [368, 356], + [11, 302], [302, 267], [267, 11], [452, 350], [350, 349], [349, 452], + [302, 303], [303, 269], [269, 302], [357, 343], [343, 277], [277, 357], + [452, 453], [453, 357], [357, 452], [333, 332], [332, 297], [297, 333], + [175, 152], [152, 377], [377, 175], [347, 348], [348, 330], [330, 347], + [303, 304], [304, 270], [270, 303], [9, 336], [336, 337], [337, 9], + [278, 279], [279, 360], [360, 278], [418, 262], [262, 431], [431, 418], + [304, 408], [408, 409], [409, 304], [310, 415], [415, 407], [407, 310], + [270, 409], [409, 410], [410, 270], [450, 348], [348, 347], [347, 450], + [422, 430], [430, 434], [434, 422], [313, 314], [314, 17], [17, 313], + [306, 307], [307, 375], [375, 306], [387, 388], [388, 260], [260, 387], + [286, 414], [414, 398], [398, 286], [335, 406], [406, 418], [418, 335], + [364, 367], [367, 416], [416, 364], [423, 358], [358, 327], [327, 423], + [251, 284], [284, 298], [298, 251], [281, 5], [5, 4], [4, 281], + [373, 374], [374, 253], [253, 373], [307, 320], [320, 321], [321, 307], + [425, 427], [427, 411], [411, 425], [421, 313], [313, 18], [18, 421], + [321, 405], [405, 406], [406, 321], [320, 404], [404, 405], [405, 320], + [315, 16], [16, 17], [17, 315], [426, 425], [425, 266], [266, 426], + [377, 400], [400, 369], [369, 377], [322, 391], [391, 269], [269, 322], + [417, 465], [465, 464], [464, 417], [386, 257], [257, 258], [258, 386], + [466, 260], [260, 388], [388, 466], [456, 399], [399, 419], [419, 456], + [284, 332], [332, 333], [333, 284], [417, 285], [285, 8], [8, 417], + [346, 340], [340, 261], [261, 346], [413, 441], [441, 285], [285, 413], + [327, 460], [460, 328], [328, 327], [355, 371], [371, 329], [329, 355], + [392, 439], [439, 438], [438, 392], [382, 341], [341, 256], [256, 382], + [429, 420], [420, 360], [360, 429], [364, 394], [394, 379], [379, 364], + [277, 343], [343, 437], [437, 277], [443, 444], [444, 283], [283, 443], + [275, 440], [440, 363], [363, 275], [431, 262], [262, 369], [369, 431], + [297, 338], [338, 337], [337, 297], [273, 375], [375, 321], [321, 273], + [450, 451], [451, 349], [349, 450], [446, 342], [342, 467], [467, 446], + [293, 334], [334, 282], [282, 293], [458, 461], [461, 462], [462, 458], + [276, 353], [353, 383], [383, 276], [308, 324], [324, 325], [325, 308], + [276, 300], [300, 293], [293, 276], [372, 345], [345, 447], [447, 372], + [352, 345], [345, 340], [340, 352], [274, 1], [1, 19], [19, 274], + [456, 248], [248, 281], [281, 456], [436, 427], [427, 425], [425, 436], + [381, 256], [256, 252], [252, 381], [269, 391], [391, 393], [393, 269], + [200, 199], [199, 428], [428, 200], [266, 330], [330, 329], [329, 266], + [287, 273], [273, 422], [422, 287], [250, 462], [462, 328], [328, 250], + [258, 286], [286, 384], [384, 258], [265, 353], [353, 342], [342, 265], + [387, 259], [259, 257], [257, 387], [424, 431], [431, 430], [430, 424], + [342, 353], [353, 276], [276, 342], [273, 335], [335, 424], [424, 273], + [292, 325], [325, 307], [307, 292], [366, 447], [447, 345], [345, 366], + [271, 303], [303, 302], [302, 271], [423, 266], [266, 371], [371, 423], + [294, 455], [455, 460], [460, 294], [279, 278], [278, 294], [294, 279], + [271, 272], [272, 304], [304, 271], [432, 434], [434, 427], [427, 432], + [272, 407], [407, 408], [408, 272], [394, 430], [430, 431], [431, 394], + [395, 369], [369, 400], [400, 395], [334, 333], [333, 299], [299, 334], + [351, 417], [417, 168], [168, 351], [352, 280], [280, 411], [411, 352], + [325, 319], [319, 320], [320, 325], [295, 296], [296, 336], [336, 295], + [319, 403], [403, 404], [404, 319], [330, 348], [348, 349], [349, 330], + [293, 298], [298, 333], [333, 293], [323, 454], [454, 447], [447, 323], + [15, 16], [16, 315], [315, 15], [358, 429], [429, 279], [279, 358], + [14, 15], [15, 316], [316, 14], [285, 336], [336, 9], [9, 285], + [329, 349], [349, 350], [350, 329], [374, 380], [380, 252], [252, 374], + [318, 402], [402, 403], [403, 318], [6, 197], [197, 419], [419, 6], + [318, 319], [319, 325], [325, 318], [367, 364], [364, 365], [365, 367], + [435, 367], [367, 397], [397, 435], [344, 438], [438, 439], [439, 344], + [272, 271], [271, 311], [311, 272], [195, 5], [5, 281], [281, 195], + [273, 287], [287, 291], [291, 273], [396, 428], [428, 199], [199, 396], + [311, 271], [271, 268], [268, 311], [283, 444], [444, 445], [445, 283], + [373, 254], [254, 339], [339, 373], [282, 334], [334, 296], [296, 282], + [449, 347], [347, 346], [346, 449], [264, 447], [447, 454], [454, 264], + [336, 296], [296, 299], [299, 336], [338, 10], [10, 151], [151, 338], + [278, 439], [439, 455], [455, 278], [292, 407], [407, 415], [415, 292], + [358, 371], [371, 355], [355, 358], [340, 345], [345, 372], [372, 340], + [346, 347], [347, 280], [280, 346], [442, 443], [443, 282], [282, 442], + [19, 94], [94, 370], [370, 19], [441, 442], [442, 295], [295, 441], + [248, 419], [419, 197], [197, 248], [263, 255], [255, 359], [359, 263], + [440, 275], [275, 274], [274, 440], [300, 383], [383, 368], [368, 300], + [351, 412], [412, 465], [465, 351], [263, 467], [467, 466], [466, 263], + [301, 368], [368, 389], [389, 301], [395, 378], [378, 379], [379, 395], + [412, 351], [351, 419], [419, 412], [436, 426], [426, 322], [322, 436], + [2, 164], [164, 393], [393, 2], [370, 462], [462, 461], [461, 370], + [164, 0], [0, 267], [267, 164], [302, 11], [11, 12], [12, 302], + [268, 12], [12, 13], [13, 268], [293, 300], [300, 301], [301, 293], + [446, 261], [261, 340], [340, 446], [330, 266], [266, 425], [425, 330], + [426, 423], [423, 391], [391, 426], [429, 355], [355, 437], [437, 429], + [391, 327], [327, 326], [326, 391], [440, 457], [457, 438], [438, 440], + [341, 382], [382, 362], [362, 341], [459, 457], [457, 461], [461, 459], + [434, 430], [430, 394], [394, 434], [414, 463], [463, 362], [362, 414], + [396, 369], [369, 262], [262, 396], [354, 461], [461, 457], [457, 354], + [316, 403], [403, 402], [402, 316], [315, 404], [404, 403], [403, 315], + [314, 405], [405, 404], [404, 314], [313, 406], [406, 405], [405, 313], + [421, 418], [418, 406], [406, 421], [366, 401], [401, 361], [361, 366], + [306, 408], [408, 407], [407, 306], [291, 409], [409, 408], [408, 291], + [287, 410], [410, 409], [409, 287], [432, 436], [436, 410], [410, 432], + [434, 416], [416, 411], [411, 434], [264, 368], [368, 383], [383, 264], + [309, 438], [438, 457], [457, 309], [352, 376], [376, 401], [401, 352], + [274, 275], [275, 4], [4, 274], [421, 428], [428, 262], [262, 421], + [294, 327], [327, 358], [358, 294], [433, 416], [416, 367], [367, 433], + [289, 455], [455, 439], [439, 289], [462, 370], [370, 326], [326, 462], + [2, 326], [326, 370], [370, 2], [305, 460], [460, 455], [455, 305], + [254, 449], [449, 448], [448, 254], [255, 261], [261, 446], [446, 255], + [253, 450], [450, 449], [449, 253], [252, 451], [451, 450], [450, 252], + [256, 452], [452, 451], [451, 256], [341, 453], [453, 452], [452, 341], + [413, 464], [464, 463], [463, 413], [441, 413], [413, 414], [414, 441], + [258, 442], [442, 441], [441, 258], [257, 443], [443, 442], [442, 257], + [259, 444], [444, 443], [443, 259], [260, 445], [445, 444], [444, 260], + [467, 342], [342, 445], [445, 467], [459, 458], [458, 250], [250, 459], + [289, 392], [392, 290], [290, 289], [290, 328], [328, 460], [460, 290], + [376, 433], [433, 435], [435, 376], [250, 290], [290, 392], [392, 250], + [411, 416], [416, 433], [433, 411], [341, 463], [463, 464], [464, 341], + [453, 464], [464, 465], [465, 453], [357, 465], [465, 412], [412, 357], + [343, 412], [412, 399], [399, 343], [360, 363], [363, 440], [440, 360], + [437, 399], [399, 456], [456, 437], [420, 456], [456, 363], [363, 420], + [401, 435], [435, 288], [288, 401], [372, 383], [383, 353], [353, 372], + [339, 255], [255, 249], [249, 339], [448, 261], [261, 255], [255, 448], + [133, 243], [243, 190], [190, 133], [133, 155], [155, 112], [112, 133], + [33, 246], [246, 247], [247, 33], [33, 130], [130, 25], [25, 33], + [398, 384], [384, 286], [286, 398], [362, 398], [398, 414], [414, 362], + [362, 463], [463, 341], [341, 362], [263, 359], [359, 467], [467, 263], + [263, 249], [249, 255], [255, 263], [466, 467], [467, 260], [260, 466], + [75, 60], [60, 166], [166, 75], [238, 239], [239, 79], [79, 238], + [162, 127], [127, 139], [139, 162], [72, 11], [11, 37], [37, 72], + [121, 232], [232, 120], [120, 121], [73, 72], [72, 39], [39, 73], + [114, 128], [128, 47], [47, 114], [233, 232], [232, 128], [128, 233], + [103, 104], [104, 67], [67, 103], [152, 175], [175, 148], [148, 152], + [119, 118], [118, 101], [101, 119], [74, 73], [73, 40], [40, 74], + [107, 9], [9, 108], [108, 107], [49, 48], [48, 131], [131, 49], + [32, 194], [194, 211], [211, 32], [184, 74], [74, 185], [185, 184], + [191, 80], [80, 183], [183, 191], [185, 40], [40, 186], [186, 185], + [119, 230], [230, 118], [118, 119], [210, 202], [202, 214], [214, 210], + [84, 83], [83, 17], [17, 84], [77, 76], [76, 146], [146, 77], + [161, 160], [160, 30], [30, 161], [190, 56], [56, 173], [173, 190], + [182, 106], [106, 194], [194, 182], [138, 135], [135, 192], [192, 138], + [129, 203], [203, 98], [98, 129], [54, 21], [21, 68], [68, 54], + [5, 51], [51, 4], [4, 5], [145, 144], [144, 23], [23, 145], + [90, 77], [77, 91], [91, 90], [207, 205], [205, 187], [187, 207], + [83, 201], [201, 18], [18, 83], [181, 91], [91, 182], [182, 181], + [180, 90], [90, 181], [181, 180], [16, 85], [85, 17], [17, 16], + [205, 206], [206, 36], [36, 205], [176, 148], [148, 140], [140, 176], + [165, 92], [92, 39], [39, 165], [245, 193], [193, 244], [244, 245], + [27, 159], [159, 28], [28, 27], [30, 247], [247, 161], [161, 30], + [174, 236], [236, 196], [196, 174], [103, 54], [54, 104], [104, 103], + [55, 193], [193, 8], [8, 55], [111, 117], [117, 31], [31, 111], + [221, 189], [189, 55], [55, 221], [240, 98], [98, 99], [99, 240], + [142, 126], [126, 100], [100, 142], [219, 166], [166, 218], [218, 219], + [112, 155], [155, 26], [26, 112], [198, 209], [209, 131], [131, 198], + [169, 135], [135, 150], [150, 169], [114, 47], [47, 217], [217, 114], + [224, 223], [223, 53], [53, 224], [220, 45], [45, 134], [134, 220], + [32, 211], [211, 140], [140, 32], [109, 67], [67, 108], [108, 109], + [146, 43], [43, 91], [91, 146], [231, 230], [230, 120], [120, 231], + [113, 226], [226, 247], [247, 113], [105, 63], [63, 52], [52, 105], + [241, 238], [238, 242], [242, 241], [124, 46], [46, 156], [156, 124], + [95, 78], [78, 96], [96, 95], [70, 46], [46, 63], [63, 70], + [116, 143], [143, 227], [227, 116], [116, 123], [123, 111], [111, 116], + [1, 44], [44, 19], [19, 1], [3, 236], [236, 51], [51, 3], + [207, 216], [216, 205], [205, 207], [26, 154], [154, 22], [22, 26], + [165, 39], [39, 167], [167, 165], [199, 200], [200, 208], [208, 199], + [101, 36], [36, 100], [100, 101], [43, 57], [57, 202], [202, 43], + [242, 20], [20, 99], [99, 242], [56, 28], [28, 157], [157, 56], + [124, 35], [35, 113], [113, 124], [29, 160], [160, 27], [27, 29], + [211, 204], [204, 210], [210, 211], [124, 113], [113, 46], [46, 124], + [106, 43], [43, 204], [204, 106], [96, 62], [62, 77], [77, 96], + [227, 137], [137, 116], [116, 227], [73, 41], [41, 72], [72, 73], + [36, 203], [203, 142], [142, 36], [235, 64], [64, 240], [240, 235], + [48, 49], [49, 64], [64, 48], [42, 41], [41, 74], [74, 42], + [214, 212], [212, 207], [207, 214], [183, 42], [42, 184], [184, 183], + [210, 169], [169, 211], [211, 210], [140, 170], [170, 176], [176, 140], + [104, 105], [105, 69], [69, 104], [193, 122], [122, 168], [168, 193], + [50, 123], [123, 187], [187, 50], [89, 96], [96, 90], [90, 89], + [66, 65], [65, 107], [107, 66], [179, 89], [89, 180], [180, 179], + [119, 101], [101, 120], [120, 119], [68, 63], [63, 104], [104, 68], + [234, 93], [93, 227], [227, 234], [16, 15], [15, 85], [85, 16], + [209, 129], [129, 49], [49, 209], [15, 14], [14, 86], [86, 15], + [107, 55], [55, 9], [9, 107], [120, 100], [100, 121], [121, 120], + [153, 145], [145, 22], [22, 153], [178, 88], [88, 179], [179, 178], + [197, 6], [6, 196], [196, 197], [89, 88], [88, 96], [96, 89], + [135, 138], [138, 136], [136, 135], [138, 215], [215, 172], [172, 138], + [218, 115], [115, 219], [219, 218], [41, 42], [42, 81], [81, 41], + [5, 195], [195, 51], [51, 5], [57, 43], [43, 61], [61, 57], + [208, 171], [171, 199], [199, 208], [41, 81], [81, 38], [38, 41], + [224, 53], [53, 225], [225, 224], [24, 144], [144, 110], [110, 24], + [105, 52], [52, 66], [66, 105], [118, 229], [229, 117], [117, 118], + [227, 34], [34, 234], [234, 227], [66, 107], [107, 69], [69, 66], + [10, 109], [109, 151], [151, 10], [219, 48], [48, 235], [235, 219], + [183, 62], [62, 191], [191, 183], [142, 129], [129, 126], [126, 142], + [116, 111], [111, 143], [143, 116], [118, 117], [117, 50], [50, 118], + [223, 222], [222, 52], [52, 223], [94, 19], [19, 141], [141, 94], + [222, 221], [221, 65], [65, 222], [196, 3], [3, 197], [197, 196], + [45, 220], [220, 44], [44, 45], [156, 70], [70, 139], [139, 156], + [188, 122], [122, 245], [245, 188], [139, 71], [71, 162], [162, 139], + [149, 170], [170, 150], [150, 149], [122, 188], [188, 196], [196, 122], + [206, 216], [216, 92], [92, 206], [164, 2], [2, 167], [167, 164], + [242, 141], [141, 241], [241, 242], [0, 164], [164, 37], [37, 0], + [11, 72], [72, 12], [12, 11], [12, 38], [38, 13], [13, 12], + [70, 63], [63, 71], [71, 70], [31, 226], [226, 111], [111, 31], + [36, 101], [101, 205], [205, 36], [203, 206], [206, 165], [165, 203], + [126, 209], [209, 217], [217, 126], [98, 165], [165, 97], [97, 98], + [237, 220], [220, 218], [218, 237], [237, 239], [239, 241], [241, 237], + [210, 214], [214, 169], [169, 210], [140, 171], [171, 32], [32, 140], + [241, 125], [125, 237], [237, 241], [179, 86], [86, 178], [178, 179], + [180, 85], [85, 179], [179, 180], [181, 84], [84, 180], [180, 181], + [182, 83], [83, 181], [181, 182], [194, 201], [201, 182], [182, 194], + [177, 137], [137, 132], [132, 177], [184, 76], [76, 183], [183, 184], + [185, 61], [61, 184], [184, 185], [186, 57], [57, 185], [185, 186], + [216, 212], [212, 186], [186, 216], [192, 214], [214, 187], [187, 192], + [139, 34], [34, 156], [156, 139], [218, 79], [79, 237], [237, 218], + [147, 123], [123, 177], [177, 147], [45, 44], [44, 4], [4, 45], + [208, 201], [201, 32], [32, 208], [98, 64], [64, 129], [129, 98], + [192, 213], [213, 138], [138, 192], [235, 59], [59, 219], [219, 235], + [141, 242], [242, 97], [97, 141], [97, 2], [2, 141], [141, 97], + [240, 75], [75, 235], [235, 240], [229, 24], [24, 228], [228, 229], + [31, 25], [25, 226], [226, 31], [230, 23], [23, 229], [229, 230], + [231, 22], [22, 230], [230, 231], [232, 26], [26, 231], [231, 232], + [233, 112], [112, 232], [232, 233], [244, 189], [189, 243], [243, 244], + [189, 221], [221, 190], [190, 189], [222, 28], [28, 221], [221, 222], + [223, 27], [27, 222], [222, 223], [224, 29], [29, 223], [223, 224], + [225, 30], [30, 224], [224, 225], [113, 247], [247, 225], [225, 113], + [99, 60], [60, 240], [240, 99], [213, 147], [147, 215], [215, 213], + [60, 20], [20, 166], [166, 60], [192, 187], [187, 213], [213, 192], + [243, 112], [112, 244], [244, 243], [244, 233], [233, 245], [245, 244], + [245, 128], [128, 188], [188, 245], [188, 114], [114, 174], [174, 188], + [134, 131], [131, 220], [220, 134], [174, 217], [217, 236], [236, 174], + [236, 198], [198, 134], [134, 236], [215, 177], [177, 58], [58, 215], + [156, 143], [143, 124], [124, 156], [25, 110], [110, 7], [7, 25], + [31, 228], [228, 25], [25, 31], [264, 356], [356, 368], [368, 264], + [0, 11], [11, 267], [267, 0], [451, 452], [452, 349], [349, 451], + [267, 302], [302, 269], [269, 267], [350, 357], [357, 277], [277, 350], + [350, 452], [452, 357], [357, 350], [299, 333], [333, 297], [297, 299], + [396, 175], [175, 377], [377, 396], [280, 347], [347, 330], [330, 280], + [269, 303], [303, 270], [270, 269], [151, 9], [9, 337], [337, 151], + [344, 278], [278, 360], [360, 344], [424, 418], [418, 431], [431, 424], + [270, 304], [304, 409], [409, 270], [272, 310], [310, 407], [407, 272], + [322, 270], [270, 410], [410, 322], [449, 450], [450, 347], [347, 449], + [432, 422], [422, 434], [434, 432], [18, 313], [313, 17], [17, 18], + [291, 306], [306, 375], [375, 291], [259, 387], [387, 260], [260, 259], + [424, 335], [335, 418], [418, 424], [434, 364], [364, 416], [416, 434], + [391, 423], [423, 327], [327, 391], [301, 251], [251, 298], [298, 301], + [275, 281], [281, 4], [4, 275], [254, 373], [373, 253], [253, 254], + [375, 307], [307, 321], [321, 375], [280, 425], [425, 411], [411, 280], + [200, 421], [421, 18], [18, 200], [335, 321], [321, 406], [406, 335], + [321, 320], [320, 405], [405, 321], [314, 315], [315, 17], [17, 314], + [423, 426], [426, 266], [266, 423], [396, 377], [377, 369], [369, 396], + [270, 322], [322, 269], [269, 270], [413, 417], [417, 464], [464, 413], + [385, 386], [386, 258], [258, 385], [248, 456], [456, 419], [419, 248], + [298, 284], [284, 333], [333, 298], [168, 417], [417, 8], [8, 168], + [448, 346], [346, 261], [261, 448], [417, 413], [413, 285], [285, 417], + [326, 327], [327, 328], [328, 326], [277, 355], [355, 329], [329, 277], + [309, 392], [392, 438], [438, 309], [381, 382], [382, 256], [256, 381], + [279, 429], [429, 360], [360, 279], [365, 364], [364, 379], [379, 365], + [355, 277], [277, 437], [437, 355], [282, 443], [443, 283], [283, 282], + [281, 275], [275, 363], [363, 281], [395, 431], [431, 369], [369, 395], + [299, 297], [297, 337], [337, 299], [335, 273], [273, 321], [321, 335], + [348, 450], [450, 349], [349, 348], [359, 446], [446, 467], [467, 359], + [283, 293], [293, 282], [282, 283], [250, 458], [458, 462], [462, 250], + [300, 276], [276, 383], [383, 300], [292, 308], [308, 325], [325, 292], + [283, 276], [276, 293], [293, 283], [264, 372], [372, 447], [447, 264], + [346, 352], [352, 340], [340, 346], [354, 274], [274, 19], [19, 354], + [363, 456], [456, 281], [281, 363], [426, 436], [436, 425], [425, 426], + [380, 381], [381, 252], [252, 380], [267, 269], [269, 393], [393, 267], + [421, 200], [200, 428], [428, 421], [371, 266], [266, 329], [329, 371], + [432, 287], [287, 422], [422, 432], [290, 250], [250, 328], [328, 290], + [385, 258], [258, 384], [384, 385], [446, 265], [265, 342], [342, 446], + [386, 387], [387, 257], [257, 386], [422, 424], [424, 430], [430, 422], + [445, 342], [342, 276], [276, 445], [422, 273], [273, 424], [424, 422], + [306, 292], [292, 307], [307, 306], [352, 366], [366, 345], [345, 352], + [268, 271], [271, 302], [302, 268], [358, 423], [423, 371], [371, 358], + [327, 294], [294, 460], [460, 327], [331, 279], [279, 294], [294, 331], + [303, 271], [271, 304], [304, 303], [436, 432], [432, 427], [427, 436], + [304, 272], [272, 408], [408, 304], [395, 394], [394, 431], [431, 395], + [378, 395], [395, 400], [400, 378], [296, 334], [334, 299], [299, 296], + [6, 351], [351, 168], [168, 6], [376, 352], [352, 411], [411, 376], + [307, 325], [325, 320], [320, 307], [285, 295], [295, 336], [336, 285], + [320, 319], [319, 404], [404, 320], [329, 330], [330, 349], [349, 329], + [334, 293], [293, 333], [333, 334], [366, 323], [323, 447], [447, 366], + [316, 15], [15, 315], [315, 316], [331, 358], [358, 279], [279, 331], + [317, 14], [14, 316], [316, 317], [8, 285], [285, 9], [9, 8], + [277, 329], [329, 350], [350, 277], [253, 374], [374, 252], [252, 253], + [319, 318], [318, 403], [403, 319], [351, 6], [6, 419], [419, 351], + [324, 318], [318, 325], [325, 324], [397, 367], [367, 365], [365, 397], + [288, 435], [435, 397], [397, 288], [278, 344], [344, 439], [439, 278], + [310, 272], [272, 311], [311, 310], [248, 195], [195, 281], [281, 248], + [375, 273], [273, 291], [291, 375], [175, 396], [396, 199], [199, 175], + [312, 311], [311, 268], [268, 312], [276, 283], [283, 445], [445, 276], + [390, 373], [373, 339], [339, 390], [295, 282], [282, 296], [296, 295], + [448, 449], [449, 346], [346, 448], [356, 264], [264, 454], [454, 356], + [337, 336], [336, 299], [299, 337], [337, 338], [338, 151], [151, 337], + [294, 278], [278, 455], [455, 294], [308, 292], [292, 415], [415, 308], + [429, 358], [358, 355], [355, 429], [265, 340], [340, 372], [372, 265], + [352, 346], [346, 280], [280, 352], [295, 442], [442, 282], [282, 295], + [354, 19], [19, 370], [370, 354], [285, 441], [441, 295], [295, 285], + [195, 248], [248, 197], [197, 195], [457, 440], [440, 274], [274, 457], + [301, 300], [300, 368], [368, 301], [417, 351], [351, 465], [465, 417], + [251, 301], [301, 389], [389, 251], [394, 395], [395, 379], [379, 394], + [399, 412], [412, 419], [419, 399], [410, 436], [436, 322], [322, 410], + [326, 2], [2, 393], [393, 326], [354, 370], [370, 461], [461, 354], + [393, 164], [164, 267], [267, 393], [268, 302], [302, 12], [12, 268], + [312, 268], [268, 13], [13, 312], [298, 293], [293, 301], [301, 298], + [265, 446], [446, 340], [340, 265], [280, 330], [330, 425], [425, 280], + [322, 426], [426, 391], [391, 322], [420, 429], [429, 437], [437, 420], + [393, 391], [391, 326], [326, 393], [344, 440], [440, 438], [438, 344], + [458, 459], [459, 461], [461, 458], [364, 434], [434, 394], [394, 364], + [428, 396], [396, 262], [262, 428], [274, 354], [354, 457], [457, 274], + [317, 316], [316, 402], [402, 317], [316, 315], [315, 403], [403, 316], + [315, 314], [314, 404], [404, 315], [314, 313], [313, 405], [405, 314], + [313, 421], [421, 406], [406, 313], [323, 366], [366, 361], [361, 323], + [292, 306], [306, 407], [407, 292], [306, 291], [291, 408], [408, 306], + [291, 287], [287, 409], [409, 291], [287, 432], [432, 410], [410, 287], + [427, 434], [434, 411], [411, 427], [372, 264], [264, 383], [383, 372], + [459, 309], [309, 457], [457, 459], [366, 352], [352, 401], [401, 366], + [1, 274], [274, 4], [4, 1], [418, 421], [421, 262], [262, 418], + [331, 294], [294, 358], [358, 331], [435, 433], [433, 367], [367, 435], + [392, 289], [289, 439], [439, 392], [328, 462], [462, 326], [326, 328], + [94, 2], [2, 370], [370, 94], [289, 305], [305, 455], [455, 289], + [339, 254], [254, 448], [448, 339], [359, 255], [255, 446], [446, 359], + [254, 253], [253, 449], [449, 254], [253, 252], [252, 450], [450, 253], + [252, 256], [256, 451], [451, 252], [256, 341], [341, 452], [452, 256], + [414, 413], [413, 463], [463, 414], [286, 441], [441, 414], [414, 286], + [286, 258], [258, 441], [441, 286], [258, 257], [257, 442], [442, 258], + [257, 259], [259, 443], [443, 257], [259, 260], [260, 444], [444, 259], + [260, 467], [467, 445], [445, 260], [309, 459], [459, 250], [250, 309], + [305, 289], [289, 290], [290, 305], [305, 290], [290, 460], [460, 305], + [401, 376], [376, 435], [435, 401], [309, 250], [250, 392], [392, 309], + [376, 411], [411, 433], [433, 376], [453, 341], [341, 464], [464, 453], + [357, 453], [453, 465], [465, 357], [343, 357], [357, 412], [412, 343], + [437, 343], [343, 399], [399, 437], [344, 360], [360, 440], [440, 344], + [420, 437], [437, 456], [456, 420], [360, 420], [420, 363], [363, 360], + [361, 401], [401, 288], [288, 361], [265, 372], [372, 353], [353, 265], + [390, 339], [339, 249], [249, 390], [339, 448], [448, 255], [255, 339]])) diff --git a/client/src/hooks/useFaceLandmarkDetection.tsx b/client/src/hooks/useFaceLandmarkDetection.tsx new file mode 100644 index 0000000000000000000000000000000000000000..ace35a6cc0ca01e55c1aa39948cb23227f9d0aab --- /dev/null +++ b/client/src/hooks/useFaceLandmarkDetection.tsx @@ -0,0 +1,632 @@ +import React, { useState, useEffect, useRef, useCallback } from 'react'; +import * as vision from '@mediapipe/tasks-vision'; + +import { facePoke } from '@/lib/facePoke'; +import { useMainStore } from './useMainStore'; +import useThrottledCallback from 'beautiful-react-hooks/useThrottledCallback'; + +import { landmarkGroups, FACEMESH_LIPS, FACEMESH_LEFT_EYE, FACEMESH_LEFT_EYEBROW, FACEMESH_RIGHT_EYE, FACEMESH_RIGHT_EYEBROW, FACEMESH_FACE_OVAL } from './landmarks'; + +// New types for improved type safety +export type LandmarkGroup = 'lips' | 'leftEye' | 'leftEyebrow' | 'rightEye' | 'rightEyebrow' | 'faceOval' | 'background'; +export type LandmarkCenter = { x: number; y: number; z: number }; +export type ClosestLandmark = { group: LandmarkGroup; distance: number; vector: { x: number; y: number; z: number } }; + +export type MediaPipeResources = { + faceLandmarker: vision.FaceLandmarker | null; + drawingUtils: vision.DrawingUtils | null; +}; + +export function useFaceLandmarkDetection() { + const error = useMainStore(s => s.error); + const setError = useMainStore(s => s.setError); + const imageFile = useMainStore(s => s.imageFile); + const setImageFile = useMainStore(s => s.setImageFile); + const originalImage = useMainStore(s => s.originalImage); + const originalImageHash = useMainStore(s => s.originalImageHash); + const setOriginalImageHash = useMainStore(s => s.setOriginalImageHash); + const previewImage = useMainStore(s => s.previewImage); + const setPreviewImage = useMainStore(s => s.setPreviewImage); + const resetImage = useMainStore(s => s.resetImage); + + ;(window as any).debugJuju = useMainStore; + //////////////////////////////////////////////////////////////////////// + // ok so apparently I cannot vary the latency, or else there is a bug + // const averageLatency = useMainStore(s => s.averageLatency); + const averageLatency = 220 + //////////////////////////////////////////////////////////////////////// + + // State for face detection + const [faceLandmarks, setFaceLandmarks] = useState([]); + const [isMediaPipeReady, setIsMediaPipeReady] = useState(false); + const [isDrawingUtilsReady, setIsDrawingUtilsReady] = useState(false); + const [blendShapes, setBlendShapes] = useState([]); + + // State for mouse interaction + const [dragStart, setDragStart] = useState<{ x: number; y: number } | null>(null); + const [dragEnd, setDragEnd] = useState<{ x: number; y: number } | null>(null); + + const [isDragging, setIsDragging] = useState(false); + const [isWaitingForResponse, setIsWaitingForResponse] = useState(false); + const dragStartRef = useRef<{ x: number; y: number } | null>(null); + const currentMousePosRef = useRef<{ x: number; y: number } | null>(null); + const lastModifiedImageHashRef = useRef(null); + + const [currentLandmark, setCurrentLandmark] = useState(null); + const [previousLandmark, setPreviousLandmark] = useState(null); + const [currentOpacity, setCurrentOpacity] = useState(0); + const [previousOpacity, setPreviousOpacity] = useState(0); + + const [isHovering, setIsHovering] = useState(false); + + // Refs + const canvasRef = useRef(null); + const mediaPipeRef = useRef({ + faceLandmarker: null, + drawingUtils: null, + }); + + const setActiveLandmark = useCallback((newLandmark: ClosestLandmark | undefined) => { + //if (newLandmark && (!currentLandmark || newLandmark.group !== currentLandmark.group)) { + setPreviousLandmark(currentLandmark || null); + setCurrentLandmark(newLandmark || null); + setCurrentOpacity(0); + setPreviousOpacity(1); + //} + }, [currentLandmark, setPreviousLandmark, setCurrentLandmark, setCurrentOpacity, setPreviousOpacity]); + + // Initialize MediaPipe + useEffect(() => { + console.log('Initializing MediaPipe...'); + let isMounted = true; + + const initializeMediaPipe = async () => { + const { FaceLandmarker, FilesetResolver, DrawingUtils } = vision; + + try { + console.log('Initializing FilesetResolver...'); + const filesetResolver = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.3/wasm" + ); + + console.log('Creating FaceLandmarker...'); + const faceLandmarker = await FaceLandmarker.createFromOptions(filesetResolver, { + baseOptions: { + modelAssetPath: `https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task`, + delegate: "GPU" + }, + outputFaceBlendshapes: true, + runningMode: "IMAGE", + numFaces: 1 + }); + + if (isMounted) { + console.log('FaceLandmarker created successfully.'); + mediaPipeRef.current.faceLandmarker = faceLandmarker; + setIsMediaPipeReady(true); + } else { + faceLandmarker.close(); + } + } catch (error) { + console.error('Error during MediaPipe initialization:', error); + setError('Failed to initialize face detection. Please try refreshing the page.'); + } + }; + + initializeMediaPipe(); + + + return () => { + isMounted = false; + if (mediaPipeRef.current.faceLandmarker) { + mediaPipeRef.current.faceLandmarker.close(); + } + }; + }, []); + + // New state for storing landmark centers + const [landmarkCenters, setLandmarkCenters] = useState>({} as Record); + + // Function to compute the center of each landmark group + const computeLandmarkCenters = useCallback((landmarks: vision.NormalizedLandmark[]) => { + const centers: Record = {} as Record; + + const computeGroupCenter = (group: Readonly>): LandmarkCenter => { + let sumX = 0, sumY = 0, sumZ = 0, count = 0; + group.forEach(([index]) => { + if (landmarks[index]) { + sumX += landmarks[index].x; + sumY += landmarks[index].y; + sumZ += landmarks[index].z || 0; + count++; + } + }); + return { x: sumX / count, y: sumY / count, z: sumZ / count }; + }; + + centers.lips = computeGroupCenter(FACEMESH_LIPS); + centers.leftEye = computeGroupCenter(FACEMESH_LEFT_EYE); + centers.leftEyebrow = computeGroupCenter(FACEMESH_LEFT_EYEBROW); + centers.rightEye = computeGroupCenter(FACEMESH_RIGHT_EYE); + centers.rightEyebrow = computeGroupCenter(FACEMESH_RIGHT_EYEBROW); + centers.faceOval = computeGroupCenter(FACEMESH_FACE_OVAL); + centers.background = { x: 0.5, y: 0.5, z: 0 }; + + setLandmarkCenters(centers); + // console.log('Landmark centers computed:', centers); + }, []); + + // Function to find the closest landmark to the mouse position + const findClosestLandmark = useCallback((mouseX: number, mouseY: number, isGroup?: LandmarkGroup): ClosestLandmark => { + const defaultLandmark: ClosestLandmark = { + group: 'background', + distance: 0, + vector: { + x: mouseX, + y: mouseY, + z: 0 + } + } + + if (Object.keys(landmarkCenters).length === 0) { + console.warn('Landmark centers not computed yet'); + return defaultLandmark; + } + + let closestGroup: LandmarkGroup | null = null; + let minDistance = Infinity; + let closestVector = { x: 0, y: 0, z: 0 }; + let faceOvalDistance = Infinity; + let faceOvalVector = { x: 0, y: 0, z: 0 }; + + Object.entries(landmarkCenters).forEach(([group, center]) => { + const dx = mouseX - center.x; + const dy = mouseY - center.y; + const distance = Math.sqrt(dx * dx + dy * dy); + + if (group === 'faceOval') { + faceOvalDistance = distance; + faceOvalVector = { x: dx, y: dy, z: 0 }; + } + + // filter to keep the group if it is belonging to `ofGroup` + if (isGroup) { + if (group !== isGroup) { + return + } + } + + if (distance < minDistance) { + minDistance = distance; + closestGroup = group as LandmarkGroup; + closestVector = { x: dx, y: dy, z: 0 }; // Z is 0 as mouse interaction is 2D + } + }); + + // Fallback to faceOval if no group found or distance is too large + if (minDistance > 0.05) { + // console.log('Distance is too high, so we use the faceOval group'); + closestGroup = 'background'; + minDistance = faceOvalDistance; + closestVector = faceOvalVector; + } + + if (closestGroup) { + // console.log(`Closest landmark: ${closestGroup}, distance: ${minDistance.toFixed(4)}`); + return { group: closestGroup, distance: minDistance, vector: closestVector }; + } else { + // console.log('No group found, returning fallback'); + return defaultLandmark + } + }, [landmarkCenters]); + + // Detect face landmarks + const detectFaceLandmarks = useCallback(async (imageDataUrl: string) => { + // console.log('Attempting to detect face landmarks...'); + if (!isMediaPipeReady) { + console.log('MediaPipe not ready. Skipping detection.'); + return; + } + + const faceLandmarker = mediaPipeRef.current.faceLandmarker; + + if (!faceLandmarker) { + console.error('FaceLandmarker is not initialized.'); + return; + } + + const drawingUtils = mediaPipeRef.current.drawingUtils; + + const image = new Image(); + image.src = imageDataUrl; + await new Promise((resolve) => { image.onload = resolve; }); + + const faceLandmarkerResult = faceLandmarker.detect(image); + // console.log("Face landmarks detected:", faceLandmarkerResult); + + setFaceLandmarks(faceLandmarkerResult.faceLandmarks); + setBlendShapes(faceLandmarkerResult.faceBlendshapes || []); + + if (faceLandmarkerResult.faceLandmarks && faceLandmarkerResult.faceLandmarks[0]) { + computeLandmarkCenters(faceLandmarkerResult.faceLandmarks[0]); + } + + if (canvasRef.current && drawingUtils) { + drawLandmarks(faceLandmarkerResult.faceLandmarks[0], canvasRef.current, drawingUtils); + } + }, [isMediaPipeReady, isDrawingUtilsReady, computeLandmarkCenters]); + + const drawLandmarks = useCallback(( + landmarks: vision.NormalizedLandmark[], + canvas: HTMLCanvasElement, + drawingUtils: vision.DrawingUtils + ) => { + const ctx = canvas.getContext('2d'); + if (!ctx) return; + + ctx.clearRect(0, 0, canvas.width, canvas.height); + + if (canvasRef.current && previewImage) { + const img = new Image(); + img.onload = () => { + canvas.width = img.width; + canvas.height = img.height; + + const drawLandmarkGroup = (landmark: ClosestLandmark | null, opacity: number) => { + if (!landmark) return; + const connections = landmarkGroups[landmark.group]; + if (connections) { + ctx.globalAlpha = opacity; + drawingUtils.drawConnectors( + landmarks, + connections, + { color: 'orange', lineWidth: 4 } + ); + } + }; + + drawLandmarkGroup(previousLandmark, previousOpacity); + drawLandmarkGroup(currentLandmark, currentOpacity); + + ctx.globalAlpha = 1; + }; + img.src = previewImage; + } + }, [previewImage, currentLandmark, previousLandmark, currentOpacity, previousOpacity]); + + useEffect(() => { + if (isMediaPipeReady && isDrawingUtilsReady && faceLandmarks.length > 0 && canvasRef.current && mediaPipeRef.current.drawingUtils) { + drawLandmarks(faceLandmarks[0], canvasRef.current, mediaPipeRef.current.drawingUtils); + } + }, [isMediaPipeReady, isDrawingUtilsReady, faceLandmarks, currentLandmark, previousLandmark, currentOpacity, previousOpacity, drawLandmarks]); + useEffect(() => { + let animationFrame: number; + const animate = () => { + setCurrentOpacity((prev) => Math.min(prev + 0.2, 1)); + setPreviousOpacity((prev) => Math.max(prev - 0.2, 0)); + + if (currentOpacity < 1 || previousOpacity > 0) { + animationFrame = requestAnimationFrame(animate); + } + }; + animationFrame = requestAnimationFrame(animate); + return () => cancelAnimationFrame(animationFrame); + }, [currentLandmark]); + + // Canvas ref callback + const canvasRefCallback = useCallback((node: HTMLCanvasElement | null) => { + if (node !== null) { + const ctx = node.getContext('2d'); + if (ctx) { + // Get device pixel ratio + const pixelRatio = window.devicePixelRatio || 1; + + // Scale canvas based on the pixel ratio + node.width = node.clientWidth * pixelRatio; + node.height = node.clientHeight * pixelRatio; + ctx.scale(pixelRatio, pixelRatio); + + mediaPipeRef.current.drawingUtils = new vision.DrawingUtils(ctx); + setIsDrawingUtilsReady(true); + } else { + console.error('Failed to get 2D context from canvas.'); + } + canvasRef.current = node; + } + }, []); + + + useEffect(() => { + if (!isMediaPipeReady) { + console.log('MediaPipe not ready. Skipping landmark detection.'); + return + } + if (!previewImage) { + console.log('Preview image not ready. Skipping landmark detection.'); + return + } + if (!isDrawingUtilsReady) { + console.log('DrawingUtils not ready. Skipping landmark detection.'); + return + } + detectFaceLandmarks(previewImage); + }, [isMediaPipeReady, isDrawingUtilsReady, previewImage]) + + + + const modifyImage = useCallback(({ landmark, vector }: { + landmark: ClosestLandmark + vector: { x: number; y: number; z: number } + }) => { + + const { + originalImage, + originalImageHash, + params: previousParams, + setParams, + setError + } = useMainStore.getState() + + + if (!originalImage) { + console.error('Image file or facePoke not available'); + return; + } + + const params = { + ...previousParams + } + + const minX = -0.50; + const maxX = 0.50; + const minY = -0.50; + const maxY = 0.50; + + // Function to map a value from one range to another + const mapRange = (value: number, inMin: number, inMax: number, outMin: number, outMax: number): number => { + return Math.min(outMax, Math.max(outMin, ((value - inMin) * (outMax - outMin)) / (inMax - inMin) + outMin)); + }; + + console.log("modifyImage:", { + originalImage, + originalImageHash, + landmark, + vector, + minX, + maxX, + minY, + maxY, + }) + + // Map landmarks to ImageModificationParams + switch (landmark.group) { + case 'leftEye': + case 'rightEye': + // eyebrow (min: -20, max: 5, default: 0) + const eyesMin = 210 + const eyesMax = 5 + params.eyes = mapRange(vector.x, minX, maxX, eyesMin, eyesMax); + + break; + case 'leftEyebrow': + case 'rightEyebrow': + // moving the mouse vertically for the eyebrow + // should make them up/down + // eyebrow (min: -10, max: 15, default: 0) + const eyebrowMin = -10 + const eyebrowMax = 15 + params.eyebrow = mapRange(vector.y, minY, maxY, eyebrowMin, eyebrowMax); + + break; + case 'lips': + // aaa (min: -30, max: 120, default: 0) + //const aaaMin = -30 + //const aaaMax = 120 + //params.aaa = mapRange(vector.x, minY, maxY, aaaMin, aaaMax); + + // eee (min: -20, max: 15, default: 0) + const eeeMin = -20 + const eeeMax = 15 + params.eee = mapRange(vector.y, minY, maxY, eeeMin, eeeMax); + + + // woo (min: -20, max: 15, default: 0) + const wooMin = -20 + const wooMax = 15 + params.woo = mapRange(vector.x, minX, maxX, wooMin, wooMax); + + break; + case 'faceOval': + // displacing the face horizontally by moving the mouse on the X axis + // should perform a yaw rotation + // rotate_roll (min: -20, max: 20, default: 0) + const rollMin = -40 + const rollMax = 40 + + // note: we invert the axis here + params.rotate_roll = mapRange(vector.x, minX, maxX, rollMin, rollMax); + break; + + case 'background': + // displacing the face horizontally by moving the mouse on the X axis + // should perform a yaw rotation + // rotate_yaw (min: -20, max: 20, default: 0) + const yawMin = -40 + const yawMax = 40 + + // note: we invert the axis here + params.rotate_yaw = mapRange(-vector.x, minX, maxX, yawMin, yawMax); + + // displacing the face vertically by moving the mouse on the Y axis + // should perform a pitch rotation + // rotate_pitch (min: -20, max: 20, default: 0) + const pitchMin = -40 + const pitchMax = 40 + params.rotate_pitch = mapRange(vector.y, minY, maxY, pitchMin, pitchMax); + break; + default: + return + } + + for (const [key, value] of Object.entries(params)) { + if (isNaN(value as any) || !isFinite(value as any)) { + console.log(`${key} is NaN, aborting`) + return + } + } + console.log(`PITCH=${params.rotate_pitch || 0}, YAW=${params.rotate_yaw || 0}, ROLL=${params.rotate_roll || 0}`); + + setParams(params) + try { + // For the first request or when the image file changes, send the full image + if (!lastModifiedImageHashRef.current || lastModifiedImageHashRef.current !== originalImageHash) { + lastModifiedImageHashRef.current = originalImageHash; + facePoke.modifyImage(originalImage, null, params); + } else { + // For subsequent requests, send only the hash + facePoke.modifyImage(null, lastModifiedImageHashRef.current, params); + } + } catch (error) { + // console.error('Error modifying image:', error); + setError('Failed to modify image'); + } + }, []); + + // this is throttled by our average latency + const modifyImageWithRateLimit = useThrottledCallback((params: { + landmark: ClosestLandmark + vector: { x: number; y: number; z: number } + }) => { + modifyImage(params); + }, [modifyImage], averageLatency); + + const handleMouseEnter = useCallback(() => { + setIsHovering(true); + }, []); + + const handleMouseLeave = useCallback(() => { + setIsHovering(false); + }, []); + + // Update mouse event handlers + const handleMouseDown = useCallback((event: React.MouseEvent) => { + if (!canvasRef.current) return; + + const rect = canvasRef.current.getBoundingClientRect(); + const x = (event.clientX - rect.left) / rect.width; + const y = (event.clientY - rect.top) / rect.height; + + const landmark = findClosestLandmark(x, y); + console.log(`Mouse down on ${landmark.group}`); + setActiveLandmark(landmark); + setDragStart({ x, y }); + dragStartRef.current = { x, y }; + }, [findClosestLandmark, setActiveLandmark, setDragStart]); + + const handleMouseMove = useCallback((event: React.MouseEvent) => { + if (!canvasRef.current) return; + + const rect = canvasRef.current.getBoundingClientRect(); + const x = (event.clientX - rect.left) / rect.width; + const y = (event.clientY - rect.top) / rect.height; + + // only send an API request to modify the image if we are actively dragging + if (dragStart && dragStartRef.current) { + + const landmark = findClosestLandmark(x, y, currentLandmark?.group); + + console.log(`Dragging mouse (was over ${currentLandmark?.group || 'nothing'}, now over ${landmark.group})`); + + // Compute the vector from the landmark center to the current mouse position + modifyImageWithRateLimit({ + landmark: currentLandmark || landmark, // this will still use the initially selected landmark + vector: { + x: x - landmarkCenters[landmark.group].x, + y: y - landmarkCenters[landmark.group].y, + z: 0 // Z is 0 as mouse interaction is 2D + } + }); + setIsDragging(true); + } else { + const landmark = findClosestLandmark(x, y); + + //console.log(`Moving mouse over ${landmark.group}`); + // console.log(`Simple mouse move over ${landmark.group}`); + + // we need to be careful here, we don't want to change the active + // landmark dynamically if we are busy dragging + + if (!currentLandmark || (currentLandmark?.group !== landmark?.group)) { + // console.log("setting activeLandmark to ", landmark); + setActiveLandmark(landmark); + } + setIsHovering(true); // Ensure hovering state is maintained during movement + } + }, [currentLandmark, dragStart, setIsHovering, setActiveLandmark, setIsDragging, modifyImageWithRateLimit, landmarkCenters]); + + const handleMouseUp = useCallback((event: React.MouseEvent) => { + if (!canvasRef.current) return; + + const rect = canvasRef.current.getBoundingClientRect(); + const x = (event.clientX - rect.left) / rect.width; + const y = (event.clientY - rect.top) / rect.height; + + // only send an API request to modify the image if we are actively dragging + if (dragStart && dragStartRef.current) { + + const landmark = findClosestLandmark(x, y, currentLandmark?.group); + + console.log(`Mouse up (was over ${currentLandmark?.group || 'nothing'}, now over ${landmark.group})`); + + // Compute the vector from the landmark center to the current mouse position + modifyImageWithRateLimit({ + landmark: currentLandmark || landmark, // this will still use the initially selected landmark + vector: { + x: x - landmarkCenters[landmark.group].x, + y: y - landmarkCenters[landmark.group].y, + z: 0 // Z is 0 as mouse interaction is 2D + } + }); + } + + setIsDragging(false); + dragStartRef.current = null; + setActiveLandmark(undefined); + }, [currentLandmark, isDragging, modifyImageWithRateLimit, findClosestLandmark, setActiveLandmark, landmarkCenters, modifyImageWithRateLimit, setIsDragging]); + + useEffect(() => { + facePoke.setOnModifiedImage((image: string, image_hash: string) => { + if (image) { + setPreviewImage(image); + } + setOriginalImageHash(image_hash); + lastModifiedImageHashRef.current = image_hash; + }); + }, [setPreviewImage, setOriginalImageHash]); + + return { + canvasRef, + canvasRefCallback, + mediaPipeRef, + faceLandmarks, + isMediaPipeReady, + isDrawingUtilsReady, + blendShapes, + + //dragStart, + //setDragStart, + //dragEnd, + //setDragEnd, + setFaceLandmarks, + setBlendShapes, + + handleMouseDown, + handleMouseUp, + handleMouseMove, + handleMouseEnter, + handleMouseLeave, + + currentLandmark, + currentOpacity, + } +} diff --git a/client/src/hooks/useFacePokeAPI.ts b/client/src/hooks/useFacePokeAPI.ts new file mode 100644 index 0000000000000000000000000000000000000000..5ebad7fc096da5b813df54d28ddeed9dd6c95936 --- /dev/null +++ b/client/src/hooks/useFacePokeAPI.ts @@ -0,0 +1,44 @@ +import { useEffect, useState } from "react"; + +import { facePoke } from "../lib/facePoke"; +import { useMainStore } from "./useMainStore"; + +export function useFacePokeAPI() { + + // State for FacePoke + const [status, setStatus] = useState(''); + const [isDebugMode, setIsDebugMode] = useState(false); + const [interruptMessage, setInterruptMessage] = useState(null); + + const [isLoading, setIsLoading] = useState(false); + + // Initialize FacePoke + useEffect(() => { + const urlParams = new URLSearchParams(window.location.search); + setIsDebugMode(urlParams.get('debug') === 'true'); + }, []); + + // Handle WebSocket interruptions + useEffect(() => { + const handleInterruption = (event: CustomEvent) => { + setInterruptMessage(event.detail.message); + }; + + window.addEventListener('websocketInterruption' as any, handleInterruption); + + return () => { + window.removeEventListener('websocketInterruption' as any, handleInterruption); + }; + }, []); + + return { + facePoke, + status, + setStatus, + isDebugMode, + setIsDebugMode, + interruptMessage, + isLoading, + setIsLoading, + } +} diff --git a/client/src/hooks/useMainStore.ts b/client/src/hooks/useMainStore.ts new file mode 100644 index 0000000000000000000000000000000000000000..6c973a00c20f540c13d54b4a470b14f86c583f8d --- /dev/null +++ b/client/src/hooks/useMainStore.ts @@ -0,0 +1,58 @@ +import { create } from 'zustand' +import type { ClosestLandmark } from './useFaceLandmarkDetection' +import type { ImageModificationParams } from '@/lib/facePoke' + +interface ImageState { + error: string + imageFile: File | null + originalImage: string + previewImage: string + originalImageHash: string + minLatency: number + averageLatency: number + maxLatency: number + activeLandmark?: ClosestLandmark + params: Partial + setError: (error?: string) => void + setImageFile: (file: File | null) => void + setOriginalImage: (url: string) => void + setOriginalImageHash: (hash: string) => void + setPreviewImage: (url: string) => void + resetImage: () => void + setAverageLatency: (averageLatency: number) => void + setActiveLandmark: (activeLandmark?: ClosestLandmark) => void + setParams: (params: Partial) => void +} + +export const useMainStore = create((set, get) => ({ + error: '', + imageFile: null, + originalImage: '', + originalImageHash: '', + previewImage: '', + minLatency: 20, // min time between requests + averageLatency: 190, // this should be the average for most people + maxLatency: 4000, // max time between requests + activeLandmark: undefined, + params: {}, + setError: (error: string = '') => set({ error }), + setImageFile: (file) => set({ imageFile: file }), + setOriginalImage: (url) => set({ originalImage: url }), + setOriginalImageHash: (originalImageHash) => set({ originalImageHash }), + setPreviewImage: (url) => set({ previewImage: url }), + resetImage: () => { + const { originalImage } = get() + if (originalImage) { + set({ previewImage: originalImage }) + } + }, + setAverageLatency: (averageLatency: number) => set({ averageLatency }), + setActiveLandmark: (activeLandmark?: ClosestLandmark) => set({ activeLandmark }), + setParams: (params: Partial) => { + const {params: previousParams } = get() + set({ params: { + ...previousParams, + ...params + }}) + }, +})) diff --git a/client/src/index.tsx b/client/src/index.tsx new file mode 100644 index 0000000000000000000000000000000000000000..460767af6b3be1b6c158d7bd82971aab7349d245 --- /dev/null +++ b/client/src/index.tsx @@ -0,0 +1,6 @@ +import { createRoot } from 'react-dom/client'; + +import { App } from './app'; + +const root = createRoot(document.getElementById('root')!); +root.render(); diff --git a/client/src/layout.tsx b/client/src/layout.tsx new file mode 100644 index 0000000000000000000000000000000000000000..396a91da985257878db9ece6dbbe1649676233d4 --- /dev/null +++ b/client/src/layout.tsx @@ -0,0 +1,14 @@ +import React, { type ReactNode } from 'react'; + +export function Layout({ children }: { children: ReactNode }) { + return ( +
+
+
+ {children} +
+
+
+ ); +} diff --git a/client/src/lib/circularBuffer.ts b/client/src/lib/circularBuffer.ts new file mode 100644 index 0000000000000000000000000000000000000000..d368caad7767242b494bbfc8baa21a88c8d097cd --- /dev/null +++ b/client/src/lib/circularBuffer.ts @@ -0,0 +1,31 @@ + + +/** + * Circular buffer for storing and managing response times. + */ +export class CircularBuffer { + private buffer: T[]; + private pointer: number; + + constructor(private capacity: number) { + this.buffer = new Array(capacity); + this.pointer = 0; + } + + /** + * Adds an item to the buffer, overwriting the oldest item if full. + * @param item - The item to add to the buffer. + */ + push(item: T): void { + this.buffer[this.pointer] = item; + this.pointer = (this.pointer + 1) % this.capacity; + } + + /** + * Retrieves all items currently in the buffer. + * @returns An array of all items in the buffer. + */ + getAll(): T[] { + return this.buffer.filter(item => item !== undefined); + } +} diff --git a/client/src/lib/convertImageToBase64.ts b/client/src/lib/convertImageToBase64.ts new file mode 100644 index 0000000000000000000000000000000000000000..9c260075dcebf75efc3ac4f596f038a1e339573f --- /dev/null +++ b/client/src/lib/convertImageToBase64.ts @@ -0,0 +1,19 @@ +export async function convertImageToBase64(imageFile: File): Promise { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + + reader.onload = () => { + if (typeof reader.result === 'string') { + resolve(reader.result); + } else { + reject(new Error('Failed to convert image to base64')); + } + }; + + reader.onerror = () => { + reject(new Error('Error reading file')); + }; + + reader.readAsDataURL(imageFile); + }); +} diff --git a/client/src/lib/facePoke.ts b/client/src/lib/facePoke.ts new file mode 100644 index 0000000000000000000000000000000000000000..14aaacacf46d5911b774b0d1e5e45ea2f8e3648e --- /dev/null +++ b/client/src/lib/facePoke.ts @@ -0,0 +1,398 @@ +import { v4 as uuidv4 } from 'uuid'; +import { CircularBuffer } from './circularBuffer'; +import { useMainStore } from '@/hooks/useMainStore'; + +/** + * Represents a tracked request with its UUID and timestamp. + */ +export interface TrackedRequest { + uuid: string; + timestamp: number; +} + +/** + * Represents the parameters for image modification. + */ +export interface ImageModificationParams { + eyes: number; + eyebrow: number; + wink: number; + pupil_x: number; + pupil_y: number; + aaa: number; + eee: number; + woo: number; + smile: number; + rotate_pitch: number; + rotate_yaw: number; + rotate_roll: number; +} + +/** + * Represents a message to modify an image. + */ +export interface ModifyImageMessage { + type: 'modify_image'; + image?: string; + image_hash?: string; + params: Partial; +} + + +/** + * Callback type for handling modified images. + */ +type OnModifiedImage = (image: string, image_hash: string) => void; + +/** + * Enum representing the different states of a WebSocket connection. + */ +enum WebSocketState { + CONNECTING = 0, + OPEN = 1, + CLOSING = 2, + CLOSED = 3 +} + +/** + * FacePoke class manages the WebSocket connection + */ +export class FacePoke { + private ws: WebSocket | null = null; + private readonly connectionId: string = uuidv4(); + private isUnloading: boolean = false; + private onModifiedImage: OnModifiedImage = () => {}; + private reconnectAttempts: number = 0; + private readonly maxReconnectAttempts: number = 5; + private readonly reconnectDelay: number = 5000; + private readonly eventListeners: Map> = new Map(); + + private requestTracker: Map = new Map(); + private responseTimeBuffer: CircularBuffer; + private readonly MAX_TRACKED_TIMES = 5; // Number of recent response times to track + + /** + * Creates an instance of FacePoke. + * Initializes the WebSocket connection. + */ + constructor() { + console.log(`[FacePoke] Initializing FacePoke instance with connection ID: ${this.connectionId}`); + this.initializeWebSocket(); + this.setupUnloadHandler(); + + this.responseTimeBuffer = new CircularBuffer(this.MAX_TRACKED_TIMES); + console.log(`[FacePoke] Initialized response time tracker with capacity: ${this.MAX_TRACKED_TIMES}`); + } + + + /** + * Generates a unique UUID for a request and starts tracking it. + * @returns The generated UUID for the request. + */ + private trackRequest(): string { + const uuid = uuidv4(); + this.requestTracker.set(uuid, { uuid, timestamp: Date.now() }); + // console.log(`[FacePoke] Started tracking request with UUID: ${uuid}`); + return uuid; + } + + /** + * Completes tracking for a request and updates response time statistics. + * @param uuid - The UUID of the completed request. + */ + private completeRequest(uuid: string): void { + const request = this.requestTracker.get(uuid); + if (request) { + const responseTime = Date.now() - request.timestamp; + this.responseTimeBuffer.push(responseTime); + this.requestTracker.delete(uuid); + this.updateThrottleTime(); + console.log(`[FacePoke] Completed request ${uuid}. Response time: ${responseTime}ms`); + } else { + console.warn(`[FacePoke] Attempted to complete unknown request: ${uuid}`); + } + } + + /** + * Calculates the average response time from recent requests. + * @returns The average response time in milliseconds. + */ + private calculateAverageResponseTime(): number { + const times = this.responseTimeBuffer.getAll(); + + const averageLatency = useMainStore.getState().averageLatency; + + if (times.length === 0) return averageLatency; + const sum = times.reduce((acc, time) => acc + time, 0); + return sum / times.length; + } + + /** + * Updates the throttle time based on recent response times. + */ + private updateThrottleTime(): void { + const { minLatency, maxLatency, averageLatency, setAverageLatency } = useMainStore.getState(); + const avgResponseTime = this.calculateAverageResponseTime(); + const newLatency = Math.min(minLatency, Math.max(minLatency, avgResponseTime)); + + if (newLatency !== averageLatency) { + setAverageLatency(newLatency) + console.log(`[FacePoke] Updated throttle time (latency is ${newLatency}ms)`); + } + } + + /** + * Sets the callback function for handling modified images. + * @param handler - The function to be called when a modified image is received. + */ + public setOnModifiedImage(handler: OnModifiedImage): void { + this.onModifiedImage = handler; + console.log(`[FacePoke] onModifiedImage handler set`); + } + + /** + * Starts or restarts the WebSocket connection. + */ + public async startWebSocket(): Promise { + console.log(`[FacePoke] Starting WebSocket connection.`); + if (!this.ws || this.ws.readyState !== WebSocketState.OPEN) { + await this.initializeWebSocket(); + } + } + + /** + * Initializes the WebSocket connection. + * Implements exponential backoff for reconnection attempts. + */ + private async initializeWebSocket(): Promise { + console.log(`[FacePoke][${this.connectionId}] Initializing WebSocket connection`); + + const connect = () => { + this.ws = new WebSocket(`wss://${window.location.host}/ws`); + + this.ws.onopen = this.handleWebSocketOpen.bind(this); + this.ws.onmessage = this.handleWebSocketMessage.bind(this); + this.ws.onclose = this.handleWebSocketClose.bind(this); + this.ws.onerror = this.handleWebSocketError.bind(this); + }; + + // const debouncedConnect = debounce(connect, this.reconnectDelay, { leading: true, trailing: false }); + + connect(); // Initial connection attempt + } + + /** + * Handles the WebSocket open event. + */ + private handleWebSocketOpen(): void { + console.log(`[FacePoke][${this.connectionId}] WebSocket connection opened`); + this.reconnectAttempts = 0; // Reset reconnect attempts on successful connection + this.emitEvent('websocketOpen'); + } + + // Update handleWebSocketMessage to complete request tracking + private handleWebSocketMessage(event: MessageEvent): void { + try { + const data = JSON.parse(event.data); + // console.log(`[FacePoke][${this.connectionId}] Received JSON data:`, data); + + if (data.uuid) { + this.completeRequest(data.uuid); + } + + if (data.type === 'modified_image') { + if (data?.image) { + this.onModifiedImage(data.image, data.image_hash); + } + } + + this.emitEvent('message', data); + } catch (error) { + console.error(`[FacePoke][${this.connectionId}] Error parsing WebSocket message:`, error); + } + } + + /** + * Handles WebSocket close events. + * Implements reconnection logic with exponential backoff. + * @param event - The CloseEvent containing close information. + */ + private handleWebSocketClose(event: CloseEvent): void { + if (event.wasClean) { + console.log(`[FacePoke][${this.connectionId}] WebSocket connection closed cleanly, code=${event.code}, reason=${event.reason}`); + } else { + console.warn(`[FacePoke][${this.connectionId}] WebSocket connection abruptly closed`); + } + + this.emitEvent('websocketClose', event); + + // Attempt to reconnect after a delay, unless the page is unloading or max attempts reached + if (!this.isUnloading && this.reconnectAttempts < this.maxReconnectAttempts) { + this.reconnectAttempts++; + const delay = Math.min(1000 * (2 ** this.reconnectAttempts), 30000); // Exponential backoff, max 30 seconds + console.log(`[FacePoke][${this.connectionId}] Attempting to reconnect in ${delay}ms (Attempt ${this.reconnectAttempts}/${this.maxReconnectAttempts})...`); + setTimeout(() => this.initializeWebSocket(), delay); + } else if (this.reconnectAttempts >= this.maxReconnectAttempts) { + console.error(`[FacePoke][${this.connectionId}] Max reconnect attempts reached. Please refresh the page.`); + this.emitEvent('maxReconnectAttemptsReached'); + } + } + + /** + * Handles WebSocket errors. + * @param error - The error event. + */ + private handleWebSocketError(error: Event): void { + console.error(`[FacePoke][${this.connectionId}] WebSocket error:`, error); + this.emitEvent('websocketError', error); + } + + /** + * Handles interruption messages from the server. + * @param message - The interruption message. + */ + private handleInterruption(message: string): void { + console.warn(`[FacePoke] Interruption: ${message}`); + this.emitEvent('interruption', message); + } + + /** + * Toggles the microphone on or off. + * @param isOn - Whether to turn the microphone on (true) or off (false). + */ + public async toggleMicrophone(isOn: boolean): Promise { + console.log(`[FacePoke] Attempting to ${isOn ? 'start' : 'stop'} microphone`); + try { + if (isOn) { + await this.startMicrophone(); + } else { + this.stopMicrophone(); + } + this.emitEvent('microphoneToggled', isOn); + } catch (error) { + console.error(`[FacePoke] Error toggling microphone:`, error); + this.emitEvent('microphoneError', error); + throw error; + } + } + + +/** + * Cleans up resources and closes connections. + */ +public cleanup(): void { + console.log('[FacePoke] Starting cleanup process'); + if (this.ws) { + this.ws.close(); + this.ws = null; + } + this.eventListeners.clear(); + console.log('[FacePoke] Cleanup completed'); + this.emitEvent('cleanup'); +} + + /** + * Modifies an image based on the provided parameters + * @param image - The data-uri base64 image to modify. + * @param imageHash - The hash of the image to modify. + * @param params - The parameters for image modification. + */ + public modifyImage(image: string | null, imageHash: string | null, params: Partial): void { + try { + const message: ModifyImageMessage = { + type: 'modify_image', + params: params + }; + + if (image) { + message.image = image; + } else if (imageHash) { + message.image_hash = imageHash; + } else { + throw new Error('Either image or imageHash must be provided'); + } + + this.sendJsonMessage(message); + // console.log(`[FacePoke] Sent modify image request with UUID: ${uuid}`); + } catch (err) { + console.error(`[FacePoke] Failed to modify the image:`, err); + } + } + + /** + * Sends a JSON message through the WebSocket connection with request tracking. + * @param message - The message to send. + * @throws Error if the WebSocket is not open. + */ + private sendJsonMessage(message: T): void { + if (!this.ws || this.ws.readyState !== WebSocketState.OPEN) { + const error = new Error('WebSocket connection is not open'); + console.error('[FacePoke] Error sending JSON message:', error); + this.emitEvent('sendJsonMessageError', error); + throw error; + } + + const uuid = this.trackRequest(); + const messageWithUuid = { ...message, uuid }; + // console.log(`[FacePoke] Sending JSON message with UUID ${uuid}:`, messageWithUuid); + this.ws.send(JSON.stringify(messageWithUuid)); + } + +/** + * Sets up the unload handler to clean up resources when the page is unloading. + */ +private setupUnloadHandler(): void { + window.addEventListener('beforeunload', () => { + console.log('[FacePoke] Page is unloading, cleaning up resources'); + this.isUnloading = true; + if (this.ws) { + this.ws.close(1000, 'Page is unloading'); + } + this.cleanup(); + }); +} + +/** + * Adds an event listener for a specific event type. + * @param eventType - The type of event to listen for. + * @param listener - The function to be called when the event is emitted. + */ +public addEventListener(eventType: string, listener: Function): void { + if (!this.eventListeners.has(eventType)) { + this.eventListeners.set(eventType, new Set()); + } + this.eventListeners.get(eventType)!.add(listener); + console.log(`[FacePoke] Added event listener for '${eventType}'`); +} + +/** + * Removes an event listener for a specific event type. + * @param eventType - The type of event to remove the listener from. + * @param listener - The function to be removed from the listeners. + */ +public removeEventListener(eventType: string, listener: Function): void { + const listeners = this.eventListeners.get(eventType); + if (listeners) { + listeners.delete(listener); + console.log(`[FacePoke] Removed event listener for '${eventType}'`); + } +} + +/** + * Emits an event to all registered listeners for that event type. + * @param eventType - The type of event to emit. + * @param data - Optional data to pass to the event listeners. + */ +private emitEvent(eventType: string, data?: any): void { + const listeners = this.eventListeners.get(eventType); + if (listeners) { + console.log(`[FacePoke] Emitting event '${eventType}' with data:`, data); + listeners.forEach(listener => listener(data)); + } +} +} + +/** +* Singleton instance of the FacePoke class. +*/ +export const facePoke = new FacePoke(); diff --git a/client/src/lib/throttle.ts b/client/src/lib/throttle.ts new file mode 100644 index 0000000000000000000000000000000000000000..70a7dd34ec33d85c0e8383a10dac86a9244dcfc4 --- /dev/null +++ b/client/src/lib/throttle.ts @@ -0,0 +1,32 @@ + +/** + * Custom throttle function that allows the first call to go through immediately + * and then limits subsequent calls. + * @param func - The function to throttle. + * @param limit - The minimum time between function calls in milliseconds. + * @returns A throttled version of the function. + */ +export function throttle any>(func: T, limit: number): T { + let lastCall = 0; + let timeoutId: NodeJS.Timer | null = null; + + return function (this: any, ...args: Parameters) { + const context = this; + const now = Date.now(); + + if (now - lastCall >= limit) { + if (timeoutId !== null) { + clearTimeout(timeoutId); + timeoutId = null; + } + lastCall = now; + return func.apply(context, args); + } else if (!timeoutId) { + timeoutId = setTimeout(() => { + lastCall = Date.now(); + timeoutId = null; + func.apply(context, args); + }, limit - (now - lastCall)); + } + } as T; +} diff --git a/client/src/lib/utils.ts b/client/src/lib/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..a4eb975048cbdbf4ea12e4c2dc55da20d7181486 --- /dev/null +++ b/client/src/lib/utils.ts @@ -0,0 +1,15 @@ +import { clsx, type ClassValue } from "clsx" +import { twMerge } from "tailwind-merge" + +export function cn(...inputs: ClassValue[]) { + return twMerge(clsx(inputs)) +} + +export function truncateFileName(fileName: string, maxLength: number = 16) { + if (fileName.length <= maxLength) return fileName; + + const start = fileName.slice(0, maxLength / 2 - 1); + const end = fileName.slice(-maxLength / 2 + 2); + + return `${start}...${end}`; +}; diff --git a/client/src/styles/globals.css b/client/src/styles/globals.css new file mode 100644 index 0000000000000000000000000000000000000000..95d8cab145debb5c4214ff5a53fb71977f58504b --- /dev/null +++ b/client/src/styles/globals.css @@ -0,0 +1,81 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +@layer base { + :root { + --background: 0 0% 100%; + --foreground: 222.2 47.4% 11.2%; + + --muted: 210 40% 96.1%; + --muted-foreground: 215.4 16.3% 46.9%; + + --popover: 0 0% 100%; + --popover-foreground: 222.2 47.4% 11.2%; + + --border: 214.3 31.8% 91.4%; + --input: 214.3 31.8% 91.4%; + + --card: 0 0% 100%; + --card-foreground: 222.2 47.4% 11.2%; + + --primary: 222.2 47.4% 11.2%; + --primary-foreground: 210 40% 98%; + + --secondary: 210 40% 96.1%; + --secondary-foreground: 222.2 47.4% 11.2%; + + --accent: 210 40% 96.1%; + --accent-foreground: 222.2 47.4% 11.2%; + + --destructive: 0 100% 50%; + --destructive-foreground: 210 40% 98%; + + --ring: 215 20.2% 65.1%; + + --radius: 0.5rem; + } + + .dark { + --background: 224 71% 4%; + --foreground: 213 31% 91%; + + --muted: 223 47% 11%; + --muted-foreground: 215.4 16.3% 56.9%; + + --accent: 216 34% 17%; + --accent-foreground: 210 40% 98%; + + --popover: 224 71% 4%; + --popover-foreground: 215 20.2% 65.1%; + + --border: 216 34% 17%; + --input: 216 34% 17%; + + --card: 224 71% 4%; + --card-foreground: 213 31% 91%; + + --primary: 210 40% 98%; + --primary-foreground: 222.2 47.4% 1.2%; + + --secondary: 222.2 47.4% 11.2%; + --secondary-foreground: 210 40% 98%; + + --destructive: 0 63% 31%; + --destructive-foreground: 210 40% 98%; + + --ring: 216 34% 17%; + + --radius: 0.5rem; + } +} + +@layer base { + * { + @apply border-border; + } + body { + @apply bg-background text-foreground; + font-feature-settings: "rlig" 1, "calt" 1; + } +} diff --git a/client/tailwind.config.js b/client/tailwind.config.js new file mode 100644 index 0000000000000000000000000000000000000000..ab1e406f32370649581dbf771427e5c88946796c --- /dev/null +++ b/client/tailwind.config.js @@ -0,0 +1,86 @@ +const { fontFamily } = require("tailwindcss/defaultTheme") + +/** @type {import('tailwindcss').Config} */ +module.exports = { + darkMode: ["class"], + content: [ + "app/**/*.{ts,tsx}", + "components/**/*.{ts,tsx}", + '../public/index.html' + ], + theme: { + container: { + center: true, + padding: "2rem", + screens: { + "2xl": "1400px", + }, + }, + extend: { + colors: { + border: "hsl(var(--border))", + input: "hsl(var(--input))", + ring: "hsl(var(--ring))", + background: "hsl(var(--background))", + foreground: "hsl(var(--foreground))", + primary: { + DEFAULT: "hsl(var(--primary))", + foreground: "hsl(var(--primary-foreground))", + }, + secondary: { + DEFAULT: "hsl(var(--secondary))", + foreground: "hsl(var(--secondary-foreground))", + }, + destructive: { + DEFAULT: "hsl(var(--destructive))", + foreground: "hsl(var(--destructive-foreground))", + }, + muted: { + DEFAULT: "hsl(var(--muted))", + foreground: "hsl(var(--muted-foreground))", + }, + accent: { + DEFAULT: "hsl(var(--accent))", + foreground: "hsl(var(--accent-foreground))", + }, + popover: { + DEFAULT: "hsl(var(--popover))", + foreground: "hsl(var(--popover-foreground))", + }, + card: { + DEFAULT: "hsl(var(--card))", + foreground: "hsl(var(--card-foreground))", + }, + }, + borderRadius: { + lg: `var(--radius)`, + md: `calc(var(--radius) - 2px)`, + sm: "calc(var(--radius) - 4px)", + }, + fontFamily: { + sans: ["var(--font-sans)", ...fontFamily.sans], + }, + fontSize: { + "5xs": "8px", + "4xs": "9px", + "3xs": "10px", + "2xs": "11px" + }, + keyframes: { + "accordion-down": { + from: { height: "0" }, + to: { height: "var(--radix-accordion-content-height)" }, + }, + "accordion-up": { + from: { height: "var(--radix-accordion-content-height)" }, + to: { height: "0" }, + }, + }, + animation: { + "accordion-down": "accordion-down 0.2s ease-out", + "accordion-up": "accordion-up 0.2s ease-out", + }, + }, + }, + plugins: [require("tailwindcss-animate")], +} diff --git a/client/tsconfig.json b/client/tsconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..66823aa62486477bfce85e0c79a06a2665b2fd3a --- /dev/null +++ b/client/tsconfig.json @@ -0,0 +1,32 @@ +{ + "compilerOptions": { + // Enable latest features + "lib": ["ESNext", "DOM", "DOM.Iterable"], + "target": "ESNext", + "module": "ESNext", + "moduleDetection": "force", + "jsx": "react-jsx", + "allowJs": true, + + // Bundler mode + "moduleResolution": "bundler", + "allowImportingTsExtensions": true, + "verbatimModuleSyntax": true, + "noEmit": true, + + "baseUrl": ".", + "paths": { + "@/*": ["./src/*"] + }, + + // Best practices + "strict": true, + "skipLibCheck": true, + "noFallthroughCasesInSwitch": true, + + // Some stricter flags (disabled by default) + "noUnusedLocals": false, + "noUnusedParameters": false, + "noPropertyAccessFromIndexSignature": false + } +} diff --git a/engine.py b/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..04af6f7bbd70a1f132d52ee83fdab0693b57e10d --- /dev/null +++ b/engine.py @@ -0,0 +1,300 @@ +import time +import logging +import hashlib +import uuid +import os +import io +import shutil +import asyncio +import base64 +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple, Union +from functools import lru_cache +import av +import numpy as np +import cv2 +import torch +import torch.nn.functional as F +from PIL import Image + +from liveportrait.config.argument_config import ArgumentConfig +from liveportrait.utils.camera import get_rotation_matrix +from liveportrait.utils.io import load_image_rgb, load_driving_info, resize_to_limit +from liveportrait.utils.crop import prepare_paste_back, paste_back + +# Configure logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Global constants +DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') +MODELS_DIR = os.path.join(DATA_ROOT, "models") + +def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image: + """ + Convert a base64 data URI to a PIL Image. + + Args: + base64_string (str): The base64 encoded image data. + + Returns: + Image.Image: The decoded PIL Image. + """ + if ',' in base64_string: + base64_string = base64_string.split(',')[1] + img_data = base64.b64decode(base64_string) + return Image.open(io.BytesIO(img_data)) + +class Engine: + """ + The main engine class for FacePoke + """ + + def __init__(self, live_portrait): + """ + Initialize the FacePoke engine with necessary models and processors. + + Args: + live_portrait (LivePortraitPipeline): The LivePortrait model for video generation. + """ + self.live_portrait = live_portrait + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # cache for the "modify image" workflow + self.image_cache = {} # Stores the original images + self.processed_cache = {} # Stores the processed image data + + logger.info("✅ FacePoke Engine initialized successfully.") + + def get_image_hash(self, image: Union[Image.Image, str, bytes]) -> str: + """ + Compute or retrieve the hash for an image. + + Args: + image (Union[Image.Image, str, bytes]): The input image, either as a PIL Image, + base64 string, or bytes. + + Returns: + str: The computed hash of the image. + """ + if isinstance(image, str): + # Assume it's already a hash if it's a string of the right length + if len(image) == 32: + return image + # Otherwise, assume it's a base64 string + image = base64_data_uri_to_PIL_Image(image) + + if isinstance(image, Image.Image): + return hashlib.md5(image.tobytes()).hexdigest() + elif isinstance(image, bytes): + return hashlib.md5(image).hexdigest() + else: + raise ValueError("Unsupported image type") + + @lru_cache(maxsize=128) + def _process_image(self, image_hash: str) -> Dict[str, Any]: + """ + Process the input image and cache the results. + + Args: + image_hash (str): Hash of the input image. + + Returns: + Dict[str, Any]: Processed image data. + """ + logger.info(f"Processing image with hash: {image_hash}") + if image_hash not in self.image_cache: + raise ValueError(f"Image with hash {image_hash} not found in cache") + + image = self.image_cache[image_hash] + img_rgb = np.array(image) + + inference_cfg = self.live_portrait.live_portrait_wrapper.cfg + img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) + crop_info = self.live_portrait.cropper.crop_single_image(img_rgb) + img_crop_256x256 = crop_info['img_crop_256x256'] + + I_s = self.live_portrait.live_portrait_wrapper.prepare_source(img_crop_256x256) + x_s_info = self.live_portrait.live_portrait_wrapper.get_kp_info(I_s) + f_s = self.live_portrait.live_portrait_wrapper.extract_feature_3d(I_s) + x_s = self.live_portrait.live_portrait_wrapper.transform_keypoint(x_s_info) + + processed_data = { + 'img_rgb': img_rgb, + 'crop_info': crop_info, + 'x_s_info': x_s_info, + 'f_s': f_s, + 'x_s': x_s, + 'inference_cfg': inference_cfg + } + + self.processed_cache[image_hash] = processed_data + + return processed_data + + async def modify_image(self, image_or_hash: Union[Image.Image, str, bytes], params: Dict[str, float]) -> str: + """ + Modify the input image based on the provided parameters, using caching for efficiency + and outputting the result as a WebP image. + + Args: + image_or_hash (Union[Image.Image, str, bytes]): Input image as a PIL Image, base64-encoded string, + image bytes, or a hash string. + params (Dict[str, float]): Parameters for face transformation. + + Returns: + str: Modified image as a base64-encoded WebP data URI. + + Raises: + ValueError: If there's an error modifying the image or WebP is not supported. + """ + logger.info("Starting image modification") + logger.debug(f"Modification parameters: {params}") + + try: + image_hash = self.get_image_hash(image_or_hash) + + # If we don't have the image in cache yet, add it + if image_hash not in self.image_cache: + if isinstance(image_or_hash, (Image.Image, bytes)): + self.image_cache[image_hash] = image_or_hash + elif isinstance(image_or_hash, str) and len(image_or_hash) != 32: + # It's a base64 string, not a hash + self.image_cache[image_hash] = base64_data_uri_to_PIL_Image(image_or_hash) + else: + raise ValueError("Image not found in cache and no valid image provided") + + # Process the image (this will use the cache if available) + if image_hash not in self.processed_cache: + processed_data = await asyncio.to_thread(self._process_image, image_hash) + else: + processed_data = self.processed_cache[image_hash] + + # Apply modifications based on params + x_d_new = processed_data['x_s_info']['kp'].clone() + await self._apply_facial_modifications(x_d_new, params) + + # Apply rotation + R_new = get_rotation_matrix( + processed_data['x_s_info']['pitch'] + params.get('rotate_pitch', 0), + processed_data['x_s_info']['yaw'] + params.get('rotate_yaw', 0), + processed_data['x_s_info']['roll'] + params.get('rotate_roll', 0) + ) + x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t'] + + # Apply stitching + x_d_new = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.stitching, processed_data['x_s'], x_d_new) + + # Generate the output + out = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.warp_decode, processed_data['f_s'], processed_data['x_s'], x_d_new) + I_p = self.live_portrait.live_portrait_wrapper.parse_output(out['out'])[0] + + # Paste back to full size + mask_ori = await asyncio.to_thread( + prepare_paste_back, + processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'], + dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0]) + ) + I_p_to_ori_blend = await asyncio.to_thread( + paste_back, + I_p, processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori + ) + + # Convert the result to a PIL Image + result_image = Image.fromarray(I_p_to_ori_blend) + + # Save as WebP + buffered = io.BytesIO() + result_image.save(buffered, format="WebP", quality=85) # Adjust quality as needed + modified_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + + logger.info("Image modification completed successfully") + return f"data:image/webp;base64,{modified_image_base64}" + + except Exception as e: + logger.error(f"Error in modify_image: {str(e)}") + logger.exception("Full traceback:") + raise ValueError(f"Failed to modify image: {str(e)}") + + async def _apply_facial_modifications(self, x_d_new: torch.Tensor, params: Dict[str, float]) -> None: + """ + Apply facial modifications to the keypoints based on the provided parameters. + + Args: + x_d_new (torch.Tensor): Tensor of facial keypoints to be modified. + params (Dict[str, float]): Parameters for face transformation. + """ + modifications = [ + ('smile', [ + (0, 20, 1, -0.01), (0, 14, 1, -0.02), (0, 17, 1, 0.0065), (0, 17, 2, 0.003), + (0, 13, 1, -0.00275), (0, 16, 1, -0.00275), (0, 3, 1, -0.0035), (0, 7, 1, -0.0035) + ]), + ('aaa', [ + (0, 19, 1, 0.001), (0, 19, 2, 0.0001), (0, 17, 1, -0.0001) + ]), + ('eee', [ + (0, 20, 2, -0.001), (0, 20, 1, -0.001), (0, 14, 1, -0.001) + ]), + ('woo', [ + (0, 14, 1, 0.001), (0, 3, 1, -0.0005), (0, 7, 1, -0.0005), (0, 17, 2, -0.0005) + ]), + ('wink', [ + (0, 11, 1, 0.001), (0, 13, 1, -0.0003), (0, 17, 0, 0.0003), + (0, 17, 1, 0.0003), (0, 3, 1, -0.0003) + ]), + ('pupil_x', [ + (0, 11, 0, 0.0007 if params.get('pupil_x', 0) > 0 else 0.001), + (0, 15, 0, 0.001 if params.get('pupil_x', 0) > 0 else 0.0007) + ]), + ('pupil_y', [ + (0, 11, 1, -0.001), (0, 15, 1, -0.001) + ]), + ('eyes', [ + (0, 11, 1, -0.001), (0, 13, 1, 0.0003), (0, 15, 1, -0.001), (0, 16, 1, 0.0003), + (0, 1, 1, -0.00025), (0, 2, 1, 0.00025) + ]), + ('eyebrow', [ + (0, 1, 1, 0.001 if params.get('eyebrow', 0) > 0 else 0.0003), + (0, 2, 1, -0.001 if params.get('eyebrow', 0) > 0 else -0.0003), + (0, 1, 0, -0.001 if params.get('eyebrow', 0) <= 0 else 0), + (0, 2, 0, 0.001 if params.get('eyebrow', 0) <= 0 else 0) + ]) + ] + + for param_name, adjustments in modifications: + param_value = params.get(param_name, 0) + for i, j, k, factor in adjustments: + x_d_new[i, j, k] += param_value * factor + + # Special case for pupil_y affecting eyes + x_d_new[0, 11, 1] -= params.get('pupil_y', 0) * 0.001 + x_d_new[0, 15, 1] -= params.get('pupil_y', 0) * 0.001 + params['eyes'] = params.get('eyes', 0) - params.get('pupil_y', 0) / 2. + + async def cleanup(self): + """ + Perform cleanup operations for the Engine. + This method should be called when shutting down the application. + """ + logger.info("Starting Engine cleanup") + try: + # TODO: Add any additional cleanup operations here + logger.info("Engine cleanup completed successfully") + except Exception as e: + logger.error(f"Error during Engine cleanup: {str(e)}") + logger.exception("Full traceback:") + +def create_engine(models): + logger.info("Creating Engine instance...") + + live_portrait = models + + engine = Engine( + live_portrait=live_portrait, + # we might have more in the future + ) + + logger.info("Engine instance created successfully") + return engine diff --git a/liveportrait/config/__init__.py b/liveportrait/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/liveportrait/config/argument_config.py b/liveportrait/config/argument_config.py new file mode 100644 index 0000000000000000000000000000000000000000..043162742ce65227076501c834289fda8d24a0cf --- /dev/null +++ b/liveportrait/config/argument_config.py @@ -0,0 +1,44 @@ +# coding: utf-8 + +""" +config for user +""" + +import os.path as osp +from dataclasses import dataclass +import tyro +from typing_extensions import Annotated +from .base_config import PrintableConfig, make_abs_path + + +@dataclass(repr=False) # use repr from PrintableConfig +class ArgumentConfig(PrintableConfig): + ########## input arguments ########## + source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait + driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format) + output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video + ##################################### + + ########## inference arguments ########## + device_id: int = 0 + flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False + flag_eye_retargeting: bool = False + flag_lip_retargeting: bool = False + flag_stitching: bool = True # we recommend setting it to True! + flag_relative: bool = True # whether to use relative motion + flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space + flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space + flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True + ######################################### + + ########## crop arguments ########## + dsize: int = 512 + scale: float = 2.3 + vx_ratio: float = 0 # vx ratio + vy_ratio: float = -0.125 # vy ratio +up, -down + #################################### + + ########## gradio arguments ########## + server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890 + share: bool = True + server_name: str = "0.0.0.0" diff --git a/liveportrait/config/base_config.py b/liveportrait/config/base_config.py new file mode 100644 index 0000000000000000000000000000000000000000..216b8be50aecc8af4b9d1d2a9401e034dd7769e4 --- /dev/null +++ b/liveportrait/config/base_config.py @@ -0,0 +1,29 @@ +# coding: utf-8 + +""" +pretty printing class +""" + +from __future__ import annotations +import os.path as osp +from typing import Tuple + + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +class PrintableConfig: # pylint: disable=too-few-public-methods + """Printable Config defining str function""" + + def __repr__(self): + lines = [self.__class__.__name__ + ":"] + for key, val in vars(self).items(): + if isinstance(val, Tuple): + flattened_val = "[" + for item in val: + flattened_val += str(item) + "\n" + flattened_val = flattened_val.rstrip("\n") + val = flattened_val + "]" + lines += f"{key}: {str(val)}".split("\n") + return "\n ".join(lines) diff --git a/liveportrait/config/crop_config.py b/liveportrait/config/crop_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c79be214ec6018ef4af298e306c47bd2a187f8 --- /dev/null +++ b/liveportrait/config/crop_config.py @@ -0,0 +1,18 @@ +# coding: utf-8 + +""" +parameters used for crop faces +""" + +import os.path as osp +from dataclasses import dataclass +from typing import Union, List +from .base_config import PrintableConfig + + +@dataclass(repr=False) # use repr from PrintableConfig +class CropConfig(PrintableConfig): + dsize: int = 512 # crop size + scale: float = 2.3 # scale factor + vx_ratio: float = 0 # vx ratio + vy_ratio: float = -0.125 # vy ratio +up, -down diff --git a/liveportrait/config/inference_config.py b/liveportrait/config/inference_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0a384458fe709befa9f7b0ae73fdbdd4e576a44b --- /dev/null +++ b/liveportrait/config/inference_config.py @@ -0,0 +1,53 @@ +# coding: utf-8 + +""" +config dataclass used for inference +""" + +import os +import os.path as osp +from dataclasses import dataclass +from typing import Literal, Tuple +from .base_config import PrintableConfig, make_abs_path + +# Configuration +DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') +MODELS_DIR = os.path.join(DATA_ROOT, "models") + +@dataclass(repr=False) # use repr from PrintableConfig +class InferenceConfig(PrintableConfig): + models_config: str = make_abs_path('./models.yaml') # portrait animation config + checkpoint_F = os.path.join(MODELS_DIR, "liveportrait", "appearance_feature_extractor.pth") + checkpoint_M = os.path.join(MODELS_DIR, "liveportrait", "motion_extractor.pth") + checkpoint_W = os.path.join(MODELS_DIR, "liveportrait", "warping_module.pth") + checkpoint_G = os.path.join(MODELS_DIR, "liveportrait", "spade_generator.pth") + checkpoint_S = os.path.join(MODELS_DIR, "liveportrait", "stitching_retargeting_module.pth") + + flag_use_half_precision: bool = True # whether to use half precision + + flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False + lip_zero_threshold: float = 0.03 + + flag_eye_retargeting: bool = False + flag_lip_retargeting: bool = False + flag_stitching: bool = True # we recommend setting it to True! + + flag_relative: bool = True # whether to use relative motion + anchor_frame: int = 0 # set this value if find_best_frame is True + + input_shape: Tuple[int, int] = (256, 256) # input shape + output_format: Literal['mp4', 'gif'] = 'mp4' # output video format + output_fps: int = 25 # MuseTalk prefers 25 fps, so we use 25 as default fps for output video + crf: int = 15 # crf for output video + + flag_write_result: bool = True # whether to write output video + flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space + mask_crop = None + flag_write_gif: bool = False + size_gif: int = 256 + ref_max_shape: int = 1280 + ref_shape_n: int = 2 + + device_id: int = 0 + flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space + flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True diff --git a/liveportrait/config/models.yaml b/liveportrait/config/models.yaml new file mode 100644 index 0000000000000000000000000000000000000000..131d1c65025c31e37af9239e211ea14454128a2e --- /dev/null +++ b/liveportrait/config/models.yaml @@ -0,0 +1,43 @@ +model_params: + appearance_feature_extractor_params: # the F in the paper + image_channel: 3 + block_expansion: 64 + num_down_blocks: 2 + max_features: 512 + reshape_channel: 32 + reshape_depth: 16 + num_resblocks: 6 + motion_extractor_params: # the M in the paper + num_kp: 21 + backbone: convnextv2_tiny + warping_module_params: # the W in the paper + num_kp: 21 + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + reshape_depth: 16 + compress: 4 + spade_generator_params: # the G in the paper + upscale: 2 # represents upsample factor 256x256 -> 512x512 + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + stitching_retargeting_module_params: # the S in the paper + stitching: + input_size: 126 # (21*3)*2 + hidden_sizes: [128, 128, 64] + output_size: 65 # (21*3)+2(tx,ty) + lip: + input_size: 65 # (21*3)+2 + hidden_sizes: [128, 128, 64] + output_size: 63 # (21*3) + eye: + input_size: 66 # (21*3)+3 + hidden_sizes: [256, 256, 128, 128, 64] + output_size: 63 # (21*3) diff --git a/liveportrait/gradio_pipeline.py b/liveportrait/gradio_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c717897483380d69b4cd4622cf15eccde14a2eb0 --- /dev/null +++ b/liveportrait/gradio_pipeline.py @@ -0,0 +1,140 @@ +# coding: utf-8 + +""" +Pipeline for gradio +""" +import gradio as gr +from .config.argument_config import ArgumentConfig +from .live_portrait_pipeline import LivePortraitPipeline +from .utils.io import load_img_online +from .utils.rprint import rlog as log +from .utils.crop import prepare_paste_back, paste_back +from .utils.camera import get_rotation_matrix +from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio + +def update_args(args, user_args): + """update the args according to user inputs + """ + for k, v in user_args.items(): + if hasattr(args, k): + setattr(args, k, v) + return args + +class GradioPipeline(LivePortraitPipeline): + + def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig): + super().__init__(inference_cfg, crop_cfg) + # self.live_portrait_wrapper = self.live_portrait_wrapper + self.args = args + # for single image retargeting + self.start_prepare = False + self.f_s_user = None + self.x_c_s_info_user = None + self.x_s_user = None + self.source_lmk_user = None + self.mask_ori = None + self.img_rgb = None + self.crop_M_c2o = None + + + def execute_video( + self, + input_image_path, + input_video_path, + flag_relative_input, + flag_do_crop_input, + flag_remap_input, + ): + """ for video driven potrait animation + """ + if input_image_path is not None and input_video_path is not None: + args_user = { + 'source_image': input_image_path, + 'driving_info': input_video_path, + 'flag_relative': flag_relative_input, + 'flag_do_crop': flag_do_crop_input, + 'flag_pasteback': flag_remap_input, + } + # update config from user input + self.args = update_args(self.args, args_user) + self.live_portrait_wrapper.update_config(self.args.__dict__) + self.cropper.update_config(self.args.__dict__) + # video driven animation + video_path, video_path_concat = self.execute(self.args) + gr.Info("Run successfully!", duration=2) + return video_path, video_path_concat, + else: + raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5) + + def execute_image(self, input_eye_ratio: float, input_lip_ratio: float): + """ for single image retargeting + """ + if input_eye_ratio is None or input_eye_ratio is None: + raise gr.Error("Invalid ratio input 💥!", duration=5) + elif self.f_s_user is None: + if self.start_prepare: + raise gr.Error( + "The source portrait is under processing 💥! Please wait for a second.", + duration=5 + ) + else: + raise gr.Error( + "The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.", + duration=5 + ) + else: + # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) + combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user) + eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor) + # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) + combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user) + lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor) + num_kp = self.x_s_user.shape[1] + # default: use x_s + x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) + # D(W(f_s; x_s, x′_d)) + out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new) + out = self.live_portrait_wrapper.parse_output(out['out'])[0] + out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori) + gr.Info("Run successfully!", duration=2) + return out, out_to_ori_blend + + + def prepare_retargeting(self, input_image_path, flag_do_crop = True): + """ for single image retargeting + """ + if input_image_path is not None: + gr.Info("Upload successfully!", duration=2) + self.start_prepare = True + inference_cfg = self.live_portrait_wrapper.cfg + ######## process source portrait ######## + img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16) + log(f"Load source image from {input_image_path}.") + crop_info = self.cropper.crop_single_image(img_rgb) + if flag_do_crop: + I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) + else: + I_s = self.live_portrait_wrapper.prepare_source(img_rgb) + x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) + R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) + ############################################ + + # record global info for next time use + self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) + self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) + self.x_s_info_user = x_s_info + self.source_lmk_user = crop_info['lmk_crop'] + self.img_rgb = img_rgb + self.crop_M_c2o = crop_info['M_c2o'] + self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) + # update slider + eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None]) + eye_close_ratio = float(eye_close_ratio.squeeze(0).mean()) + lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None]) + lip_close_ratio = float(lip_close_ratio.squeeze(0).mean()) + # for vis + self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0] + return eye_close_ratio, lip_close_ratio, self.I_s_vis + else: + # when press the clear button, go here + return 0.8, 0.8, self.I_s_vis diff --git a/liveportrait/live_portrait_pipeline.py b/liveportrait/live_portrait_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..76a27d1e8d984de2b14eafc0ce1fe0aab90ecf29 --- /dev/null +++ b/liveportrait/live_portrait_pipeline.py @@ -0,0 +1,193 @@ +# coding: utf-8 + +""" +Pipeline of LivePortrait +""" + +# TODO: +# 1. 当前假定所有的模板都是已经裁好的,需要修改下 +# 2. pick样例图 source + driving + +import cv2 +import numpy as np +import pickle +import os.path as osp +from rich.progress import track + +from .config.argument_config import ArgumentConfig +from .config.inference_config import InferenceConfig +from .config.crop_config import CropConfig +from .utils.cropper import Cropper +from .utils.camera import get_rotation_matrix +from .utils.video import images2video, concat_frames +from .utils.crop import _transform_img, prepare_paste_back, paste_back +from .utils.retargeting_utils import calc_lip_close_ratio +from .utils.io import load_image_rgb, load_driving_info, resize_to_limit +from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template +from .utils.rprint import rlog as log +from .live_portrait_wrapper import LivePortraitWrapper + + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +class LivePortraitPipeline(object): + + def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig): + self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg) + self.cropper = Cropper(crop_cfg=crop_cfg) + + def execute(self, args: ArgumentConfig): + inference_cfg = self.live_portrait_wrapper.cfg # for convenience + ######## process source portrait ######## + img_rgb = load_image_rgb(args.source_image) + img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) + log(f"Load source image from {args.source_image}") + crop_info = self.cropper.crop_single_image(img_rgb) + source_lmk = crop_info['lmk_crop'] + img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] + if inference_cfg.flag_do_crop: + I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256) + else: + I_s = self.live_portrait_wrapper.prepare_source(img_rgb) + x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) + x_c_s = x_s_info['kp'] + R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) + f_s = self.live_portrait_wrapper.extract_feature_3d(I_s) + x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info) + + if inference_cfg.flag_lip_zero: + # let lip-open scalar to be 0 at first + c_d_lip_before_animation = [0.] + combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk) + if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold: + inference_cfg.flag_lip_zero = False + else: + lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation) + ############################################ + + ######## process driving info ######## + if is_video(args.driving_info): + log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}") + # TODO: 这里track一下驱动视频 -> 构建模板 + driving_rgb_lst = load_driving_info(args.driving_info) + driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] + I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256) + n_frames = I_d_lst.shape[0] + if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting: + driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst) + input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) + elif is_template(args.driving_info): + log(f"Load from video templates {args.driving_info}") + with open(args.driving_info, 'rb') as f: + template_lst, driving_lmk_lst = pickle.load(f) + n_frames = template_lst[0]['n_frames'] + input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst) + else: + raise Exception("Unsupported driving types!") + ######################################### + + ######## prepare for pasteback ######## + if inference_cfg.flag_pasteback: + mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) + I_p_paste_lst = [] + ######################################### + + I_p_lst = [] + R_d_0, x_d_0_info = None, None + for i in track(range(n_frames), description='Animating...', total=n_frames): + if is_video(args.driving_info): + # extract kp info by M + I_d_i = I_d_lst[i] + x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i) + R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll']) + else: + # from template + x_d_i_info = template_lst[i] + x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id) + R_d_i = x_d_i_info['R_d'] + + if i == 0: + R_d_0 = R_d_i + x_d_0_info = x_d_i_info + + if inference_cfg.flag_relative: + R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s + delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) + t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) + else: + R_new = R_d_i + delta_new = x_d_i_info['exp'] + scale_new = x_s_info['scale'] + t_new = x_d_i_info['t'] + + t_new[..., 2].fill_(0) # zero tz + x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new + + # Algorithm 1: + if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: + # without stitching or retargeting + if inference_cfg.flag_lip_zero: + x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) + else: + pass + elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting: + # with stitching and without retargeting + if inference_cfg.flag_lip_zero: + x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) + else: + x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + else: + eyes_delta, lip_delta = None, None + if inference_cfg.flag_eye_retargeting: + c_d_eyes_i = input_eye_ratio_lst[i] + combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk) + # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) + eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor) + if inference_cfg.flag_lip_retargeting: + c_d_lip_i = input_lip_ratio_lst[i] + combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk) + # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) + lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor) + + if inference_cfg.flag_relative: # use x_s + x_d_i_new = x_s + \ + (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ + (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) + else: # use x_d,i + x_d_i_new = x_d_i_new + \ + (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ + (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) + + if inference_cfg.flag_stitching: + x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + + out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) + I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0] + I_p_lst.append(I_p_i) + + if inference_cfg.flag_pasteback: + I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori) + I_p_paste_lst.append(I_p_i_to_ori_blend) + + mkdir(args.output_dir) + wfp_concat = None + + # note by @jbilcke-hf: + # I have disabled this block, since we don't need to debug it + #if is_video(args.driving_info): + # frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256) + # # save (driving frames, source image, drived frames) result + # wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4') + # images2video(frames_concatenated, wfp=wfp_concat)# + + # save drived result + wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4') + if inference_cfg.flag_pasteback: + images2video(I_p_paste_lst, wfp=wfp) + else: + images2video(I_p_lst, wfp=wfp) + + return wfp, wfp_concat diff --git a/liveportrait/live_portrait_wrapper.py b/liveportrait/live_portrait_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..d96156c257f43a7e77a052af80ebb088a165db92 --- /dev/null +++ b/liveportrait/live_portrait_wrapper.py @@ -0,0 +1,307 @@ +# coding: utf-8 + +""" +Wrapper for LivePortrait core functions +""" + +import os.path as osp +import numpy as np +import cv2 +import torch +import yaml + +from .utils.timer import Timer +from .utils.helper import load_model, concat_feat +from .utils.camera import headpose_pred_to_degree, get_rotation_matrix +from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio +from .config.inference_config import InferenceConfig +from .utils.rprint import rlog as log + + +class LivePortraitWrapper(object): + + def __init__(self, cfg: InferenceConfig): + + model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader) + + # init F + self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor') + #log(f'Load appearance_feature_extractor done.') + # init M + self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor') + #log(f'Load motion_extractor done.') + # init W + self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module') + #log(f'Load warping_module done.') + # init G + self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator') + #log(f'Load spade_generator done.') + # init S and R + if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S): + self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module') + #log(f'Load stitching_retargeting_module done.') + else: + self.stitching_retargeting_module = None + + self.cfg = cfg + self.device_id = cfg.device_id + self.timer = Timer() + + def update_config(self, user_args): + for k, v in user_args.items(): + if hasattr(self.cfg, k): + setattr(self.cfg, k, v) + + def prepare_source(self, img: np.ndarray) -> torch.Tensor: + """ construct the input as standard + img: HxWx3, uint8, 256x256 + """ + h, w = img.shape[:2] + if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]: + x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1])) + else: + x = img.copy() + + if x.ndim == 3: + x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1 + elif x.ndim == 4: + x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1 + else: + raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') + x = np.clip(x, 0, 1) # clip to 0~1 + x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW + x = x.cuda(self.device_id) + return x + + def prepare_driving_videos(self, imgs) -> torch.Tensor: + """ construct the input as standard + imgs: NxBxHxWx3, uint8 + """ + if isinstance(imgs, list): + _imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1 + elif isinstance(imgs, np.ndarray): + _imgs = imgs + else: + raise ValueError(f'imgs type error: {type(imgs)}') + + y = _imgs.astype(np.float32) / 255. + y = np.clip(y, 0, 1) # clip to 0~1 + y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW + y = y.cuda(self.device_id) + + return y + + def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor: + """ get the appearance feature of the image by F + x: Bx3xHxW, normalized to 0~1 + """ + with torch.no_grad(): + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision): + feature_3d = self.appearance_feature_extractor(x) + + return feature_3d.float() + + def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict: + """ get the implicit keypoint information + x: Bx3xHxW, normalized to 0~1 + flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape + return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp' + """ + with torch.no_grad(): + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision): + kp_info = self.motion_extractor(x) + + if self.cfg.flag_use_half_precision: + # float the dict + for k, v in kp_info.items(): + if isinstance(v, torch.Tensor): + kp_info[k] = v.float() + + flag_refine_info: bool = kwargs.get('flag_refine_info', True) + if flag_refine_info: + bs = kp_info['kp'].shape[0] + kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1 + kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1 + kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1 + kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3 + kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3 + + return kp_info + + def get_pose_dct(self, kp_info: dict) -> dict: + pose_dct = dict( + pitch=headpose_pred_to_degree(kp_info['pitch']).item(), + yaw=headpose_pred_to_degree(kp_info['yaw']).item(), + roll=headpose_pred_to_degree(kp_info['roll']).item(), + ) + return pose_dct + + def get_fs_and_kp_info(self, source_prepared, driving_first_frame): + + # get the canonical keypoints of source image by M + source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True) + source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll']) + + # get the canonical keypoints of first driving frame by M + driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True) + driving_first_frame_rotation = get_rotation_matrix( + driving_first_frame_kp_info['pitch'], + driving_first_frame_kp_info['yaw'], + driving_first_frame_kp_info['roll'] + ) + + # get feature volume by F + source_feature_3d = self.extract_feature_3d(source_prepared) + + return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation + + def transform_keypoint(self, kp_info: dict): + """ + transform the implicit keypoints with the pose, shift, and expression deformation + kp: BxNx3 + """ + kp = kp_info['kp'] # (bs, k, 3) + pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll'] + + t, exp = kp_info['t'], kp_info['exp'] + scale = kp_info['scale'] + + pitch = headpose_pred_to_degree(pitch) + yaw = headpose_pred_to_degree(yaw) + roll = headpose_pred_to_degree(roll) + + bs = kp.shape[0] + if kp.ndim == 2: + num_kp = kp.shape[1] // 3 # Bx(num_kpx3) + else: + num_kp = kp.shape[1] # Bxnum_kpx3 + + rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3) + + # Eqn.2: s * (R * x_c,s + exp) + t + kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3) + kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3) + kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty + + return kp_transformed + + def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor: + """ + kp_source: BxNx3 + eye_close_ratio: Bx3 + Return: Bx(3*num_kp+2) + """ + feat_eye = concat_feat(kp_source, eye_close_ratio) + + with torch.no_grad(): + delta = self.stitching_retargeting_module['eye'](feat_eye) + + return delta + + def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor: + """ + kp_source: BxNx3 + lip_close_ratio: Bx2 + """ + feat_lip = concat_feat(kp_source, lip_close_ratio) + + with torch.no_grad(): + delta = self.stitching_retargeting_module['lip'](feat_lip) + + return delta + + def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ + kp_source: BxNx3 + kp_driving: BxNx3 + Return: Bx(3*num_kp+2) + """ + feat_stiching = concat_feat(kp_source, kp_driving) + + with torch.no_grad(): + delta = self.stitching_retargeting_module['stitching'](feat_stiching) + + return delta + + def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ conduct the stitching + kp_source: Bxnum_kpx3 + kp_driving: Bxnum_kpx3 + """ + + if self.stitching_retargeting_module is not None: + + bs, num_kp = kp_source.shape[:2] + + kp_driving_new = kp_driving.clone() + delta = self.stitch(kp_source, kp_driving_new) + + delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3 + delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2 + + kp_driving_new += delta_exp + kp_driving_new[..., :2] += delta_tx_ty + + return kp_driving_new + + return kp_driving + + def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ get the image after the warping of the implicit keypoints + feature_3d: Bx32x16x64x64, feature volume + kp_source: BxNx3 + kp_driving: BxNx3 + """ + # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i)) + with torch.no_grad(): + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision): + # get decoder input + ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving) + # decode + ret_dct['out'] = self.spade_generator(feature=ret_dct['out']) + + # float the dict + if self.cfg.flag_use_half_precision: + for k, v in ret_dct.items(): + if isinstance(v, torch.Tensor): + ret_dct[k] = v.float() + + return ret_dct + + def parse_output(self, out: torch.Tensor) -> np.ndarray: + """ construct the output as standard + return: 1xHxWx3, uint8 + """ + out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3 + out = np.clip(out, 0, 1) # clip to 0~1 + out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255 + + return out + + def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst): + input_eye_ratio_lst = [] + input_lip_ratio_lst = [] + for lmk in driving_lmk_lst: + # for eyes retargeting + input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None])) + # for lip retargeting + input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None])) + return input_eye_ratio_lst, input_lip_ratio_lst + + def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk): + eye_close_ratio = calc_eye_close_ratio(source_lmk[None]) + eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id) + input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id) + # [c_s,eyes, c_d,eyes,i] + combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1) + return combined_eye_ratio_tensor + + def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk): + lip_close_ratio = calc_lip_close_ratio(source_lmk[None]) + lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id) + # [c_s,lip, c_d,lip,i] + input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id) + if input_lip_ratio_tensor.shape != [1, 1]: + input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1) + combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1) + return combined_lip_ratio_tensor diff --git a/liveportrait/modules/__init__.py b/liveportrait/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/liveportrait/modules/appearance_feature_extractor.py b/liveportrait/modules/appearance_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..8d89e4f18a2fbe58447f52ab4c5e3f2011a4ec80 --- /dev/null +++ b/liveportrait/modules/appearance_feature_extractor.py @@ -0,0 +1,48 @@ +# coding: utf-8 + +""" +Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume. +""" + +import torch +from torch import nn +from .util import SameBlock2d, DownBlock2d, ResBlock3d + + +class AppearanceFeatureExtractor(nn.Module): + + def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks): + super(AppearanceFeatureExtractor, self).__init__() + self.image_channel = image_channel + self.block_expansion = block_expansion + self.num_down_blocks = num_down_blocks + self.max_features = max_features + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + def forward(self, source_image): + out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256 + + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape # ->Bx512x64x64 + + f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64 + f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64 + return f_s diff --git a/liveportrait/modules/convnextv2.py b/liveportrait/modules/convnextv2.py new file mode 100644 index 0000000000000000000000000000000000000000..83ea12662b607854915df8c7abb160b588d330b1 --- /dev/null +++ b/liveportrait/modules/convnextv2.py @@ -0,0 +1,149 @@ +# coding: utf-8 + +""" +This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation. +""" + +import torch +import torch.nn as nn +# from timm.models.layers import trunc_normal_, DropPath +from .util import LayerNorm, DropPath, trunc_normal_, GRN + +__all__ = ['convnextv2_tiny'] + + +class Block(nn.Module): + """ ConvNeXtV2 Block. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim, drop_path=0.): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(4 * dim) + self.pwconv2 = nn.Linear(4 * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXtV2(nn.Module): + """ ConvNeXt V2 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + in_chans=3, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + drop_path_rate=0., + **kwargs + ): + super().__init__() + self.depths = depths + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + + # NOTE: the output semantic items + num_bins = kwargs.get('num_bins', 66) + num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints + self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints + + # print('dims[-1]: ', dims[-1]) + self.fc_scale = nn.Linear(dims[-1], 1) # scale + self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins + self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins + self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins + self.fc_t = nn.Linear(dims[-1], 3) # translation + self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + x = self.forward_features(x) + + # implicit keypoints + kp = self.fc_kp(x) + + # pose and expression deformation + pitch = self.fc_pitch(x) + yaw = self.fc_yaw(x) + roll = self.fc_roll(x) + t = self.fc_t(x) + exp = self.fc_exp(x) + scale = self.fc_scale(x) + + ret_dct = { + 'pitch': pitch, + 'yaw': yaw, + 'roll': roll, + 't': t, + 'exp': exp, + 'scale': scale, + + 'kp': kp, # canonical keypoint + } + + return ret_dct + + +def convnextv2_tiny(**kwargs): + model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + return model diff --git a/liveportrait/modules/dense_motion.py b/liveportrait/modules/dense_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..0eec0c46345f8854b125a51eaee730bd4ee77f7d --- /dev/null +++ b/liveportrait/modules/dense_motion.py @@ -0,0 +1,104 @@ +# coding: utf-8 + +""" +The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving +""" + +from torch import nn +import torch.nn.functional as F +import torch +from .util import Hourglass, make_coordinate_grid, kp2gaussian + + +class DenseMotionNetwork(nn.Module): + def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True): + super(DenseMotionNetwork, self).__init__() + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G + + self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large + self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G + self.norm = nn.BatchNorm3d(compress, affine=True) + self.num_kp = num_kp + self.flag_estimate_occlusion_map = estimate_occlusion_map + + if self.flag_estimate_occlusion_map: + self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) + else: + self.occlusion = None + + def create_sparse_motions(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64) + identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3) + identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3) + coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3) + + k = coordinate_grid.shape[1] + + # NOTE: there lacks an one-order flow + driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) + + # adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3) + return sparse_motions + + def create_deformed_feature(self, feature, sparse_motions): + bs, _, d, h, w = feature.shape + feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) + feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) + sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False) + sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) + + return sparse_deformed + + def create_heatmap_representations(self, feature, kp_driving, kp_source): + spatial_size = feature.shape[3:] # (d=16, h=64, w=64) + gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) + gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w) + heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w) + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w) + return heatmap + + def forward(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64) + + feature = self.compress(feature) # (bs, 4, 16, 64, 64) + feature = self.norm(feature) # (bs, 4, 16, 64, 64) + feature = F.relu(feature) # (bs, 4, 16, 64, 64) + + out_dict = dict() + + # 1. deform 3d feature + sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3) + deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64) + + # 2. (bs, 1+num_kp, d, h, w) + heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w) + + input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64) + input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64) + + prediction = self.hourglass(input) + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64) + out_dict['mask'] = mask + mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) + deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place + deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) + + out_dict['deformation'] = deformation + + if self.flag_estimate_occlusion_map: + bs, _, d, h, w = prediction.shape + prediction_reshape = prediction.view(bs, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64 + out_dict['occlusion_map'] = occlusion_map + + return out_dict diff --git a/liveportrait/modules/motion_extractor.py b/liveportrait/modules/motion_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..b2982e53c52d9ec1e0bec0453cc05edb51a15d23 --- /dev/null +++ b/liveportrait/modules/motion_extractor.py @@ -0,0 +1,35 @@ +# coding: utf-8 + +""" +Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image +""" + +from torch import nn +import torch + +from .convnextv2 import convnextv2_tiny +from .util import filter_state_dict + +model_dict = { + 'convnextv2_tiny': convnextv2_tiny, +} + + +class MotionExtractor(nn.Module): + def __init__(self, **kwargs): + super(MotionExtractor, self).__init__() + + # default is convnextv2_base + backbone = kwargs.get('backbone', 'convnextv2_tiny') + self.detector = model_dict.get(backbone)(**kwargs) + + def load_pretrained(self, init_path: str): + if init_path not in (None, ''): + state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model'] + state_dict = filter_state_dict(state_dict, remove_name='head') + ret = self.detector.load_state_dict(state_dict, strict=False) + print(f'Load pretrained model from {init_path}, ret: {ret}') + + def forward(self, x): + out = self.detector(x) + return out diff --git a/liveportrait/modules/spade_generator.py b/liveportrait/modules/spade_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..147a9aed0c7707fe6ae3d59ce1a30154ef75afcc --- /dev/null +++ b/liveportrait/modules/spade_generator.py @@ -0,0 +1,59 @@ +# coding: utf-8 + +""" +Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image. +""" + +import torch +from torch import nn +import torch.nn.functional as F +from .util import SPADEResnetBlock + + +class SPADEDecoder(nn.Module): + def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2): + for i in range(num_down_blocks): + input_channels = min(max_features, block_expansion * (2 ** (i + 1))) + self.upscale = upscale + super().__init__() + norm_G = 'spadespectralinstance' + label_num_channels = input_channels # 256 + + self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels) + self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels) + self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels) + self.up = nn.Upsample(scale_factor=2) + + if self.upscale is None or self.upscale <= 1: + self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1) + else: + self.conv_img = nn.Sequential( + nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1), + nn.PixelShuffle(upscale_factor=2) + ) + + def forward(self, feature): + seg = feature # Bx256x64x64 + x = self.fc(feature) # Bx512x64x64 + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + + x = self.up(x) # Bx512x64x64 -> Bx512x128x128 + x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128 + x = self.up(x) # Bx256x128x128 -> Bx256x256x256 + x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW + x = torch.sigmoid(x) # Bx3xHxW + + return x \ No newline at end of file diff --git a/liveportrait/modules/stitching_retargeting_network.py b/liveportrait/modules/stitching_retargeting_network.py new file mode 100644 index 0000000000000000000000000000000000000000..5f50b7cf5a21cd71c70a7bbaaa4b6b68b4762ea3 --- /dev/null +++ b/liveportrait/modules/stitching_retargeting_network.py @@ -0,0 +1,38 @@ +# coding: utf-8 + +""" +Stitching module(S) and two retargeting modules(R) defined in the paper. + +- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in +the stitching region. + +- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially +when a person with small eyes drives a person with larger eyes. + +- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that +the lips are in a closed state, which facilitates better animation driving. +""" +from torch import nn + + +class StitchingRetargetingNetwork(nn.Module): + def __init__(self, input_size, hidden_sizes, output_size): + super(StitchingRetargetingNetwork, self).__init__() + layers = [] + for i in range(len(hidden_sizes)): + if i == 0: + layers.append(nn.Linear(input_size, hidden_sizes[i])) + else: + layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(hidden_sizes[-1], output_size)) + self.mlp = nn.Sequential(*layers) + + def initialize_weights_to_zero(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + return self.mlp(x) diff --git a/liveportrait/modules/util.py b/liveportrait/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f83980b24372bee38779ceeb3349fca91735e56e --- /dev/null +++ b/liveportrait/modules/util.py @@ -0,0 +1,441 @@ +# coding: utf-8 + +""" +This file defines various neural network modules and utility functions, including convolutional and residual blocks, +normalizations, and functions for spatial transformation and tensor manipulation. +""" + +from torch import nn +import torch.nn.functional as F +import torch +import torch.nn.utils.spectral_norm as spectral_norm +import math +import warnings + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp + + coordinate_grid = make_coordinate_grid(spatial_size, mean) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, ref, **kwargs): + d, h, w = spatial_size + x = torch.arange(w).type(ref.dtype).to(ref.device) + y = torch.arange(h).type(ref.dtype).to(ref.device) + z = torch.arange(d).type(ref.dtype).to(ref.device) + + # NOTE: must be right-down-in + x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right + y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom + z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ConvT2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1): + super(ConvT2d, self).__init__() + + self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride, + padding=padding, output_padding=output_padding) + self.norm = nn.InstanceNorm2d(out_features) + + def forward(self, x): + out = self.convT(x) + out = self.norm(out) + out = F.leaky_relu(out) + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) + self.norm1 = nn.BatchNorm3d(in_features, affine=True) + self.norm2 = nn.BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = nn.BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = nn.BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + + +class GRN(nn.Module): + """ GRN (Global Response Normalization) layer + """ + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def drop_path(x, drop_prob=0., training=False, scale_by_keep=True): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/liveportrait/modules/warping_network.py b/liveportrait/modules/warping_network.py new file mode 100644 index 0000000000000000000000000000000000000000..9191a197055a954272ee8ed86c5e34f3f33f9ad5 --- /dev/null +++ b/liveportrait/modules/warping_network.py @@ -0,0 +1,77 @@ +# coding: utf-8 + +""" +Warping field estimator(W) defined in the paper, which generates a warping field using the implicit +keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s. +""" + +from torch import nn +import torch.nn.functional as F +from .util import SameBlock2d +from .dense_motion import DenseMotionNetwork + + +class WarpingNetwork(nn.Module): + def __init__( + self, + num_kp, + block_expansion, + max_features, + num_down_blocks, + reshape_channel, + estimate_occlusion_map=False, + dense_motion_params=None, + **kwargs + ): + super(WarpingNetwork, self).__init__() + + self.upscale = kwargs.get('upscale', 1) + self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True) + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork( + num_kp=num_kp, + feature_channel=reshape_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params + ) + else: + self.dense_motion_network = None + + self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1) + + self.estimate_occlusion_map = estimate_occlusion_map + + def deform_input(self, inp, deformation): + return F.grid_sample(inp, deformation, align_corners=False) + + def forward(self, feature_3d, kp_driving, kp_source): + if self.dense_motion_network is not None: + # Feature warper, Transforming feature representation according to deformation and occlusion + dense_motion = self.dense_motion_network( + feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source + ) + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64 + else: + occlusion_map = None + + deformation = dense_motion['deformation'] # Bx16x64x64x3 + out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64 + + bs, c, d, h, w = out.shape # Bx32x16x64x64 + out = out.view(bs, c * d, h, w) # -> Bx512x64x64 + out = self.third(out) # -> Bx256x64x64 + out = self.fourth(out) # -> Bx256x64x64 + + if self.flag_use_occlusion_map and (occlusion_map is not None): + out = out * occlusion_map + + ret_dct = { + 'occlusion_map': occlusion_map, + 'deformation': deformation, + 'out': out, + } + + return ret_dct diff --git a/liveportrait/template_maker.py b/liveportrait/template_maker.py new file mode 100644 index 0000000000000000000000000000000000000000..7f3ce06201d6f9db98a299346a3324364196fad1 --- /dev/null +++ b/liveportrait/template_maker.py @@ -0,0 +1,65 @@ +# coding: utf-8 + +""" +Make video template +""" + +import os +import cv2 +import numpy as np +import pickle +from rich.progress import track +from .utils.cropper import Cropper + +from .utils.io import load_driving_info +from .utils.camera import get_rotation_matrix +from .utils.helper import mkdir, basename +from .utils.rprint import rlog as log +from .config.crop_config import CropConfig +from .config.inference_config import InferenceConfig +from .live_portrait_wrapper import LivePortraitWrapper + +class TemplateMaker: + + def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig): + self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg) + self.cropper = Cropper(crop_cfg=crop_cfg) + + def make_motion_template(self, video_fp: str, output_path: str, **kwargs): + """ make video template (.pkl format) + video_fp: driving video file path + output_path: where to save the pickle file + """ + + driving_rgb_lst = load_driving_info(video_fp) + driving_rgb_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] + driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst) + I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst) + + n_frames = I_d_lst.shape[0] + + templates = [] + + + for i in track(range(n_frames), description='Making templates...', total=n_frames): + I_d_i = I_d_lst[i] + x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i) + R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll']) + # collect s_d, R_d, δ_d and t_d for inference + template_dct = { + 'n_frames': n_frames, + 'frames_index': i, + } + template_dct['scale'] = x_d_i_info['scale'].cpu().numpy().astype(np.float32) + template_dct['R_d'] = R_d_i.cpu().numpy().astype(np.float32) + template_dct['exp'] = x_d_i_info['exp'].cpu().numpy().astype(np.float32) + template_dct['t'] = x_d_i_info['t'].cpu().numpy().astype(np.float32) + + templates.append(template_dct) + + mkdir(output_path) + # Save the dictionary as a pickle file + pickle_fp = os.path.join(output_path, f'{basename(video_fp)}.pkl') + with open(pickle_fp, 'wb') as f: + pickle.dump([templates, driving_lmk_lst], f) + log(f"Template saved at {pickle_fp}") diff --git a/liveportrait/utils/__init__.py b/liveportrait/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/liveportrait/utils/camera.py b/liveportrait/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..8bbfc90ac87e99caf09bee69982761fdde753527 --- /dev/null +++ b/liveportrait/utils/camera.py @@ -0,0 +1,75 @@ +# coding: utf-8 + +""" +functions for processing and transforming 3D facial keypoints +""" + +import numpy as np +import torch +import torch.nn.functional as F + +PI = np.pi + + +def headpose_pred_to_degree(pred): + """ + pred: (bs, 66) or (bs, 1) or others + """ + if pred.ndim > 1 and pred.shape[1] == 66: + # NOTE: note that the average is modified to 97.5 + device = pred.device + idx_tensor = [idx for idx in range(0, 66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred, dim=1) + degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5 + + return degree + + return pred + + +def get_rotation_matrix(pitch_, yaw_, roll_): + """ the input is in degree + """ + # calculate the rotation matrix: vps @ rot + + # transform to radian + pitch = pitch_ / 180 * PI + yaw = yaw_ / 180 * PI + roll = roll_ / 180 * PI + + device = pitch.device + + if pitch.ndim == 1: + pitch = pitch.unsqueeze(1) + if yaw.ndim == 1: + yaw = yaw.unsqueeze(1) + if roll.ndim == 1: + roll = roll.unsqueeze(1) + + # calculate the euler matrix + bs = pitch.shape[0] + ones = torch.ones([bs, 1]).to(device) + zeros = torch.zeros([bs, 1]).to(device) + x, y, z = pitch, yaw, roll + + rot_x = torch.cat([ + ones, zeros, zeros, + zeros, torch.cos(x), -torch.sin(x), + zeros, torch.sin(x), torch.cos(x) + ], dim=1).reshape([bs, 3, 3]) + + rot_y = torch.cat([ + torch.cos(y), zeros, torch.sin(y), + zeros, ones, zeros, + -torch.sin(y), zeros, torch.cos(y) + ], dim=1).reshape([bs, 3, 3]) + + rot_z = torch.cat([ + torch.cos(z), -torch.sin(z), zeros, + torch.sin(z), torch.cos(z), zeros, + zeros, zeros, ones + ], dim=1).reshape([bs, 3, 3]) + + rot = rot_z @ rot_y @ rot_x + return rot.permute(0, 2, 1) # transpose diff --git a/liveportrait/utils/crop.py b/liveportrait/utils/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..8f233639fae332c502623e9a0e695af69379dc57 --- /dev/null +++ b/liveportrait/utils/crop.py @@ -0,0 +1,412 @@ +# coding: utf-8 + +""" +cropping function and the related preprocess functions for cropping +""" + +import numpy as np +import os.path as osp +from math import sin, cos, acos, degrees +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread +from .rprint import rprint as print + +DTYPE = np.float32 +CV2_INTERP = cv2.INTER_LINEAR + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + +def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None): + """ conduct similarity or affine transformation to the image, do not do border operation! + img: + M: 2x3 matrix or 3x3 matrix + dsize: target shape (width, height) + """ + if isinstance(dsize, tuple) or isinstance(dsize, list): + _dsize = tuple(dsize) + else: + _dsize = (dsize, dsize) + + if borderMode is not None: + return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags, borderMode=borderMode, borderValue=(0, 0, 0)) + else: + return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags) + + +def _transform_pts(pts, M): + """ conduct similarity or affine transformation to the pts + pts: Nx2 ndarray + M: 2x3 matrix or 3x3 matrix + return: Nx2 + """ + return pts @ M[:2, :2].T + M[:2, 2] + + +def parse_pt2_from_pt101(pt101, use_lip=True): + """ + parsing the 2 points according to the 101 points, which cancels the roll + """ + # the former version use the eye center, but it is not robust, now use interpolation + pt_left_eye = np.mean(pt101[[39, 42, 45, 48]], axis=0) # left eye center + pt_right_eye = np.mean(pt101[[51, 54, 57, 60]], axis=0) # right eye center + + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt101[75] + pt101[81]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt106(pt106, use_lip=True): + """ + parsing the 2 points according to the 106 points, which cancels the roll + """ + pt_left_eye = np.mean(pt106[[33, 35, 40, 39]], axis=0) # left eye center + pt_right_eye = np.mean(pt106[[87, 89, 94, 93]], axis=0) # right eye center + + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt106[52] + pt106[61]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt203(pt203, use_lip=True): + """ + parsing the 2 points according to the 203 points, which cancels the roll + """ + pt_left_eye = np.mean(pt203[[0, 6, 12, 18]], axis=0) # left eye center + pt_right_eye = np.mean(pt203[[24, 30, 36, 42]], axis=0) # right eye center + if use_lip: + # use lip + pt_center_eye = (pt_left_eye + pt_right_eye) / 2 + pt_center_lip = (pt203[48] + pt203[66]) / 2 + pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0) + else: + pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0) + return pt2 + + +def parse_pt2_from_pt68(pt68, use_lip=True): + """ + parsing the 2 points according to the 68 points, which cancels the roll + """ + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55], dtype=np.int32) - 1 + if use_lip: + pt5 = np.stack([ + np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye + np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye + pt68[lm_idx[0], :], # nose + pt68[lm_idx[5], :], # lip + pt68[lm_idx[6], :] # lip + ], axis=0) + + pt2 = np.stack([ + (pt5[0] + pt5[1]) / 2, + (pt5[3] + pt5[4]) / 2 + ], axis=0) + else: + pt2 = np.stack([ + np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye + np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye + ], axis=0) + + return pt2 + + +def parse_pt2_from_pt5(pt5, use_lip=True): + """ + parsing the 2 points according to the 5 points, which cancels the roll + """ + if use_lip: + pt2 = np.stack([ + (pt5[0] + pt5[1]) / 2, + (pt5[3] + pt5[4]) / 2 + ], axis=0) + else: + pt2 = np.stack([ + pt5[0], + pt5[1] + ], axis=0) + return pt2 + + +def parse_pt2_from_pt_x(pts, use_lip=True): + if pts.shape[0] == 101: + pt2 = parse_pt2_from_pt101(pts, use_lip=use_lip) + elif pts.shape[0] == 106: + pt2 = parse_pt2_from_pt106(pts, use_lip=use_lip) + elif pts.shape[0] == 68: + pt2 = parse_pt2_from_pt68(pts, use_lip=use_lip) + elif pts.shape[0] == 5: + pt2 = parse_pt2_from_pt5(pts, use_lip=use_lip) + elif pts.shape[0] == 203: + pt2 = parse_pt2_from_pt203(pts, use_lip=use_lip) + elif pts.shape[0] > 101: + # take the first 101 points + pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip) + else: + raise Exception(f'Unknow shape: {pts.shape}') + + if not use_lip: + # NOTE: to compile with the latter code, need to rotate the pt2 90 degrees clockwise manually + v = pt2[1] - pt2[0] + pt2[1, 0] = pt2[0, 0] - v[1] + pt2[1, 1] = pt2[0, 1] + v[0] + + return pt2 + + +def parse_rect_from_landmark( + pts, + scale=1.5, + need_square=True, + vx_ratio=0, + vy_ratio=0, + use_deg_flag=False, + **kwargs +): + """parsing center, size, angle from 101/68/5/x landmarks + vx_ratio: the offset ratio along the pupil axis x-axis, multiplied by size + vy_ratio: the offset ratio along the pupil axis y-axis, multiplied by size, which is used to contain more forehead area + + judge with pts.shape + """ + pt2 = parse_pt2_from_pt_x(pts, use_lip=kwargs.get('use_lip', True)) + + uy = pt2[1] - pt2[0] + l = np.linalg.norm(uy) + if l <= 1e-3: + uy = np.array([0, 1], dtype=DTYPE) + else: + uy /= l + ux = np.array((uy[1], -uy[0]), dtype=DTYPE) + + # the rotation degree of the x-axis, the clockwise is positive, the counterclockwise is negative (image coordinate system) + # print(uy) + # print(ux) + angle = acos(ux[0]) + if ux[1] < 0: + angle = -angle + + # rotation matrix + M = np.array([ux, uy]) + + # calculate the size which contains the angle degree of the bbox, and the center + center0 = np.mean(pts, axis=0) + rpts = (pts - center0) @ M.T # (M @ P.T).T = P @ M.T + lt_pt = np.min(rpts, axis=0) + rb_pt = np.max(rpts, axis=0) + center1 = (lt_pt + rb_pt) / 2 + + size = rb_pt - lt_pt + if need_square: + m = max(size[0], size[1]) + size[0] = m + size[1] = m + + size *= scale # scale size + center = center0 + ux * center1[0] + uy * center1[1] # counterclockwise rotation, equivalent to M.T @ center1.T + center = center + ux * (vx_ratio * size) + uy * \ + (vy_ratio * size) # considering the offset in vx and vy direction + + if use_deg_flag: + angle = degrees(angle) + + return center, size, angle + + +def parse_bbox_from_landmark(pts, **kwargs): + center, size, angle = parse_rect_from_landmark(pts, **kwargs) + cx, cy = center + w, h = size + + # calculate the vertex positions before rotation + bbox = np.array([ + [cx-w/2, cy-h/2], # left, top + [cx+w/2, cy-h/2], + [cx+w/2, cy+h/2], # right, bottom + [cx-w/2, cy+h/2] + ], dtype=DTYPE) + + # construct rotation matrix + bbox_rot = bbox.copy() + R = np.array([ + [np.cos(angle), -np.sin(angle)], + [np.sin(angle), np.cos(angle)] + ], dtype=DTYPE) + + # calculate the relative position of each vertex from the rotation center, then rotate these positions, and finally add the coordinates of the rotation center + bbox_rot = (bbox_rot - center) @ R.T + center + + return { + 'center': center, # 2x1 + 'size': size, # scalar + 'angle': angle, # rad, counterclockwise + 'bbox': bbox, # 4x2 + 'bbox_rot': bbox_rot, # 4x2 + } + + +def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=False, **kwargs): + left, top, right, bot = bbox + if int(right - left) != int(bot - top): + print(f'right-left {right-left} != bot-top {bot-top}') + size = right - left + + src_center = np.array([(left + right) / 2, (top + bot) / 2], dtype=DTYPE) + tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) + + s = dsize / size # scale + if flag_rot and angle is not None: + costheta, sintheta = cos(angle), sin(angle) + cx, cy = src_center[0], src_center[1] # ori center + tcx, tcy = tgt_center[0], tgt_center[1] # target center + # need to infer + M_o2c = np.array( + [[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)], + [-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]], + dtype=DTYPE + ) + else: + M_o2c = np.array( + [[s, 0, tgt_center[0] - s * src_center[0]], + [0, s, tgt_center[1] - s * src_center[1]]], + dtype=DTYPE + ) + + if flag_rot and angle is None: + print('angle is None, but flag_rotate is True', style="bold yellow") + + img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None)) + + lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None + + M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)]) + M_c2o = np.linalg.inv(M_o2c) + + # cv2.imwrite('crop.jpg', img_crop) + + return { + 'img_crop': img_crop, + 'lmk_crop': lmk_crop, + 'M_o2c': M_o2c, + 'M_c2o': M_c2o, + } + + +def _estimate_similar_transform_from_pts( + pts, + dsize, + scale=1.5, + vx_ratio=0, + vy_ratio=-0.1, + flag_do_rot=True, + **kwargs +): + """ calculate the affine matrix of the cropped image from sparse points, the original image to the cropped image, the inverse is the cropped image to the original image + pts: landmark, 101 or 68 points or other points, Nx2 + scale: the larger scale factor, the smaller face ratio + vx_ratio: x shift + vy_ratio: y shift, the smaller the y shift, the lower the face region + rot_flag: if it is true, conduct correction + """ + center, size, angle = parse_rect_from_landmark( + pts, scale=scale, vx_ratio=vx_ratio, vy_ratio=vy_ratio, + use_lip=kwargs.get('use_lip', True) + ) + + s = dsize / size[0] # scale + tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) # center of dsize + + if flag_do_rot: + costheta, sintheta = cos(angle), sin(angle) + cx, cy = center[0], center[1] # ori center + tcx, tcy = tgt_center[0], tgt_center[1] # target center + # need to infer + M_INV = np.array( + [[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)], + [-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]], + dtype=DTYPE + ) + else: + M_INV = np.array( + [[s, 0, tgt_center[0] - s * center[0]], + [0, s, tgt_center[1] - s * center[1]]], + dtype=DTYPE + ) + + M_INV_H = np.vstack([M_INV, np.array([0, 0, 1])]) + M = np.linalg.inv(M_INV_H) + + # M_INV is from the original image to the cropped image, M is from the cropped image to the original image + return M_INV, M[:2, ...] + + +def crop_image(img, pts: np.ndarray, **kwargs): + dsize = kwargs.get('dsize', 224) + scale = kwargs.get('scale', 1.5) # 1.5 | 1.6 + vy_ratio = kwargs.get('vy_ratio', -0.1) # -0.0625 | -0.1 + + M_INV, _ = _estimate_similar_transform_from_pts( + pts, + dsize=dsize, + scale=scale, + vy_ratio=vy_ratio, + flag_do_rot=kwargs.get('flag_do_rot', True), + ) + + if img is None: + M_INV_H = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)]) + M = np.linalg.inv(M_INV_H) + ret_dct = { + 'M': M[:2, ...], # from the original image to the cropped image + 'M_o2c': M[:2, ...], # from the cropped image to the original image + 'img_crop': None, + 'pt_crop': None, + } + return ret_dct + + img_crop = _transform_img(img, M_INV, dsize) # origin to crop + pt_crop = _transform_pts(pts, M_INV) + + M_o2c = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)]) + M_c2o = np.linalg.inv(M_o2c) + + ret_dct = { + 'M_o2c': M_o2c, # from the original image to the cropped image 3x3 + 'M_c2o': M_c2o, # from the cropped image to the original image 3x3 + 'img_crop': img_crop, # the cropped image + 'pt_crop': pt_crop, # the landmarks of the cropped image + } + + return ret_dct + +def average_bbox_lst(bbox_lst): + if len(bbox_lst) == 0: + return None + bbox_arr = np.array(bbox_lst) + return np.mean(bbox_arr, axis=0).tolist() + +def prepare_paste_back(mask_crop, crop_M_c2o, dsize): + """prepare mask for later image paste back + """ + if mask_crop is None: + mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR) + mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize) + mask_ori = mask_ori.astype(np.float32) / 255. + return mask_ori + +def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori): + """paste back the image + """ + dsize = (rgb_ori.shape[1], rgb_ori.shape[0]) + result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize) + result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8) + return result \ No newline at end of file diff --git a/liveportrait/utils/cropper.py b/liveportrait/utils/cropper.py new file mode 100644 index 0000000000000000000000000000000000000000..7800ed83873eee08cafa17741cfe47d7b2a6af3b --- /dev/null +++ b/liveportrait/utils/cropper.py @@ -0,0 +1,150 @@ +# coding: utf-8 + +import gradio as gr +import numpy as np +import os +import os.path as osp +from typing import List, Union, Tuple +from dataclasses import dataclass, field +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) + +from .landmark_runner import LandmarkRunner +from .face_analysis_diy import FaceAnalysisDIY +from .helper import prefix +from .crop import crop_image, crop_image_by_bbox, parse_bbox_from_landmark, average_bbox_lst +from .timer import Timer +from .rprint import rlog as log +from .io import load_image_rgb +from .video import VideoWriter, get_fps, change_video_fps + +import os + +DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') +MODELS_DIR = os.path.join(DATA_ROOT, "models") + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +@dataclass +class Trajectory: + start: int = -1 # 起始帧 闭区间 + end: int = -1 # 结束帧 闭区间 + lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list + bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list + frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list + frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list + + +class Cropper(object): + def __init__(self, **kwargs) -> None: + device_id = kwargs.get('device_id', 0) + self.landmark_runner = LandmarkRunner( + ckpt_path=make_abs_path(os.path.join(MODELS_DIR, "liveportrait", "landmark.onnx")), + onnx_provider='cuda', + device_id=device_id + ) + self.landmark_runner.warmup() + + self.face_analysis_wrapper = FaceAnalysisDIY( + name='buffalo_l', + root=make_abs_path(os.path.join(MODELS_DIR, "insightface")), + providers=["CUDAExecutionProvider"] + ) + self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512)) + self.face_analysis_wrapper.warmup() + + self.crop_cfg = kwargs.get('crop_cfg', None) + + def update_config(self, user_args): + for k, v in user_args.items(): + if hasattr(self.crop_cfg, k): + setattr(self.crop_cfg, k, v) + + def crop_single_image(self, obj, **kwargs): + direction = kwargs.get('direction', 'large-small') + + # crop and align a single image + if isinstance(obj, str): + img_rgb = load_image_rgb(obj) + elif isinstance(obj, np.ndarray): + img_rgb = obj + + src_face = self.face_analysis_wrapper.get( + img_rgb, + flag_do_landmark_2d_106=True, + direction=direction + ) + + if len(src_face) == 0: + log('No face detected in the source image.') + raise gr.Error("No face detected in the source image 💥!", duration=5) + raise Exception("No face detected in the source image!") + elif len(src_face) > 1: + log(f'More than one face detected in the image, only pick one face by rule {direction}.') + + src_face = src_face[0] + pts = src_face.landmark_2d_106 + + # crop the face + ret_dct = crop_image( + img_rgb, # ndarray + pts, # 106x2 or Nx2 + dsize=kwargs.get('dsize', 512), + scale=kwargs.get('scale', 2.3), + vy_ratio=kwargs.get('vy_ratio', -0.15), + ) + # update a 256x256 version for network input or else + ret_dct['img_crop_256x256'] = cv2.resize(ret_dct['img_crop'], (256, 256), interpolation=cv2.INTER_AREA) + ret_dct['pt_crop_256x256'] = ret_dct['pt_crop'] * 256 / kwargs.get('dsize', 512) + + recon_ret = self.landmark_runner.run(img_rgb, pts) + lmk = recon_ret['pts'] + ret_dct['lmk_crop'] = lmk + + return ret_dct + + def get_retargeting_lmk_info(self, driving_rgb_lst): + # TODO: implement a tracking-based version + driving_lmk_lst = [] + for driving_image in driving_rgb_lst: + ret_dct = self.crop_single_image(driving_image) + driving_lmk_lst.append(ret_dct['lmk_crop']) + return driving_lmk_lst + + def make_video_clip(self, driving_rgb_lst, output_path, output_fps=30, **kwargs): + trajectory = Trajectory() + direction = kwargs.get('direction', 'large-small') + for idx, driving_image in enumerate(driving_rgb_lst): + if idx == 0 or trajectory.start == -1: + src_face = self.face_analysis_wrapper.get( + driving_image, + flag_do_landmark_2d_106=True, + direction=direction + ) + if len(src_face) == 0: + # No face detected in the driving_image + continue + elif len(src_face) > 1: + log(f'More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.') + src_face = src_face[0] + pts = src_face.landmark_2d_106 + lmk_203 = self.landmark_runner(driving_image, pts)['pts'] + trajectory.start, trajectory.end = idx, idx + else: + lmk_203 = self.face_recon_wrapper(driving_image, trajectory.lmk_lst[-1])['pts'] + trajectory.end = idx + + trajectory.lmk_lst.append(lmk_203) + ret_bbox = parse_bbox_from_landmark(lmk_203, scale=self.crop_cfg.globalscale, vy_ratio=elf.crop_cfg.vy_ratio)['bbox'] + bbox = [ret_bbox[0, 0], ret_bbox[0, 1], ret_bbox[2, 0], ret_bbox[2, 1]] # 4, + trajectory.bbox_lst.append(bbox) # bbox + trajectory.frame_rgb_lst.append(driving_image) + + global_bbox = average_bbox_lst(trajectory.bbox_lst) + for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)): + ret_dct = crop_image_by_bbox( + frame_rgb, global_bbox, lmk=lmk, + dsize=self.video_crop_cfg.dsize, flag_rot=self.video_crop_cfg.flag_rot, borderValue=self.video_crop_cfg.borderValue + ) + frame_rgb_crop = ret_dct['img_crop'] diff --git a/liveportrait/utils/dependencies/insightface/__init__.py b/liveportrait/utils/dependencies/insightface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1680083da47850b31da10803c7d255e67dda619a --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/__init__.py @@ -0,0 +1,20 @@ +# coding: utf-8 +# pylint: disable=wrong-import-position +"""InsightFace: A Face Analysis Toolkit.""" +from __future__ import absolute_import + +try: + #import mxnet as mx + import onnxruntime +except ImportError: + raise ImportError( + "Unable to import dependency onnxruntime. " + ) + +__version__ = '0.7.3' + +from . import model_zoo +from . import utils +from . import app +from . import data + diff --git a/liveportrait/utils/dependencies/insightface/app/__init__.py b/liveportrait/utils/dependencies/insightface/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc574616885290489798bac5c682e7aaa65a5dad --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/app/__init__.py @@ -0,0 +1 @@ +from .face_analysis import * diff --git a/liveportrait/utils/dependencies/insightface/app/common.py b/liveportrait/utils/dependencies/insightface/app/common.py new file mode 100644 index 0000000000000000000000000000000000000000..82ca987aeede35510b3aef72b4edf2390ad84e65 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/app/common.py @@ -0,0 +1,49 @@ +import numpy as np +from numpy.linalg import norm as l2norm +#from easydict import EasyDict + +class Face(dict): + + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + #for k in self.__class__.__dict__.keys(): + # if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + # setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(Face, self).__setattr__(name, value) + super(Face, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def __getattr__(self, name): + return None + + @property + def embedding_norm(self): + if self.embedding is None: + return None + return l2norm(self.embedding) + + @property + def normed_embedding(self): + if self.embedding is None: + return None + return self.embedding / self.embedding_norm + + @property + def sex(self): + if self.gender is None: + return None + return 'M' if self.gender==1 else 'F' diff --git a/liveportrait/utils/dependencies/insightface/app/face_analysis.py b/liveportrait/utils/dependencies/insightface/app/face_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5128b3f5e02c2c19e7df195cc1c1e7fcf36c4d --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/app/face_analysis.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + + +from __future__ import division + +import glob +import os.path as osp + +import numpy as np +import onnxruntime +from numpy.linalg import norm + +from ..model_zoo import model_zoo +from ..utils import ensure_available +from .common import Face + + +DEFAULT_MP_NAME = 'buffalo_l' +__all__ = ['FaceAnalysis'] + +class FaceAnalysis: + def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs): + onnxruntime.set_default_logger_severity(3) + self.models = {} + self.model_dir = ensure_available('models', name, root=root) + onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx')) + onnx_files = sorted(onnx_files) + for onnx_file in onnx_files: + model = model_zoo.get_model(onnx_file, **kwargs) + if model is None: + print('model not recognized:', onnx_file) + elif allowed_modules is not None and model.taskname not in allowed_modules: + print('model ignore:', onnx_file, model.taskname) + del model + elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules): + # print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std) + self.models[model.taskname] = model + else: + print('duplicated model task type, ignore:', onnx_file, model.taskname) + del model + assert 'detection' in self.models + self.det_model = self.models['detection'] + + + def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): + self.det_thresh = det_thresh + assert det_size is not None + # print('set det-size:', det_size) + self.det_size = det_size + for taskname, model in self.models.items(): + if taskname=='detection': + model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh) + else: + model.prepare(ctx_id) + + def get(self, img, max_num=0): + bboxes, kpss = self.det_model.detect(img, + max_num=max_num, + metric='default') + if bboxes.shape[0] == 0: + return [] + ret = [] + for i in range(bboxes.shape[0]): + bbox = bboxes[i, 0:4] + det_score = bboxes[i, 4] + kps = None + if kpss is not None: + kps = kpss[i] + face = Face(bbox=bbox, kps=kps, det_score=det_score) + for taskname, model in self.models.items(): + if taskname=='detection': + continue + model.get(img, face) + ret.append(face) + return ret + + def draw_on(self, img, faces): + import cv2 + dimg = img.copy() + for i in range(len(faces)): + face = faces[i] + box = face.bbox.astype(np.int) + color = (0, 0, 255) + cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2) + if face.kps is not None: + kps = face.kps.astype(np.int) + #print(landmark.shape) + for l in range(kps.shape[0]): + color = (0, 0, 255) + if l == 0 or l == 3: + color = (0, 255, 0) + cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color, + 2) + if face.gender is not None and face.age is not None: + cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1) + + #for key, value in face.items(): + # if key.startswith('landmark_3d'): + # print(key, value.shape) + # print(value[0:10,:]) + # lmk = np.round(value).astype(np.int) + # for l in range(lmk.shape[0]): + # color = (255, 0, 0) + # cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color, + # 2) + return dimg diff --git a/liveportrait/utils/dependencies/insightface/data/__init__.py b/liveportrait/utils/dependencies/insightface/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..665c59ec99b6ebf12822015e0350969c7903e243 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/__init__.py @@ -0,0 +1,2 @@ +from .image import get_image +from .pickle_object import get_object diff --git a/liveportrait/utils/dependencies/insightface/data/image.py b/liveportrait/utils/dependencies/insightface/data/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6d32c4bcb1b13d33bcb0d840cf7b8c08d183b3ea --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/image.py @@ -0,0 +1,27 @@ +import cv2 +import os +import os.path as osp +from pathlib import Path + +class ImageCache: + data = {} + +def get_image(name, to_rgb=False): + key = (name, to_rgb) + if key in ImageCache.data: + return ImageCache.data[key] + images_dir = osp.join(Path(__file__).parent.absolute(), 'images') + ext_names = ['.jpg', '.png', '.jpeg'] + image_file = None + for ext_name in ext_names: + _image_file = osp.join(images_dir, "%s%s"%(name, ext_name)) + if osp.exists(_image_file): + image_file = _image_file + break + assert image_file is not None, '%s not found'%name + img = cv2.imread(image_file) + if to_rgb: + img = img[:,:,::-1] + ImageCache.data[key] = img + return img + diff --git a/liveportrait/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png b/liveportrait/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png new file mode 100644 index 0000000000000000000000000000000000000000..83387c3572e335441b617d0ba1f6a7f897b87f37 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/images/Tom_Hanks_54745.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8545da294e8c7c79911169c3915fed8528f1960cd0ed99b92453788ca4275083 +size 12123 diff --git a/liveportrait/utils/dependencies/insightface/data/images/mask_black.jpg b/liveportrait/utils/dependencies/insightface/data/images/mask_black.jpg new file mode 100644 index 0000000000000000000000000000000000000000..68233a29abd20c692a151143e22abcd62461ce2b --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/images/mask_black.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d81c40c77a0e38717dc7e1d4ca994e2bc58efa804465ff8d09d915daeddf2c83 +size 21308 diff --git a/liveportrait/utils/dependencies/insightface/data/images/mask_blue.jpg b/liveportrait/utils/dependencies/insightface/data/images/mask_blue.jpg new file mode 100644 index 0000000000000000000000000000000000000000..37a029125e7a6257d59cf2201be4b3409564d33e --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/images/mask_blue.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a55ec2b181de4143af27c441852ebb4bbd364fb0e0b9c46bb17631c59bc4c840 +size 44728 diff --git a/liveportrait/utils/dependencies/insightface/data/images/mask_green.jpg b/liveportrait/utils/dependencies/insightface/data/images/mask_green.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ed0111375d6c189f2a071ec6fcbcbdb8607b4fb --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/images/mask_green.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79526b5a31ef75cb7fe51227abc69d589e4b7ba39dbbf51e79147da577a3f154 +size 6116 diff --git a/liveportrait/utils/dependencies/insightface/data/images/mask_white.jpg b/liveportrait/utils/dependencies/insightface/data/images/mask_white.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e6c92ec3c5dfadc36e8de5bc1c91572d602f2ce3 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/images/mask_white.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3adee6ba6680d26197bbea95deb9d335252d01c2c51ead23e01af9784c310aa6 +size 78905 diff --git a/liveportrait/utils/dependencies/insightface/data/images/t1.jpg b/liveportrait/utils/dependencies/insightface/data/images/t1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8fd6427a177bd01650c0150e9d02457c3a5dcddd --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/images/t1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47f682e945b659f93a9e490b9c9c4a2a864abe64dace9e1a2893845ddfd69489 +size 128824 diff --git a/liveportrait/utils/dependencies/insightface/data/objects/meanshape_68.pkl b/liveportrait/utils/dependencies/insightface/data/objects/meanshape_68.pkl new file mode 100644 index 0000000000000000000000000000000000000000..2af4d76020057e90528756c62d8833aaf0e0a5c8 Binary files /dev/null and b/liveportrait/utils/dependencies/insightface/data/objects/meanshape_68.pkl differ diff --git a/liveportrait/utils/dependencies/insightface/data/pickle_object.py b/liveportrait/utils/dependencies/insightface/data/pickle_object.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd87030ea15e1d01af1cd4cff1be2bc54cc82dd --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/pickle_object.py @@ -0,0 +1,17 @@ +import cv2 +import os +import os.path as osp +from pathlib import Path +import pickle + +def get_object(name): + objects_dir = osp.join(Path(__file__).parent.absolute(), 'objects') + if not name.endswith('.pkl'): + name = name+".pkl" + filepath = osp.join(objects_dir, name) + if not osp.exists(filepath): + return None + with open(filepath, 'rb') as f: + obj = pickle.load(f) + return obj + diff --git a/liveportrait/utils/dependencies/insightface/data/rec_builder.py b/liveportrait/utils/dependencies/insightface/data/rec_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..e02abc969da2f882639326f5bad3c7e8d08c1fde --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/data/rec_builder.py @@ -0,0 +1,71 @@ +import pickle +import numpy as np +import os +import os.path as osp +import sys +import mxnet as mx + + +class RecBuilder(): + def __init__(self, path, image_size=(112, 112)): + self.path = path + self.image_size = image_size + self.widx = 0 + self.wlabel = 0 + self.max_label = -1 + assert not osp.exists(path), '%s exists' % path + os.makedirs(path) + self.writer = mx.recordio.MXIndexedRecordIO(os.path.join(path, 'train.idx'), + os.path.join(path, 'train.rec'), + 'w') + self.meta = [] + + def add(self, imgs): + #!!! img should be BGR!!!! + #assert label >= 0 + #assert label > self.last_label + assert len(imgs) > 0 + label = self.wlabel + for img in imgs: + idx = self.widx + image_meta = {'image_index': idx, 'image_classes': [label]} + header = mx.recordio.IRHeader(0, label, idx, 0) + if isinstance(img, np.ndarray): + s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') + else: + s = mx.recordio.pack(header, img) + self.writer.write_idx(idx, s) + self.meta.append(image_meta) + self.widx += 1 + self.max_label = label + self.wlabel += 1 + + + def add_image(self, img, label): + #!!! img should be BGR!!!! + #assert label >= 0 + #assert label > self.last_label + idx = self.widx + header = mx.recordio.IRHeader(0, label, idx, 0) + if isinstance(label, list): + idlabel = label[0] + else: + idlabel = label + image_meta = {'image_index': idx, 'image_classes': [idlabel]} + if isinstance(img, np.ndarray): + s = mx.recordio.pack_img(header,img,quality=95,img_fmt='.jpg') + else: + s = mx.recordio.pack(header, img) + self.writer.write_idx(idx, s) + self.meta.append(image_meta) + self.widx += 1 + self.max_label = max(self.max_label, idlabel) + + def close(self): + with open(osp.join(self.path, 'train.meta'), 'wb') as pfile: + pickle.dump(self.meta, pfile, protocol=pickle.HIGHEST_PROTOCOL) + print('stat:', self.widx, self.wlabel) + with open(os.path.join(self.path, 'property'), 'w') as f: + f.write("%d,%d,%d\n" % (self.max_label+1, self.image_size[0], self.image_size[1])) + f.write("%d\n" % (self.widx)) + diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/__init__.py b/liveportrait/utils/dependencies/insightface/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..225623d6142c968b4040f391039bfab88bdd1b2a --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/__init__.py @@ -0,0 +1,6 @@ +from .model_zoo import get_model +from .arcface_onnx import ArcFaceONNX +from .retinaface import RetinaFace +from .scrfd import SCRFD +from .landmark import Landmark +from .attribute import Attribute diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/arcface_onnx.py b/liveportrait/utils/dependencies/insightface/model_zoo/arcface_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..b537ce2ee15d4a1834d54e185f34e336aab30a77 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/arcface_onnx.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + +from __future__ import division +import numpy as np +import cv2 +import onnx +import onnxruntime +from ..utils import face_align + +__all__ = [ + 'ArcFaceONNX', +] + + +class ArcFaceONNX: + def __init__(self, model_file=None, session=None): + assert model_file is not None + self.model_file = model_file + self.session = session + self.taskname = 'recognition' + find_sub = False + find_mul = False + model = onnx.load(self.model_file) + graph = model.graph + for nid, node in enumerate(graph.node[:8]): + #print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul'): + find_mul = True + if find_sub and find_mul: + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 127.5 + self.input_mean = input_mean + self.input_std = input_std + #print('input mean and std:', self.input_mean, self.input_std) + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + input_name = input_cfg.name + self.input_size = tuple(input_shape[2:4][::-1]) + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.input_name = input_name + self.output_names = output_names + assert len(self.output_names)==1 + self.output_shape = outputs[0].shape + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + + def get(self, img, face): + aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) + face.embedding = self.get_feat(aimg).flatten() + return face.embedding + + def compute_sim(self, feat1, feat2): + from numpy.linalg import norm + feat1 = feat1.ravel() + feat2 = feat2.ravel() + sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) + return sim + + def get_feat(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.input_size + + blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + def forward(self, batch_data): + blob = (batch_data - self.input_mean) / self.input_std + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/attribute.py b/liveportrait/utils/dependencies/insightface/model_zoo/attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..40c34de3f0995499448cf5779004cc1e5f3564fb --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/attribute.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-06-19 +# @Function : + +from __future__ import division +import numpy as np +import cv2 +import onnx +import onnxruntime +from ..utils import face_align + +__all__ = [ + 'Attribute', +] + + +class Attribute: + def __init__(self, model_file=None, session=None): + assert model_file is not None + self.model_file = model_file + self.session = session + find_sub = False + find_mul = False + model = onnx.load(self.model_file) + graph = model.graph + for nid, node in enumerate(graph.node[:8]): + #print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul'): + find_mul = True + if nid<3 and node.name=='bn_data': + find_sub = True + find_mul = True + if find_sub and find_mul: + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 128.0 + self.input_mean = input_mean + self.input_std = input_std + #print('input mean and std:', model_file, self.input_mean, self.input_std) + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + input_name = input_cfg.name + self.input_size = tuple(input_shape[2:4][::-1]) + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.input_name = input_name + self.output_names = output_names + assert len(self.output_names)==1 + output_shape = outputs[0].shape + #print('init output_shape:', output_shape) + if output_shape[1]==3: + self.taskname = 'genderage' + else: + self.taskname = 'attribute_%d'%output_shape[1] + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + + def get(self, img, face): + bbox = face.bbox + w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) + center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 + rotate = 0 + _scale = self.input_size[0] / (max(w, h)*1.5) + #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate) + aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate) + input_size = tuple(aimg.shape[0:2][::-1]) + #assert input_size==self.input_size + blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] + if self.taskname=='genderage': + assert len(pred)==3 + gender = np.argmax(pred[:2]) + age = int(np.round(pred[2]*100)) + face['gender'] = gender + face['age'] = age + return gender, age + else: + return pred + + diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/inswapper.py b/liveportrait/utils/dependencies/insightface/model_zoo/inswapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f321c627ee66cceddcab98b561b997441dd4f768 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/inswapper.py @@ -0,0 +1,114 @@ +import time +import numpy as np +import onnxruntime +import cv2 +import onnx +from onnx import numpy_helper +from ..utils import face_align + + + + +class INSwapper(): + def __init__(self, model_file=None, session=None): + self.model_file = model_file + self.session = session + model = onnx.load(self.model_file) + graph = model.graph + self.emap = numpy_helper.to_array(graph.initializer[-1]) + self.input_mean = 0.0 + self.input_std = 255.0 + #print('input mean and std:', model_file, self.input_mean, self.input_std) + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + inputs = self.session.get_inputs() + self.input_names = [] + for inp in inputs: + self.input_names.append(inp.name) + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.output_names = output_names + assert len(self.output_names)==1 + output_shape = outputs[0].shape + input_cfg = inputs[0] + input_shape = input_cfg.shape + self.input_shape = input_shape + # print('inswapper-shape:', self.input_shape) + self.input_size = tuple(input_shape[2:4][::-1]) + + def forward(self, img, latent): + img = (img - self.input_mean) / self.input_std + pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0] + return pred + + def get(self, img, target_face, source_face, paste_back=True): + face_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8) + cv2.fillPoly(face_mask, np.array([target_face.landmark_2d_106[[1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,51,49,48,43]].astype('int64')]), 1) + aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) + blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + latent = source_face.normed_embedding.reshape((1,-1)) + latent = np.dot(latent, self.emap) + latent /= np.linalg.norm(latent) + pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0] + #print(latent.shape, latent.dtype, pred.shape) + img_fake = pred.transpose((0,2,3,1))[0] + bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] + if not paste_back: + return bgr_fake, M + else: + target_img = img + fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32) + fake_diff = np.abs(fake_diff).mean(axis=2) + fake_diff[:2,:] = 0 + fake_diff[-2:,:] = 0 + fake_diff[:,:2] = 0 + fake_diff[:,-2:] = 0 + IM = cv2.invertAffineTransform(M) + img_white = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32) + bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) + img_white[img_white>20] = 255 + fthresh = 10 + fake_diff[fake_diff=fthresh] = 255 + img_mask = img_white + mask_h_inds, mask_w_inds = np.where(img_mask==255) + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) + mask_size = int(np.sqrt(mask_h*mask_w)) + k = max(mask_size//10, 10) + #k = max(mask_size//20, 6) + #k = 6 + kernel = np.ones((k,k),np.uint8) + img_mask = cv2.erode(img_mask,kernel,iterations = 1) + kernel = np.ones((2,2),np.uint8) + fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1) + + face_mask = cv2.erode(face_mask,np.ones((11,11),np.uint8),iterations = 1) + fake_diff[face_mask==1] = 255 + + k = max(mask_size//20, 5) + #k = 3 + #k = 3 + kernel_size = (k, k) + blur_size = tuple(2*i+1 for i in kernel_size) + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) + k = 5 + kernel_size = (k, k) + blur_size = tuple(2*i+1 for i in kernel_size) + fake_diff = cv2.blur(fake_diff, (11,11), 0) + ##fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0) + # print('blur_size: ', blur_size) + # fake_diff = cv2.blur(fake_diff, (21, 21), 0) # blur_size + img_mask /= 255 + fake_diff /= 255 + # img_mask = fake_diff + img_mask = img_mask*fake_diff + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32) + fake_merged = fake_merged.astype(np.uint8) + return fake_merged diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/landmark.py b/liveportrait/utils/dependencies/insightface/model_zoo/landmark.py new file mode 100644 index 0000000000000000000000000000000000000000..598b4b29a2d0674d8bb25b681f921c61460d101c --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/landmark.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + +from __future__ import division +import numpy as np +import cv2 +import onnx +import onnxruntime +from ..utils import face_align +from ..utils import transform +from ..data import get_object + +__all__ = [ + 'Landmark', +] + + +class Landmark: + def __init__(self, model_file=None, session=None): + assert model_file is not None + self.model_file = model_file + self.session = session + find_sub = False + find_mul = False + model = onnx.load(self.model_file) + graph = model.graph + for nid, node in enumerate(graph.node[:8]): + #print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul'): + find_mul = True + if nid<3 and node.name=='bn_data': + find_sub = True + find_mul = True + if find_sub and find_mul: + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 128.0 + self.input_mean = input_mean + self.input_std = input_std + #print('input mean and std:', model_file, self.input_mean, self.input_std) + if self.session is None: + self.session = onnxruntime.InferenceSession(self.model_file, None) + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + input_name = input_cfg.name + self.input_size = tuple(input_shape[2:4][::-1]) + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for out in outputs: + output_names.append(out.name) + self.input_name = input_name + self.output_names = output_names + assert len(self.output_names)==1 + output_shape = outputs[0].shape + self.require_pose = False + #print('init output_shape:', output_shape) + if output_shape[1]==3309: + self.lmk_dim = 3 + self.lmk_num = 68 + self.mean_lmk = get_object('meanshape_68.pkl') + self.require_pose = True + else: + self.lmk_dim = 2 + self.lmk_num = output_shape[1]//self.lmk_dim + self.taskname = 'landmark_%dd_%d'%(self.lmk_dim, self.lmk_num) + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + + def get(self, img, face): + bbox = face.bbox + w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1]) + center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2 + rotate = 0 + _scale = self.input_size[0] / (max(w, h)*1.5) + #print('param:', img.shape, bbox, center, self.input_size, _scale, rotate) + aimg, M = face_align.transform(img, center, self.input_size[0], _scale, rotate) + input_size = tuple(aimg.shape[0:2][::-1]) + #assert input_size==self.input_size + blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + pred = self.session.run(self.output_names, {self.input_name : blob})[0][0] + if pred.shape[0] >= 3000: + pred = pred.reshape((-1, 3)) + else: + pred = pred.reshape((-1, 2)) + if self.lmk_num < pred.shape[0]: + pred = pred[self.lmk_num*-1:,:] + pred[:, 0:2] += 1 + pred[:, 0:2] *= (self.input_size[0] // 2) + if pred.shape[1] == 3: + pred[:, 2] *= (self.input_size[0] // 2) + + IM = cv2.invertAffineTransform(M) + pred = face_align.trans_points(pred, IM) + face[self.taskname] = pred + if self.require_pose: + P = transform.estimate_affine_matrix_3d23d(self.mean_lmk, pred) + s, R, t = transform.P2sRt(P) + rx, ry, rz = transform.matrix2angle(R) + pose = np.array( [rx, ry, rz], dtype=np.float32 ) + face['pose'] = pose #pitch, yaw, roll + return pred + + diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/model_store.py b/liveportrait/utils/dependencies/insightface/model_zoo/model_store.py new file mode 100644 index 0000000000000000000000000000000000000000..50bb85d314f5b7a0ea8211d2cd21186e32791592 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/model_store.py @@ -0,0 +1,103 @@ +""" +This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_store.py +""" +from __future__ import print_function + +__all__ = ['get_model_file'] +import os +import zipfile +import glob + +from ..utils import download, check_sha1 + +_model_sha1 = { + name: checksum + for checksum, name in [ + ('95be21b58e29e9c1237f229dae534bd854009ce0', 'arcface_r100_v1'), + ('', 'arcface_mfn_v1'), + ('39fd1e087a2a2ed70a154ac01fecaa86c315d01b', 'retinaface_r50_v1'), + ('2c9de8116d1f448fd1d4661f90308faae34c990a', 'retinaface_mnet025_v1'), + ('0db1d07921d005e6c9a5b38e059452fc5645e5a4', 'retinaface_mnet025_v2'), + ('7dd8111652b7aac2490c5dcddeb268e53ac643e6', 'genderage_v1'), + ] +} + +base_repo_url = 'https://insightface.ai/files/' +_url_format = '{repo_url}models/{file_name}.zip' + + +def short_hash(name): + if name not in _model_sha1: + raise ValueError( + 'Pretrained model for {name} is not available.'.format(name=name)) + return _model_sha1[name][:8] + + +def find_params_file(dir_path): + if not os.path.exists(dir_path): + return None + paths = glob.glob("%s/*.params" % dir_path) + if len(paths) == 0: + return None + paths = sorted(paths) + return paths[-1] + + +def get_model_file(name, root=os.path.join('~', '.insightface', 'models')): + r"""Return location for the pretrained on local file system. + + This function will download from online model zoo when model cannot be found or has mismatch. + The root directory will be created if it doesn't exist. + + Parameters + ---------- + name : str + Name of the model. + root : str, default '~/.mxnet/models' + Location for keeping the model parameters. + + Returns + ------- + file_path + Path to the requested pretrained model file. + """ + + file_name = name + root = os.path.expanduser(root) + dir_path = os.path.join(root, name) + file_path = find_params_file(dir_path) + #file_path = os.path.join(root, file_name + '.params') + sha1_hash = _model_sha1[name] + if file_path is not None: + if check_sha1(file_path, sha1_hash): + return file_path + else: + print( + 'Mismatch in the content of model file detected. Downloading again.' + ) + else: + print('Model file is not found. Downloading.') + + if not os.path.exists(root): + os.makedirs(root) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + zip_file_path = os.path.join(root, file_name + '.zip') + repo_url = base_repo_url + if repo_url[-1] != '/': + repo_url = repo_url + '/' + download(_url_format.format(repo_url=repo_url, file_name=file_name), + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(dir_path) + os.remove(zip_file_path) + file_path = find_params_file(dir_path) + + if check_sha1(file_path, sha1_hash): + return file_path + else: + raise ValueError( + 'Downloaded file has different hash. Please try again.') + diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/model_zoo.py b/liveportrait/utils/dependencies/insightface/model_zoo/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..d8366e2a5461d5d6688f23e102a40944330084a4 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/model_zoo.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + +import os +import os.path as osp +import glob +import onnxruntime +from .arcface_onnx import * +from .retinaface import * +#from .scrfd import * +from .landmark import * +from .attribute import Attribute +from .inswapper import INSwapper +from ..utils import download_onnx + +__all__ = ['get_model'] + + +class PickableInferenceSession(onnxruntime.InferenceSession): + # This is a wrapper to make the current InferenceSession class pickable. + def __init__(self, model_path, **kwargs): + super().__init__(model_path, **kwargs) + self.model_path = model_path + + def __getstate__(self): + return {'model_path': self.model_path} + + def __setstate__(self, values): + model_path = values['model_path'] + self.__init__(model_path) + +class ModelRouter: + def __init__(self, onnx_file): + self.onnx_file = onnx_file + + def get_model(self, **kwargs): + session = PickableInferenceSession(self.onnx_file, **kwargs) + # print(f'Applied providers: {session._providers}, with options: {session._provider_options}') + inputs = session.get_inputs() + input_cfg = inputs[0] + input_shape = input_cfg.shape + outputs = session.get_outputs() + + if len(outputs)>=5: + return RetinaFace(model_file=self.onnx_file, session=session) + elif input_shape[2]==192 and input_shape[3]==192: + return Landmark(model_file=self.onnx_file, session=session) + elif input_shape[2]==96 and input_shape[3]==96: + return Attribute(model_file=self.onnx_file, session=session) + elif len(inputs)==2 and input_shape[2]==128 and input_shape[3]==128: + return INSwapper(model_file=self.onnx_file, session=session) + elif input_shape[2]==input_shape[3] and input_shape[2]>=112 and input_shape[2]%16==0: + return ArcFaceONNX(model_file=self.onnx_file, session=session) + else: + #raise RuntimeError('error on model routing') + return None + +def find_onnx_file(dir_path): + if not os.path.exists(dir_path): + return None + paths = glob.glob("%s/*.onnx" % dir_path) + if len(paths) == 0: + return None + paths = sorted(paths) + return paths[-1] + +def get_default_providers(): + return ['CUDAExecutionProvider', 'CPUExecutionProvider'] + +def get_default_provider_options(): + return None + +def get_model(name, **kwargs): + root = kwargs.get('root', '~/.insightface') + root = os.path.expanduser(root) + model_root = osp.join(root, 'models') + allow_download = kwargs.get('download', False) + download_zip = kwargs.get('download_zip', False) + if not name.endswith('.onnx'): + model_dir = os.path.join(model_root, name) + model_file = find_onnx_file(model_dir) + if model_file is None: + return None + else: + model_file = name + if not osp.exists(model_file) and allow_download: + model_file = download_onnx('models', model_file, root=root, download_zip=download_zip) + assert osp.exists(model_file), 'model_file %s should exist'%model_file + assert osp.isfile(model_file), 'model_file %s should be a file'%model_file + router = ModelRouter(model_file) + providers = kwargs.get('providers', get_default_providers()) + provider_options = kwargs.get('provider_options', get_default_provider_options()) + model = router.get_model(providers=providers, provider_options=provider_options) + return model diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/retinaface.py b/liveportrait/utils/dependencies/insightface/model_zoo/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4ad91ed70688b38503127137e928dc7e5433e1 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/retinaface.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-09-18 +# @Function : + +from __future__ import division +import datetime +import numpy as np +import onnx +import onnxruntime +import os +import os.path as osp +import cv2 +import sys + +def softmax(z): + assert len(z.shape) == 2 + s = np.max(z, axis=1) + s = s[:, np.newaxis] # necessary step to do broadcasting + e_x = np.exp(z - s) + div = np.sum(e_x, axis=1) + div = div[:, np.newaxis] # dito + return e_x / div + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i%2] + distance[:, i] + py = points[:, i%2+1] + distance[:, i+1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + +class RetinaFace: + def __init__(self, model_file=None, session=None): + import onnxruntime + self.model_file = model_file + self.session = session + self.taskname = 'detection' + if self.session is None: + assert self.model_file is not None + assert osp.exists(self.model_file) + self.session = onnxruntime.InferenceSession(self.model_file, None) + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + self._init_vars() + + def _init_vars(self): + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + #print(input_shape) + if isinstance(input_shape[2], str): + self.input_size = None + else: + self.input_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + self.input_shape = input_shape + outputs = self.session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + self.input_name = input_name + self.output_names = output_names + self.input_mean = 127.5 + self.input_std = 128.0 + #print(self.output_names) + #assert len(outputs)==10 or len(outputs)==15 + self.use_kps = False + self._anchor_ratio = 1.0 + self._num_anchors = 1 + if len(outputs)==6: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + elif len(outputs)==9: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + elif len(outputs)==10: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + elif len(outputs)==15: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + self.use_kps = True + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + nms_thresh = kwargs.get('nms_thresh', None) + if nms_thresh is not None: + self.nms_thresh = nms_thresh + det_thresh = kwargs.get('det_thresh', None) + if det_thresh is not None: + self.det_thresh = det_thresh + input_size = kwargs.get('input_size', None) + if input_size is not None: + if self.input_size is not None: + print('warning: det_size is already set in detection model, ignore') + else: + self.input_size = input_size + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_outs = self.session.run(self.output_names, {self.input_name : blob}) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + scores = net_outs[idx] + bbox_preds = net_outs[idx+fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx+fmc*2] * stride + height = input_height // stride + width = input_width // stride + K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + #solution-1, c style: + #anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + #for i in range(height): + # anchor_centers[i, :, 1] = i + #for i in range(width): + # anchor_centers[:, i, 0] = i + + #solution-2: + #ax = np.arange(width, dtype=np.float32) + #ay = np.arange(height, dtype=np.float32) + #xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + #anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + #solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + #print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) + if self._num_anchors>1: + anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) + if len(self.center_cache)<100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores>=threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + #kpss = kps_preds + kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size = None, max_num=0, metric='default'): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio>model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order,:,:] + kpss = kpss[keep,:,:] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - + det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric=='max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort( + values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + +def get_retinaface(name, download=False, root='~/.insightface/models', **kwargs): + if not download: + assert os.path.exists(name) + return RetinaFace(name) + else: + from .model_store import get_model_file + _file = get_model_file("retinaface_%s" % name, root=root) + return retinaface(_file) + + diff --git a/liveportrait/utils/dependencies/insightface/model_zoo/scrfd.py b/liveportrait/utils/dependencies/insightface/model_zoo/scrfd.py new file mode 100644 index 0000000000000000000000000000000000000000..674db4bba761157592dfb95c5d1638da1099f89c --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/model_zoo/scrfd.py @@ -0,0 +1,348 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : + +from __future__ import division +import datetime +import numpy as np +import onnx +import onnxruntime +import os +import os.path as osp +import cv2 +import sys + +def softmax(z): + assert len(z.shape) == 2 + s = np.max(z, axis=1) + s = s[:, np.newaxis] # necessary step to do broadcasting + e_x = np.exp(z - s) + div = np.sum(e_x, axis=1) + div = div[:, np.newaxis] # dito + return e_x / div + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i%2] + distance[:, i] + py = points[:, i%2+1] + distance[:, i+1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + +class SCRFD: + def __init__(self, model_file=None, session=None): + import onnxruntime + self.model_file = model_file + self.session = session + self.taskname = 'detection' + self.batched = False + if self.session is None: + assert self.model_file is not None + assert osp.exists(self.model_file) + self.session = onnxruntime.InferenceSession(self.model_file, None) + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + self._init_vars() + + def _init_vars(self): + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + #print(input_shape) + if isinstance(input_shape[2], str): + self.input_size = None + else: + self.input_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + self.input_shape = input_shape + outputs = self.session.get_outputs() + if len(outputs[0].shape) == 3: + self.batched = True + output_names = [] + for o in outputs: + output_names.append(o.name) + self.input_name = input_name + self.output_names = output_names + self.input_mean = 127.5 + self.input_std = 128.0 + #print(self.output_names) + #assert len(outputs)==10 or len(outputs)==15 + self.use_kps = False + self._anchor_ratio = 1.0 + self._num_anchors = 1 + if len(outputs)==6: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + elif len(outputs)==9: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + elif len(outputs)==10: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + elif len(outputs)==15: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + self.use_kps = True + + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + nms_thresh = kwargs.get('nms_thresh', None) + if nms_thresh is not None: + self.nms_thresh = nms_thresh + det_thresh = kwargs.get('det_thresh', None) + if det_thresh is not None: + self.det_thresh = det_thresh + input_size = kwargs.get('input_size', None) + if input_size is not None: + if self.input_size is not None: + print('warning: det_size is already set in scrfd model, ignore') + else: + self.input_size = input_size + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_outs = self.session.run(self.output_names, {self.input_name : blob}) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + # If model support batch dim, take first output + if self.batched: + scores = net_outs[idx][0] + bbox_preds = net_outs[idx + fmc][0] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2][0] * stride + # If model doesn't support batching take output as is + else: + scores = net_outs[idx] + bbox_preds = net_outs[idx + fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2] * stride + + height = input_height // stride + width = input_width // stride + K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + #solution-1, c style: + #anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + #for i in range(height): + # anchor_centers[i, :, 1] = i + #for i in range(width): + # anchor_centers[:, i, 0] = i + + #solution-2: + #ax = np.arange(width, dtype=np.float32) + #ay = np.arange(height, dtype=np.float32) + #xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + #anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + #solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + #print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) + if self._num_anchors>1: + anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) + if len(self.center_cache)<100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores>=threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + #kpss = kps_preds + kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size = None, max_num=0, metric='default'): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio>model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order,:,:] + kpss = kpss[keep,:,:] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - + det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric=='max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort( + values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + +def get_scrfd(name, download=False, root='~/.insightface/models', **kwargs): + if not download: + assert os.path.exists(name) + return SCRFD(name) + else: + from .model_store import get_model_file + _file = get_model_file("scrfd_%s" % name, root=root) + return SCRFD(_file) + + +def scrfd_2p5gkps(**kwargs): + return get_scrfd("2p5gkps", download=True, **kwargs) + + +if __name__ == '__main__': + import glob + detector = SCRFD(model_file='./det.onnx') + detector.prepare(-1) + img_paths = ['tests/data/t1.jpg'] + for img_path in img_paths: + img = cv2.imread(img_path) + + for _ in range(1): + ta = datetime.datetime.now() + #bboxes, kpss = detector.detect(img, 0.5, input_size = (640, 640)) + bboxes, kpss = detector.detect(img, 0.5) + tb = datetime.datetime.now() + print('all cost:', (tb-ta).total_seconds()*1000) + print(img_path, bboxes.shape) + if kpss is not None: + print(kpss.shape) + for i in range(bboxes.shape[0]): + bbox = bboxes[i] + x1,y1,x2,y2,score = bbox.astype(np.int) + cv2.rectangle(img, (x1,y1) , (x2,y2) , (255,0,0) , 2) + if kpss is not None: + kps = kpss[i] + for kp in kps: + kp = kp.astype(np.int) + cv2.circle(img, tuple(kp) , 1, (0,0,255) , 2) + filename = img_path.split('/')[-1] + print('output:', filename) + cv2.imwrite('./outputs/%s'%filename, img) + diff --git a/liveportrait/utils/dependencies/insightface/utils/__init__.py b/liveportrait/utils/dependencies/insightface/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6960431b1bd6db38890e391c4c94dd2182f2e1fd --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/utils/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import + +from .storage import download, ensure_available, download_onnx +from .filesystem import get_model_dir +from .filesystem import makedirs, try_import_dali +from .constant import * diff --git a/liveportrait/utils/dependencies/insightface/utils/constant.py b/liveportrait/utils/dependencies/insightface/utils/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..8860ff077ae7227235591edfc84c0cdc227a6432 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/utils/constant.py @@ -0,0 +1,3 @@ + +DEFAULT_MP_NAME = 'buffalo_l' + diff --git a/liveportrait/utils/dependencies/insightface/utils/download.py b/liveportrait/utils/dependencies/insightface/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..5cda84dede45b81dcd99161d87792b6c409fa279 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/utils/download.py @@ -0,0 +1,95 @@ +""" +This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/download.py +""" +import os +import hashlib +import requests +from tqdm import tqdm + + +def check_sha1(filename, sha1_hash): + """Check whether the sha1 hash of the file content matches the expected hash. + Parameters + ---------- + filename : str + Path to the file. + sha1_hash : str + Expected sha1 hash in hexadecimal digits. + Returns + ------- + bool + Whether the file content matches the expected hash. + """ + sha1 = hashlib.sha1() + with open(filename, 'rb') as f: + while True: + data = f.read(1048576) + if not data: + break + sha1.update(data) + + sha1_file = sha1.hexdigest() + l = min(len(sha1_file), len(sha1_hash)) + return sha1.hexdigest()[0:l] == sha1_hash[0:l] + + +def download_file(url, path=None, overwrite=False, sha1_hash=None): + """Download an given URL + Parameters + ---------- + url : str + URL to download + path : str, optional + Destination path to store downloaded file. By default stores to the + current directory with same name as in url. + overwrite : bool, optional + Whether to overwrite destination file if already exists. + sha1_hash : str, optional + Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified + but doesn't match. + Returns + ------- + str + The file path of the downloaded file. + """ + if path is None: + fname = url.split('/')[-1] + else: + path = os.path.expanduser(path) + if os.path.isdir(path): + fname = os.path.join(path, url.split('/')[-1]) + else: + fname = path + + if overwrite or not os.path.exists(fname) or ( + sha1_hash and not check_sha1(fname, sha1_hash)): + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) + if not os.path.exists(dirname): + os.makedirs(dirname) + + print('Downloading %s from %s...' % (fname, url)) + r = requests.get(url, stream=True) + if r.status_code != 200: + raise RuntimeError("Failed downloading url %s" % url) + total_length = r.headers.get('content-length') + with open(fname, 'wb') as f: + if total_length is None: # no content length header + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + else: + total_length = int(total_length) + for chunk in tqdm(r.iter_content(chunk_size=1024), + total=int(total_length / 1024. + 0.5), + unit='KB', + unit_scale=False, + dynamic_ncols=True): + f.write(chunk) + + if sha1_hash and not check_sha1(fname, sha1_hash): + raise UserWarning('File {} is downloaded but the content hash does not match. ' \ + 'The repo may be outdated or download may be incomplete. ' \ + 'If the "repo_url" is overridden, consider switching to ' \ + 'the default repo.'.format(fname)) + + return fname diff --git a/liveportrait/utils/dependencies/insightface/utils/face_align.py b/liveportrait/utils/dependencies/insightface/utils/face_align.py new file mode 100644 index 0000000000000000000000000000000000000000..226628b39cf743947df230feffbb97bf5c585e1d --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/utils/face_align.py @@ -0,0 +1,103 @@ +import cv2 +import numpy as np +from skimage import transform as trans + + +arcface_dst = np.array( + [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], + [41.5493, 92.3655], [70.7299, 92.2041]], + dtype=np.float32) + +def estimate_norm(lmk, image_size=112,mode='arcface'): + assert lmk.shape == (5, 2) + assert image_size%112==0 or image_size%128==0 + if image_size%112==0: + ratio = float(image_size)/112.0 + diff_x = 0 + else: + ratio = float(image_size)/128.0 + diff_x = 8.0*ratio + dst = arcface_dst * ratio + dst[:,0] += diff_x + tform = trans.SimilarityTransform() + tform.estimate(lmk, dst) + M = tform.params[0:2, :] + return M + +def norm_crop(img, landmark, image_size=112, mode='arcface'): + M = estimate_norm(landmark, image_size, mode) + warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) + return warped + +def norm_crop2(img, landmark, image_size=112, mode='arcface'): + M = estimate_norm(landmark, image_size, mode) + warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) + return warped, M + +def square_crop(im, S): + if im.shape[0] > im.shape[1]: + height = S + width = int(float(im.shape[1]) / im.shape[0] * S) + scale = float(S) / im.shape[0] + else: + width = S + height = int(float(im.shape[0]) / im.shape[1] * S) + scale = float(S) / im.shape[1] + resized_im = cv2.resize(im, (width, height)) + det_im = np.zeros((S, S, 3), dtype=np.uint8) + det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im + return det_im, scale + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, + output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, + M, (output_size, output_size), + borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + #print('new_pt', new_pt.shape, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + +def trans_points3d(pts, M): + scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) + #print(scale) + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + #print('new_pt', new_pt.shape, new_pt) + new_pts[i][0:2] = new_pt[0:2] + new_pts[i][2] = pts[i][2] * scale + + return new_pts + + +def trans_points(pts, M): + if pts.shape[1] == 2: + return trans_points2d(pts, M) + else: + return trans_points3d(pts, M) + diff --git a/liveportrait/utils/dependencies/insightface/utils/filesystem.py b/liveportrait/utils/dependencies/insightface/utils/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..01e3851975bdcbbf7f5eeb7e68e70a36dc040535 --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/utils/filesystem.py @@ -0,0 +1,157 @@ +""" +This code file mainly comes from https://github.com/dmlc/gluon-cv/blob/master/gluoncv/utils/filesystem.py +""" +import os +import os.path as osp +import errno + + +def get_model_dir(name, root='~/.insightface'): + root = os.path.expanduser(root) + model_dir = osp.join(root, 'models', name) + return model_dir + +def makedirs(path): + """Create directory recursively if not exists. + Similar to `makedir -p`, you can skip checking existence before this function. + + Parameters + ---------- + path : str + Path of the desired dir + """ + try: + os.makedirs(path) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + + +def try_import(package, message=None): + """Try import specified package, with custom message support. + + Parameters + ---------- + package : str + The name of the targeting package. + message : str, default is None + If not None, this function will raise customized error message when import error is found. + + + Returns + ------- + module if found, raise ImportError otherwise + + """ + try: + return __import__(package) + except ImportError as e: + if not message: + raise e + raise ImportError(message) + + +def try_import_cv2(): + """Try import cv2 at runtime. + + Returns + ------- + cv2 module if found. Raise ImportError otherwise + + """ + msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \ + or `pip install opencv-python --user` (note that this is unofficial PYPI package)." + + return try_import('cv2', msg) + + +def try_import_mmcv(): + """Try import mmcv at runtime. + + Returns + ------- + mmcv module if found. Raise ImportError otherwise + + """ + msg = "mmcv is required, you can install by first `pip install Cython --user` \ + and then `pip install mmcv --user` (note that this is unofficial PYPI package)." + + return try_import('mmcv', msg) + + +def try_import_rarfile(): + """Try import rarfile at runtime. + + Returns + ------- + rarfile module if found. Raise ImportError otherwise + + """ + msg = "rarfile is required, you can install by first `sudo apt-get install unrar` \ + and then `pip install rarfile --user` (note that this is unofficial PYPI package)." + + return try_import('rarfile', msg) + + +def import_try_install(package, extern_url=None): + """Try import the specified package. + If the package not installed, try use pip to install and import if success. + + Parameters + ---------- + package : str + The name of the package trying to import. + extern_url : str or None, optional + The external url if package is not hosted on PyPI. + For example, you can install a package using: + "pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx". + In this case, you can pass the url to the extern_url. + + Returns + ------- + + The imported python module. + + """ + try: + return __import__(package) + except ImportError: + try: + from pip import main as pipmain + except ImportError: + from pip._internal import main as pipmain + + # trying to install package + url = package if extern_url is None else extern_url + pipmain(['install', '--user', + url]) # will raise SystemExit Error if fails + + # trying to load again + try: + return __import__(package) + except ImportError: + import sys + import site + user_site = site.getusersitepackages() + if user_site not in sys.path: + sys.path.append(user_site) + return __import__(package) + return __import__(package) + + +def try_import_dali(): + """Try import NVIDIA DALI at runtime. + """ + try: + dali = __import__('nvidia.dali', fromlist=['pipeline', 'ops', 'types']) + dali.Pipeline = dali.pipeline.Pipeline + except ImportError: + + class dali: + class Pipeline: + def __init__(self): + raise NotImplementedError( + "DALI not found, please check if you installed it correctly." + ) + + return dali diff --git a/liveportrait/utils/dependencies/insightface/utils/storage.py b/liveportrait/utils/dependencies/insightface/utils/storage.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf37e2d17b28dee2a8839484778815f87fc4a9c --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/utils/storage.py @@ -0,0 +1,52 @@ + +import os +import os.path as osp +import zipfile +from .download import download_file + +BASE_REPO_URL = 'https://github.com/deepinsight/insightface/releases/download/v0.7' + +def download(sub_dir, name, force=False, root='~/.insightface'): + _root = os.path.expanduser(root) + dir_path = os.path.join(_root, sub_dir, name) + if osp.exists(dir_path) and not force: + return dir_path + print('download_path:', dir_path) + zip_file_path = os.path.join(_root, sub_dir, name + '.zip') + model_url = "%s/%s.zip"%(BASE_REPO_URL, name) + download_file(model_url, + path=zip_file_path, + overwrite=True) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(dir_path) + #os.remove(zip_file_path) + return dir_path + +def ensure_available(sub_dir, name, root='~/.insightface'): + return download(sub_dir, name, force=False, root=root) + +def download_onnx(sub_dir, model_file, force=False, root='~/.insightface', download_zip=False): + _root = os.path.expanduser(root) + model_root = osp.join(_root, sub_dir) + new_model_file = osp.join(model_root, model_file) + if osp.exists(new_model_file) and not force: + return new_model_file + if not osp.exists(model_root): + os.makedirs(model_root) + print('download_path:', new_model_file) + if not download_zip: + model_url = "%s/%s"%(BASE_REPO_URL, model_file) + download_file(model_url, + path=new_model_file, + overwrite=True) + else: + model_url = "%s/%s.zip"%(BASE_REPO_URL, model_file) + zip_file_path = new_model_file+".zip" + download_file(model_url, + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(model_root) + return new_model_file diff --git a/liveportrait/utils/dependencies/insightface/utils/transform.py b/liveportrait/utils/dependencies/insightface/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..06531d257b694211a0b9a09c9d741b9b2ff53bfe --- /dev/null +++ b/liveportrait/utils/dependencies/insightface/utils/transform.py @@ -0,0 +1,116 @@ +import cv2 +import math +import numpy as np +from skimage import transform as trans + + +def transform(data, center, output_size, scale, rotation): + scale_ratio = scale + rot = float(rotation) * np.pi / 180.0 + #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) + t1 = trans.SimilarityTransform(scale=scale_ratio) + cx = center[0] * scale_ratio + cy = center[1] * scale_ratio + t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) + t3 = trans.SimilarityTransform(rotation=rot) + t4 = trans.SimilarityTransform(translation=(output_size / 2, + output_size / 2)) + t = t1 + t2 + t3 + t4 + M = t.params[0:2] + cropped = cv2.warpAffine(data, + M, (output_size, output_size), + borderValue=0.0) + return cropped, M + + +def trans_points2d(pts, M): + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + #print('new_pt', new_pt.shape, new_pt) + new_pts[i] = new_pt[0:2] + + return new_pts + + +def trans_points3d(pts, M): + scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) + #print(scale) + new_pts = np.zeros(shape=pts.shape, dtype=np.float32) + for i in range(pts.shape[0]): + pt = pts[i] + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) + new_pt = np.dot(M, new_pt) + #print('new_pt', new_pt.shape, new_pt) + new_pts[i][0:2] = new_pt[0:2] + new_pts[i][2] = pts[i][2] * scale + + return new_pts + + +def trans_points(pts, M): + if pts.shape[1] == 2: + return trans_points2d(pts, M) + else: + return trans_points3d(pts, M) + +def estimate_affine_matrix_3d23d(X, Y): + ''' Using least-squares solution + Args: + X: [n, 3]. 3d points(fixed) + Y: [n, 3]. corresponding 3d points(moving). Y = PX + Returns: + P_Affine: (3, 4). Affine camera matrix (the third row is [0, 0, 0, 1]). + ''' + X_homo = np.hstack((X, np.ones([X.shape[0],1]))) #n x 4 + P = np.linalg.lstsq(X_homo, Y)[0].T # Affine matrix. 3 x 4 + return P + +def P2sRt(P): + ''' decompositing camera matrix P + Args: + P: (3, 4). Affine Camera Matrix. + Returns: + s: scale factor. + R: (3, 3). rotation matrix. + t: (3,). translation. + ''' + t = P[:, 3] + R1 = P[0:1, :3] + R2 = P[1:2, :3] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2.0 + r1 = R1/np.linalg.norm(R1) + r2 = R2/np.linalg.norm(R2) + r3 = np.cross(r1, r2) + + R = np.concatenate((r1, r2, r3), 0) + return s, R, t + +def matrix2angle(R): + ''' get three Euler angles from Rotation Matrix + Args: + R: (3,3). rotation matrix + Returns: + x: pitch + y: yaw + z: roll + ''' + sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) + + singular = sy < 1e-6 + + if not singular : + x = math.atan2(R[2,1] , R[2,2]) + y = math.atan2(-R[2,0], sy) + z = math.atan2(R[1,0], R[0,0]) + else : + x = math.atan2(-R[1,2], R[1,1]) + y = math.atan2(-R[2,0], sy) + z = 0 + + # rx, ry, rz = np.rad2deg(x), np.rad2deg(y), np.rad2deg(z) + rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi + return rx, ry, rz + diff --git a/liveportrait/utils/face_analysis_diy.py b/liveportrait/utils/face_analysis_diy.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa2986df93db39e81aaf55765bc46bad9e05546 --- /dev/null +++ b/liveportrait/utils/face_analysis_diy.py @@ -0,0 +1,79 @@ +# coding: utf-8 + +""" +face detectoin and alignment using InsightFace +""" + +import numpy as np +from .rprint import rlog as log +from .dependencies.insightface.app import FaceAnalysis +from .dependencies.insightface.app.common import Face +from .timer import Timer + + +def sort_by_direction(faces, direction: str = 'large-small', face_center=None): + if len(faces) <= 0: + return faces + + if direction == 'left-right': + return sorted(faces, key=lambda face: face['bbox'][0]) + if direction == 'right-left': + return sorted(faces, key=lambda face: face['bbox'][0], reverse=True) + if direction == 'top-bottom': + return sorted(faces, key=lambda face: face['bbox'][1]) + if direction == 'bottom-top': + return sorted(faces, key=lambda face: face['bbox'][1], reverse=True) + if direction == 'small-large': + return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1])) + if direction == 'large-small': + return sorted(faces, key=lambda face: (face['bbox'][2] - face['bbox'][0]) * (face['bbox'][3] - face['bbox'][1]), reverse=True) + if direction == 'distance-from-retarget-face': + return sorted(faces, key=lambda face: (((face['bbox'][2]+face['bbox'][0])/2-face_center[0])**2+((face['bbox'][3]+face['bbox'][1])/2-face_center[1])**2)**0.5) + return faces + + +class FaceAnalysisDIY(FaceAnalysis): + def __init__(self, name='buffalo_l', root='~/.insightface', allowed_modules=None, **kwargs): + super().__init__(name=name, root=root, allowed_modules=allowed_modules, **kwargs) + + self.timer = Timer() + + def get(self, img_bgr, **kwargs): + max_num = kwargs.get('max_num', 0) # the number of the detected faces, 0 means no limit + flag_do_landmark_2d_106 = kwargs.get('flag_do_landmark_2d_106', True) # whether to do 106-point detection + direction = kwargs.get('direction', 'large-small') # sorting direction + face_center = None + + bboxes, kpss = self.det_model.detect(img_bgr, max_num=max_num, metric='default') + if bboxes.shape[0] == 0: + return [] + ret = [] + for i in range(bboxes.shape[0]): + bbox = bboxes[i, 0:4] + det_score = bboxes[i, 4] + kps = None + if kpss is not None: + kps = kpss[i] + face = Face(bbox=bbox, kps=kps, det_score=det_score) + for taskname, model in self.models.items(): + if taskname == 'detection': + continue + + if (not flag_do_landmark_2d_106) and taskname == 'landmark_2d_106': + continue + + # print(f'taskname: {taskname}') + model.get(img_bgr, face) + ret.append(face) + + ret = sort_by_direction(ret, direction, face_center) + return ret + + def warmup(self): + self.timer.tic() + + img_bgr = np.zeros((512, 512, 3), dtype=np.uint8) + self.get(img_bgr) + + elapse = self.timer.toc() + #log(f'FaceAnalysisDIY warmup time: {elapse:.3f}s') diff --git a/liveportrait/utils/helper.py b/liveportrait/utils/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..9854563734a41bd217842311b9093b95e8f11e6d --- /dev/null +++ b/liveportrait/utils/helper.py @@ -0,0 +1,154 @@ +# coding: utf-8 + +""" +utility functions and classes to handle feature extraction and model loading +""" + +import os +import os.path as osp +import torch +from collections import OrderedDict + +from ..modules.spade_generator import SPADEDecoder +from ..modules.warping_network import WarpingNetwork +from ..modules.motion_extractor import MotionExtractor +from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor +from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork + + +def suffix(filename): + """a.jpg -> jpg""" + pos = filename.rfind(".") + if pos == -1: + return "" + return filename[pos + 1:] + + +def prefix(filename): + """a.jpg -> a""" + pos = filename.rfind(".") + if pos == -1: + return filename + return filename[:pos] + + +def basename(filename): + """a/b/c.jpg -> c""" + return prefix(osp.basename(filename)) + + +def is_video(file_path): + if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path): + return True + return False + + +def is_template(file_path): + if file_path.endswith(".pkl"): + return True + return False + + +def mkdir(d, log=False): + # return self-assined `d`, for one line code + if not osp.exists(d): + os.makedirs(d, exist_ok=True) + if log: + print(f"Make dir: {d}") + return d + + +def squeeze_tensor_to_numpy(tensor): + out = tensor.data.squeeze(0).cpu().numpy() + return out + + +def dct2cuda(dct: dict, device_id: int): + for key in dct: + dct[key] = torch.tensor(dct[key]).cuda(device_id) + return dct + + +def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: + """ + kp_source: (bs, k, 3) + kp_driving: (bs, k, 3) + Return: (bs, 2k*3) + """ + bs_src = kp_source.shape[0] + bs_dri = kp_driving.shape[0] + assert bs_src == bs_dri, 'batch size must be equal' + + feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1) + return feat + + +def remove_ddp_dumplicate_key(state_dict): + state_dict_new = OrderedDict() + for key in state_dict.keys(): + state_dict_new[key.replace('module.', '')] = state_dict[key] + return state_dict_new + + +def load_model(ckpt_path, model_config, device, model_type): + model_params = model_config['model_params'][f'{model_type}_params'] + + if model_type == 'appearance_feature_extractor': + model = AppearanceFeatureExtractor(**model_params).cuda(device) + elif model_type == 'motion_extractor': + model = MotionExtractor(**model_params).cuda(device) + elif model_type == 'warping_module': + model = WarpingNetwork(**model_params).cuda(device) + elif model_type == 'spade_generator': + model = SPADEDecoder(**model_params).cuda(device) + elif model_type == 'stitching_retargeting_module': + # Special handling for stitching and retargeting module + config = model_config['model_params']['stitching_retargeting_module_params'] + checkpoint = torch.load(ckpt_path, weights_only=True, map_location=lambda storage, loc: storage) + + stitcher = StitchingRetargetingNetwork(**config.get('stitching')) + stitcher.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder'])) + stitcher = stitcher.cuda(device) + stitcher.eval() + + retargetor_lip = StitchingRetargetingNetwork(**config.get('lip')) + retargetor_lip.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_mouth'])) + retargetor_lip = retargetor_lip.cuda(device) + retargetor_lip.eval() + + retargetor_eye = StitchingRetargetingNetwork(**config.get('eye')) + retargetor_eye.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_eye'])) + retargetor_eye = retargetor_eye.cuda(device) + retargetor_eye.eval() + + return { + 'stitching': stitcher, + 'lip': retargetor_lip, + 'eye': retargetor_eye + } + else: + raise ValueError(f"Unknown model type: {model_type}") + + model.load_state_dict(torch.load(ckpt_path, weights_only=True, map_location=lambda storage, loc: storage)) + model.eval() + return model + + +# get coefficients of Eqn. 7 +def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R_t_0, R_t_i): + if config.relative: + new_rotation = (R_t_i @ R_t_0.permute(0, 2, 1)) @ R_s + new_expression = s_kp_info['exp'] + (t_i_kp_info['exp'] - t_0_kp_info['exp']) + else: + new_rotation = R_t_i + new_expression = t_i_kp_info['exp'] + new_translation = s_kp_info['t'] + (t_i_kp_info['t'] - t_0_kp_info['t']) + new_translation[..., 2].fill_(0) # Keep the z-axis unchanged + new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale']) + return new_rotation, new_expression, new_translation, new_scale + + +def load_description(fp): + with open(fp, 'r', encoding='utf-8') as f: + content = f.read() + return content diff --git a/liveportrait/utils/io.py b/liveportrait/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..29a7e008759ec8428171a18ac65960fc14397fdc --- /dev/null +++ b/liveportrait/utils/io.py @@ -0,0 +1,97 @@ +# coding: utf-8 + +import os +from glob import glob +import os.path as osp +import imageio +import numpy as np +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) + + +def load_image_rgb(image_path: str): + if not osp.exists(image_path): + raise FileNotFoundError(f"Image not found: {image_path}") + img = cv2.imread(image_path, cv2.IMREAD_COLOR) + return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + +def load_driving_info(driving_info): + driving_video_ori = [] + + def load_images_from_directory(directory): + image_paths = sorted(glob(osp.join(directory, '*.png')) + glob(osp.join(directory, '*.jpg'))) + return [load_image_rgb(im_path) for im_path in image_paths] + + def load_images_from_video(file_path): + reader = imageio.get_reader(file_path) + return [image for idx, image in enumerate(reader)] + + if osp.isdir(driving_info): + driving_video_ori = load_images_from_directory(driving_info) + elif osp.isfile(driving_info): + driving_video_ori = load_images_from_video(driving_info) + + return driving_video_ori + + +def contiguous(obj): + if not obj.flags.c_contiguous: + obj = obj.copy(order="C") + return obj + + +def resize_to_limit(img: np.ndarray, max_dim=1920, n=2): + """ + ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n. + :param img: the image to be processed. + :param max_dim: the maximum dimension constraint. + :param n: the number that needs to be multiples of. + :return: the adjusted image. + """ + h, w = img.shape[:2] + + # ajust the size of the image according to the maximum dimension + if max_dim > 0 and max(h, w) > max_dim: + if h > w: + new_h = max_dim + new_w = int(w * (max_dim / h)) + else: + new_w = max_dim + new_h = int(h * (max_dim / w)) + img = cv2.resize(img, (new_w, new_h)) + + # ensure that the image dimensions are multiples of n + n = max(n, 1) + new_h = img.shape[0] - (img.shape[0] % n) + new_w = img.shape[1] - (img.shape[1] % n) + + if new_h == 0 or new_w == 0: + # when the width or height is less than n, no need to process + return img + + if new_h != img.shape[0] or new_w != img.shape[1]: + img = img[:new_h, :new_w] + + return img + + +def load_img_online(obj, mode="bgr", **kwargs): + max_dim = kwargs.get("max_dim", 1920) + n = kwargs.get("n", 2) + if isinstance(obj, str): + if mode.lower() == "gray": + img = cv2.imread(obj, cv2.IMREAD_GRAYSCALE) + else: + img = cv2.imread(obj, cv2.IMREAD_COLOR) + else: + img = obj + + # Resize image to satisfy constraints + img = resize_to_limit(img, max_dim=max_dim, n=n) + + if mode.lower() == "bgr": + return contiguous(img) + elif mode.lower() == "rgb": + return contiguous(img[..., ::-1]) + else: + raise Exception(f"Unknown mode {mode}") diff --git a/liveportrait/utils/landmark_runner.py b/liveportrait/utils/landmark_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..75bf5c254f7eb32f3096bd667fcacfe190da6bc4 --- /dev/null +++ b/liveportrait/utils/landmark_runner.py @@ -0,0 +1,89 @@ +# coding: utf-8 + +import os.path as osp +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) +import torch +import numpy as np +import onnxruntime +from .timer import Timer +from .rprint import rlog +from .crop import crop_image, _transform_pts + + +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) + + +def to_ndarray(obj): + if isinstance(obj, torch.Tensor): + return obj.cpu().numpy() + elif isinstance(obj, np.ndarray): + return obj + else: + return np.array(obj) + + +class LandmarkRunner(object): + """landmark runner""" + def __init__(self, **kwargs): + ckpt_path = kwargs.get('ckpt_path') + onnx_provider = kwargs.get('onnx_provider', 'cuda') # 默认用cuda + device_id = kwargs.get('device_id', 0) + self.dsize = kwargs.get('dsize', 224) + self.timer = Timer() + + if onnx_provider.lower() == 'cuda': + self.session = onnxruntime.InferenceSession( + ckpt_path, providers=[ + ('CUDAExecutionProvider', {'device_id': device_id}) + ] + ) + else: + opts = onnxruntime.SessionOptions() + opts.intra_op_num_threads = 4 # 默认线程数为 4 + self.session = onnxruntime.InferenceSession( + ckpt_path, providers=['CPUExecutionProvider'], + sess_options=opts + ) + + def _run(self, inp): + out = self.session.run(None, {'input': inp}) + return out + + def run(self, img_rgb: np.ndarray, lmk=None): + if lmk is not None: + crop_dct = crop_image(img_rgb, lmk, dsize=self.dsize, scale=1.5, vy_ratio=-0.1) + img_crop_rgb = crop_dct['img_crop'] + else: + img_crop_rgb = cv2.resize(img_rgb, (self.dsize, self.dsize)) + scale = max(img_rgb.shape[:2]) / self.dsize + crop_dct = { + 'M_c2o': np.array([ + [scale, 0., 0.], + [0., scale, 0.], + [0., 0., 1.], + ], dtype=np.float32), + } + + inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!) + + out_lst = self._run(inp) + out_pts = out_lst[2] + + pts = to_ndarray(out_pts[0]).reshape(-1, 2) * self.dsize # scale to 0-224 + pts = _transform_pts(pts, M=crop_dct['M_c2o']) + + return { + 'pts': pts, # 2d landmarks 203 points + } + + def warmup(self): + # 构造dummy image进行warmup + self.timer.tic() + + dummy_image = np.zeros((1, 3, self.dsize, self.dsize), dtype=np.float32) + + _ = self._run(dummy_image) + + elapse = self.timer.toc() + #rlog(f'LandmarkRunner warmup time: {elapse:.3f}s') diff --git a/liveportrait/utils/resources/mask_template.png b/liveportrait/utils/resources/mask_template.png new file mode 100644 index 0000000000000000000000000000000000000000..d3522210825b8fe128210f261d32e80a7ca02a80 --- /dev/null +++ b/liveportrait/utils/resources/mask_template.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c53e64a07ae3056af2b38548ae5b6cceb04d5c3e6514c7a8d2c3aff9da1ee76 +size 3470 diff --git a/liveportrait/utils/retargeting_utils.py b/liveportrait/utils/retargeting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..20a1bdd37a05f8140124db2a1be568683e917653 --- /dev/null +++ b/liveportrait/utils/retargeting_utils.py @@ -0,0 +1,54 @@ + +""" +Functions to compute distance ratios between specific pairs of facial landmarks +""" + +import numpy as np + + +def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray: + """ + Calculate the ratio of the distance between two pairs of landmarks. + + Parameters: + lmk (np.ndarray): Landmarks array of shape (B, N, 2). + idx1, idx2, idx3, idx4 (int): Indices of the landmarks. + eps (float): Small value to avoid division by zero. + + Returns: + np.ndarray: Calculated distance ratio. + """ + return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) / + (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps)) + + +def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray: + """ + Calculate the eye-close ratio for left and right eyes. + + Parameters: + lmk (np.ndarray): Landmarks array of shape (B, N, 2). + target_eye_ratio (np.ndarray, optional): Additional target eye ratio array to include. + + Returns: + np.ndarray: Concatenated eye-close ratios. + """ + lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12) + righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36) + if target_eye_ratio is not None: + return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1) + else: + return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1) + + +def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray: + """ + Calculate the lip-close ratio. + + Parameters: + lmk (np.ndarray): Landmarks array of shape (B, N, 2). + + Returns: + np.ndarray: Calculated lip-close ratio. + """ + return calculate_distance_ratio(lmk, 90, 102, 48, 66) diff --git a/liveportrait/utils/rprint.py b/liveportrait/utils/rprint.py new file mode 100644 index 0000000000000000000000000000000000000000..c43a42f9855bbb019725e6c2b6c6c50e6fa4d0c5 --- /dev/null +++ b/liveportrait/utils/rprint.py @@ -0,0 +1,16 @@ +# coding: utf-8 + +""" +custom print and log functions +""" + +__all__ = ['rprint', 'rlog'] + +try: + from rich.console import Console + console = Console() + rprint = console.print + rlog = console.log +except: + rprint = print + rlog = print diff --git a/liveportrait/utils/timer.py b/liveportrait/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..3570fa45d3ff36376471b82a5b3c02efe46eed98 --- /dev/null +++ b/liveportrait/utils/timer.py @@ -0,0 +1,29 @@ +# coding: utf-8 + +""" +tools to measure elapsed time +""" + +import time + +class Timer(object): + """A simple timer.""" + + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + return self.diff + + def clear(self): + self.start_time = 0. + self.diff = 0. diff --git a/liveportrait/utils/video.py b/liveportrait/utils/video.py new file mode 100644 index 0000000000000000000000000000000000000000..555122f728a1d8956d5532104c1f9c930e529925 --- /dev/null +++ b/liveportrait/utils/video.py @@ -0,0 +1,140 @@ +# coding: utf-8 + +""" +functions for processing video +""" + +import os.path as osp +import numpy as np +import subprocess +import imageio +import cv2 + +from rich.progress import track +from .helper import prefix +from .rprint import rprint as print + + +def exec_cmd(cmd): + subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + +def images2video(images, wfp, **kwargs): + fps = kwargs.get('fps', 25) # MuseTalk prefers 25 fps by default + video_format = kwargs.get('format', 'mp4') # default is mp4 format + codec = kwargs.get('codec', 'libx264') # default is libx264 encoding + quality = kwargs.get('quality') # video quality + pixelformat = kwargs.get('pixelformat', 'yuv420p') # video pixel format + image_mode = kwargs.get('image_mode', 'rgb') + macro_block_size = kwargs.get('macro_block_size', 2) + ffmpeg_params = ['-crf', str(kwargs.get('crf', 18))] + + writer = imageio.get_writer( + wfp, fps=fps, format=video_format, + codec=codec, quality=quality, ffmpeg_params=ffmpeg_params, pixelformat=pixelformat, macro_block_size=macro_block_size + ) + + n = len(images) + for i in track(range(n), description='writing', transient=True): + if image_mode.lower() == 'bgr': + writer.append_data(images[i][..., ::-1]) + else: + writer.append_data(images[i]) + + writer.close() + + # print(f':smiley: Dump to {wfp}\n', style="bold green") + print(f'Dump to {wfp}\n') + + +# useTalk prefers 25 fps by default +def video2gif(video_fp, fps=25, size=256): + if osp.exists(video_fp): + d = osp.split(video_fp)[0] + fn = prefix(osp.basename(video_fp)) + palette_wfp = osp.join(d, 'palette.png') + gif_wfp = osp.join(d, f'{fn}.gif') + # generate the palette + cmd = f'ffmpeg -i {video_fp} -vf "fps={fps},scale={size}:-1:flags=lanczos,palettegen" {palette_wfp} -y' + exec_cmd(cmd) + # use the palette to generate the gif + cmd = f'ffmpeg -i {video_fp} -i {palette_wfp} -filter_complex "fps={fps},scale={size}:-1:flags=lanczos[x];[x][1:v]paletteuse" {gif_wfp} -y' + exec_cmd(cmd) + else: + print(f'video_fp: {video_fp} not exists!') + + +def merge_audio_video(video_fp, audio_fp, wfp): + if osp.exists(video_fp) and osp.exists(audio_fp): + cmd = f'ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y' + exec_cmd(cmd) + print(f'merge {video_fp} and {audio_fp} to {wfp}') + else: + print(f'video_fp: {video_fp} or audio_fp: {audio_fp} not exists!') + + +def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)): + mask_float = mask.astype(np.float32) / 255. + background_color = np.array(background_color).reshape([1, 1, 3]) + bg = np.ones_like(img) * background_color + img = np.clip(mask_float * img + (1 - mask_float) * bg, 0, 255).astype(np.uint8) + return img + + +def concat_frames(I_p_lst, driving_rgb_lst, img_rgb): + # TODO: add more concat style, e.g., left-down corner driving + out_lst = [] + for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'): + source_image_drived = I_p_lst[idx] + image_drive = driving_rgb_lst[idx] + + # resize images to match source_image_drived shape + h, w, _ = source_image_drived.shape + image_drive_resized = cv2.resize(image_drive, (w, h)) + img_rgb_resized = cv2.resize(img_rgb, (w, h)) + + # concatenate images horizontally + frame = np.concatenate((image_drive_resized, img_rgb_resized, source_image_drived), axis=1) + out_lst.append(frame) + return out_lst + + +class VideoWriter: + def __init__(self, **kwargs): + self.fps = kwargs.get('fps', 25) # MuseTalk prefers 25 fps by default + self.wfp = kwargs.get('wfp', 'video.mp4') + self.video_format = kwargs.get('format', 'mp4') + self.codec = kwargs.get('codec', 'libx264') + self.quality = kwargs.get('quality') + self.pixelformat = kwargs.get('pixelformat', 'yuv420p') + self.image_mode = kwargs.get('image_mode', 'rgb') + self.ffmpeg_params = kwargs.get('ffmpeg_params') + + self.writer = imageio.get_writer( + self.wfp, fps=self.fps, format=self.video_format, + codec=self.codec, quality=self.quality, + ffmpeg_params=self.ffmpeg_params, pixelformat=self.pixelformat + ) + + def write(self, image): + if self.image_mode.lower() == 'bgr': + self.writer.append_data(image[..., ::-1]) + else: + self.writer.append_data(image) + + def close(self): + if self.writer is not None: + self.writer.close() + + +def change_video_fps(input_file, output_file, fps=25, codec='libx264', crf=5): + cmd = f"ffmpeg -i {input_file} -c:v {codec} -crf {crf} -r {fps} {output_file} -y" + exec_cmd(cmd) + + +def get_fps(filepath): + import ffmpeg + probe = ffmpeg.probe(filepath) + video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) + fps = eval(video_stream['avg_frame_rate']) + return fps diff --git a/loader.py b/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..0172b81a01e1bfe58ed674c5f74032e66a25cc12 --- /dev/null +++ b/loader.py @@ -0,0 +1,156 @@ +import os +import logging +import torch +import asyncio +import aiohttp +import requests +from huggingface_hub import hf_hub_download +import sentencepiece + + +# Configure logging +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Configuration +DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') +MODELS_DIR = os.path.join(DATA_ROOT, "models") +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Hugging Face repository information +HF_REPO_ID = "jbilcke-hf/model-cocktail" + +# Model files to download +MODEL_FILES = [ + "dwpose/dw-ll_ucoco_384.pth", + "face-detector/s3fd-619a316812.pth", + "liveportrait/spade_generator.pth", + "liveportrait/warping_module.pth", + "liveportrait/motion_extractor.pth", + "liveportrait/stitching_retargeting_module.pth", + "liveportrait/appearance_feature_extractor.pth", + "liveportrait/landmark.onnx", + + # this is a hack, instead we should probably try to + # fix liveportrait/utils/dependencies/insightface/utils/storage.py + "insightface/models/buffalo_l.zip", + + "insightface/buffalo_l/det_10g.onnx", + "insightface/buffalo_l/2d106det.onnx", + "sd-vae-ft-mse/diffusion_pytorch_model.bin", + "sd-vae-ft-mse/diffusion_pytorch_model.safetensors", + "sd-vae-ft-mse/config.json", + + # we don't use those yet + #"flux-dev/flux-dev-fp8.safetensors", + #"flux-dev/flux_dev_quantization_map.json", + #"pulid-flux/pulid_flux_v0.9.0.safetensors", + #"pulid-flux/pulid_v1.bin" +] + +def create_directory(directory): + """Create a directory if it doesn't exist and log its status.""" + if not os.path.exists(directory): + os.makedirs(directory) + logger.info(f" Directory created: {directory}") + else: + logger.info(f" Directory already exists: {directory}") + +def print_directory_structure(startpath): + """Print the directory structure starting from the given path.""" + for root, dirs, files in os.walk(startpath): + level = root.replace(startpath, '').count(os.sep) + indent = ' ' * 4 * level + logger.info(f"{indent}{os.path.basename(root)}/") + subindent = ' ' * 4 * (level + 1) + for f in files: + logger.info(f"{subindent}{f}") + +async def download_hf_file(filename: str) -> None: + """Download a file from Hugging Face to the models directory.""" + dest = os.path.join(MODELS_DIR, filename) + os.makedirs(os.path.dirname(dest), exist_ok=True) + if os.path.exists(dest): + # this is really for debugging purposes only + logger.debug(f" ✅ {filename}") + return + + logger.info(f" ⏳ Downloading {HF_REPO_ID}/{filename}") + + try: + await asyncio.get_event_loop().run_in_executor( + None, + lambda: hf_hub_download( + repo_id=HF_REPO_ID, + filename=filename, + local_dir=MODELS_DIR + ) + ) + logger.info(f" ✅ Downloaded {filename}") + except Exception as e: + logger.error(f"🚨 Error downloading file from Hugging Face: {e}") + if os.path.exists(dest): + os.remove(dest) + raise + +async def download_all_models(): + """Download all required models from the Hugging Face repository.""" + logger.info(" 🔎 Looking for models...") + tasks = [download_hf_file(filename) for filename in MODEL_FILES] + await asyncio.gather(*tasks) + logger.info(" ✅ All models are available") + + # are you looking to debug the app and verify that models are downloaded properly? + # then un-comment the two following lines: + #logger.info("💡 Printing directory structure of models:") + #print_directory_structure(MODELS_DIR) + +class ModelLoader: + """A class responsible for loading and initializing all required models.""" + + def __init__(self): + self.device = DEVICE + self.models_dir = MODELS_DIR + + async def load_live_portrait(self): + """Load LivePortrait models.""" + from liveportrait.config.inference_config import InferenceConfig + from liveportrait.config.crop_config import CropConfig + from liveportrait.live_portrait_pipeline import LivePortraitPipeline + + logger.info(" ⏳ Loading LivePortrait models...") + live_portrait_pipeline = await asyncio.to_thread( + LivePortraitPipeline, + inference_cfg=InferenceConfig( + # default values + flag_stitching=True, # we recommend setting it to True! + flag_relative=True, # whether to use relative motion + flag_pasteback=True, # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space + flag_do_crop= True, # whether to crop the source portrait to the face-cropping space + flag_do_rot=True, # whether to conduct the rotation when flag_do_crop is True + ), + crop_cfg=CropConfig() + ) + logger.info(" ✅ LivePortrait models loaded successfully.") + return live_portrait_pipeline + +async def initialize_models(): + """Initialize and load all required models.""" + logger.info("🚀 Starting model initialization...") + + # Ensure all required models are downloaded + await download_all_models() + + # Initialize the ModelLoader + loader = ModelLoader() + + # Load LivePortrait models + live_portrait = await loader.load_live_portrait() + + logger.info("✅ Model initialization completed.") + return live_portrait + +# Initial setup +logger.info("🚀 Setting up storage directories...") +create_directory(MODELS_DIR) +logger.info("✅ Storage directories setup completed.") diff --git a/public/hf-logo.svg b/public/hf-logo.svg new file mode 100644 index 0000000000000000000000000000000000000000..43c5d3c0c97a9150b19d10bc68f4c9042c5034cd --- /dev/null +++ b/public/hf-logo.svg @@ -0,0 +1,37 @@ + + + + + + + + + + + diff --git a/public/index.html b/public/index.html new file mode 100644 index 0000000000000000000000000000000000000000..1f41cf6fdbb89ce439c948a8b816056c06b0555b --- /dev/null +++ b/public/index.html @@ -0,0 +1,13 @@ + + + + + + FacePoke by @jbilcke-hf + + + +
+ + + diff --git a/public/index.js b/public/index.js new file mode 100644 index 0000000000000000000000000000000000000000..ef702b3385ec524aaad921114acc60aa3258289e --- /dev/null +++ b/public/index.js @@ -0,0 +1,33512 @@ +var __create = Object.create; +var __defProp = Object.defineProperty; +var __getProtoOf = Object.getPrototypeOf; +var __getOwnPropNames = Object.getOwnPropertyNames; +var __hasOwnProp = Object.prototype.hasOwnProperty; +var __toESM = (mod, isNodeMode, target) => { + target = mod != null ? __create(__getProtoOf(mod)) : {}; + const to = isNodeMode || !mod || !mod.__esModule ? __defProp(target, "default", { value: mod, enumerable: true }) : target; + for (let key of __getOwnPropNames(mod)) + if (!__hasOwnProp.call(to, key)) + __defProp(to, key, { + get: () => mod[key], + enumerable: true + }); + return to; +}; +var __commonJS = (cb, mod) => () => (mod || cb((mod = { exports: {} }).exports, mod), mod.exports); +var __export = (target, all) => { + for (var name in all) + __defProp(target, name, { + get: all[name], + enumerable: true, + configurable: true, + set: (newValue) => all[name] = () => newValue + }); +}; + +// node_modules/react/cjs/react.development.js +var require_react_development = __commonJS((exports, module) => { + if (true) { + (function() { + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== "undefined" && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart === "function") { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart(new Error); + } + var ReactVersion = "18.3.1"; + var REACT_ELEMENT_TYPE = Symbol.for("react.element"); + var REACT_PORTAL_TYPE = Symbol.for("react.portal"); + var REACT_FRAGMENT_TYPE = Symbol.for("react.fragment"); + var REACT_STRICT_MODE_TYPE = Symbol.for("react.strict_mode"); + var REACT_PROFILER_TYPE = Symbol.for("react.profiler"); + var REACT_PROVIDER_TYPE = Symbol.for("react.provider"); + var REACT_CONTEXT_TYPE = Symbol.for("react.context"); + var REACT_FORWARD_REF_TYPE = Symbol.for("react.forward_ref"); + var REACT_SUSPENSE_TYPE = Symbol.for("react.suspense"); + var REACT_SUSPENSE_LIST_TYPE = Symbol.for("react.suspense_list"); + var REACT_MEMO_TYPE = Symbol.for("react.memo"); + var REACT_LAZY_TYPE = Symbol.for("react.lazy"); + var REACT_OFFSCREEN_TYPE = Symbol.for("react.offscreen"); + var MAYBE_ITERATOR_SYMBOL = Symbol.iterator; + var FAUX_ITERATOR_SYMBOL = "@@iterator"; + function getIteratorFn(maybeIterable) { + if (maybeIterable === null || typeof maybeIterable !== "object") { + return null; + } + var maybeIterator = MAYBE_ITERATOR_SYMBOL && maybeIterable[MAYBE_ITERATOR_SYMBOL] || maybeIterable[FAUX_ITERATOR_SYMBOL]; + if (typeof maybeIterator === "function") { + return maybeIterator; + } + return null; + } + var ReactCurrentDispatcher = { + current: null + }; + var ReactCurrentBatchConfig = { + transition: null + }; + var ReactCurrentActQueue = { + current: null, + isBatchingLegacy: false, + didScheduleLegacyUpdate: false + }; + var ReactCurrentOwner = { + current: null + }; + var ReactDebugCurrentFrame = {}; + var currentExtraStackFrame = null; + function setExtraStackFrame(stack) { + { + currentExtraStackFrame = stack; + } + } + { + ReactDebugCurrentFrame.setExtraStackFrame = function(stack) { + { + currentExtraStackFrame = stack; + } + }; + ReactDebugCurrentFrame.getCurrentStack = null; + ReactDebugCurrentFrame.getStackAddendum = function() { + var stack = ""; + if (currentExtraStackFrame) { + stack += currentExtraStackFrame; + } + var impl = ReactDebugCurrentFrame.getCurrentStack; + if (impl) { + stack += impl() || ""; + } + return stack; + }; + } + var enableScopeAPI = false; + var enableCacheElement = false; + var enableTransitionTracing = false; + var enableLegacyHidden = false; + var enableDebugTracing = false; + var ReactSharedInternals = { + ReactCurrentDispatcher, + ReactCurrentBatchConfig, + ReactCurrentOwner + }; + { + ReactSharedInternals.ReactDebugCurrentFrame = ReactDebugCurrentFrame; + ReactSharedInternals.ReactCurrentActQueue = ReactCurrentActQueue; + } + function warn(format) { + { + { + for (var _len = arguments.length, args = new Array(_len > 1 ? _len - 1 : 0), _key = 1;_key < _len; _key++) { + args[_key - 1] = arguments[_key]; + } + printWarning("warn", format, args); + } + } + } + function error(format) { + { + { + for (var _len2 = arguments.length, args = new Array(_len2 > 1 ? _len2 - 1 : 0), _key2 = 1;_key2 < _len2; _key2++) { + args[_key2 - 1] = arguments[_key2]; + } + printWarning("error", format, args); + } + } + } + function printWarning(level, format, args) { + { + var ReactDebugCurrentFrame2 = ReactSharedInternals.ReactDebugCurrentFrame; + var stack = ReactDebugCurrentFrame2.getStackAddendum(); + if (stack !== "") { + format += "%s"; + args = args.concat([stack]); + } + var argsWithFormat = args.map(function(item) { + return String(item); + }); + argsWithFormat.unshift("Warning: " + format); + Function.prototype.apply.call(console[level], console, argsWithFormat); + } + } + var didWarnStateUpdateForUnmountedComponent = {}; + function warnNoop(publicInstance, callerName) { + { + var _constructor = publicInstance.constructor; + var componentName = _constructor && (_constructor.displayName || _constructor.name) || "ReactClass"; + var warningKey = componentName + "." + callerName; + if (didWarnStateUpdateForUnmountedComponent[warningKey]) { + return; + } + error("Can't call %s on a component that is not yet mounted. This is a no-op, but it might indicate a bug in your application. Instead, assign to `this.state` directly or define a `state = {};` class property with the desired state in the %s component.", callerName, componentName); + didWarnStateUpdateForUnmountedComponent[warningKey] = true; + } + } + var ReactNoopUpdateQueue = { + isMounted: function(publicInstance) { + return false; + }, + enqueueForceUpdate: function(publicInstance, callback, callerName) { + warnNoop(publicInstance, "forceUpdate"); + }, + enqueueReplaceState: function(publicInstance, completeState, callback, callerName) { + warnNoop(publicInstance, "replaceState"); + }, + enqueueSetState: function(publicInstance, partialState, callback, callerName) { + warnNoop(publicInstance, "setState"); + } + }; + var assign = Object.assign; + var emptyObject = {}; + { + Object.freeze(emptyObject); + } + function Component(props, context, updater) { + this.props = props; + this.context = context; + this.refs = emptyObject; + this.updater = updater || ReactNoopUpdateQueue; + } + Component.prototype.isReactComponent = {}; + Component.prototype.setState = function(partialState, callback) { + if (typeof partialState !== "object" && typeof partialState !== "function" && partialState != null) { + throw new Error("setState(...): takes an object of state variables to update or a function which returns an object of state variables."); + } + this.updater.enqueueSetState(this, partialState, callback, "setState"); + }; + Component.prototype.forceUpdate = function(callback) { + this.updater.enqueueForceUpdate(this, callback, "forceUpdate"); + }; + { + var deprecatedAPIs = { + isMounted: ["isMounted", "Instead, make sure to clean up subscriptions and pending requests in componentWillUnmount to prevent memory leaks."], + replaceState: ["replaceState", "Refactor your code to use setState instead (see https://github.com/facebook/react/issues/3236)."] + }; + var defineDeprecationWarning = function(methodName, info) { + Object.defineProperty(Component.prototype, methodName, { + get: function() { + warn("%s(...) is deprecated in plain JavaScript React classes. %s", info[0], info[1]); + return; + } + }); + }; + for (var fnName in deprecatedAPIs) { + if (deprecatedAPIs.hasOwnProperty(fnName)) { + defineDeprecationWarning(fnName, deprecatedAPIs[fnName]); + } + } + } + function ComponentDummy() { + } + ComponentDummy.prototype = Component.prototype; + function PureComponent(props, context, updater) { + this.props = props; + this.context = context; + this.refs = emptyObject; + this.updater = updater || ReactNoopUpdateQueue; + } + var pureComponentPrototype = PureComponent.prototype = new ComponentDummy; + pureComponentPrototype.constructor = PureComponent; + assign(pureComponentPrototype, Component.prototype); + pureComponentPrototype.isPureReactComponent = true; + function createRef() { + var refObject = { + current: null + }; + { + Object.seal(refObject); + } + return refObject; + } + var isArrayImpl = Array.isArray; + function isArray(a) { + return isArrayImpl(a); + } + function typeName(value) { + { + var hasToStringTag = typeof Symbol === "function" && Symbol.toStringTag; + var type = hasToStringTag && value[Symbol.toStringTag] || value.constructor.name || "Object"; + return type; + } + } + function willCoercionThrow(value) { + { + try { + testStringCoercion(value); + return false; + } catch (e) { + return true; + } + } + } + function testStringCoercion(value) { + return "" + value; + } + function checkKeyStringCoercion(value) { + { + if (willCoercionThrow(value)) { + error("The provided key is an unsupported type %s. This value must be coerced to a string before before using it here.", typeName(value)); + return testStringCoercion(value); + } + } + } + function getWrappedName(outerType, innerType, wrapperName) { + var displayName = outerType.displayName; + if (displayName) { + return displayName; + } + var functionName = innerType.displayName || innerType.name || ""; + return functionName !== "" ? wrapperName + "(" + functionName + ")" : wrapperName; + } + function getContextName(type) { + return type.displayName || "Context"; + } + function getComponentNameFromType(type) { + if (type == null) { + return null; + } + { + if (typeof type.tag === "number") { + error("Received an unexpected object in getComponentNameFromType(). This is likely a bug in React. Please file an issue."); + } + } + if (typeof type === "function") { + return type.displayName || type.name || null; + } + if (typeof type === "string") { + return type; + } + switch (type) { + case REACT_FRAGMENT_TYPE: + return "Fragment"; + case REACT_PORTAL_TYPE: + return "Portal"; + case REACT_PROFILER_TYPE: + return "Profiler"; + case REACT_STRICT_MODE_TYPE: + return "StrictMode"; + case REACT_SUSPENSE_TYPE: + return "Suspense"; + case REACT_SUSPENSE_LIST_TYPE: + return "SuspenseList"; + } + if (typeof type === "object") { + switch (type.$$typeof) { + case REACT_CONTEXT_TYPE: + var context = type; + return getContextName(context) + ".Consumer"; + case REACT_PROVIDER_TYPE: + var provider = type; + return getContextName(provider._context) + ".Provider"; + case REACT_FORWARD_REF_TYPE: + return getWrappedName(type, type.render, "ForwardRef"); + case REACT_MEMO_TYPE: + var outerName = type.displayName || null; + if (outerName !== null) { + return outerName; + } + return getComponentNameFromType(type.type) || "Memo"; + case REACT_LAZY_TYPE: { + var lazyComponent = type; + var payload = lazyComponent._payload; + var init = lazyComponent._init; + try { + return getComponentNameFromType(init(payload)); + } catch (x) { + return null; + } + } + } + } + return null; + } + var hasOwnProperty = Object.prototype.hasOwnProperty; + var RESERVED_PROPS = { + key: true, + ref: true, + __self: true, + __source: true + }; + var specialPropKeyWarningShown, specialPropRefWarningShown, didWarnAboutStringRefs; + { + didWarnAboutStringRefs = {}; + } + function hasValidRef(config) { + { + if (hasOwnProperty.call(config, "ref")) { + var getter = Object.getOwnPropertyDescriptor(config, "ref").get; + if (getter && getter.isReactWarning) { + return false; + } + } + } + return config.ref !== undefined; + } + function hasValidKey(config) { + { + if (hasOwnProperty.call(config, "key")) { + var getter = Object.getOwnPropertyDescriptor(config, "key").get; + if (getter && getter.isReactWarning) { + return false; + } + } + } + return config.key !== undefined; + } + function defineKeyPropWarningGetter(props, displayName) { + var warnAboutAccessingKey = function() { + { + if (!specialPropKeyWarningShown) { + specialPropKeyWarningShown = true; + error("%s: `key` is not a prop. Trying to access it will result in `undefined` being returned. If you need to access the same value within the child component, you should pass it as a different prop. (https://reactjs.org/link/special-props)", displayName); + } + } + }; + warnAboutAccessingKey.isReactWarning = true; + Object.defineProperty(props, "key", { + get: warnAboutAccessingKey, + configurable: true + }); + } + function defineRefPropWarningGetter(props, displayName) { + var warnAboutAccessingRef = function() { + { + if (!specialPropRefWarningShown) { + specialPropRefWarningShown = true; + error("%s: `ref` is not a prop. Trying to access it will result in `undefined` being returned. If you need to access the same value within the child component, you should pass it as a different prop. (https://reactjs.org/link/special-props)", displayName); + } + } + }; + warnAboutAccessingRef.isReactWarning = true; + Object.defineProperty(props, "ref", { + get: warnAboutAccessingRef, + configurable: true + }); + } + function warnIfStringRefCannotBeAutoConverted(config) { + { + if (typeof config.ref === "string" && ReactCurrentOwner.current && config.__self && ReactCurrentOwner.current.stateNode !== config.__self) { + var componentName = getComponentNameFromType(ReactCurrentOwner.current.type); + if (!didWarnAboutStringRefs[componentName]) { + error('Component "%s" contains the string ref "%s". Support for string refs will be removed in a future major release. This case cannot be automatically converted to an arrow function. We ask you to manually fix this case by using useRef() or createRef() instead. Learn more about using refs safely here: https://reactjs.org/link/strict-mode-string-ref', componentName, config.ref); + didWarnAboutStringRefs[componentName] = true; + } + } + } + } + var ReactElement = function(type, key, ref, self2, source, owner, props) { + var element = { + $$typeof: REACT_ELEMENT_TYPE, + type, + key, + ref, + props, + _owner: owner + }; + { + element._store = {}; + Object.defineProperty(element._store, "validated", { + configurable: false, + enumerable: false, + writable: true, + value: false + }); + Object.defineProperty(element, "_self", { + configurable: false, + enumerable: false, + writable: false, + value: self2 + }); + Object.defineProperty(element, "_source", { + configurable: false, + enumerable: false, + writable: false, + value: source + }); + if (Object.freeze) { + Object.freeze(element.props); + Object.freeze(element); + } + } + return element; + }; + function createElement(type, config, children) { + var propName; + var props = {}; + var key = null; + var ref = null; + var self2 = null; + var source = null; + if (config != null) { + if (hasValidRef(config)) { + ref = config.ref; + { + warnIfStringRefCannotBeAutoConverted(config); + } + } + if (hasValidKey(config)) { + { + checkKeyStringCoercion(config.key); + } + key = "" + config.key; + } + self2 = config.__self === undefined ? null : config.__self; + source = config.__source === undefined ? null : config.__source; + for (propName in config) { + if (hasOwnProperty.call(config, propName) && !RESERVED_PROPS.hasOwnProperty(propName)) { + props[propName] = config[propName]; + } + } + } + var childrenLength = arguments.length - 2; + if (childrenLength === 1) { + props.children = children; + } else if (childrenLength > 1) { + var childArray = Array(childrenLength); + for (var i = 0;i < childrenLength; i++) { + childArray[i] = arguments[i + 2]; + } + { + if (Object.freeze) { + Object.freeze(childArray); + } + } + props.children = childArray; + } + if (type && type.defaultProps) { + var defaultProps = type.defaultProps; + for (propName in defaultProps) { + if (props[propName] === undefined) { + props[propName] = defaultProps[propName]; + } + } + } + { + if (key || ref) { + var displayName = typeof type === "function" ? type.displayName || type.name || "Unknown" : type; + if (key) { + defineKeyPropWarningGetter(props, displayName); + } + if (ref) { + defineRefPropWarningGetter(props, displayName); + } + } + } + return ReactElement(type, key, ref, self2, source, ReactCurrentOwner.current, props); + } + function cloneAndReplaceKey(oldElement, newKey) { + var newElement = ReactElement(oldElement.type, newKey, oldElement.ref, oldElement._self, oldElement._source, oldElement._owner, oldElement.props); + return newElement; + } + function cloneElement(element, config, children) { + if (element === null || element === undefined) { + throw new Error("React.cloneElement(...): The argument must be a React element, but you passed " + element + "."); + } + var propName; + var props = assign({}, element.props); + var key = element.key; + var ref = element.ref; + var self2 = element._self; + var source = element._source; + var owner = element._owner; + if (config != null) { + if (hasValidRef(config)) { + ref = config.ref; + owner = ReactCurrentOwner.current; + } + if (hasValidKey(config)) { + { + checkKeyStringCoercion(config.key); + } + key = "" + config.key; + } + var defaultProps; + if (element.type && element.type.defaultProps) { + defaultProps = element.type.defaultProps; + } + for (propName in config) { + if (hasOwnProperty.call(config, propName) && !RESERVED_PROPS.hasOwnProperty(propName)) { + if (config[propName] === undefined && defaultProps !== undefined) { + props[propName] = defaultProps[propName]; + } else { + props[propName] = config[propName]; + } + } + } + } + var childrenLength = arguments.length - 2; + if (childrenLength === 1) { + props.children = children; + } else if (childrenLength > 1) { + var childArray = Array(childrenLength); + for (var i = 0;i < childrenLength; i++) { + childArray[i] = arguments[i + 2]; + } + props.children = childArray; + } + return ReactElement(element.type, key, ref, self2, source, owner, props); + } + function isValidElement(object) { + return typeof object === "object" && object !== null && object.$$typeof === REACT_ELEMENT_TYPE; + } + var SEPARATOR = "."; + var SUBSEPARATOR = ":"; + function escape(key) { + var escapeRegex = /[=:]/g; + var escaperLookup = { + "=": "=0", + ":": "=2" + }; + var escapedString = key.replace(escapeRegex, function(match) { + return escaperLookup[match]; + }); + return "$" + escapedString; + } + var didWarnAboutMaps = false; + var userProvidedKeyEscapeRegex = /\/+/g; + function escapeUserProvidedKey(text) { + return text.replace(userProvidedKeyEscapeRegex, "$&/"); + } + function getElementKey(element, index) { + if (typeof element === "object" && element !== null && element.key != null) { + { + checkKeyStringCoercion(element.key); + } + return escape("" + element.key); + } + return index.toString(36); + } + function mapIntoArray(children, array, escapedPrefix, nameSoFar, callback) { + var type = typeof children; + if (type === "undefined" || type === "boolean") { + children = null; + } + var invokeCallback = false; + if (children === null) { + invokeCallback = true; + } else { + switch (type) { + case "string": + case "number": + invokeCallback = true; + break; + case "object": + switch (children.$$typeof) { + case REACT_ELEMENT_TYPE: + case REACT_PORTAL_TYPE: + invokeCallback = true; + } + } + } + if (invokeCallback) { + var _child = children; + var mappedChild = callback(_child); + var childKey = nameSoFar === "" ? SEPARATOR + getElementKey(_child, 0) : nameSoFar; + if (isArray(mappedChild)) { + var escapedChildKey = ""; + if (childKey != null) { + escapedChildKey = escapeUserProvidedKey(childKey) + "/"; + } + mapIntoArray(mappedChild, array, escapedChildKey, "", function(c) { + return c; + }); + } else if (mappedChild != null) { + if (isValidElement(mappedChild)) { + { + if (mappedChild.key && (!_child || _child.key !== mappedChild.key)) { + checkKeyStringCoercion(mappedChild.key); + } + } + mappedChild = cloneAndReplaceKey(mappedChild, escapedPrefix + (mappedChild.key && (!_child || _child.key !== mappedChild.key) ? escapeUserProvidedKey("" + mappedChild.key) + "/" : "") + childKey); + } + array.push(mappedChild); + } + return 1; + } + var child; + var nextName; + var subtreeCount = 0; + var nextNamePrefix = nameSoFar === "" ? SEPARATOR : nameSoFar + SUBSEPARATOR; + if (isArray(children)) { + for (var i = 0;i < children.length; i++) { + child = children[i]; + nextName = nextNamePrefix + getElementKey(child, i); + subtreeCount += mapIntoArray(child, array, escapedPrefix, nextName, callback); + } + } else { + var iteratorFn = getIteratorFn(children); + if (typeof iteratorFn === "function") { + var iterableChildren = children; + { + if (iteratorFn === iterableChildren.entries) { + if (!didWarnAboutMaps) { + warn("Using Maps as children is not supported. Use an array of keyed ReactElements instead."); + } + didWarnAboutMaps = true; + } + } + var iterator = iteratorFn.call(iterableChildren); + var step; + var ii = 0; + while (!(step = iterator.next()).done) { + child = step.value; + nextName = nextNamePrefix + getElementKey(child, ii++); + subtreeCount += mapIntoArray(child, array, escapedPrefix, nextName, callback); + } + } else if (type === "object") { + var childrenString = String(children); + throw new Error("Objects are not valid as a React child (found: " + (childrenString === "[object Object]" ? "object with keys {" + Object.keys(children).join(", ") + "}" : childrenString) + "). If you meant to render a collection of children, use an array instead."); + } + } + return subtreeCount; + } + function mapChildren(children, func, context) { + if (children == null) { + return children; + } + var result = []; + var count = 0; + mapIntoArray(children, result, "", "", function(child) { + return func.call(context, child, count++); + }); + return result; + } + function countChildren(children) { + var n = 0; + mapChildren(children, function() { + n++; + }); + return n; + } + function forEachChildren(children, forEachFunc, forEachContext) { + mapChildren(children, function() { + forEachFunc.apply(this, arguments); + }, forEachContext); + } + function toArray(children) { + return mapChildren(children, function(child) { + return child; + }) || []; + } + function onlyChild(children) { + if (!isValidElement(children)) { + throw new Error("React.Children.only expected to receive a single React element child."); + } + return children; + } + function createContext(defaultValue) { + var context = { + $$typeof: REACT_CONTEXT_TYPE, + _currentValue: defaultValue, + _currentValue2: defaultValue, + _threadCount: 0, + Provider: null, + Consumer: null, + _defaultValue: null, + _globalName: null + }; + context.Provider = { + $$typeof: REACT_PROVIDER_TYPE, + _context: context + }; + var hasWarnedAboutUsingNestedContextConsumers = false; + var hasWarnedAboutUsingConsumerProvider = false; + var hasWarnedAboutDisplayNameOnConsumer = false; + { + var Consumer = { + $$typeof: REACT_CONTEXT_TYPE, + _context: context + }; + Object.defineProperties(Consumer, { + Provider: { + get: function() { + if (!hasWarnedAboutUsingConsumerProvider) { + hasWarnedAboutUsingConsumerProvider = true; + error("Rendering is not supported and will be removed in a future major release. Did you mean to render instead?"); + } + return context.Provider; + }, + set: function(_Provider) { + context.Provider = _Provider; + } + }, + _currentValue: { + get: function() { + return context._currentValue; + }, + set: function(_currentValue) { + context._currentValue = _currentValue; + } + }, + _currentValue2: { + get: function() { + return context._currentValue2; + }, + set: function(_currentValue2) { + context._currentValue2 = _currentValue2; + } + }, + _threadCount: { + get: function() { + return context._threadCount; + }, + set: function(_threadCount) { + context._threadCount = _threadCount; + } + }, + Consumer: { + get: function() { + if (!hasWarnedAboutUsingNestedContextConsumers) { + hasWarnedAboutUsingNestedContextConsumers = true; + error("Rendering is not supported and will be removed in a future major release. Did you mean to render instead?"); + } + return context.Consumer; + } + }, + displayName: { + get: function() { + return context.displayName; + }, + set: function(displayName) { + if (!hasWarnedAboutDisplayNameOnConsumer) { + warn("Setting `displayName` on Context.Consumer has no effect. You should set it directly on the context with Context.displayName = '%s'.", displayName); + hasWarnedAboutDisplayNameOnConsumer = true; + } + } + } + }); + context.Consumer = Consumer; + } + { + context._currentRenderer = null; + context._currentRenderer2 = null; + } + return context; + } + var Uninitialized = -1; + var Pending = 0; + var Resolved = 1; + var Rejected = 2; + function lazyInitializer(payload) { + if (payload._status === Uninitialized) { + var ctor = payload._result; + var thenable = ctor(); + thenable.then(function(moduleObject2) { + if (payload._status === Pending || payload._status === Uninitialized) { + var resolved = payload; + resolved._status = Resolved; + resolved._result = moduleObject2; + } + }, function(error2) { + if (payload._status === Pending || payload._status === Uninitialized) { + var rejected = payload; + rejected._status = Rejected; + rejected._result = error2; + } + }); + if (payload._status === Uninitialized) { + var pending = payload; + pending._status = Pending; + pending._result = thenable; + } + } + if (payload._status === Resolved) { + var moduleObject = payload._result; + { + if (moduleObject === undefined) { + error("lazy: Expected the result of a dynamic import() call. Instead received: %s\n\nYour code should look like: \n const MyComponent = lazy(() => import('./MyComponent'))\n\nDid you accidentally put curly braces around the import?", moduleObject); + } + } + { + if (!("default" in moduleObject)) { + error("lazy: Expected the result of a dynamic import() call. Instead received: %s\n\nYour code should look like: \n const MyComponent = lazy(() => import('./MyComponent'))", moduleObject); + } + } + return moduleObject.default; + } else { + throw payload._result; + } + } + function lazy(ctor) { + var payload = { + _status: Uninitialized, + _result: ctor + }; + var lazyType = { + $$typeof: REACT_LAZY_TYPE, + _payload: payload, + _init: lazyInitializer + }; + { + var defaultProps; + var propTypes; + Object.defineProperties(lazyType, { + defaultProps: { + configurable: true, + get: function() { + return defaultProps; + }, + set: function(newDefaultProps) { + error("React.lazy(...): It is not supported to assign `defaultProps` to a lazy component import. Either specify them where the component is defined, or create a wrapping component around it."); + defaultProps = newDefaultProps; + Object.defineProperty(lazyType, "defaultProps", { + enumerable: true + }); + } + }, + propTypes: { + configurable: true, + get: function() { + return propTypes; + }, + set: function(newPropTypes) { + error("React.lazy(...): It is not supported to assign `propTypes` to a lazy component import. Either specify them where the component is defined, or create a wrapping component around it."); + propTypes = newPropTypes; + Object.defineProperty(lazyType, "propTypes", { + enumerable: true + }); + } + } + }); + } + return lazyType; + } + function forwardRef(render) { + { + if (render != null && render.$$typeof === REACT_MEMO_TYPE) { + error("forwardRef requires a render function but received a `memo` component. Instead of forwardRef(memo(...)), use memo(forwardRef(...))."); + } else if (typeof render !== "function") { + error("forwardRef requires a render function but was given %s.", render === null ? "null" : typeof render); + } else { + if (render.length !== 0 && render.length !== 2) { + error("forwardRef render functions accept exactly two parameters: props and ref. %s", render.length === 1 ? "Did you forget to use the ref parameter?" : "Any additional parameter will be undefined."); + } + } + if (render != null) { + if (render.defaultProps != null || render.propTypes != null) { + error("forwardRef render functions do not support propTypes or defaultProps. Did you accidentally pass a React component?"); + } + } + } + var elementType = { + $$typeof: REACT_FORWARD_REF_TYPE, + render + }; + { + var ownName; + Object.defineProperty(elementType, "displayName", { + enumerable: false, + configurable: true, + get: function() { + return ownName; + }, + set: function(name) { + ownName = name; + if (!render.name && !render.displayName) { + render.displayName = name; + } + } + }); + } + return elementType; + } + var REACT_MODULE_REFERENCE; + { + REACT_MODULE_REFERENCE = Symbol.for("react.module.reference"); + } + function isValidElementType(type) { + if (typeof type === "string" || typeof type === "function") { + return true; + } + if (type === REACT_FRAGMENT_TYPE || type === REACT_PROFILER_TYPE || enableDebugTracing || type === REACT_STRICT_MODE_TYPE || type === REACT_SUSPENSE_TYPE || type === REACT_SUSPENSE_LIST_TYPE || enableLegacyHidden || type === REACT_OFFSCREEN_TYPE || enableScopeAPI || enableCacheElement || enableTransitionTracing) { + return true; + } + if (typeof type === "object" && type !== null) { + if (type.$$typeof === REACT_LAZY_TYPE || type.$$typeof === REACT_MEMO_TYPE || type.$$typeof === REACT_PROVIDER_TYPE || type.$$typeof === REACT_CONTEXT_TYPE || type.$$typeof === REACT_FORWARD_REF_TYPE || type.$$typeof === REACT_MODULE_REFERENCE || type.getModuleId !== undefined) { + return true; + } + } + return false; + } + function memo(type, compare) { + { + if (!isValidElementType(type)) { + error("memo: The first argument must be a component. Instead received: %s", type === null ? "null" : typeof type); + } + } + var elementType = { + $$typeof: REACT_MEMO_TYPE, + type, + compare: compare === undefined ? null : compare + }; + { + var ownName; + Object.defineProperty(elementType, "displayName", { + enumerable: false, + configurable: true, + get: function() { + return ownName; + }, + set: function(name) { + ownName = name; + if (!type.name && !type.displayName) { + type.displayName = name; + } + } + }); + } + return elementType; + } + function resolveDispatcher() { + var dispatcher = ReactCurrentDispatcher.current; + { + if (dispatcher === null) { + error("Invalid hook call. Hooks can only be called inside of the body of a function component. This could happen for one of the following reasons:\n1. You might have mismatching versions of React and the renderer (such as React DOM)\n2. You might be breaking the Rules of Hooks\n3. You might have more than one copy of React in the same app\nSee https://reactjs.org/link/invalid-hook-call for tips about how to debug and fix this problem."); + } + } + return dispatcher; + } + function useContext(Context) { + var dispatcher = resolveDispatcher(); + { + if (Context._context !== undefined) { + var realContext = Context._context; + if (realContext.Consumer === Context) { + error("Calling useContext(Context.Consumer) is not supported, may cause bugs, and will be removed in a future major release. Did you mean to call useContext(Context) instead?"); + } else if (realContext.Provider === Context) { + error("Calling useContext(Context.Provider) is not supported. Did you mean to call useContext(Context) instead?"); + } + } + } + return dispatcher.useContext(Context); + } + function useState(initialState) { + var dispatcher = resolveDispatcher(); + return dispatcher.useState(initialState); + } + function useReducer(reducer, initialArg, init) { + var dispatcher = resolveDispatcher(); + return dispatcher.useReducer(reducer, initialArg, init); + } + function useRef(initialValue) { + var dispatcher = resolveDispatcher(); + return dispatcher.useRef(initialValue); + } + function useEffect(create, deps) { + var dispatcher = resolveDispatcher(); + return dispatcher.useEffect(create, deps); + } + function useInsertionEffect(create, deps) { + var dispatcher = resolveDispatcher(); + return dispatcher.useInsertionEffect(create, deps); + } + function useLayoutEffect(create, deps) { + var dispatcher = resolveDispatcher(); + return dispatcher.useLayoutEffect(create, deps); + } + function useCallback(callback, deps) { + var dispatcher = resolveDispatcher(); + return dispatcher.useCallback(callback, deps); + } + function useMemo(create, deps) { + var dispatcher = resolveDispatcher(); + return dispatcher.useMemo(create, deps); + } + function useImperativeHandle(ref, create, deps) { + var dispatcher = resolveDispatcher(); + return dispatcher.useImperativeHandle(ref, create, deps); + } + function useDebugValue(value, formatterFn) { + { + var dispatcher = resolveDispatcher(); + return dispatcher.useDebugValue(value, formatterFn); + } + } + function useTransition() { + var dispatcher = resolveDispatcher(); + return dispatcher.useTransition(); + } + function useDeferredValue(value) { + var dispatcher = resolveDispatcher(); + return dispatcher.useDeferredValue(value); + } + function useId() { + var dispatcher = resolveDispatcher(); + return dispatcher.useId(); + } + function useSyncExternalStore(subscribe, getSnapshot, getServerSnapshot) { + var dispatcher = resolveDispatcher(); + return dispatcher.useSyncExternalStore(subscribe, getSnapshot, getServerSnapshot); + } + var disabledDepth = 0; + var prevLog; + var prevInfo; + var prevWarn; + var prevError; + var prevGroup; + var prevGroupCollapsed; + var prevGroupEnd; + function disabledLog() { + } + disabledLog.__reactDisabledLog = true; + function disableLogs() { + { + if (disabledDepth === 0) { + prevLog = console.log; + prevInfo = console.info; + prevWarn = console.warn; + prevError = console.error; + prevGroup = console.group; + prevGroupCollapsed = console.groupCollapsed; + prevGroupEnd = console.groupEnd; + var props = { + configurable: true, + enumerable: true, + value: disabledLog, + writable: true + }; + Object.defineProperties(console, { + info: props, + log: props, + warn: props, + error: props, + group: props, + groupCollapsed: props, + groupEnd: props + }); + } + disabledDepth++; + } + } + function reenableLogs() { + { + disabledDepth--; + if (disabledDepth === 0) { + var props = { + configurable: true, + enumerable: true, + writable: true + }; + Object.defineProperties(console, { + log: assign({}, props, { + value: prevLog + }), + info: assign({}, props, { + value: prevInfo + }), + warn: assign({}, props, { + value: prevWarn + }), + error: assign({}, props, { + value: prevError + }), + group: assign({}, props, { + value: prevGroup + }), + groupCollapsed: assign({}, props, { + value: prevGroupCollapsed + }), + groupEnd: assign({}, props, { + value: prevGroupEnd + }) + }); + } + if (disabledDepth < 0) { + error("disabledDepth fell below zero. This is a bug in React. Please file an issue."); + } + } + } + var ReactCurrentDispatcher$1 = ReactSharedInternals.ReactCurrentDispatcher; + var prefix; + function describeBuiltInComponentFrame(name, source, ownerFn) { + { + if (prefix === undefined) { + try { + throw Error(); + } catch (x) { + var match = x.stack.trim().match(/\n( *(at )?)/); + prefix = match && match[1] || ""; + } + } + return "\n" + prefix + name; + } + } + var reentry = false; + var componentFrameCache; + { + var PossiblyWeakMap = typeof WeakMap === "function" ? WeakMap : Map; + componentFrameCache = new PossiblyWeakMap; + } + function describeNativeComponentFrame(fn, construct) { + if (!fn || reentry) { + return ""; + } + { + var frame = componentFrameCache.get(fn); + if (frame !== undefined) { + return frame; + } + } + var control; + reentry = true; + var previousPrepareStackTrace = Error.prepareStackTrace; + Error.prepareStackTrace = undefined; + var previousDispatcher; + { + previousDispatcher = ReactCurrentDispatcher$1.current; + ReactCurrentDispatcher$1.current = null; + disableLogs(); + } + try { + if (construct) { + var Fake = function() { + throw Error(); + }; + Object.defineProperty(Fake.prototype, "props", { + set: function() { + throw Error(); + } + }); + if (typeof Reflect === "object" && Reflect.construct) { + try { + Reflect.construct(Fake, []); + } catch (x) { + control = x; + } + Reflect.construct(fn, [], Fake); + } else { + try { + Fake.call(); + } catch (x) { + control = x; + } + fn.call(Fake.prototype); + } + } else { + try { + throw Error(); + } catch (x) { + control = x; + } + fn(); + } + } catch (sample) { + if (sample && control && typeof sample.stack === "string") { + var sampleLines = sample.stack.split("\n"); + var controlLines = control.stack.split("\n"); + var s = sampleLines.length - 1; + var c = controlLines.length - 1; + while (s >= 1 && c >= 0 && sampleLines[s] !== controlLines[c]) { + c--; + } + for (;s >= 1 && c >= 0; s--, c--) { + if (sampleLines[s] !== controlLines[c]) { + if (s !== 1 || c !== 1) { + do { + s--; + c--; + if (c < 0 || sampleLines[s] !== controlLines[c]) { + var _frame = "\n" + sampleLines[s].replace(" at new ", " at "); + if (fn.displayName && _frame.includes("")) { + _frame = _frame.replace("", fn.displayName); + } + { + if (typeof fn === "function") { + componentFrameCache.set(fn, _frame); + } + } + return _frame; + } + } while (s >= 1 && c >= 0); + } + break; + } + } + } + } finally { + reentry = false; + { + ReactCurrentDispatcher$1.current = previousDispatcher; + reenableLogs(); + } + Error.prepareStackTrace = previousPrepareStackTrace; + } + var name = fn ? fn.displayName || fn.name : ""; + var syntheticFrame = name ? describeBuiltInComponentFrame(name) : ""; + { + if (typeof fn === "function") { + componentFrameCache.set(fn, syntheticFrame); + } + } + return syntheticFrame; + } + function describeFunctionComponentFrame(fn, source, ownerFn) { + { + return describeNativeComponentFrame(fn, false); + } + } + function shouldConstruct(Component2) { + var prototype = Component2.prototype; + return !!(prototype && prototype.isReactComponent); + } + function describeUnknownElementTypeFrameInDEV(type, source, ownerFn) { + if (type == null) { + return ""; + } + if (typeof type === "function") { + { + return describeNativeComponentFrame(type, shouldConstruct(type)); + } + } + if (typeof type === "string") { + return describeBuiltInComponentFrame(type); + } + switch (type) { + case REACT_SUSPENSE_TYPE: + return describeBuiltInComponentFrame("Suspense"); + case REACT_SUSPENSE_LIST_TYPE: + return describeBuiltInComponentFrame("SuspenseList"); + } + if (typeof type === "object") { + switch (type.$$typeof) { + case REACT_FORWARD_REF_TYPE: + return describeFunctionComponentFrame(type.render); + case REACT_MEMO_TYPE: + return describeUnknownElementTypeFrameInDEV(type.type, source, ownerFn); + case REACT_LAZY_TYPE: { + var lazyComponent = type; + var payload = lazyComponent._payload; + var init = lazyComponent._init; + try { + return describeUnknownElementTypeFrameInDEV(init(payload), source, ownerFn); + } catch (x) { + } + } + } + } + return ""; + } + var loggedTypeFailures = {}; + var ReactDebugCurrentFrame$1 = ReactSharedInternals.ReactDebugCurrentFrame; + function setCurrentlyValidatingElement(element) { + { + if (element) { + var owner = element._owner; + var stack = describeUnknownElementTypeFrameInDEV(element.type, element._source, owner ? owner.type : null); + ReactDebugCurrentFrame$1.setExtraStackFrame(stack); + } else { + ReactDebugCurrentFrame$1.setExtraStackFrame(null); + } + } + } + function checkPropTypes(typeSpecs, values, location, componentName, element) { + { + var has = Function.call.bind(hasOwnProperty); + for (var typeSpecName in typeSpecs) { + if (has(typeSpecs, typeSpecName)) { + var error$1 = undefined; + try { + if (typeof typeSpecs[typeSpecName] !== "function") { + var err = Error((componentName || "React class") + ": " + location + " type `" + typeSpecName + "` is invalid; it must be a function, usually from the `prop-types` package, but received `" + typeof typeSpecs[typeSpecName] + "`.This often happens because of typos such as `PropTypes.function` instead of `PropTypes.func`."); + err.name = "Invariant Violation"; + throw err; + } + error$1 = typeSpecs[typeSpecName](values, typeSpecName, componentName, location, null, "SECRET_DO_NOT_PASS_THIS_OR_YOU_WILL_BE_FIRED"); + } catch (ex) { + error$1 = ex; + } + if (error$1 && !(error$1 instanceof Error)) { + setCurrentlyValidatingElement(element); + error("%s: type specification of %s `%s` is invalid; the type checker function must return `null` or an `Error` but returned a %s. You may have forgotten to pass an argument to the type checker creator (arrayOf, instanceOf, objectOf, oneOf, oneOfType, and shape all require an argument).", componentName || "React class", location, typeSpecName, typeof error$1); + setCurrentlyValidatingElement(null); + } + if (error$1 instanceof Error && !(error$1.message in loggedTypeFailures)) { + loggedTypeFailures[error$1.message] = true; + setCurrentlyValidatingElement(element); + error("Failed %s type: %s", location, error$1.message); + setCurrentlyValidatingElement(null); + } + } + } + } + } + function setCurrentlyValidatingElement$1(element) { + { + if (element) { + var owner = element._owner; + var stack = describeUnknownElementTypeFrameInDEV(element.type, element._source, owner ? owner.type : null); + setExtraStackFrame(stack); + } else { + setExtraStackFrame(null); + } + } + } + var propTypesMisspellWarningShown; + { + propTypesMisspellWarningShown = false; + } + function getDeclarationErrorAddendum() { + if (ReactCurrentOwner.current) { + var name = getComponentNameFromType(ReactCurrentOwner.current.type); + if (name) { + return "\n\nCheck the render method of `" + name + "`."; + } + } + return ""; + } + function getSourceInfoErrorAddendum(source) { + if (source !== undefined) { + var fileName = source.fileName.replace(/^.*[\\\/]/, ""); + var lineNumber = source.lineNumber; + return "\n\nCheck your code at " + fileName + ":" + lineNumber + "."; + } + return ""; + } + function getSourceInfoErrorAddendumForProps(elementProps) { + if (elementProps !== null && elementProps !== undefined) { + return getSourceInfoErrorAddendum(elementProps.__source); + } + return ""; + } + var ownerHasKeyUseWarning = {}; + function getCurrentComponentErrorInfo(parentType) { + var info = getDeclarationErrorAddendum(); + if (!info) { + var parentName = typeof parentType === "string" ? parentType : parentType.displayName || parentType.name; + if (parentName) { + info = "\n\nCheck the top-level render call using <" + parentName + ">."; + } + } + return info; + } + function validateExplicitKey(element, parentType) { + if (!element._store || element._store.validated || element.key != null) { + return; + } + element._store.validated = true; + var currentComponentErrorInfo = getCurrentComponentErrorInfo(parentType); + if (ownerHasKeyUseWarning[currentComponentErrorInfo]) { + return; + } + ownerHasKeyUseWarning[currentComponentErrorInfo] = true; + var childOwner = ""; + if (element && element._owner && element._owner !== ReactCurrentOwner.current) { + childOwner = " It was passed a child from " + getComponentNameFromType(element._owner.type) + "."; + } + { + setCurrentlyValidatingElement$1(element); + error('Each child in a list should have a unique "key" prop.%s%s See https://reactjs.org/link/warning-keys for more information.', currentComponentErrorInfo, childOwner); + setCurrentlyValidatingElement$1(null); + } + } + function validateChildKeys(node, parentType) { + if (typeof node !== "object") { + return; + } + if (isArray(node)) { + for (var i = 0;i < node.length; i++) { + var child = node[i]; + if (isValidElement(child)) { + validateExplicitKey(child, parentType); + } + } + } else if (isValidElement(node)) { + if (node._store) { + node._store.validated = true; + } + } else if (node) { + var iteratorFn = getIteratorFn(node); + if (typeof iteratorFn === "function") { + if (iteratorFn !== node.entries) { + var iterator = iteratorFn.call(node); + var step; + while (!(step = iterator.next()).done) { + if (isValidElement(step.value)) { + validateExplicitKey(step.value, parentType); + } + } + } + } + } + } + function validatePropTypes(element) { + { + var type = element.type; + if (type === null || type === undefined || typeof type === "string") { + return; + } + var propTypes; + if (typeof type === "function") { + propTypes = type.propTypes; + } else if (typeof type === "object" && (type.$$typeof === REACT_FORWARD_REF_TYPE || type.$$typeof === REACT_MEMO_TYPE)) { + propTypes = type.propTypes; + } else { + return; + } + if (propTypes) { + var name = getComponentNameFromType(type); + checkPropTypes(propTypes, element.props, "prop", name, element); + } else if (type.PropTypes !== undefined && !propTypesMisspellWarningShown) { + propTypesMisspellWarningShown = true; + var _name = getComponentNameFromType(type); + error("Component %s declared `PropTypes` instead of `propTypes`. Did you misspell the property assignment?", _name || "Unknown"); + } + if (typeof type.getDefaultProps === "function" && !type.getDefaultProps.isReactClassApproved) { + error("getDefaultProps is only used on classic React.createClass definitions. Use a static property named `defaultProps` instead."); + } + } + } + function validateFragmentProps(fragment) { + { + var keys = Object.keys(fragment.props); + for (var i = 0;i < keys.length; i++) { + var key = keys[i]; + if (key !== "children" && key !== "key") { + setCurrentlyValidatingElement$1(fragment); + error("Invalid prop `%s` supplied to `React.Fragment`. React.Fragment can only have `key` and `children` props.", key); + setCurrentlyValidatingElement$1(null); + break; + } + } + if (fragment.ref !== null) { + setCurrentlyValidatingElement$1(fragment); + error("Invalid attribute `ref` supplied to `React.Fragment`."); + setCurrentlyValidatingElement$1(null); + } + } + } + function createElementWithValidation(type, props, children) { + var validType = isValidElementType(type); + if (!validType) { + var info = ""; + if (type === undefined || typeof type === "object" && type !== null && Object.keys(type).length === 0) { + info += " You likely forgot to export your component from the file it's defined in, or you might have mixed up default and named imports."; + } + var sourceInfo = getSourceInfoErrorAddendumForProps(props); + if (sourceInfo) { + info += sourceInfo; + } else { + info += getDeclarationErrorAddendum(); + } + var typeString; + if (type === null) { + typeString = "null"; + } else if (isArray(type)) { + typeString = "array"; + } else if (type !== undefined && type.$$typeof === REACT_ELEMENT_TYPE) { + typeString = "<" + (getComponentNameFromType(type.type) || "Unknown") + " />"; + info = " Did you accidentally export a JSX literal instead of a component?"; + } else { + typeString = typeof type; + } + { + error("React.createElement: type is invalid -- expected a string (for built-in components) or a class/function (for composite components) but got: %s.%s", typeString, info); + } + } + var element = createElement.apply(this, arguments); + if (element == null) { + return element; + } + if (validType) { + for (var i = 2;i < arguments.length; i++) { + validateChildKeys(arguments[i], type); + } + } + if (type === REACT_FRAGMENT_TYPE) { + validateFragmentProps(element); + } else { + validatePropTypes(element); + } + return element; + } + var didWarnAboutDeprecatedCreateFactory = false; + function createFactoryWithValidation(type) { + var validatedFactory = createElementWithValidation.bind(null, type); + validatedFactory.type = type; + { + if (!didWarnAboutDeprecatedCreateFactory) { + didWarnAboutDeprecatedCreateFactory = true; + warn("React.createFactory() is deprecated and will be removed in a future major release. Consider using JSX or use React.createElement() directly instead."); + } + Object.defineProperty(validatedFactory, "type", { + enumerable: false, + get: function() { + warn("Factory.type is deprecated. Access the class directly before passing it to createFactory."); + Object.defineProperty(this, "type", { + value: type + }); + return type; + } + }); + } + return validatedFactory; + } + function cloneElementWithValidation(element, props, children) { + var newElement = cloneElement.apply(this, arguments); + for (var i = 2;i < arguments.length; i++) { + validateChildKeys(arguments[i], newElement.type); + } + validatePropTypes(newElement); + return newElement; + } + function startTransition(scope, options) { + var prevTransition = ReactCurrentBatchConfig.transition; + ReactCurrentBatchConfig.transition = {}; + var currentTransition = ReactCurrentBatchConfig.transition; + { + ReactCurrentBatchConfig.transition._updatedFibers = new Set; + } + try { + scope(); + } finally { + ReactCurrentBatchConfig.transition = prevTransition; + { + if (prevTransition === null && currentTransition._updatedFibers) { + var updatedFibersCount = currentTransition._updatedFibers.size; + if (updatedFibersCount > 10) { + warn("Detected a large number of updates inside startTransition. If this is due to a subscription please re-write it to use React provided hooks. Otherwise concurrent mode guarantees are off the table."); + } + currentTransition._updatedFibers.clear(); + } + } + } + } + var didWarnAboutMessageChannel = false; + var enqueueTaskImpl = null; + function enqueueTask(task) { + if (enqueueTaskImpl === null) { + try { + var requireString = ("require" + Math.random()).slice(0, 7); + var nodeRequire = module && module[requireString]; + enqueueTaskImpl = nodeRequire.call(module, "timers").setImmediate; + } catch (_err) { + enqueueTaskImpl = function(callback) { + { + if (didWarnAboutMessageChannel === false) { + didWarnAboutMessageChannel = true; + if (typeof MessageChannel === "undefined") { + error("This browser does not have a MessageChannel implementation, so enqueuing tasks via await act(async () => ...) will fail. Please file an issue at https://github.com/facebook/react/issues if you encounter this warning."); + } + } + } + var channel = new MessageChannel; + channel.port1.onmessage = callback; + channel.port2.postMessage(undefined); + }; + } + } + return enqueueTaskImpl(task); + } + var actScopeDepth = 0; + var didWarnNoAwaitAct = false; + function act(callback) { + { + var prevActScopeDepth = actScopeDepth; + actScopeDepth++; + if (ReactCurrentActQueue.current === null) { + ReactCurrentActQueue.current = []; + } + var prevIsBatchingLegacy = ReactCurrentActQueue.isBatchingLegacy; + var result; + try { + ReactCurrentActQueue.isBatchingLegacy = true; + result = callback(); + if (!prevIsBatchingLegacy && ReactCurrentActQueue.didScheduleLegacyUpdate) { + var queue = ReactCurrentActQueue.current; + if (queue !== null) { + ReactCurrentActQueue.didScheduleLegacyUpdate = false; + flushActQueue(queue); + } + } + } catch (error2) { + popActScope(prevActScopeDepth); + throw error2; + } finally { + ReactCurrentActQueue.isBatchingLegacy = prevIsBatchingLegacy; + } + if (result !== null && typeof result === "object" && typeof result.then === "function") { + var thenableResult = result; + var wasAwaited = false; + var thenable = { + then: function(resolve, reject) { + wasAwaited = true; + thenableResult.then(function(returnValue2) { + popActScope(prevActScopeDepth); + if (actScopeDepth === 0) { + recursivelyFlushAsyncActWork(returnValue2, resolve, reject); + } else { + resolve(returnValue2); + } + }, function(error2) { + popActScope(prevActScopeDepth); + reject(error2); + }); + } + }; + { + if (!didWarnNoAwaitAct && typeof Promise !== "undefined") { + Promise.resolve().then(function() { + }).then(function() { + if (!wasAwaited) { + didWarnNoAwaitAct = true; + error("You called act(async () => ...) without await. This could lead to unexpected testing behaviour, interleaving multiple act calls and mixing their scopes. You should - await act(async () => ...);"); + } + }); + } + } + return thenable; + } else { + var returnValue = result; + popActScope(prevActScopeDepth); + if (actScopeDepth === 0) { + var _queue = ReactCurrentActQueue.current; + if (_queue !== null) { + flushActQueue(_queue); + ReactCurrentActQueue.current = null; + } + var _thenable = { + then: function(resolve, reject) { + if (ReactCurrentActQueue.current === null) { + ReactCurrentActQueue.current = []; + recursivelyFlushAsyncActWork(returnValue, resolve, reject); + } else { + resolve(returnValue); + } + } + }; + return _thenable; + } else { + var _thenable2 = { + then: function(resolve, reject) { + resolve(returnValue); + } + }; + return _thenable2; + } + } + } + } + function popActScope(prevActScopeDepth) { + { + if (prevActScopeDepth !== actScopeDepth - 1) { + error("You seem to have overlapping act() calls, this is not supported. Be sure to await previous act() calls before making a new one. "); + } + actScopeDepth = prevActScopeDepth; + } + } + function recursivelyFlushAsyncActWork(returnValue, resolve, reject) { + { + var queue = ReactCurrentActQueue.current; + if (queue !== null) { + try { + flushActQueue(queue); + enqueueTask(function() { + if (queue.length === 0) { + ReactCurrentActQueue.current = null; + resolve(returnValue); + } else { + recursivelyFlushAsyncActWork(returnValue, resolve, reject); + } + }); + } catch (error2) { + reject(error2); + } + } else { + resolve(returnValue); + } + } + } + var isFlushing = false; + function flushActQueue(queue) { + { + if (!isFlushing) { + isFlushing = true; + var i = 0; + try { + for (;i < queue.length; i++) { + var callback = queue[i]; + do { + callback = callback(true); + } while (callback !== null); + } + queue.length = 0; + } catch (error2) { + queue = queue.slice(i + 1); + throw error2; + } finally { + isFlushing = false; + } + } + } + } + var createElement$1 = createElementWithValidation; + var cloneElement$1 = cloneElementWithValidation; + var createFactory = createFactoryWithValidation; + var Children = { + map: mapChildren, + forEach: forEachChildren, + count: countChildren, + toArray, + only: onlyChild + }; + exports.Children = Children; + exports.Component = Component; + exports.Fragment = REACT_FRAGMENT_TYPE; + exports.Profiler = REACT_PROFILER_TYPE; + exports.PureComponent = PureComponent; + exports.StrictMode = REACT_STRICT_MODE_TYPE; + exports.Suspense = REACT_SUSPENSE_TYPE; + exports.__SECRET_INTERNALS_DO_NOT_USE_OR_YOU_WILL_BE_FIRED = ReactSharedInternals; + exports.act = act; + exports.cloneElement = cloneElement$1; + exports.createContext = createContext; + exports.createElement = createElement$1; + exports.createFactory = createFactory; + exports.createRef = createRef; + exports.forwardRef = forwardRef; + exports.isValidElement = isValidElement; + exports.lazy = lazy; + exports.memo = memo; + exports.startTransition = startTransition; + exports.unstable_act = act; + exports.useCallback = useCallback; + exports.useContext = useContext; + exports.useDebugValue = useDebugValue; + exports.useDeferredValue = useDeferredValue; + exports.useEffect = useEffect; + exports.useId = useId; + exports.useImperativeHandle = useImperativeHandle; + exports.useInsertionEffect = useInsertionEffect; + exports.useLayoutEffect = useLayoutEffect; + exports.useMemo = useMemo; + exports.useReducer = useReducer; + exports.useRef = useRef; + exports.useState = useState; + exports.useSyncExternalStore = useSyncExternalStore; + exports.useTransition = useTransition; + exports.version = ReactVersion; + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== "undefined" && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStop === "function") { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStop(new Error); + } + })(); + } +}); + +// node_modules/react/index.js +var require_react = __commonJS((exports, module) => { + var react_development = __toESM(require_react_development(), 1); + if (false) { + } else { + module.exports = react_development; + } +}); + +// node_modules/scheduler/cjs/scheduler.development.js +var require_scheduler_development = __commonJS((exports) => { + if (true) { + (function() { + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== "undefined" && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart === "function") { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart(new Error); + } + var enableSchedulerDebugging = false; + var enableProfiling = false; + var frameYieldMs = 5; + function push(heap, node) { + var index = heap.length; + heap.push(node); + siftUp(heap, node, index); + } + function peek(heap) { + return heap.length === 0 ? null : heap[0]; + } + function pop(heap) { + if (heap.length === 0) { + return null; + } + var first = heap[0]; + var last = heap.pop(); + if (last !== first) { + heap[0] = last; + siftDown(heap, last, 0); + } + return first; + } + function siftUp(heap, node, i) { + var index = i; + while (index > 0) { + var parentIndex = index - 1 >>> 1; + var parent = heap[parentIndex]; + if (compare(parent, node) > 0) { + heap[parentIndex] = node; + heap[index] = parent; + index = parentIndex; + } else { + return; + } + } + } + function siftDown(heap, node, i) { + var index = i; + var length = heap.length; + var halfLength = length >>> 1; + while (index < halfLength) { + var leftIndex = (index + 1) * 2 - 1; + var left = heap[leftIndex]; + var rightIndex = leftIndex + 1; + var right = heap[rightIndex]; + if (compare(left, node) < 0) { + if (rightIndex < length && compare(right, left) < 0) { + heap[index] = right; + heap[rightIndex] = node; + index = rightIndex; + } else { + heap[index] = left; + heap[leftIndex] = node; + index = leftIndex; + } + } else if (rightIndex < length && compare(right, node) < 0) { + heap[index] = right; + heap[rightIndex] = node; + index = rightIndex; + } else { + return; + } + } + } + function compare(a, b) { + var diff = a.sortIndex - b.sortIndex; + return diff !== 0 ? diff : a.id - b.id; + } + var ImmediatePriority = 1; + var UserBlockingPriority = 2; + var NormalPriority = 3; + var LowPriority = 4; + var IdlePriority = 5; + function markTaskErrored(task, ms) { + } + var hasPerformanceNow = typeof performance === "object" && typeof performance.now === "function"; + if (hasPerformanceNow) { + var localPerformance = performance; + exports.unstable_now = function() { + return localPerformance.now(); + }; + } else { + var localDate = Date; + var initialTime = localDate.now(); + exports.unstable_now = function() { + return localDate.now() - initialTime; + }; + } + var maxSigned31BitInt = 1073741823; + var IMMEDIATE_PRIORITY_TIMEOUT = -1; + var USER_BLOCKING_PRIORITY_TIMEOUT = 250; + var NORMAL_PRIORITY_TIMEOUT = 5000; + var LOW_PRIORITY_TIMEOUT = 1e4; + var IDLE_PRIORITY_TIMEOUT = maxSigned31BitInt; + var taskQueue = []; + var timerQueue = []; + var taskIdCounter = 1; + var currentTask = null; + var currentPriorityLevel = NormalPriority; + var isPerformingWork = false; + var isHostCallbackScheduled = false; + var isHostTimeoutScheduled = false; + var localSetTimeout = typeof setTimeout === "function" ? setTimeout : null; + var localClearTimeout = typeof clearTimeout === "function" ? clearTimeout : null; + var localSetImmediate = typeof setImmediate !== "undefined" ? setImmediate : null; + var isInputPending = typeof navigator !== "undefined" && navigator.scheduling !== undefined && navigator.scheduling.isInputPending !== undefined ? navigator.scheduling.isInputPending.bind(navigator.scheduling) : null; + function advanceTimers(currentTime) { + var timer = peek(timerQueue); + while (timer !== null) { + if (timer.callback === null) { + pop(timerQueue); + } else if (timer.startTime <= currentTime) { + pop(timerQueue); + timer.sortIndex = timer.expirationTime; + push(taskQueue, timer); + } else { + return; + } + timer = peek(timerQueue); + } + } + function handleTimeout(currentTime) { + isHostTimeoutScheduled = false; + advanceTimers(currentTime); + if (!isHostCallbackScheduled) { + if (peek(taskQueue) !== null) { + isHostCallbackScheduled = true; + requestHostCallback(flushWork); + } else { + var firstTimer = peek(timerQueue); + if (firstTimer !== null) { + requestHostTimeout(handleTimeout, firstTimer.startTime - currentTime); + } + } + } + } + function flushWork(hasTimeRemaining, initialTime2) { + isHostCallbackScheduled = false; + if (isHostTimeoutScheduled) { + isHostTimeoutScheduled = false; + cancelHostTimeout(); + } + isPerformingWork = true; + var previousPriorityLevel = currentPriorityLevel; + try { + if (enableProfiling) { + try { + return workLoop(hasTimeRemaining, initialTime2); + } catch (error) { + if (currentTask !== null) { + var currentTime = exports.unstable_now(); + markTaskErrored(currentTask, currentTime); + currentTask.isQueued = false; + } + throw error; + } + } else { + return workLoop(hasTimeRemaining, initialTime2); + } + } finally { + currentTask = null; + currentPriorityLevel = previousPriorityLevel; + isPerformingWork = false; + } + } + function workLoop(hasTimeRemaining, initialTime2) { + var currentTime = initialTime2; + advanceTimers(currentTime); + currentTask = peek(taskQueue); + while (currentTask !== null && !enableSchedulerDebugging) { + if (currentTask.expirationTime > currentTime && (!hasTimeRemaining || shouldYieldToHost())) { + break; + } + var callback = currentTask.callback; + if (typeof callback === "function") { + currentTask.callback = null; + currentPriorityLevel = currentTask.priorityLevel; + var didUserCallbackTimeout = currentTask.expirationTime <= currentTime; + var continuationCallback = callback(didUserCallbackTimeout); + currentTime = exports.unstable_now(); + if (typeof continuationCallback === "function") { + currentTask.callback = continuationCallback; + } else { + if (currentTask === peek(taskQueue)) { + pop(taskQueue); + } + } + advanceTimers(currentTime); + } else { + pop(taskQueue); + } + currentTask = peek(taskQueue); + } + if (currentTask !== null) { + return true; + } else { + var firstTimer = peek(timerQueue); + if (firstTimer !== null) { + requestHostTimeout(handleTimeout, firstTimer.startTime - currentTime); + } + return false; + } + } + function unstable_runWithPriority(priorityLevel, eventHandler) { + switch (priorityLevel) { + case ImmediatePriority: + case UserBlockingPriority: + case NormalPriority: + case LowPriority: + case IdlePriority: + break; + default: + priorityLevel = NormalPriority; + } + var previousPriorityLevel = currentPriorityLevel; + currentPriorityLevel = priorityLevel; + try { + return eventHandler(); + } finally { + currentPriorityLevel = previousPriorityLevel; + } + } + function unstable_next(eventHandler) { + var priorityLevel; + switch (currentPriorityLevel) { + case ImmediatePriority: + case UserBlockingPriority: + case NormalPriority: + priorityLevel = NormalPriority; + break; + default: + priorityLevel = currentPriorityLevel; + break; + } + var previousPriorityLevel = currentPriorityLevel; + currentPriorityLevel = priorityLevel; + try { + return eventHandler(); + } finally { + currentPriorityLevel = previousPriorityLevel; + } + } + function unstable_wrapCallback(callback) { + var parentPriorityLevel = currentPriorityLevel; + return function() { + var previousPriorityLevel = currentPriorityLevel; + currentPriorityLevel = parentPriorityLevel; + try { + return callback.apply(this, arguments); + } finally { + currentPriorityLevel = previousPriorityLevel; + } + }; + } + function unstable_scheduleCallback(priorityLevel, callback, options) { + var currentTime = exports.unstable_now(); + var startTime2; + if (typeof options === "object" && options !== null) { + var delay = options.delay; + if (typeof delay === "number" && delay > 0) { + startTime2 = currentTime + delay; + } else { + startTime2 = currentTime; + } + } else { + startTime2 = currentTime; + } + var timeout; + switch (priorityLevel) { + case ImmediatePriority: + timeout = IMMEDIATE_PRIORITY_TIMEOUT; + break; + case UserBlockingPriority: + timeout = USER_BLOCKING_PRIORITY_TIMEOUT; + break; + case IdlePriority: + timeout = IDLE_PRIORITY_TIMEOUT; + break; + case LowPriority: + timeout = LOW_PRIORITY_TIMEOUT; + break; + case NormalPriority: + default: + timeout = NORMAL_PRIORITY_TIMEOUT; + break; + } + var expirationTime = startTime2 + timeout; + var newTask = { + id: taskIdCounter++, + callback, + priorityLevel, + startTime: startTime2, + expirationTime, + sortIndex: -1 + }; + if (startTime2 > currentTime) { + newTask.sortIndex = startTime2; + push(timerQueue, newTask); + if (peek(taskQueue) === null && newTask === peek(timerQueue)) { + if (isHostTimeoutScheduled) { + cancelHostTimeout(); + } else { + isHostTimeoutScheduled = true; + } + requestHostTimeout(handleTimeout, startTime2 - currentTime); + } + } else { + newTask.sortIndex = expirationTime; + push(taskQueue, newTask); + if (!isHostCallbackScheduled && !isPerformingWork) { + isHostCallbackScheduled = true; + requestHostCallback(flushWork); + } + } + return newTask; + } + function unstable_pauseExecution() { + } + function unstable_continueExecution() { + if (!isHostCallbackScheduled && !isPerformingWork) { + isHostCallbackScheduled = true; + requestHostCallback(flushWork); + } + } + function unstable_getFirstCallbackNode() { + return peek(taskQueue); + } + function unstable_cancelCallback(task) { + task.callback = null; + } + function unstable_getCurrentPriorityLevel() { + return currentPriorityLevel; + } + var isMessageLoopRunning = false; + var scheduledHostCallback = null; + var taskTimeoutID = -1; + var frameInterval = frameYieldMs; + var startTime = -1; + function shouldYieldToHost() { + var timeElapsed = exports.unstable_now() - startTime; + if (timeElapsed < frameInterval) { + return false; + } + return true; + } + function requestPaint() { + } + function forceFrameRate(fps) { + if (fps < 0 || fps > 125) { + console["error"]("forceFrameRate takes a positive int between 0 and 125, forcing frame rates higher than 125 fps is not supported"); + return; + } + if (fps > 0) { + frameInterval = Math.floor(1000 / fps); + } else { + frameInterval = frameYieldMs; + } + } + var performWorkUntilDeadline = function() { + if (scheduledHostCallback !== null) { + var currentTime = exports.unstable_now(); + startTime = currentTime; + var hasTimeRemaining = true; + var hasMoreWork = true; + try { + hasMoreWork = scheduledHostCallback(hasTimeRemaining, currentTime); + } finally { + if (hasMoreWork) { + schedulePerformWorkUntilDeadline(); + } else { + isMessageLoopRunning = false; + scheduledHostCallback = null; + } + } + } else { + isMessageLoopRunning = false; + } + }; + var schedulePerformWorkUntilDeadline; + if (typeof localSetImmediate === "function") { + schedulePerformWorkUntilDeadline = function() { + localSetImmediate(performWorkUntilDeadline); + }; + } else if (typeof MessageChannel !== "undefined") { + var channel = new MessageChannel; + var port = channel.port2; + channel.port1.onmessage = performWorkUntilDeadline; + schedulePerformWorkUntilDeadline = function() { + port.postMessage(null); + }; + } else { + schedulePerformWorkUntilDeadline = function() { + localSetTimeout(performWorkUntilDeadline, 0); + }; + } + function requestHostCallback(callback) { + scheduledHostCallback = callback; + if (!isMessageLoopRunning) { + isMessageLoopRunning = true; + schedulePerformWorkUntilDeadline(); + } + } + function requestHostTimeout(callback, ms) { + taskTimeoutID = localSetTimeout(function() { + callback(exports.unstable_now()); + }, ms); + } + function cancelHostTimeout() { + localClearTimeout(taskTimeoutID); + taskTimeoutID = -1; + } + var unstable_requestPaint = requestPaint; + var unstable_Profiling = null; + exports.unstable_IdlePriority = IdlePriority; + exports.unstable_ImmediatePriority = ImmediatePriority; + exports.unstable_LowPriority = LowPriority; + exports.unstable_NormalPriority = NormalPriority; + exports.unstable_Profiling = unstable_Profiling; + exports.unstable_UserBlockingPriority = UserBlockingPriority; + exports.unstable_cancelCallback = unstable_cancelCallback; + exports.unstable_continueExecution = unstable_continueExecution; + exports.unstable_forceFrameRate = forceFrameRate; + exports.unstable_getCurrentPriorityLevel = unstable_getCurrentPriorityLevel; + exports.unstable_getFirstCallbackNode = unstable_getFirstCallbackNode; + exports.unstable_next = unstable_next; + exports.unstable_pauseExecution = unstable_pauseExecution; + exports.unstable_requestPaint = unstable_requestPaint; + exports.unstable_runWithPriority = unstable_runWithPriority; + exports.unstable_scheduleCallback = unstable_scheduleCallback; + exports.unstable_shouldYield = shouldYieldToHost; + exports.unstable_wrapCallback = unstable_wrapCallback; + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== "undefined" && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStop === "function") { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStop(new Error); + } + })(); + } +}); + +// node_modules/scheduler/index.js +var require_scheduler = __commonJS((exports, module) => { + var scheduler_development = __toESM(require_scheduler_development(), 1); + if (false) { + } else { + module.exports = scheduler_development; + } +}); + +// node_modules/react-dom/cjs/react-dom.development.js +var require_react_dom_development = __commonJS((exports) => { + var React = __toESM(require_react(), 1); + var Scheduler = __toESM(require_scheduler(), 1); + if (true) { + (function() { + if (typeof __REACT_DEVTOOLS_GLOBAL_HOOK__ !== "undefined" && typeof __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart === "function") { + __REACT_DEVTOOLS_GLOBAL_HOOK__.registerInternalModuleStart(new Error); + } + var ReactSharedInternals = React.__SECRET_INTERNALS_DO_NOT_USE_OR_YOU_WILL_BE_FIRED; + var suppressWarning = false; + function setSuppressWarning(newSuppressWarning) { + { + suppressWarning = newSuppressWarning; + } + } + function warn(format) { + { + if (!suppressWarning) { + for (var _len = arguments.length, args = new Array(_len > 1 ? _len - 1 : 0), _key = 1;_key < _len; _key++) { + args[_key - 1] = arguments[_key]; + } + printWarning("warn", format, args); + } + } + } + function error(format) { + { + if (!suppressWarning) { + for (var _len2 = arguments.length, args = new Array(_len2 > 1 ? _len2 - 1 : 0), _key2 = 1;_key2 < _len2; _key2++) { + args[_key2 - 1] = arguments[_key2]; + } + printWarning("error", format, args); + } + } + } + function printWarning(level, format, args) { + { + var ReactDebugCurrentFrame2 = ReactSharedInternals.ReactDebugCurrentFrame; + var stack = ReactDebugCurrentFrame2.getStackAddendum(); + if (stack !== "") { + format += "%s"; + args = args.concat([stack]); + } + var argsWithFormat = args.map(function(item) { + return String(item); + }); + argsWithFormat.unshift("Warning: " + format); + Function.prototype.apply.call(console[level], console, argsWithFormat); + } + } + var FunctionComponent = 0; + var ClassComponent = 1; + var IndeterminateComponent = 2; + var HostRoot = 3; + var HostPortal = 4; + var HostComponent = 5; + var HostText = 6; + var Fragment = 7; + var Mode = 8; + var ContextConsumer = 9; + var ContextProvider = 10; + var ForwardRef = 11; + var Profiler = 12; + var SuspenseComponent = 13; + var MemoComponent = 14; + var SimpleMemoComponent = 15; + var LazyComponent = 16; + var IncompleteClassComponent = 17; + var DehydratedFragment = 18; + var SuspenseListComponent = 19; + var ScopeComponent = 21; + var OffscreenComponent = 22; + var LegacyHiddenComponent = 23; + var CacheComponent = 24; + var TracingMarkerComponent = 25; + var enableClientRenderFallbackOnTextMismatch = true; + var enableNewReconciler = false; + var enableLazyContextPropagation = false; + var enableLegacyHidden = false; + var enableSuspenseAvoidThisFallback = false; + var disableCommentsAsDOMContainers = true; + var enableCustomElementPropertySupport = false; + var warnAboutStringRefs = true; + var enableSchedulingProfiler = true; + var enableProfilerTimer = true; + var enableProfilerCommitHooks = true; + var allNativeEvents = new Set; + var registrationNameDependencies = {}; + var possibleRegistrationNames = {}; + function registerTwoPhaseEvent(registrationName, dependencies) { + registerDirectEvent(registrationName, dependencies); + registerDirectEvent(registrationName + "Capture", dependencies); + } + function registerDirectEvent(registrationName, dependencies) { + { + if (registrationNameDependencies[registrationName]) { + error("EventRegistry: More than one plugin attempted to publish the same registration name, `%s`.", registrationName); + } + } + registrationNameDependencies[registrationName] = dependencies; + { + var lowerCasedName = registrationName.toLowerCase(); + possibleRegistrationNames[lowerCasedName] = registrationName; + if (registrationName === "onDoubleClick") { + possibleRegistrationNames.ondblclick = registrationName; + } + } + for (var i = 0;i < dependencies.length; i++) { + allNativeEvents.add(dependencies[i]); + } + } + var canUseDOM = !!(typeof window !== "undefined" && typeof window.document !== "undefined" && typeof window.document.createElement !== "undefined"); + var hasOwnProperty = Object.prototype.hasOwnProperty; + function typeName(value) { + { + var hasToStringTag = typeof Symbol === "function" && Symbol.toStringTag; + var type = hasToStringTag && value[Symbol.toStringTag] || value.constructor.name || "Object"; + return type; + } + } + function willCoercionThrow(value) { + { + try { + testStringCoercion(value); + return false; + } catch (e) { + return true; + } + } + } + function testStringCoercion(value) { + return "" + value; + } + function checkAttributeStringCoercion(value, attributeName) { + { + if (willCoercionThrow(value)) { + error("The provided `%s` attribute is an unsupported type %s. This value must be coerced to a string before before using it here.", attributeName, typeName(value)); + return testStringCoercion(value); + } + } + } + function checkKeyStringCoercion(value) { + { + if (willCoercionThrow(value)) { + error("The provided key is an unsupported type %s. This value must be coerced to a string before before using it here.", typeName(value)); + return testStringCoercion(value); + } + } + } + function checkPropStringCoercion(value, propName) { + { + if (willCoercionThrow(value)) { + error("The provided `%s` prop is an unsupported type %s. This value must be coerced to a string before before using it here.", propName, typeName(value)); + return testStringCoercion(value); + } + } + } + function checkCSSPropertyStringCoercion(value, propName) { + { + if (willCoercionThrow(value)) { + error("The provided `%s` CSS property is an unsupported type %s. This value must be coerced to a string before before using it here.", propName, typeName(value)); + return testStringCoercion(value); + } + } + } + function checkHtmlStringCoercion(value) { + { + if (willCoercionThrow(value)) { + error("The provided HTML markup uses a value of unsupported type %s. This value must be coerced to a string before before using it here.", typeName(value)); + return testStringCoercion(value); + } + } + } + function checkFormFieldValueStringCoercion(value) { + { + if (willCoercionThrow(value)) { + error("Form field values (value, checked, defaultValue, or defaultChecked props) must be strings, not %s. This value must be coerced to a string before before using it here.", typeName(value)); + return testStringCoercion(value); + } + } + } + var RESERVED = 0; + var STRING = 1; + var BOOLEANISH_STRING = 2; + var BOOLEAN = 3; + var OVERLOADED_BOOLEAN = 4; + var NUMERIC = 5; + var POSITIVE_NUMERIC = 6; + var ATTRIBUTE_NAME_START_CHAR = ":A-Z_a-z\\u00C0-\\u00D6\\u00D8-\\u00F6\\u00F8-\\u02FF\\u0370-\\u037D\\u037F-\\u1FFF\\u200C-\\u200D\\u2070-\\u218F\\u2C00-\\u2FEF\\u3001-\\uD7FF\\uF900-\\uFDCF\\uFDF0-\\uFFFD"; + var ATTRIBUTE_NAME_CHAR = ATTRIBUTE_NAME_START_CHAR + "\\-.0-9\\u00B7\\u0300-\\u036F\\u203F-\\u2040"; + var VALID_ATTRIBUTE_NAME_REGEX = new RegExp("^[" + ATTRIBUTE_NAME_START_CHAR + "][" + ATTRIBUTE_NAME_CHAR + "]*$"); + var illegalAttributeNameCache = {}; + var validatedAttributeNameCache = {}; + function isAttributeNameSafe(attributeName) { + if (hasOwnProperty.call(validatedAttributeNameCache, attributeName)) { + return true; + } + if (hasOwnProperty.call(illegalAttributeNameCache, attributeName)) { + return false; + } + if (VALID_ATTRIBUTE_NAME_REGEX.test(attributeName)) { + validatedAttributeNameCache[attributeName] = true; + return true; + } + illegalAttributeNameCache[attributeName] = true; + { + error("Invalid attribute name: `%s`", attributeName); + } + return false; + } + function shouldIgnoreAttribute(name, propertyInfo, isCustomComponentTag) { + if (propertyInfo !== null) { + return propertyInfo.type === RESERVED; + } + if (isCustomComponentTag) { + return false; + } + if (name.length > 2 && (name[0] === "o" || name[0] === "O") && (name[1] === "n" || name[1] === "N")) { + return true; + } + return false; + } + function shouldRemoveAttributeWithWarning(name, value, propertyInfo, isCustomComponentTag) { + if (propertyInfo !== null && propertyInfo.type === RESERVED) { + return false; + } + switch (typeof value) { + case "function": + case "symbol": + return true; + case "boolean": { + if (isCustomComponentTag) { + return false; + } + if (propertyInfo !== null) { + return !propertyInfo.acceptsBooleans; + } else { + var prefix2 = name.toLowerCase().slice(0, 5); + return prefix2 !== "data-" && prefix2 !== "aria-"; + } + } + default: + return false; + } + } + function shouldRemoveAttribute(name, value, propertyInfo, isCustomComponentTag) { + if (value === null || typeof value === "undefined") { + return true; + } + if (shouldRemoveAttributeWithWarning(name, value, propertyInfo, isCustomComponentTag)) { + return true; + } + if (isCustomComponentTag) { + return false; + } + if (propertyInfo !== null) { + switch (propertyInfo.type) { + case BOOLEAN: + return !value; + case OVERLOADED_BOOLEAN: + return value === false; + case NUMERIC: + return isNaN(value); + case POSITIVE_NUMERIC: + return isNaN(value) || value < 1; + } + } + return false; + } + function getPropertyInfo(name) { + return properties.hasOwnProperty(name) ? properties[name] : null; + } + function PropertyInfoRecord(name, type, mustUseProperty, attributeName, attributeNamespace, sanitizeURL2, removeEmptyString) { + this.acceptsBooleans = type === BOOLEANISH_STRING || type === BOOLEAN || type === OVERLOADED_BOOLEAN; + this.attributeName = attributeName; + this.attributeNamespace = attributeNamespace; + this.mustUseProperty = mustUseProperty; + this.propertyName = name; + this.type = type; + this.sanitizeURL = sanitizeURL2; + this.removeEmptyString = removeEmptyString; + } + var properties = {}; + var reservedProps = [ + "children", + "dangerouslySetInnerHTML", + "defaultValue", + "defaultChecked", + "innerHTML", + "suppressContentEditableWarning", + "suppressHydrationWarning", + "style" + ]; + reservedProps.forEach(function(name) { + properties[name] = new PropertyInfoRecord(name, RESERVED, false, name, null, false, false); + }); + [["acceptCharset", "accept-charset"], ["className", "class"], ["htmlFor", "for"], ["httpEquiv", "http-equiv"]].forEach(function(_ref) { + var name = _ref[0], attributeName = _ref[1]; + properties[name] = new PropertyInfoRecord(name, STRING, false, attributeName, null, false, false); + }); + ["contentEditable", "draggable", "spellCheck", "value"].forEach(function(name) { + properties[name] = new PropertyInfoRecord(name, BOOLEANISH_STRING, false, name.toLowerCase(), null, false, false); + }); + ["autoReverse", "externalResourcesRequired", "focusable", "preserveAlpha"].forEach(function(name) { + properties[name] = new PropertyInfoRecord(name, BOOLEANISH_STRING, false, name, null, false, false); + }); + [ + "allowFullScreen", + "async", + "autoFocus", + "autoPlay", + "controls", + "default", + "defer", + "disabled", + "disablePictureInPicture", + "disableRemotePlayback", + "formNoValidate", + "hidden", + "loop", + "noModule", + "noValidate", + "open", + "playsInline", + "readOnly", + "required", + "reversed", + "scoped", + "seamless", + "itemScope" + ].forEach(function(name) { + properties[name] = new PropertyInfoRecord(name, BOOLEAN, false, name.toLowerCase(), null, false, false); + }); + [ + "checked", + "multiple", + "muted", + "selected" + ].forEach(function(name) { + properties[name] = new PropertyInfoRecord(name, BOOLEAN, true, name, null, false, false); + }); + [ + "capture", + "download" + ].forEach(function(name) { + properties[name] = new PropertyInfoRecord(name, OVERLOADED_BOOLEAN, false, name, null, false, false); + }); + [ + "cols", + "rows", + "size", + "span" + ].forEach(function(name) { + properties[name] = new PropertyInfoRecord(name, POSITIVE_NUMERIC, false, name, null, false, false); + }); + ["rowSpan", "start"].forEach(function(name) { + properties[name] = new PropertyInfoRecord(name, NUMERIC, false, name.toLowerCase(), null, false, false); + }); + var CAMELIZE = /[\-\:]([a-z])/g; + var capitalize = function(token) { + return token[1].toUpperCase(); + }; + [ + "accent-height", + "alignment-baseline", + "arabic-form", + "baseline-shift", + "cap-height", + "clip-path", + "clip-rule", + "color-interpolation", + "color-interpolation-filters", + "color-profile", + "color-rendering", + "dominant-baseline", + "enable-background", + "fill-opacity", + "fill-rule", + "flood-color", + "flood-opacity", + "font-family", + "font-size", + "font-size-adjust", + "font-stretch", + "font-style", + "font-variant", + "font-weight", + "glyph-name", + "glyph-orientation-horizontal", + "glyph-orientation-vertical", + "horiz-adv-x", + "horiz-origin-x", + "image-rendering", + "letter-spacing", + "lighting-color", + "marker-end", + "marker-mid", + "marker-start", + "overline-position", + "overline-thickness", + "paint-order", + "panose-1", + "pointer-events", + "rendering-intent", + "shape-rendering", + "stop-color", + "stop-opacity", + "strikethrough-position", + "strikethrough-thickness", + "stroke-dasharray", + "stroke-dashoffset", + "stroke-linecap", + "stroke-linejoin", + "stroke-miterlimit", + "stroke-opacity", + "stroke-width", + "text-anchor", + "text-decoration", + "text-rendering", + "underline-position", + "underline-thickness", + "unicode-bidi", + "unicode-range", + "units-per-em", + "v-alphabetic", + "v-hanging", + "v-ideographic", + "v-mathematical", + "vector-effect", + "vert-adv-y", + "vert-origin-x", + "vert-origin-y", + "word-spacing", + "writing-mode", + "xmlns:xlink", + "x-height" + ].forEach(function(attributeName) { + var name = attributeName.replace(CAMELIZE, capitalize); + properties[name] = new PropertyInfoRecord(name, STRING, false, attributeName, null, false, false); + }); + [ + "xlink:actuate", + "xlink:arcrole", + "xlink:role", + "xlink:show", + "xlink:title", + "xlink:type" + ].forEach(function(attributeName) { + var name = attributeName.replace(CAMELIZE, capitalize); + properties[name] = new PropertyInfoRecord(name, STRING, false, attributeName, "http://www.w3.org/1999/xlink", false, false); + }); + [ + "xml:base", + "xml:lang", + "xml:space" + ].forEach(function(attributeName) { + var name = attributeName.replace(CAMELIZE, capitalize); + properties[name] = new PropertyInfoRecord(name, STRING, false, attributeName, "http://www.w3.org/XML/1998/namespace", false, false); + }); + ["tabIndex", "crossOrigin"].forEach(function(attributeName) { + properties[attributeName] = new PropertyInfoRecord(attributeName, STRING, false, attributeName.toLowerCase(), null, false, false); + }); + var xlinkHref = "xlinkHref"; + properties[xlinkHref] = new PropertyInfoRecord("xlinkHref", STRING, false, "xlink:href", "http://www.w3.org/1999/xlink", true, false); + ["src", "href", "action", "formAction"].forEach(function(attributeName) { + properties[attributeName] = new PropertyInfoRecord(attributeName, STRING, false, attributeName.toLowerCase(), null, true, true); + }); + var isJavaScriptProtocol = /^[\u0000-\u001F ]*j[\r\n\t]*a[\r\n\t]*v[\r\n\t]*a[\r\n\t]*s[\r\n\t]*c[\r\n\t]*r[\r\n\t]*i[\r\n\t]*p[\r\n\t]*t[\r\n\t]*\:/i; + var didWarn = false; + function sanitizeURL(url) { + { + if (!didWarn && isJavaScriptProtocol.test(url)) { + didWarn = true; + error("A future version of React will block javascript: URLs as a security precaution. Use event handlers instead if you can. If you need to generate unsafe HTML try using dangerouslySetInnerHTML instead. React was passed %s.", JSON.stringify(url)); + } + } + } + function getValueForProperty(node, name, expected, propertyInfo) { + { + if (propertyInfo.mustUseProperty) { + var propertyName = propertyInfo.propertyName; + return node[propertyName]; + } else { + { + checkAttributeStringCoercion(expected, name); + } + if (propertyInfo.sanitizeURL) { + sanitizeURL("" + expected); + } + var attributeName = propertyInfo.attributeName; + var stringValue = null; + if (propertyInfo.type === OVERLOADED_BOOLEAN) { + if (node.hasAttribute(attributeName)) { + var value = node.getAttribute(attributeName); + if (value === "") { + return true; + } + if (shouldRemoveAttribute(name, expected, propertyInfo, false)) { + return value; + } + if (value === "" + expected) { + return expected; + } + return value; + } + } else if (node.hasAttribute(attributeName)) { + if (shouldRemoveAttribute(name, expected, propertyInfo, false)) { + return node.getAttribute(attributeName); + } + if (propertyInfo.type === BOOLEAN) { + return expected; + } + stringValue = node.getAttribute(attributeName); + } + if (shouldRemoveAttribute(name, expected, propertyInfo, false)) { + return stringValue === null ? expected : stringValue; + } else if (stringValue === "" + expected) { + return expected; + } else { + return stringValue; + } + } + } + } + function getValueForAttribute(node, name, expected, isCustomComponentTag) { + { + if (!isAttributeNameSafe(name)) { + return; + } + if (!node.hasAttribute(name)) { + return expected === undefined ? undefined : null; + } + var value = node.getAttribute(name); + { + checkAttributeStringCoercion(expected, name); + } + if (value === "" + expected) { + return expected; + } + return value; + } + } + function setValueForProperty(node, name, value, isCustomComponentTag) { + var propertyInfo = getPropertyInfo(name); + if (shouldIgnoreAttribute(name, propertyInfo, isCustomComponentTag)) { + return; + } + if (shouldRemoveAttribute(name, value, propertyInfo, isCustomComponentTag)) { + value = null; + } + if (isCustomComponentTag || propertyInfo === null) { + if (isAttributeNameSafe(name)) { + var _attributeName = name; + if (value === null) { + node.removeAttribute(_attributeName); + } else { + { + checkAttributeStringCoercion(value, name); + } + node.setAttribute(_attributeName, "" + value); + } + } + return; + } + var mustUseProperty = propertyInfo.mustUseProperty; + if (mustUseProperty) { + var propertyName = propertyInfo.propertyName; + if (value === null) { + var type = propertyInfo.type; + node[propertyName] = type === BOOLEAN ? false : ""; + } else { + node[propertyName] = value; + } + return; + } + var { attributeName, attributeNamespace } = propertyInfo; + if (value === null) { + node.removeAttribute(attributeName); + } else { + var _type = propertyInfo.type; + var attributeValue; + if (_type === BOOLEAN || _type === OVERLOADED_BOOLEAN && value === true) { + attributeValue = ""; + } else { + { + { + checkAttributeStringCoercion(value, attributeName); + } + attributeValue = "" + value; + } + if (propertyInfo.sanitizeURL) { + sanitizeURL(attributeValue.toString()); + } + } + if (attributeNamespace) { + node.setAttributeNS(attributeNamespace, attributeName, attributeValue); + } else { + node.setAttribute(attributeName, attributeValue); + } + } + } + var REACT_ELEMENT_TYPE = Symbol.for("react.element"); + var REACT_PORTAL_TYPE = Symbol.for("react.portal"); + var REACT_FRAGMENT_TYPE = Symbol.for("react.fragment"); + var REACT_STRICT_MODE_TYPE = Symbol.for("react.strict_mode"); + var REACT_PROFILER_TYPE = Symbol.for("react.profiler"); + var REACT_PROVIDER_TYPE = Symbol.for("react.provider"); + var REACT_CONTEXT_TYPE = Symbol.for("react.context"); + var REACT_FORWARD_REF_TYPE = Symbol.for("react.forward_ref"); + var REACT_SUSPENSE_TYPE = Symbol.for("react.suspense"); + var REACT_SUSPENSE_LIST_TYPE = Symbol.for("react.suspense_list"); + var REACT_MEMO_TYPE = Symbol.for("react.memo"); + var REACT_LAZY_TYPE = Symbol.for("react.lazy"); + var REACT_SCOPE_TYPE = Symbol.for("react.scope"); + var REACT_DEBUG_TRACING_MODE_TYPE = Symbol.for("react.debug_trace_mode"); + var REACT_OFFSCREEN_TYPE = Symbol.for("react.offscreen"); + var REACT_LEGACY_HIDDEN_TYPE = Symbol.for("react.legacy_hidden"); + var REACT_CACHE_TYPE = Symbol.for("react.cache"); + var REACT_TRACING_MARKER_TYPE = Symbol.for("react.tracing_marker"); + var MAYBE_ITERATOR_SYMBOL = Symbol.iterator; + var FAUX_ITERATOR_SYMBOL = "@@iterator"; + function getIteratorFn(maybeIterable) { + if (maybeIterable === null || typeof maybeIterable !== "object") { + return null; + } + var maybeIterator = MAYBE_ITERATOR_SYMBOL && maybeIterable[MAYBE_ITERATOR_SYMBOL] || maybeIterable[FAUX_ITERATOR_SYMBOL]; + if (typeof maybeIterator === "function") { + return maybeIterator; + } + return null; + } + var assign = Object.assign; + var disabledDepth = 0; + var prevLog; + var prevInfo; + var prevWarn; + var prevError; + var prevGroup; + var prevGroupCollapsed; + var prevGroupEnd; + function disabledLog() { + } + disabledLog.__reactDisabledLog = true; + function disableLogs() { + { + if (disabledDepth === 0) { + prevLog = console.log; + prevInfo = console.info; + prevWarn = console.warn; + prevError = console.error; + prevGroup = console.group; + prevGroupCollapsed = console.groupCollapsed; + prevGroupEnd = console.groupEnd; + var props = { + configurable: true, + enumerable: true, + value: disabledLog, + writable: true + }; + Object.defineProperties(console, { + info: props, + log: props, + warn: props, + error: props, + group: props, + groupCollapsed: props, + groupEnd: props + }); + } + disabledDepth++; + } + } + function reenableLogs() { + { + disabledDepth--; + if (disabledDepth === 0) { + var props = { + configurable: true, + enumerable: true, + writable: true + }; + Object.defineProperties(console, { + log: assign({}, props, { + value: prevLog + }), + info: assign({}, props, { + value: prevInfo + }), + warn: assign({}, props, { + value: prevWarn + }), + error: assign({}, props, { + value: prevError + }), + group: assign({}, props, { + value: prevGroup + }), + groupCollapsed: assign({}, props, { + value: prevGroupCollapsed + }), + groupEnd: assign({}, props, { + value: prevGroupEnd + }) + }); + } + if (disabledDepth < 0) { + error("disabledDepth fell below zero. This is a bug in React. Please file an issue."); + } + } + } + var ReactCurrentDispatcher = ReactSharedInternals.ReactCurrentDispatcher; + var prefix; + function describeBuiltInComponentFrame(name, source, ownerFn) { + { + if (prefix === undefined) { + try { + throw Error(); + } catch (x) { + var match = x.stack.trim().match(/\n( *(at )?)/); + prefix = match && match[1] || ""; + } + } + return "\n" + prefix + name; + } + } + var reentry = false; + var componentFrameCache; + { + var PossiblyWeakMap = typeof WeakMap === "function" ? WeakMap : Map; + componentFrameCache = new PossiblyWeakMap; + } + function describeNativeComponentFrame(fn, construct) { + if (!fn || reentry) { + return ""; + } + { + var frame = componentFrameCache.get(fn); + if (frame !== undefined) { + return frame; + } + } + var control; + reentry = true; + var previousPrepareStackTrace = Error.prepareStackTrace; + Error.prepareStackTrace = undefined; + var previousDispatcher; + { + previousDispatcher = ReactCurrentDispatcher.current; + ReactCurrentDispatcher.current = null; + disableLogs(); + } + try { + if (construct) { + var Fake = function() { + throw Error(); + }; + Object.defineProperty(Fake.prototype, "props", { + set: function() { + throw Error(); + } + }); + if (typeof Reflect === "object" && Reflect.construct) { + try { + Reflect.construct(Fake, []); + } catch (x) { + control = x; + } + Reflect.construct(fn, [], Fake); + } else { + try { + Fake.call(); + } catch (x) { + control = x; + } + fn.call(Fake.prototype); + } + } else { + try { + throw Error(); + } catch (x) { + control = x; + } + fn(); + } + } catch (sample) { + if (sample && control && typeof sample.stack === "string") { + var sampleLines = sample.stack.split("\n"); + var controlLines = control.stack.split("\n"); + var s = sampleLines.length - 1; + var c = controlLines.length - 1; + while (s >= 1 && c >= 0 && sampleLines[s] !== controlLines[c]) { + c--; + } + for (;s >= 1 && c >= 0; s--, c--) { + if (sampleLines[s] !== controlLines[c]) { + if (s !== 1 || c !== 1) { + do { + s--; + c--; + if (c < 0 || sampleLines[s] !== controlLines[c]) { + var _frame = "\n" + sampleLines[s].replace(" at new ", " at "); + if (fn.displayName && _frame.includes("")) { + _frame = _frame.replace("", fn.displayName); + } + { + if (typeof fn === "function") { + componentFrameCache.set(fn, _frame); + } + } + return _frame; + } + } while (s >= 1 && c >= 0); + } + break; + } + } + } + } finally { + reentry = false; + { + ReactCurrentDispatcher.current = previousDispatcher; + reenableLogs(); + } + Error.prepareStackTrace = previousPrepareStackTrace; + } + var name = fn ? fn.displayName || fn.name : ""; + var syntheticFrame = name ? describeBuiltInComponentFrame(name) : ""; + { + if (typeof fn === "function") { + componentFrameCache.set(fn, syntheticFrame); + } + } + return syntheticFrame; + } + function describeClassComponentFrame(ctor, source, ownerFn) { + { + return describeNativeComponentFrame(ctor, true); + } + } + function describeFunctionComponentFrame(fn, source, ownerFn) { + { + return describeNativeComponentFrame(fn, false); + } + } + function shouldConstruct(Component) { + var prototype = Component.prototype; + return !!(prototype && prototype.isReactComponent); + } + function describeUnknownElementTypeFrameInDEV(type, source, ownerFn) { + if (type == null) { + return ""; + } + if (typeof type === "function") { + { + return describeNativeComponentFrame(type, shouldConstruct(type)); + } + } + if (typeof type === "string") { + return describeBuiltInComponentFrame(type); + } + switch (type) { + case REACT_SUSPENSE_TYPE: + return describeBuiltInComponentFrame("Suspense"); + case REACT_SUSPENSE_LIST_TYPE: + return describeBuiltInComponentFrame("SuspenseList"); + } + if (typeof type === "object") { + switch (type.$$typeof) { + case REACT_FORWARD_REF_TYPE: + return describeFunctionComponentFrame(type.render); + case REACT_MEMO_TYPE: + return describeUnknownElementTypeFrameInDEV(type.type, source, ownerFn); + case REACT_LAZY_TYPE: { + var lazyComponent = type; + var payload = lazyComponent._payload; + var init = lazyComponent._init; + try { + return describeUnknownElementTypeFrameInDEV(init(payload), source, ownerFn); + } catch (x) { + } + } + } + } + return ""; + } + function describeFiber(fiber) { + var owner = fiber._debugOwner ? fiber._debugOwner.type : null; + var source = fiber._debugSource; + switch (fiber.tag) { + case HostComponent: + return describeBuiltInComponentFrame(fiber.type); + case LazyComponent: + return describeBuiltInComponentFrame("Lazy"); + case SuspenseComponent: + return describeBuiltInComponentFrame("Suspense"); + case SuspenseListComponent: + return describeBuiltInComponentFrame("SuspenseList"); + case FunctionComponent: + case IndeterminateComponent: + case SimpleMemoComponent: + return describeFunctionComponentFrame(fiber.type); + case ForwardRef: + return describeFunctionComponentFrame(fiber.type.render); + case ClassComponent: + return describeClassComponentFrame(fiber.type); + default: + return ""; + } + } + function getStackByFiberInDevAndProd(workInProgress2) { + try { + var info = ""; + var node = workInProgress2; + do { + info += describeFiber(node); + node = node.return; + } while (node); + return info; + } catch (x) { + return "\nError generating stack: " + x.message + "\n" + x.stack; + } + } + function getWrappedName(outerType, innerType, wrapperName) { + var displayName = outerType.displayName; + if (displayName) { + return displayName; + } + var functionName = innerType.displayName || innerType.name || ""; + return functionName !== "" ? wrapperName + "(" + functionName + ")" : wrapperName; + } + function getContextName(type) { + return type.displayName || "Context"; + } + function getComponentNameFromType(type) { + if (type == null) { + return null; + } + { + if (typeof type.tag === "number") { + error("Received an unexpected object in getComponentNameFromType(). This is likely a bug in React. Please file an issue."); + } + } + if (typeof type === "function") { + return type.displayName || type.name || null; + } + if (typeof type === "string") { + return type; + } + switch (type) { + case REACT_FRAGMENT_TYPE: + return "Fragment"; + case REACT_PORTAL_TYPE: + return "Portal"; + case REACT_PROFILER_TYPE: + return "Profiler"; + case REACT_STRICT_MODE_TYPE: + return "StrictMode"; + case REACT_SUSPENSE_TYPE: + return "Suspense"; + case REACT_SUSPENSE_LIST_TYPE: + return "SuspenseList"; + } + if (typeof type === "object") { + switch (type.$$typeof) { + case REACT_CONTEXT_TYPE: + var context = type; + return getContextName(context) + ".Consumer"; + case REACT_PROVIDER_TYPE: + var provider = type; + return getContextName(provider._context) + ".Provider"; + case REACT_FORWARD_REF_TYPE: + return getWrappedName(type, type.render, "ForwardRef"); + case REACT_MEMO_TYPE: + var outerName = type.displayName || null; + if (outerName !== null) { + return outerName; + } + return getComponentNameFromType(type.type) || "Memo"; + case REACT_LAZY_TYPE: { + var lazyComponent = type; + var payload = lazyComponent._payload; + var init = lazyComponent._init; + try { + return getComponentNameFromType(init(payload)); + } catch (x) { + return null; + } + } + } + } + return null; + } + function getWrappedName$1(outerType, innerType, wrapperName) { + var functionName = innerType.displayName || innerType.name || ""; + return outerType.displayName || (functionName !== "" ? wrapperName + "(" + functionName + ")" : wrapperName); + } + function getContextName$1(type) { + return type.displayName || "Context"; + } + function getComponentNameFromFiber(fiber) { + var { tag, type } = fiber; + switch (tag) { + case CacheComponent: + return "Cache"; + case ContextConsumer: + var context = type; + return getContextName$1(context) + ".Consumer"; + case ContextProvider: + var provider = type; + return getContextName$1(provider._context) + ".Provider"; + case DehydratedFragment: + return "DehydratedFragment"; + case ForwardRef: + return getWrappedName$1(type, type.render, "ForwardRef"); + case Fragment: + return "Fragment"; + case HostComponent: + return type; + case HostPortal: + return "Portal"; + case HostRoot: + return "Root"; + case HostText: + return "Text"; + case LazyComponent: + return getComponentNameFromType(type); + case Mode: + if (type === REACT_STRICT_MODE_TYPE) { + return "StrictMode"; + } + return "Mode"; + case OffscreenComponent: + return "Offscreen"; + case Profiler: + return "Profiler"; + case ScopeComponent: + return "Scope"; + case SuspenseComponent: + return "Suspense"; + case SuspenseListComponent: + return "SuspenseList"; + case TracingMarkerComponent: + return "TracingMarker"; + case ClassComponent: + case FunctionComponent: + case IncompleteClassComponent: + case IndeterminateComponent: + case MemoComponent: + case SimpleMemoComponent: + if (typeof type === "function") { + return type.displayName || type.name || null; + } + if (typeof type === "string") { + return type; + } + break; + } + return null; + } + var ReactDebugCurrentFrame = ReactSharedInternals.ReactDebugCurrentFrame; + var current = null; + var isRendering = false; + function getCurrentFiberOwnerNameInDevOrNull() { + { + if (current === null) { + return null; + } + var owner = current._debugOwner; + if (owner !== null && typeof owner !== "undefined") { + return getComponentNameFromFiber(owner); + } + } + return null; + } + function getCurrentFiberStackInDev() { + { + if (current === null) { + return ""; + } + return getStackByFiberInDevAndProd(current); + } + } + function resetCurrentFiber() { + { + ReactDebugCurrentFrame.getCurrentStack = null; + current = null; + isRendering = false; + } + } + function setCurrentFiber(fiber) { + { + ReactDebugCurrentFrame.getCurrentStack = fiber === null ? null : getCurrentFiberStackInDev; + current = fiber; + isRendering = false; + } + } + function getCurrentFiber() { + { + return current; + } + } + function setIsRendering(rendering) { + { + isRendering = rendering; + } + } + function toString(value) { + return "" + value; + } + function getToStringValue(value) { + switch (typeof value) { + case "boolean": + case "number": + case "string": + case "undefined": + return value; + case "object": + { + checkFormFieldValueStringCoercion(value); + } + return value; + default: + return ""; + } + } + var hasReadOnlyValue = { + button: true, + checkbox: true, + image: true, + hidden: true, + radio: true, + reset: true, + submit: true + }; + function checkControlledValueProps(tagName, props) { + { + if (!(hasReadOnlyValue[props.type] || props.onChange || props.onInput || props.readOnly || props.disabled || props.value == null)) { + error("You provided a `value` prop to a form field without an `onChange` handler. This will render a read-only field. If the field should be mutable use `defaultValue`. Otherwise, set either `onChange` or `readOnly`."); + } + if (!(props.onChange || props.readOnly || props.disabled || props.checked == null)) { + error("You provided a `checked` prop to a form field without an `onChange` handler. This will render a read-only field. If the field should be mutable use `defaultChecked`. Otherwise, set either `onChange` or `readOnly`."); + } + } + } + function isCheckable(elem) { + var type = elem.type; + var nodeName = elem.nodeName; + return nodeName && nodeName.toLowerCase() === "input" && (type === "checkbox" || type === "radio"); + } + function getTracker(node) { + return node._valueTracker; + } + function detachTracker(node) { + node._valueTracker = null; + } + function getValueFromNode(node) { + var value = ""; + if (!node) { + return value; + } + if (isCheckable(node)) { + value = node.checked ? "true" : "false"; + } else { + value = node.value; + } + return value; + } + function trackValueOnNode(node) { + var valueField = isCheckable(node) ? "checked" : "value"; + var descriptor = Object.getOwnPropertyDescriptor(node.constructor.prototype, valueField); + { + checkFormFieldValueStringCoercion(node[valueField]); + } + var currentValue = "" + node[valueField]; + if (node.hasOwnProperty(valueField) || typeof descriptor === "undefined" || typeof descriptor.get !== "function" || typeof descriptor.set !== "function") { + return; + } + var { get: get2, set: set2 } = descriptor; + Object.defineProperty(node, valueField, { + configurable: true, + get: function() { + return get2.call(this); + }, + set: function(value) { + { + checkFormFieldValueStringCoercion(value); + } + currentValue = "" + value; + set2.call(this, value); + } + }); + Object.defineProperty(node, valueField, { + enumerable: descriptor.enumerable + }); + var tracker = { + getValue: function() { + return currentValue; + }, + setValue: function(value) { + { + checkFormFieldValueStringCoercion(value); + } + currentValue = "" + value; + }, + stopTracking: function() { + detachTracker(node); + delete node[valueField]; + } + }; + return tracker; + } + function track(node) { + if (getTracker(node)) { + return; + } + node._valueTracker = trackValueOnNode(node); + } + function updateValueIfChanged(node) { + if (!node) { + return false; + } + var tracker = getTracker(node); + if (!tracker) { + return true; + } + var lastValue = tracker.getValue(); + var nextValue = getValueFromNode(node); + if (nextValue !== lastValue) { + tracker.setValue(nextValue); + return true; + } + return false; + } + function getActiveElement(doc) { + doc = doc || (typeof document !== "undefined" ? document : undefined); + if (typeof doc === "undefined") { + return null; + } + try { + return doc.activeElement || doc.body; + } catch (e) { + return doc.body; + } + } + var didWarnValueDefaultValue = false; + var didWarnCheckedDefaultChecked = false; + var didWarnControlledToUncontrolled = false; + var didWarnUncontrolledToControlled = false; + function isControlled(props) { + var usesChecked = props.type === "checkbox" || props.type === "radio"; + return usesChecked ? props.checked != null : props.value != null; + } + function getHostProps(element, props) { + var node = element; + var checked = props.checked; + var hostProps = assign({}, props, { + defaultChecked: undefined, + defaultValue: undefined, + value: undefined, + checked: checked != null ? checked : node._wrapperState.initialChecked + }); + return hostProps; + } + function initWrapperState(element, props) { + { + checkControlledValueProps("input", props); + if (props.checked !== undefined && props.defaultChecked !== undefined && !didWarnCheckedDefaultChecked) { + error("%s contains an input of type %s with both checked and defaultChecked props. Input elements must be either controlled or uncontrolled (specify either the checked prop, or the defaultChecked prop, but not both). Decide between using a controlled or uncontrolled input element and remove one of these props. More info: https://reactjs.org/link/controlled-components", getCurrentFiberOwnerNameInDevOrNull() || "A component", props.type); + didWarnCheckedDefaultChecked = true; + } + if (props.value !== undefined && props.defaultValue !== undefined && !didWarnValueDefaultValue) { + error("%s contains an input of type %s with both value and defaultValue props. Input elements must be either controlled or uncontrolled (specify either the value prop, or the defaultValue prop, but not both). Decide between using a controlled or uncontrolled input element and remove one of these props. More info: https://reactjs.org/link/controlled-components", getCurrentFiberOwnerNameInDevOrNull() || "A component", props.type); + didWarnValueDefaultValue = true; + } + } + var node = element; + var defaultValue = props.defaultValue == null ? "" : props.defaultValue; + node._wrapperState = { + initialChecked: props.checked != null ? props.checked : props.defaultChecked, + initialValue: getToStringValue(props.value != null ? props.value : defaultValue), + controlled: isControlled(props) + }; + } + function updateChecked(element, props) { + var node = element; + var checked = props.checked; + if (checked != null) { + setValueForProperty(node, "checked", checked, false); + } + } + function updateWrapper(element, props) { + var node = element; + { + var controlled = isControlled(props); + if (!node._wrapperState.controlled && controlled && !didWarnUncontrolledToControlled) { + error("A component is changing an uncontrolled input to be controlled. This is likely caused by the value changing from undefined to a defined value, which should not happen. Decide between using a controlled or uncontrolled input element for the lifetime of the component. More info: https://reactjs.org/link/controlled-components"); + didWarnUncontrolledToControlled = true; + } + if (node._wrapperState.controlled && !controlled && !didWarnControlledToUncontrolled) { + error("A component is changing a controlled input to be uncontrolled. This is likely caused by the value changing from a defined to undefined, which should not happen. Decide between using a controlled or uncontrolled input element for the lifetime of the component. More info: https://reactjs.org/link/controlled-components"); + didWarnControlledToUncontrolled = true; + } + } + updateChecked(element, props); + var value = getToStringValue(props.value); + var type = props.type; + if (value != null) { + if (type === "number") { + if (value === 0 && node.value === "" || node.value != value) { + node.value = toString(value); + } + } else if (node.value !== toString(value)) { + node.value = toString(value); + } + } else if (type === "submit" || type === "reset") { + node.removeAttribute("value"); + return; + } + { + if (props.hasOwnProperty("value")) { + setDefaultValue(node, props.type, value); + } else if (props.hasOwnProperty("defaultValue")) { + setDefaultValue(node, props.type, getToStringValue(props.defaultValue)); + } + } + { + if (props.checked == null && props.defaultChecked != null) { + node.defaultChecked = !!props.defaultChecked; + } + } + } + function postMountWrapper(element, props, isHydrating2) { + var node = element; + if (props.hasOwnProperty("value") || props.hasOwnProperty("defaultValue")) { + var type = props.type; + var isButton = type === "submit" || type === "reset"; + if (isButton && (props.value === undefined || props.value === null)) { + return; + } + var initialValue = toString(node._wrapperState.initialValue); + if (!isHydrating2) { + { + if (initialValue !== node.value) { + node.value = initialValue; + } + } + } + { + node.defaultValue = initialValue; + } + } + var name = node.name; + if (name !== "") { + node.name = ""; + } + { + node.defaultChecked = !node.defaultChecked; + node.defaultChecked = !!node._wrapperState.initialChecked; + } + if (name !== "") { + node.name = name; + } + } + function restoreControlledState(element, props) { + var node = element; + updateWrapper(node, props); + updateNamedCousins(node, props); + } + function updateNamedCousins(rootNode, props) { + var name = props.name; + if (props.type === "radio" && name != null) { + var queryRoot = rootNode; + while (queryRoot.parentNode) { + queryRoot = queryRoot.parentNode; + } + { + checkAttributeStringCoercion(name, "name"); + } + var group = queryRoot.querySelectorAll("input[name=" + JSON.stringify("" + name) + '][type="radio"]'); + for (var i = 0;i < group.length; i++) { + var otherNode = group[i]; + if (otherNode === rootNode || otherNode.form !== rootNode.form) { + continue; + } + var otherProps = getFiberCurrentPropsFromNode(otherNode); + if (!otherProps) { + throw new Error("ReactDOMInput: Mixing React and non-React radio inputs with the same `name` is not supported."); + } + updateValueIfChanged(otherNode); + updateWrapper(otherNode, otherProps); + } + } + } + function setDefaultValue(node, type, value) { + if (type !== "number" || getActiveElement(node.ownerDocument) !== node) { + if (value == null) { + node.defaultValue = toString(node._wrapperState.initialValue); + } else if (node.defaultValue !== toString(value)) { + node.defaultValue = toString(value); + } + } + } + var didWarnSelectedSetOnOption = false; + var didWarnInvalidChild = false; + var didWarnInvalidInnerHTML = false; + function validateProps(element, props) { + { + if (props.value == null) { + if (typeof props.children === "object" && props.children !== null) { + React.Children.forEach(props.children, function(child) { + if (child == null) { + return; + } + if (typeof child === "string" || typeof child === "number") { + return; + } + if (!didWarnInvalidChild) { + didWarnInvalidChild = true; + error("Cannot infer the option value of complex children. Pass a `value` prop or use a plain string as children to