jbilcke-hf HF staff commited on
Commit
d69879c
0 Parent(s):

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +17 -0
  2. .gitattributes +13 -0
  3. .gitignore +31 -0
  4. Dockerfile +67 -0
  5. LICENSE +56 -0
  6. README.md +119 -0
  7. app.py +264 -0
  8. build.sh +3 -0
  9. client/.gitignore +175 -0
  10. client/README.md +13 -0
  11. client/bun.lockb +0 -0
  12. client/package.json +35 -0
  13. client/src/app.tsx +190 -0
  14. client/src/components/DoubleCard.tsx +18 -0
  15. client/src/components/PoweredBy.tsx +17 -0
  16. client/src/components/Spinner.tsx +7 -0
  17. client/src/components/Title.tsx +8 -0
  18. client/src/components/ui/alert.tsx +59 -0
  19. client/src/hooks/landmarks.ts +520 -0
  20. client/src/hooks/useFaceLandmarkDetection.tsx +632 -0
  21. client/src/hooks/useFacePokeAPI.ts +44 -0
  22. client/src/hooks/useMainStore.ts +58 -0
  23. client/src/index.tsx +6 -0
  24. client/src/layout.tsx +14 -0
  25. client/src/lib/circularBuffer.ts +31 -0
  26. client/src/lib/convertImageToBase64.ts +19 -0
  27. client/src/lib/facePoke.ts +398 -0
  28. client/src/lib/throttle.ts +32 -0
  29. client/src/lib/utils.ts +15 -0
  30. client/src/styles/globals.css +81 -0
  31. client/tailwind.config.js +86 -0
  32. client/tsconfig.json +32 -0
  33. engine.py +300 -0
  34. liveportrait/config/__init__.py +0 -0
  35. liveportrait/config/argument_config.py +44 -0
  36. liveportrait/config/base_config.py +29 -0
  37. liveportrait/config/crop_config.py +18 -0
  38. liveportrait/config/inference_config.py +53 -0
  39. liveportrait/config/models.yaml +43 -0
  40. liveportrait/gradio_pipeline.py +140 -0
  41. liveportrait/live_portrait_pipeline.py +193 -0
  42. liveportrait/live_portrait_wrapper.py +307 -0
  43. liveportrait/modules/__init__.py +0 -0
  44. liveportrait/modules/appearance_feature_extractor.py +48 -0
  45. liveportrait/modules/convnextv2.py +149 -0
  46. liveportrait/modules/dense_motion.py +104 -0
  47. liveportrait/modules/motion_extractor.py +35 -0
  48. liveportrait/modules/spade_generator.py +59 -0
  49. liveportrait/modules/stitching_retargeting_network.py +38 -0
  50. liveportrait/modules/util.py +441 -0
.dockerignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The .dockerignore file excludes files from the container build process.
2
+ #
3
+ # https://docs.docker.com/engine/reference/builder/#dockerignore-file
4
+
5
+ # Exclude Git files
6
+ .git
7
+ .github
8
+ .gitignore
9
+
10
+ # Exclude Python cache files
11
+ __pycache__
12
+ .mypy_cache
13
+ .pytest_cache
14
+ .ruff_cache
15
+
16
+ # Exclude Python virtual environment
17
+ /venv
.gitattributes ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
3
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
4
+ *.png filter=lfs diff=lfs merge=lfs -text
5
+ *.xml filter=lfs diff=lfs merge=lfs -text
6
+ *.zip filter=lfs diff=lfs merge=lfs -text
7
+ *.pdf filter=lfs diff=lfs merge=lfs -text
8
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
9
+ *.wav filter=lfs diff=lfs merge=lfs -text
10
+ *.mpg filter=lfs diff=lfs merge=lfs -text
11
+ *.webp filter=lfs diff=lfs merge=lfs -text
12
+ *.webm filter=lfs diff=lfs merge=lfs -text
13
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ **/__pycache__/
4
+ *.py[cod]
5
+ **/*.py[cod]
6
+ *$py.class
7
+
8
+ # Model weights
9
+ **/*.pth
10
+ **/*.onnx
11
+
12
+ # Ipython notebook
13
+ *.ipynb
14
+
15
+ # Temporary files or benchmark resources
16
+ animations/*
17
+ tmp/*
18
+
19
+ # more ignores
20
+ .DS_Store
21
+ *.log
22
+ .idea/
23
+ .vscode/
24
+ *.pyc
25
+ .ipynb_checkpoints
26
+ results/
27
+ data/audio/*.wav
28
+ data/video/*.mp4
29
+ ffmpeg-7.0-amd64-static
30
+ venv/
31
+ .cog/
Dockerfile ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ RUN apt-get update && apt-get install --no-install-recommends -y \
8
+ build-essential \
9
+ python3.11 \
10
+ python3-pip \
11
+ python3-dev \
12
+ git \
13
+ curl \
14
+ ffmpeg \
15
+ libglib2.0-0 \
16
+ libsm6 \
17
+ libxrender1 \
18
+ libxext6 \
19
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
20
+
21
+ WORKDIR /code
22
+
23
+ COPY ./requirements.txt /code/requirements.txt
24
+
25
+ # Install pget as root
26
+ RUN echo "Installing pget" && \
27
+ curl -o /usr/local/bin/pget -L 'https://github.com/replicate/pget/releases/download/v0.2.1/pget' && \
28
+ chmod +x /usr/local/bin/pget
29
+
30
+ # Set up a new user named "user" with user ID 1000
31
+ RUN useradd -m -u 1000 user
32
+ # Switch to the "user" user
33
+ USER user
34
+ # Set home to the user's home directory
35
+ ENV HOME=/home/user \
36
+ PATH=/home/user/.local/bin:$PATH
37
+
38
+
39
+ # Set home to the user's home directory
40
+ ENV PYTHONPATH=$HOME/app \
41
+ PYTHONUNBUFFERED=1 \
42
+ DATA_ROOT=/tmp/data
43
+
44
+ RUN echo "Installing requirements.txt"
45
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
46
+
47
+ # yeah.. this is manual for now
48
+ #RUN cd client
49
+ #RUN bun i
50
+ #RUN bun build ./src/index.tsx --outdir ../public/
51
+
52
+ RUN echo "Installing openmim and mim dependencies"
53
+ RUN pip3 install --no-cache-dir -U openmim
54
+ RUN mim install mmengine
55
+ RUN mim install "mmcv>=2.0.1"
56
+ RUN mim install "mmdet>=3.3.0"
57
+ RUN mim install "mmpose>=1.3.2"
58
+
59
+ WORKDIR $HOME/app
60
+
61
+ COPY --chown=user . $HOME/app
62
+
63
+ EXPOSE 8080
64
+
65
+ ENV PORT 8080
66
+
67
+ CMD python3 app.py
LICENSE ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## For FacePoke (the modifications I made + the server itself)
2
+
3
+ MIT License
4
+
5
+ Copyright (c) 2024 Julian Bilcke
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+
25
+ ## For LivePortrait
26
+
27
+ MIT License
28
+
29
+ Copyright (c) 2024 Kuaishou Visual Generation and Interaction Center
30
+
31
+ Permission is hereby granted, free of charge, to any person obtaining a copy
32
+ of this software and associated documentation files (the "Software"), to deal
33
+ in the Software without restriction, including without limitation the rights
34
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
35
+ copies of the Software, and to permit persons to whom the Software is
36
+ furnished to do so, subject to the following conditions:
37
+
38
+ The above copyright notice and this permission notice shall be included in all
39
+ copies or substantial portions of the Software.
40
+
41
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
44
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
45
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
46
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
47
+ SOFTWARE.
48
+
49
+ ---
50
+
51
+ The code of InsightFace is released under the MIT License.
52
+ The models of InsightFace are for non-commercial research purposes only.
53
+
54
+ If you want to use the LivePortrait project for commercial purposes, you
55
+ should remove and replace InsightFace’s detection models to fully comply with
56
+ the MIT license.
README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FacePoke
3
+ emoji: 💬
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: docker
7
+ pinned: true
8
+ license: mit
9
+ header: mini
10
+ app_file: app.py
11
+ app_port: 8080
12
+ ---
13
+
14
+ # FacePoke
15
+
16
+ ![FacePoke Demo](https://your-demo-image-url-here.gif)
17
+
18
+ ## Table of Contents
19
+
20
+ - [Introduction](#introduction)
21
+ - [Acknowledgements](#acknowledgements)
22
+ - [Installation](#installation)
23
+ - [Local Setup](#local-setup)
24
+ - [Docker Deployment](#docker-deployment)
25
+ - [Development](#development)
26
+ - [Contributing](#contributing)
27
+ - [License](#license)
28
+
29
+ ## Introduction
30
+
31
+ A real-time head transformation app.
32
+
33
+ For best performance please run the app from your own machine (local or in the cloud).
34
+
35
+ **Repository**: [GitHub - jbilcke-hf/FacePoke](https://github.com/jbilcke-hf/FacePoke)
36
+
37
+ You can try the demo but it is a shared space, latency may be high if there are multiple users or if you live far from the datacenter hosting the Hugging Face Space.
38
+
39
+ **Live Demo**: [FacePoke on Hugging Face Spaces](https://huggingface.co/spaces/jbilcke-hf/FacePoke)
40
+
41
+ ## Acknowledgements
42
+
43
+ This project is based on LivePortrait: https://arxiv.org/abs/2407.03168
44
+
45
+ It uses the face transformation routines from https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait
46
+
47
+ ## Installation
48
+
49
+ ### Local Setup
50
+
51
+ 1. Clone the repository:
52
+ ```bash
53
+ git clone https://github.com/jbilcke-hf/FacePoke.git
54
+ cd FacePoke
55
+ ```
56
+
57
+ 2. Install Python dependencies:
58
+ ```bash
59
+ pip install -r requirements.txt
60
+ ```
61
+
62
+ 3. Install frontend dependencies:
63
+ ```bash
64
+ cd client
65
+ bun install
66
+ ```
67
+
68
+ 4. Build the frontend:
69
+ ```bash
70
+ bun build ./src/index.tsx --outdir ../public/
71
+ ```
72
+
73
+ 5. Start the backend server:
74
+ ```bash
75
+ python app.py
76
+ ```
77
+
78
+ 6. Open `http://localhost:8080` in your web browser.
79
+
80
+ ### Docker Deployment
81
+
82
+ 1. Build the Docker image:
83
+ ```bash
84
+ docker build -t facepoke .
85
+ ```
86
+
87
+ 2. Run the container:
88
+ ```bash
89
+ docker run -p 8080:8080 facepoke
90
+ ```
91
+
92
+ 3. To deploy to Hugging Face Spaces:
93
+ - Fork the repository on GitHub.
94
+ - Create a new Space on Hugging Face.
95
+ - Connect your GitHub repository to the Space.
96
+ - Configure the Space to use the Docker runtime.
97
+
98
+ ## Development
99
+
100
+ The project structure is organized as follows:
101
+
102
+ - `app.py`: Main backend server handling WebSocket connections.
103
+ - `engine.py`: Core logic.
104
+ - `loader.py`: Initializes and loads AI models.
105
+ - `client/`: Frontend React application.
106
+ - `src/`: TypeScript source files.
107
+ - `public/`: Static assets and built files.
108
+
109
+ ## Contributing
110
+
111
+ Contributions to FacePoke are welcome! Please read our [Contributing Guidelines](CONTRIBUTING.md) for details on how to submit pull requests, report issues, or request features.
112
+
113
+ ## License
114
+
115
+ FacePoke is released under the MIT License. See the [LICENSE](LICENSE) file for details.
116
+
117
+ ---
118
+
119
+ Developed with ❤️ by Julian Bilcke at Hugging Face
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FacePoke API
3
+
4
+ Author: Julian Bilcke
5
+ Date: September 30, 2024
6
+ """
7
+
8
+ import sys
9
+ import asyncio
10
+ import hashlib
11
+ from aiohttp import web, WSMsgType
12
+ import json
13
+ import uuid
14
+ import logging
15
+ import os
16
+ import zipfile
17
+ import signal
18
+ from typing import Dict, Any, List, Optional
19
+ import base64
20
+ import io
21
+ from PIL import Image
22
+ import numpy as np
23
+
24
+ # Configure logging
25
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Set asyncio logger to DEBUG level
29
+ logging.getLogger("asyncio").setLevel(logging.DEBUG)
30
+
31
+ logger.debug(f"Python version: {sys.version}")
32
+
33
+ # SIGSEGV handler
34
+ def SIGSEGV_signal_arises(signalNum, stack):
35
+ logger.critical(f"{signalNum} : SIGSEGV arises")
36
+ logger.critical(f"Stack trace: {stack}")
37
+
38
+ signal.signal(signal.SIGSEGV, SIGSEGV_signal_arises)
39
+
40
+ from loader import initialize_models
41
+ from engine import Engine, base64_data_uri_to_PIL_Image, create_engine
42
+
43
+ # Global constants
44
+ DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
45
+ MODELS_DIR = os.path.join(DATA_ROOT, "models")
46
+
47
+ image_cache: Dict[str, Image.Image] = {}
48
+
49
+ async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
50
+ """
51
+ Handle WebSocket connections for the FacePoke application.
52
+
53
+ Args:
54
+ request (web.Request): The incoming request object.
55
+
56
+ Returns:
57
+ web.WebSocketResponse: The WebSocket response object.
58
+ """
59
+ ws = web.WebSocketResponse()
60
+ await ws.prepare(request)
61
+
62
+ session: Optional[FacePokeSession] = None
63
+ try:
64
+ logger.info("New WebSocket connection established")
65
+
66
+ while True:
67
+ msg = await ws.receive()
68
+
69
+ if msg.type == WSMsgType.TEXT:
70
+ data = json.loads(msg.data)
71
+
72
+ # let's not log user requests, they are heavy
73
+ #logger.debug(f"Received message: {data}")
74
+
75
+ if data['type'] == 'modify_image':
76
+ uuid = data.get('uuid')
77
+ if not uuid:
78
+ logger.warning("Received message without UUID")
79
+
80
+ await handle_modify_image(request, ws, data, uuid)
81
+
82
+
83
+ elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
84
+ logger.warning(f"WebSocket connection closed: {msg.type}")
85
+ break
86
+
87
+ except Exception as e:
88
+ logger.error(f"Error in websocket_handler: {str(e)}")
89
+ logger.exception("Full traceback:")
90
+ finally:
91
+ if session:
92
+ await session.stop()
93
+ del active_sessions[session.session_id]
94
+ logger.info("WebSocket connection closed")
95
+ return ws
96
+
97
+ async def handle_modify_image(request: web.Request, ws: web.WebSocketResponse, msg: Dict[str, Any], uuid: str):
98
+ """
99
+ Handle the 'modify_image' request.
100
+
101
+ Args:
102
+ request (web.Request): The incoming request object.
103
+ ws (web.WebSocketResponse): The WebSocket response object.
104
+ msg (Dict[str, Any]): The message containing the image or image_hash and modification parameters.
105
+ uuid: A unique identifier for the request.
106
+ """
107
+ logger.info("Received modify_image request")
108
+ try:
109
+ engine = request.app['engine']
110
+ image_hash = msg.get('image_hash')
111
+
112
+ if image_hash:
113
+ image_or_hash = image_hash
114
+ else:
115
+ image_data = msg['image']
116
+ image_or_hash = image_data
117
+
118
+ modified_image_base64 = await engine.modify_image(image_or_hash, msg['params'])
119
+
120
+ await ws.send_json({
121
+ "type": "modified_image",
122
+ "image": modified_image_base64,
123
+ "image_hash": engine.get_image_hash(image_or_hash),
124
+ "success": True,
125
+ "uuid": uuid # Include the UUID in the response
126
+ })
127
+ logger.info("Successfully sent modified image")
128
+ except Exception as e:
129
+ logger.error(f"Error in modify_image: {str(e)}")
130
+ await ws.send_json({
131
+ "type": "modified_image",
132
+ "success": False,
133
+ "error": str(e),
134
+ "uuid": uuid # Include the UUID even in error responses
135
+ })
136
+
137
+ async def index(request: web.Request) -> web.Response:
138
+ """Serve the index.html file"""
139
+ content = open(os.path.join(os.path.dirname(__file__), "public", "index.html"), "r").read()
140
+ return web.Response(content_type="text/html", text=content)
141
+
142
+ async def js_index(request: web.Request) -> web.Response:
143
+ """Serve the index.js file"""
144
+ content = open(os.path.join(os.path.dirname(__file__), "public", "index.js"), "r").read()
145
+ return web.Response(content_type="application/javascript", text=content)
146
+
147
+ async def hf_logo(request: web.Request) -> web.Response:
148
+ """Serve the hf-logo.svg file"""
149
+ content = open(os.path.join(os.path.dirname(__file__), "public", "hf-logo.svg"), "r").read()
150
+ return web.Response(content_type="image/svg+xml", text=content)
151
+
152
+ async def on_shutdown(app: web.Application):
153
+ """Cleanup function to be called on server shutdown."""
154
+ logger.info("Server shutdown initiated, cleaning up resources...")
155
+ for session in list(active_sessions.values()):
156
+ await session.stop()
157
+ active_sessions.clear()
158
+ logger.info("All active sessions have been closed")
159
+
160
+ if 'engine' in app:
161
+ await app['engine'].cleanup()
162
+ logger.info("Engine instance cleaned up")
163
+
164
+ logger.info("Server shutdown complete")
165
+
166
+ async def initialize_app() -> web.Application:
167
+ """Initialize and configure the web application."""
168
+ try:
169
+ logger.info("Initializing application...")
170
+ models = await initialize_models()
171
+ logger.info("🚀 Creating Engine instance...")
172
+ engine = create_engine(models)
173
+ logger.info("✅ Engine instance created.")
174
+
175
+ app = web.Application()
176
+ app['engine'] = engine
177
+
178
+ app.on_shutdown.append(on_shutdown)
179
+
180
+ # Configure routes
181
+ app.router.add_get("/", index)
182
+ app.router.add_get("/index.js", js_index)
183
+ app.router.add_get("/hf-logo.svg", hf_logo)
184
+ app.router.add_get("/ws", websocket_handler)
185
+
186
+ logger.info("Application routes configured")
187
+
188
+ return app
189
+ except Exception as e:
190
+ logger.error(f"🚨 Error during application initialization: {str(e)}")
191
+ logger.exception("Full traceback:")
192
+ raise
193
+
194
+ async def start_background_tasks(app: web.Application):
195
+ """
196
+ Start background tasks for the application.
197
+
198
+ Args:
199
+ app (web.Application): The web application instance.
200
+ """
201
+ app['cleanup_task'] = asyncio.create_task(periodic_cleanup(app))
202
+
203
+ async def cleanup_background_tasks(app: web.Application):
204
+ """
205
+ Clean up background tasks when the application is shutting down.
206
+
207
+ Args:
208
+ app (web.Application): The web application instance.
209
+ """
210
+ app['cleanup_task'].cancel()
211
+ await app['cleanup_task']
212
+
213
+ async def periodic_cleanup(app: web.Application):
214
+ """
215
+ Perform periodic cleanup tasks for the application.
216
+
217
+ Args:
218
+ app (web.Application): The web application instance.
219
+ """
220
+ while True:
221
+ try:
222
+ await asyncio.sleep(3600) # Run cleanup every hour
223
+ await cleanup_inactive_sessions(app)
224
+ except asyncio.CancelledError:
225
+ break
226
+ except Exception as e:
227
+ logger.error(f"Error in periodic cleanup: {str(e)}")
228
+ logger.exception("Full traceback:")
229
+
230
+ async def cleanup_inactive_sessions(app: web.Application):
231
+ """
232
+ Clean up inactive sessions.
233
+
234
+ Args:
235
+ app (web.Application): The web application instance.
236
+ """
237
+ logger.info("Starting cleanup of inactive sessions")
238
+ inactive_sessions = [
239
+ session_id for session_id, session in active_sessions.items()
240
+ if not session.is_running.is_set()
241
+ ]
242
+ for session_id in inactive_sessions:
243
+ session = active_sessions.pop(session_id)
244
+ await session.stop()
245
+ logger.info(f"Cleaned up inactive session: {session_id}")
246
+ logger.info(f"Cleaned up {len(inactive_sessions)} inactive sessions")
247
+
248
+ def main():
249
+ """
250
+ Main function to start the FacePoke application.
251
+ """
252
+ try:
253
+ logger.info("Starting FacePoke application")
254
+ app = asyncio.run(initialize_app())
255
+ app.on_startup.append(start_background_tasks)
256
+ app.on_cleanup.append(cleanup_background_tasks)
257
+ logger.info("Application initialized, starting web server")
258
+ web.run_app(app, host="0.0.0.0", port=8080)
259
+ except Exception as e:
260
+ logger.critical(f"🚨 FATAL: Failed to start the app: {str(e)}")
261
+ logger.exception("Full traceback:")
262
+
263
+ if __name__ == "__main__":
264
+ main()
build.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ cd client
2
+ bun i
3
+ bun build ./src/index.tsx --outdir ../public/
client/.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://raw.githubusercontent.com/github/gitignore/main/Node.gitignore
2
+
3
+ # Logs
4
+
5
+ logs
6
+ _.log
7
+ npm-debug.log_
8
+ yarn-debug.log*
9
+ yarn-error.log*
10
+ lerna-debug.log*
11
+ .pnpm-debug.log*
12
+
13
+ # Caches
14
+
15
+ .cache
16
+
17
+ # Diagnostic reports (https://nodejs.org/api/report.html)
18
+
19
+ report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json
20
+
21
+ # Runtime data
22
+
23
+ pids
24
+ _.pid
25
+ _.seed
26
+ *.pid.lock
27
+
28
+ # Directory for instrumented libs generated by jscoverage/JSCover
29
+
30
+ lib-cov
31
+
32
+ # Coverage directory used by tools like istanbul
33
+
34
+ coverage
35
+ *.lcov
36
+
37
+ # nyc test coverage
38
+
39
+ .nyc_output
40
+
41
+ # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
42
+
43
+ .grunt
44
+
45
+ # Bower dependency directory (https://bower.io/)
46
+
47
+ bower_components
48
+
49
+ # node-waf configuration
50
+
51
+ .lock-wscript
52
+
53
+ # Compiled binary addons (https://nodejs.org/api/addons.html)
54
+
55
+ build/Release
56
+
57
+ # Dependency directories
58
+
59
+ node_modules/
60
+ jspm_packages/
61
+
62
+ # Snowpack dependency directory (https://snowpack.dev/)
63
+
64
+ web_modules/
65
+
66
+ # TypeScript cache
67
+
68
+ *.tsbuildinfo
69
+
70
+ # Optional npm cache directory
71
+
72
+ .npm
73
+
74
+ # Optional eslint cache
75
+
76
+ .eslintcache
77
+
78
+ # Optional stylelint cache
79
+
80
+ .stylelintcache
81
+
82
+ # Microbundle cache
83
+
84
+ .rpt2_cache/
85
+ .rts2_cache_cjs/
86
+ .rts2_cache_es/
87
+ .rts2_cache_umd/
88
+
89
+ # Optional REPL history
90
+
91
+ .node_repl_history
92
+
93
+ # Output of 'npm pack'
94
+
95
+ *.tgz
96
+
97
+ # Yarn Integrity file
98
+
99
+ .yarn-integrity
100
+
101
+ # dotenv environment variable files
102
+
103
+ .env
104
+ .env.development.local
105
+ .env.test.local
106
+ .env.production.local
107
+ .env.local
108
+
109
+ # parcel-bundler cache (https://parceljs.org/)
110
+
111
+ .parcel-cache
112
+
113
+ # Next.js build output
114
+
115
+ .next
116
+ out
117
+
118
+ # Nuxt.js build / generate output
119
+
120
+ .nuxt
121
+ dist
122
+
123
+ # Gatsby files
124
+
125
+ # Comment in the public line in if your project uses Gatsby and not Next.js
126
+
127
+ # https://nextjs.org/blog/next-9-1#public-directory-support
128
+
129
+ # public
130
+
131
+ # vuepress build output
132
+
133
+ .vuepress/dist
134
+
135
+ # vuepress v2.x temp and cache directory
136
+
137
+ .temp
138
+
139
+ # Docusaurus cache and generated files
140
+
141
+ .docusaurus
142
+
143
+ # Serverless directories
144
+
145
+ .serverless/
146
+
147
+ # FuseBox cache
148
+
149
+ .fusebox/
150
+
151
+ # DynamoDB Local files
152
+
153
+ .dynamodb/
154
+
155
+ # TernJS port file
156
+
157
+ .tern-port
158
+
159
+ # Stores VSCode versions used for testing VSCode extensions
160
+
161
+ .vscode-test
162
+
163
+ # yarn v2
164
+
165
+ .yarn/cache
166
+ .yarn/unplugged
167
+ .yarn/build-state.yml
168
+ .yarn/install-state.gz
169
+ .pnp.*
170
+
171
+ # IntelliJ based IDEs
172
+ .idea
173
+
174
+ # Finder (MacOS) folder config
175
+ .DS_Store
client/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FacePoke.js
2
+
3
+ To install dependencies:
4
+
5
+ ```bash
6
+ bun i
7
+ ```
8
+
9
+ To build:
10
+
11
+ ```bash
12
+ bun build ./src/index.tsx --outdir ../public
13
+ ```
client/bun.lockb ADDED
Binary file (54.9 kB). View file
 
client/package.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "@aitube/facepoke",
3
+ "module": "src/index.ts",
4
+ "type": "module",
5
+ "scripts": {
6
+ "build": "bun build ./src/index.tsx --outdir ../public/"
7
+ },
8
+ "devDependencies": {
9
+ "@types/bun": "latest"
10
+ },
11
+ "peerDependencies": {
12
+ "typescript": "^5.0.0"
13
+ },
14
+ "dependencies": {
15
+ "@mediapipe/tasks-vision": "^0.10.16",
16
+ "@radix-ui/react-icons": "^1.3.0",
17
+ "@types/lodash": "^4.17.10",
18
+ "@types/react": "^18.3.9",
19
+ "@types/react-dom": "^18.3.0",
20
+ "@types/uuid": "^10.0.0",
21
+ "beautiful-react-hooks": "^5.0.2",
22
+ "class-variance-authority": "^0.7.0",
23
+ "clsx": "^2.1.1",
24
+ "lodash": "^4.17.21",
25
+ "lucide-react": "^0.446.0",
26
+ "react": "^18.3.1",
27
+ "react-dom": "^18.3.1",
28
+ "tailwind-merge": "^2.5.2",
29
+ "tailwindcss": "^3.4.13",
30
+ "tailwindcss-animate": "^1.0.7",
31
+ "usehooks-ts": "^3.1.0",
32
+ "uuid": "^10.0.0",
33
+ "zustand": "^5.0.0-rc.2"
34
+ }
35
+ }
client/src/app.tsx ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect, useRef, useCallback, useMemo } from 'react';
2
+ import { RotateCcw } from 'lucide-react';
3
+
4
+ import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert';
5
+ import { truncateFileName } from './lib/utils';
6
+ import { useFaceLandmarkDetection } from './hooks/useFaceLandmarkDetection';
7
+ import { PoweredBy } from './components/PoweredBy';
8
+ import { Spinner } from './components/Spinner';
9
+ import { DoubleCard } from './components/DoubleCard';
10
+ import { useFacePokeAPI } from './hooks/useFacePokeAPI';
11
+ import { Layout } from './layout';
12
+ import { useMainStore } from './hooks/useMainStore';
13
+ import { convertImageToBase64 } from './lib/convertImageToBase64';
14
+
15
+ export function App() {
16
+ const error = useMainStore(s => s.error);
17
+ const setError = useMainStore(s => s.setError);
18
+ const imageFile = useMainStore(s => s.imageFile);
19
+ const setImageFile = useMainStore(s => s.setImageFile);
20
+ const originalImage = useMainStore(s => s.originalImage);
21
+ const setOriginalImage = useMainStore(s => s.setOriginalImage);
22
+ const previewImage = useMainStore(s => s.previewImage);
23
+ const setPreviewImage = useMainStore(s => s.setPreviewImage);
24
+ const resetImage = useMainStore(s => s.resetImage);
25
+
26
+ const {
27
+ status,
28
+ setStatus,
29
+ isDebugMode,
30
+ setIsDebugMode,
31
+ interruptMessage,
32
+ } = useFacePokeAPI()
33
+
34
+ // State for face detection
35
+ const {
36
+ canvasRef,
37
+ canvasRefCallback,
38
+ mediaPipeRef,
39
+ faceLandmarks,
40
+ isMediaPipeReady,
41
+ blendShapes,
42
+
43
+ setFaceLandmarks,
44
+ setBlendShapes,
45
+
46
+ handleMouseDown,
47
+ handleMouseUp,
48
+ handleMouseMove,
49
+ handleMouseEnter,
50
+ handleMouseLeave,
51
+ currentOpacity
52
+ } = useFaceLandmarkDetection()
53
+
54
+ // Refs
55
+ const videoRef = useRef<HTMLDivElement>(null);
56
+
57
+ // Handle file change
58
+ const handleFileChange = useCallback(async (event: React.ChangeEvent<HTMLInputElement>) => {
59
+ const files = event.target.files;
60
+ if (files && files[0]) {
61
+ setImageFile(files[0]);
62
+ setStatus(`File selected: ${truncateFileName(files[0].name, 16)}`);
63
+
64
+ try {
65
+ const image = await convertImageToBase64(files[0]);
66
+ setPreviewImage(image);
67
+ setOriginalImage(image);
68
+ } catch (err) {
69
+ console.log(`failed to convert the image: `, err);
70
+ setImageFile(null);
71
+ setStatus('');
72
+ setPreviewImage('');
73
+ setOriginalImage('');
74
+ setFaceLandmarks([]);
75
+ setBlendShapes([]);
76
+ }
77
+ } else {
78
+ setImageFile(null);
79
+ setStatus('');
80
+ setPreviewImage('');
81
+ setOriginalImage('');
82
+ setFaceLandmarks([]);
83
+ setBlendShapes([]);
84
+ }
85
+ }, [isMediaPipeReady, setImageFile, setPreviewImage, setOriginalImage, setFaceLandmarks, setBlendShapes, setStatus]);
86
+
87
+ const canDisplayBlendShapes = false
88
+
89
+ // Display blend shapes
90
+ const displayBlendShapes = useMemo(() => (
91
+ <div className="mt-4">
92
+ <h3 className="text-lg font-semibold mb-2">Blend Shapes</h3>
93
+ <ul className="space-y-1">
94
+ {(blendShapes?.[0]?.categories || []).map((shape, index) => (
95
+ <li key={index} className="flex items-center">
96
+ <span className="w-32 text-sm">{shape.categoryName || shape.displayName}</span>
97
+ <div className="w-full bg-gray-200 rounded-full h-2.5">
98
+ <div
99
+ className="bg-blue-600 h-2.5 rounded-full"
100
+ style={{ width: `${shape.score * 100}%` }}
101
+ ></div>
102
+ </div>
103
+ <span className="ml-2 text-sm">{shape.score.toFixed(2)}</span>
104
+ </li>
105
+ ))}
106
+ </ul>
107
+ </div>
108
+ ), [JSON.stringify(blendShapes)])
109
+
110
+ // JSX
111
+ return (
112
+ <Layout>
113
+ {error && (
114
+ <Alert variant="destructive">
115
+ <AlertTitle>Error</AlertTitle>
116
+ <AlertDescription>{error}</AlertDescription>
117
+ </Alert>
118
+ )}
119
+ {interruptMessage && (
120
+ <Alert>
121
+ <AlertTitle>Notice</AlertTitle>
122
+ <AlertDescription>{interruptMessage}</AlertDescription>
123
+ </Alert>
124
+ )}
125
+ <div className="mb-5 relative">
126
+ <div className="flex flex-row items-center justify-between w-full">
127
+ <div className="relative">
128
+ <input
129
+ id="imageInput"
130
+ type="file"
131
+ accept="image/*"
132
+ onChange={handleFileChange}
133
+ className="hidden"
134
+ disabled={!isMediaPipeReady}
135
+ />
136
+ <label
137
+ htmlFor="imageInput"
138
+ className={`cursor-pointer inline-flex items-center px-3 py-1.5 border border-transparent text-sm font-medium rounded-md text-white ${
139
+ isMediaPipeReady ? 'bg-gray-600 hover:bg-gray-500' : 'bg-gray-500 cursor-not-allowed'
140
+ } focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-gray-500 shadow-xl`}
141
+ >
142
+ <Spinner />
143
+ {imageFile ? truncateFileName(imageFile.name, 32) : (isMediaPipeReady ? 'Choose an image' : 'Initializing...')}
144
+ </label>
145
+ </div>
146
+ {previewImage && <label className="mt-4 flex items-center">
147
+ <input
148
+ type="checkbox"
149
+ checked={isDebugMode}
150
+ onChange={(e) => setIsDebugMode(e.target.checked)}
151
+ className="mr-2"
152
+ />
153
+ Show face landmarks on hover
154
+ </label>}
155
+ </div>
156
+ {previewImage && (
157
+ <div className="mt-5 relative shadow-2xl rounded-xl overflow-hidden">
158
+ <img
159
+ src={previewImage}
160
+ alt="Preview"
161
+ className="w-full"
162
+ />
163
+ <canvas
164
+ ref={canvasRefCallback}
165
+ className="absolute top-0 left-0 w-full h-full select-none"
166
+ onMouseEnter={handleMouseEnter}
167
+ onMouseLeave={handleMouseLeave}
168
+ onMouseDown={handleMouseDown}
169
+ onMouseUp={handleMouseUp}
170
+ onMouseMove={handleMouseMove}
171
+ style={{
172
+ position: 'absolute',
173
+ top: 0,
174
+ left: 0,
175
+ width: '100%',
176
+ height: '100%',
177
+ opacity: isDebugMode ? currentOpacity : 0.0,
178
+ transition: 'opacity 0.2s ease-in-out'
179
+ }}
180
+
181
+ />
182
+ </div>
183
+ )}
184
+ {canDisplayBlendShapes && displayBlendShapes}
185
+ </div>
186
+ <PoweredBy />
187
+
188
+ </Layout>
189
+ );
190
+ }
client/src/components/DoubleCard.tsx ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { type ReactNode } from 'react';
2
+
3
+ export function DoubleCard({ children }: { children: ReactNode }) {
4
+ return (
5
+ <>
6
+ <div className="absolute inset-0 bg-gradient-to-r from-cyan-200 to-sky-300 shadow-2xl transform -skew-y-6 sm:skew-y-0 sm:-rotate-6 sm:rounded-3xl" style={{ borderTop: "solid 2px rgba(255, 255, 255, 0.2)" }}></div>
7
+ <div className="relative px-5 py-8 bg-gradient-to-r from-cyan-100 to-sky-200 shadow-2xl sm:rounded-3xl sm:p-12" style={{ borderTop: "solid 2px #ffffff33" }}>
8
+ <div className="max-w-lg mx-auto">
9
+ <div className="divide-y divide-gray-200">
10
+ <div className="text-lg leading-7 space-y-5 text-gray-700 sm:text-xl sm:leading-8">
11
+ {children}
12
+ </div>
13
+ </div>
14
+ </div>
15
+ </div>
16
+ </>
17
+ );
18
+ }
client/src/components/PoweredBy.tsx ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export function PoweredBy() {
2
+ return (
3
+ <div className="flex flex-row items-center justify-center font-sans mt-4 w-full">
4
+ {/*<span className="text-neutral-900 text-sm"
5
+ style={{ textShadow: "rgb(255 255 255 / 80%) 0px 0px 2px" }}>
6
+ Best hosted on
7
+ </span>*/}
8
+ <span className="ml-2 mr-1">
9
+ <img src="/hf-logo.svg" alt="Hugging Face" className="w-5 h-5" />
10
+ </span>
11
+ <span className="text-neutral-900 text-sm font-semibold"
12
+ style={{ textShadow: "rgb(255 255 255 / 80%) 0px 0px 2px" }}>
13
+ Hugging Face
14
+ </span>
15
+ </div>
16
+ )
17
+ }
client/src/components/Spinner.tsx ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ export function Spinner() {
2
+ return (
3
+ <svg className="mr-3 h-6 w-6" fill="none" viewBox="0 0 24 24" stroke="currentColor">
4
+ <path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z" />
5
+ </svg>
6
+ )
7
+ }
client/src/components/Title.tsx ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ export function Title() {
2
+ return (
3
+ <h2 className="bg-gradient-to-bl from-sky-500 to-sky-800 bg-clip-text text-5xl font-extrabold text-transparent leading-normal text-center"
4
+ style={{ textShadow: "rgb(176 229 255 / 61%) 0px 0px 2px" }}>
5
+ 💬 FacePoke
6
+ </h2>
7
+ )
8
+ }
client/src/components/ui/alert.tsx ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as React from "react"
2
+ import { cva, type VariantProps } from "class-variance-authority"
3
+
4
+ import { cn } from "@/lib/utils"
5
+
6
+ const alertVariants = cva(
7
+ "relative w-full rounded-lg border p-4 [&>svg~*]:pl-7 [&>svg+div]:translate-y-[-3px] [&>svg]:absolute [&>svg]:left-4 [&>svg]:top-4 [&>svg]:text-foreground",
8
+ {
9
+ variants: {
10
+ variant: {
11
+ default: "bg-background text-foreground",
12
+ destructive:
13
+ "border-destructive/50 text-destructive dark:border-destructive [&>svg]:text-destructive",
14
+ },
15
+ },
16
+ defaultVariants: {
17
+ variant: "default",
18
+ },
19
+ }
20
+ )
21
+
22
+ const Alert = React.forwardRef<
23
+ HTMLDivElement,
24
+ React.HTMLAttributes<HTMLDivElement> & VariantProps<typeof alertVariants>
25
+ >(({ className, variant, ...props }, ref) => (
26
+ <div
27
+ ref={ref}
28
+ role="alert"
29
+ className={cn(alertVariants({ variant }), className)}
30
+ {...props}
31
+ />
32
+ ))
33
+ Alert.displayName = "Alert"
34
+
35
+ const AlertTitle = React.forwardRef<
36
+ HTMLParagraphElement,
37
+ React.HTMLAttributes<HTMLHeadingElement>
38
+ >(({ className, ...props }, ref) => (
39
+ <h5
40
+ ref={ref}
41
+ className={cn("mb-1 font-medium leading-none tracking-tight", className)}
42
+ {...props}
43
+ />
44
+ ))
45
+ AlertTitle.displayName = "AlertTitle"
46
+
47
+ const AlertDescription = React.forwardRef<
48
+ HTMLParagraphElement,
49
+ React.HTMLAttributes<HTMLParagraphElement>
50
+ >(({ className, ...props }, ref) => (
51
+ <div
52
+ ref={ref}
53
+ className={cn("text-sm [&_p]:leading-relaxed", className)}
54
+ {...props}
55
+ />
56
+ ))
57
+ AlertDescription.displayName = "AlertDescription"
58
+
59
+ export { Alert, AlertTitle, AlertDescription }
client/src/hooks/landmarks.ts ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as vision from '@mediapipe/tasks-vision';
2
+
3
+ // Define unique colors for each landmark group
4
+ export const landmarkColors: { [key: string]: string } = {
5
+ lips: '#FF0000',
6
+ leftEye: '#00FF00',
7
+ leftEyebrow: '#0000FF',
8
+ leftIris: '#FFFF00',
9
+ rightEye: '#FF00FF',
10
+ rightEyebrow: '#00FFFF',
11
+ rightIris: '#FFA500',
12
+ faceOval: '#800080',
13
+ tesselation: '#C0C0C0',
14
+ };
15
+
16
+ // Define landmark groups with their semantic names
17
+ export const landmarkGroups: { [key: string]: any } = {
18
+ lips: vision.FaceLandmarker.FACE_LANDMARKS_LIPS,
19
+ leftEye: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_EYE,
20
+ leftEyebrow: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_EYEBROW,
21
+ leftIris: vision.FaceLandmarker.FACE_LANDMARKS_LEFT_IRIS,
22
+ rightEye: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_EYE,
23
+ rightEyebrow: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_EYEBROW,
24
+ rightIris: vision.FaceLandmarker.FACE_LANDMARKS_RIGHT_IRIS,
25
+ faceOval: vision.FaceLandmarker.FACE_LANDMARKS_FACE_OVAL,
26
+ // tesselation: vision.FaceLandmarker.FACE_LANDMARKS_TESSELATION,
27
+ };
28
+
29
+ export const FACEMESH_LIPS = Object.freeze(new Set([[61, 146], [146, 91], [91, 181], [181, 84], [84, 17],
30
+ [17, 314], [314, 405], [405, 321], [321, 375],
31
+ [375, 291], [61, 185], [185, 40], [40, 39], [39, 37],
32
+ [37, 0], [0, 267],
33
+ [267, 269], [269, 270], [270, 409], [409, 291],
34
+ [78, 95], [95, 88], [88, 178], [178, 87], [87, 14],
35
+ [14, 317], [317, 402], [402, 318], [318, 324],
36
+ [324, 308], [78, 191], [191, 80], [80, 81], [81, 82],
37
+ [82, 13], [13, 312], [312, 311], [311, 310],
38
+ [310, 415], [415, 308]]))
39
+
40
+ export const FACEMESH_LEFT_EYE = Object.freeze(new Set([[263, 249], [249, 390], [390, 373], [373, 374],
41
+ [374, 380], [380, 381], [381, 382], [382, 362],
42
+ [263, 466], [466, 388], [388, 387], [387, 386],
43
+ [386, 385], [385, 384], [384, 398], [398, 362]]))
44
+
45
+ export const FACEMESH_LEFT_IRIS = Object.freeze(new Set([[474, 475], [475, 476], [476, 477],
46
+ [477, 474]]))
47
+
48
+ export const FACEMESH_LEFT_EYEBROW = Object.freeze(new Set([[276, 283], [283, 282], [282, 295],
49
+ [295, 285], [300, 293], [293, 334],
50
+ [334, 296], [296, 336]]))
51
+
52
+ export const FACEMESH_RIGHT_EYE = Object.freeze(new Set([[33, 7], [7, 163], [163, 144], [144, 145],
53
+ [145, 153], [153, 154], [154, 155], [155, 133],
54
+ [33, 246], [246, 161], [161, 160], [160, 159],
55
+ [159, 158], [158, 157], [157, 173], [173, 133]]))
56
+
57
+ export const FACEMESH_RIGHT_EYEBROW = Object.freeze(new Set([[46, 53], [53, 52], [52, 65], [65, 55],
58
+ [70, 63], [63, 105], [105, 66], [66, 107]]))
59
+
60
+ export const FACEMESH_RIGHT_IRIS = Object.freeze(new Set([[469, 470], [470, 471], [471, 472],
61
+ [472, 469]]))
62
+
63
+ export const FACEMESH_FACE_OVAL = Object.freeze(new Set([[10, 338], [338, 297], [297, 332], [332, 284],
64
+ [284, 251], [251, 389], [389, 356], [356, 454],
65
+ [454, 323], [323, 361], [361, 288], [288, 397],
66
+ [397, 365], [365, 379], [379, 378], [378, 400],
67
+ [400, 377], [377, 152], [152, 148], [148, 176],
68
+ [176, 149], [149, 150], [150, 136], [136, 172],
69
+ [172, 58], [58, 132], [132, 93], [93, 234],
70
+ [234, 127], [127, 162], [162, 21], [21, 54],
71
+ [54, 103], [103, 67], [67, 109], [109, 10]]))
72
+
73
+ export const FACEMESH_NOSE = Object.freeze(new Set([[168, 6], [6, 197], [197, 195], [195, 5],
74
+ [5, 4], [4, 1], [1, 19], [19, 94], [94, 2], [98, 97],
75
+ [97, 2], [2, 326], [326, 327], [327, 294],
76
+ [294, 278], [278, 344], [344, 440], [440, 275],
77
+ [275, 4], [4, 45], [45, 220], [220, 115], [115, 48],
78
+ [48, 64], [64, 98]]))
79
+
80
+ export const FACEMESH_CONTOURS = Object.freeze(new Set([
81
+ ...FACEMESH_LIPS,
82
+ ...FACEMESH_LEFT_EYE,
83
+ ...FACEMESH_LEFT_EYEBROW,
84
+ ...FACEMESH_RIGHT_EYE,
85
+ ...FACEMESH_RIGHT_EYEBROW,
86
+ ...FACEMESH_FACE_OVAL
87
+ ]));
88
+
89
+ export const FACEMESH_IRISES = Object.freeze(new Set([
90
+ ...FACEMESH_LEFT_IRIS,
91
+ ...FACEMESH_RIGHT_IRIS
92
+ ]));
93
+
94
+ export const FACEMESH_TESSELATION = Object.freeze(new Set([
95
+ [127, 34], [34, 139], [139, 127], [11, 0], [0, 37], [37, 11],
96
+ [232, 231], [231, 120], [120, 232], [72, 37], [37, 39], [39, 72],
97
+ [128, 121], [121, 47], [47, 128], [232, 121], [121, 128], [128, 232],
98
+ [104, 69], [69, 67], [67, 104], [175, 171], [171, 148], [148, 175],
99
+ [118, 50], [50, 101], [101, 118], [73, 39], [39, 40], [40, 73],
100
+ [9, 151], [151, 108], [108, 9], [48, 115], [115, 131], [131, 48],
101
+ [194, 204], [204, 211], [211, 194], [74, 40], [40, 185], [185, 74],
102
+ [80, 42], [42, 183], [183, 80], [40, 92], [92, 186], [186, 40],
103
+ [230, 229], [229, 118], [118, 230], [202, 212], [212, 214], [214, 202],
104
+ [83, 18], [18, 17], [17, 83], [76, 61], [61, 146], [146, 76],
105
+ [160, 29], [29, 30], [30, 160], [56, 157], [157, 173], [173, 56],
106
+ [106, 204], [204, 194], [194, 106], [135, 214], [214, 192], [192, 135],
107
+ [203, 165], [165, 98], [98, 203], [21, 71], [71, 68], [68, 21],
108
+ [51, 45], [45, 4], [4, 51], [144, 24], [24, 23], [23, 144],
109
+ [77, 146], [146, 91], [91, 77], [205, 50], [50, 187], [187, 205],
110
+ [201, 200], [200, 18], [18, 201], [91, 106], [106, 182], [182, 91],
111
+ [90, 91], [91, 181], [181, 90], [85, 84], [84, 17], [17, 85],
112
+ [206, 203], [203, 36], [36, 206], [148, 171], [171, 140], [140, 148],
113
+ [92, 40], [40, 39], [39, 92], [193, 189], [189, 244], [244, 193],
114
+ [159, 158], [158, 28], [28, 159], [247, 246], [246, 161], [161, 247],
115
+ [236, 3], [3, 196], [196, 236], [54, 68], [68, 104], [104, 54],
116
+ [193, 168], [168, 8], [8, 193], [117, 228], [228, 31], [31, 117],
117
+ [189, 193], [193, 55], [55, 189], [98, 97], [97, 99], [99, 98],
118
+ [126, 47], [47, 100], [100, 126], [166, 79], [79, 218], [218, 166],
119
+ [155, 154], [154, 26], [26, 155], [209, 49], [49, 131], [131, 209],
120
+ [135, 136], [136, 150], [150, 135], [47, 126], [126, 217], [217, 47],
121
+ [223, 52], [52, 53], [53, 223], [45, 51], [51, 134], [134, 45],
122
+ [211, 170], [170, 140], [140, 211], [67, 69], [69, 108], [108, 67],
123
+ [43, 106], [106, 91], [91, 43], [230, 119], [119, 120], [120, 230],
124
+ [226, 130], [130, 247], [247, 226], [63, 53], [53, 52], [52, 63],
125
+ [238, 20], [20, 242], [242, 238], [46, 70], [70, 156], [156, 46],
126
+ [78, 62], [62, 96], [96, 78], [46, 53], [53, 63], [63, 46],
127
+ [143, 34], [34, 227], [227, 143], [123, 117], [117, 111], [111, 123],
128
+ [44, 125], [125, 19], [19, 44], [236, 134], [134, 51], [51, 236],
129
+ [216, 206], [206, 205], [205, 216], [154, 153], [153, 22], [22, 154],
130
+ [39, 37], [37, 167], [167, 39], [200, 201], [201, 208], [208, 200],
131
+ [36, 142], [142, 100], [100, 36], [57, 212], [212, 202], [202, 57],
132
+ [20, 60], [60, 99], [99, 20], [28, 158], [158, 157], [157, 28],
133
+ [35, 226], [226, 113], [113, 35], [160, 159], [159, 27], [27, 160],
134
+ [204, 202], [202, 210], [210, 204], [113, 225], [225, 46], [46, 113],
135
+ [43, 202], [202, 204], [204, 43], [62, 76], [76, 77], [77, 62],
136
+ [137, 123], [123, 116], [116, 137], [41, 38], [38, 72], [72, 41],
137
+ [203, 129], [129, 142], [142, 203], [64, 98], [98, 240], [240, 64],
138
+ [49, 102], [102, 64], [64, 49], [41, 73], [73, 74], [74, 41],
139
+ [212, 216], [216, 207], [207, 212], [42, 74], [74, 184], [184, 42],
140
+ [169, 170], [170, 211], [211, 169], [170, 149], [149, 176], [176, 170],
141
+ [105, 66], [66, 69], [69, 105], [122, 6], [6, 168], [168, 122],
142
+ [123, 147], [147, 187], [187, 123], [96, 77], [77, 90], [90, 96],
143
+ [65, 55], [55, 107], [107, 65], [89, 90], [90, 180], [180, 89],
144
+ [101, 100], [100, 120], [120, 101], [63, 105], [105, 104], [104, 63],
145
+ [93, 137], [137, 227], [227, 93], [15, 86], [86, 85], [85, 15],
146
+ [129, 102], [102, 49], [49, 129], [14, 87], [87, 86], [86, 14],
147
+ [55, 8], [8, 9], [9, 55], [100, 47], [47, 121], [121, 100],
148
+ [145, 23], [23, 22], [22, 145], [88, 89], [89, 179], [179, 88],
149
+ [6, 122], [122, 196], [196, 6], [88, 95], [95, 96], [96, 88],
150
+ [138, 172], [172, 136], [136, 138], [215, 58], [58, 172], [172, 215],
151
+ [115, 48], [48, 219], [219, 115], [42, 80], [80, 81], [81, 42],
152
+ [195, 3], [3, 51], [51, 195], [43, 146], [146, 61], [61, 43],
153
+ [171, 175], [175, 199], [199, 171], [81, 82], [82, 38], [38, 81],
154
+ [53, 46], [46, 225], [225, 53], [144, 163], [163, 110], [110, 144],
155
+ [52, 65], [65, 66], [66, 52], [229, 228], [228, 117], [117, 229],
156
+ [34, 127], [127, 234], [234, 34], [107, 108], [108, 69], [69, 107],
157
+ [109, 108], [108, 151], [151, 109], [48, 64], [64, 235], [235, 48],
158
+ [62, 78], [78, 191], [191, 62], [129, 209], [209, 126], [126, 129],
159
+ [111, 35], [35, 143], [143, 111], [117, 123], [123, 50], [50, 117],
160
+ [222, 65], [65, 52], [52, 222], [19, 125], [125, 141], [141, 19],
161
+ [221, 55], [55, 65], [65, 221], [3, 195], [195, 197], [197, 3],
162
+ [25, 7], [7, 33], [33, 25], [220, 237], [237, 44], [44, 220],
163
+ [70, 71], [71, 139], [139, 70], [122, 193], [193, 245], [245, 122],
164
+ [247, 130], [130, 33], [33, 247], [71, 21], [21, 162], [162, 71],
165
+ [170, 169], [169, 150], [150, 170], [188, 174], [174, 196], [196, 188],
166
+ [216, 186], [186, 92], [92, 216], [2, 97], [97, 167], [167, 2],
167
+ [141, 125], [125, 241], [241, 141], [164, 167], [167, 37], [37, 164],
168
+ [72, 38], [38, 12], [12, 72], [38, 82], [82, 13], [13, 38],
169
+ [63, 68], [68, 71], [71, 63], [226, 35], [35, 111], [111, 226],
170
+ [101, 50], [50, 205], [205, 101], [206, 92], [92, 165], [165, 206],
171
+ [209, 198], [198, 217], [217, 209], [165, 167], [167, 97], [97, 165],
172
+ [220, 115], [115, 218], [218, 220], [133, 112], [112, 243], [243, 133],
173
+ [239, 238], [238, 241], [241, 239], [214, 135], [135, 169], [169, 214],
174
+ [190, 173], [173, 133], [133, 190], [171, 208], [208, 32], [32, 171],
175
+ [125, 44], [44, 237], [237, 125], [86, 87], [87, 178], [178, 86],
176
+ [85, 86], [86, 179], [179, 85], [84, 85], [85, 180], [180, 84],
177
+ [83, 84], [84, 181], [181, 83], [201, 83], [83, 182], [182, 201],
178
+ [137, 93], [93, 132], [132, 137], [76, 62], [62, 183], [183, 76],
179
+ [61, 76], [76, 184], [184, 61], [57, 61], [61, 185], [185, 57],
180
+ [212, 57], [57, 186], [186, 212], [214, 207], [207, 187], [187, 214],
181
+ [34, 143], [143, 156], [156, 34], [79, 239], [239, 237], [237, 79],
182
+ [123, 137], [137, 177], [177, 123], [44, 1], [1, 4], [4, 44],
183
+ [201, 194], [194, 32], [32, 201], [64, 102], [102, 129], [129, 64],
184
+ [213, 215], [215, 138], [138, 213], [59, 166], [166, 219], [219, 59],
185
+ [242, 99], [99, 97], [97, 242], [2, 94], [94, 141], [141, 2],
186
+ [75, 59], [59, 235], [235, 75], [24, 110], [110, 228], [228, 24],
187
+ [25, 130], [130, 226], [226, 25], [23, 24], [24, 229], [229, 23],
188
+ [22, 23], [23, 230], [230, 22], [26, 22], [22, 231], [231, 26],
189
+ [112, 26], [26, 232], [232, 112], [189, 190], [190, 243], [243, 189],
190
+ [221, 56], [56, 190], [190, 221], [28, 56], [56, 221], [221, 28],
191
+ [27, 28], [28, 222], [222, 27], [29, 27], [27, 223], [223, 29],
192
+ [30, 29], [29, 224], [224, 30], [247, 30], [30, 225], [225, 247],
193
+ [238, 79], [79, 20], [20, 238], [166, 59], [59, 75], [75, 166],
194
+ [60, 75], [75, 240], [240, 60], [147, 177], [177, 215], [215, 147],
195
+ [20, 79], [79, 166], [166, 20], [187, 147], [147, 213], [213, 187],
196
+ [112, 233], [233, 244], [244, 112], [233, 128], [128, 245], [245, 233],
197
+ [128, 114], [114, 188], [188, 128], [114, 217], [217, 174], [174, 114],
198
+ [131, 115], [115, 220], [220, 131], [217, 198], [198, 236], [236, 217],
199
+ [198, 131], [131, 134], [134, 198], [177, 132], [132, 58], [58, 177],
200
+ [143, 35], [35, 124], [124, 143], [110, 163], [163, 7], [7, 110],
201
+ [228, 110], [110, 25], [25, 228], [356, 389], [389, 368], [368, 356],
202
+ [11, 302], [302, 267], [267, 11], [452, 350], [350, 349], [349, 452],
203
+ [302, 303], [303, 269], [269, 302], [357, 343], [343, 277], [277, 357],
204
+ [452, 453], [453, 357], [357, 452], [333, 332], [332, 297], [297, 333],
205
+ [175, 152], [152, 377], [377, 175], [347, 348], [348, 330], [330, 347],
206
+ [303, 304], [304, 270], [270, 303], [9, 336], [336, 337], [337, 9],
207
+ [278, 279], [279, 360], [360, 278], [418, 262], [262, 431], [431, 418],
208
+ [304, 408], [408, 409], [409, 304], [310, 415], [415, 407], [407, 310],
209
+ [270, 409], [409, 410], [410, 270], [450, 348], [348, 347], [347, 450],
210
+ [422, 430], [430, 434], [434, 422], [313, 314], [314, 17], [17, 313],
211
+ [306, 307], [307, 375], [375, 306], [387, 388], [388, 260], [260, 387],
212
+ [286, 414], [414, 398], [398, 286], [335, 406], [406, 418], [418, 335],
213
+ [364, 367], [367, 416], [416, 364], [423, 358], [358, 327], [327, 423],
214
+ [251, 284], [284, 298], [298, 251], [281, 5], [5, 4], [4, 281],
215
+ [373, 374], [374, 253], [253, 373], [307, 320], [320, 321], [321, 307],
216
+ [425, 427], [427, 411], [411, 425], [421, 313], [313, 18], [18, 421],
217
+ [321, 405], [405, 406], [406, 321], [320, 404], [404, 405], [405, 320],
218
+ [315, 16], [16, 17], [17, 315], [426, 425], [425, 266], [266, 426],
219
+ [377, 400], [400, 369], [369, 377], [322, 391], [391, 269], [269, 322],
220
+ [417, 465], [465, 464], [464, 417], [386, 257], [257, 258], [258, 386],
221
+ [466, 260], [260, 388], [388, 466], [456, 399], [399, 419], [419, 456],
222
+ [284, 332], [332, 333], [333, 284], [417, 285], [285, 8], [8, 417],
223
+ [346, 340], [340, 261], [261, 346], [413, 441], [441, 285], [285, 413],
224
+ [327, 460], [460, 328], [328, 327], [355, 371], [371, 329], [329, 355],
225
+ [392, 439], [439, 438], [438, 392], [382, 341], [341, 256], [256, 382],
226
+ [429, 420], [420, 360], [360, 429], [364, 394], [394, 379], [379, 364],
227
+ [277, 343], [343, 437], [437, 277], [443, 444], [444, 283], [283, 443],
228
+ [275, 440], [440, 363], [363, 275], [431, 262], [262, 369], [369, 431],
229
+ [297, 338], [338, 337], [337, 297], [273, 375], [375, 321], [321, 273],
230
+ [450, 451], [451, 349], [349, 450], [446, 342], [342, 467], [467, 446],
231
+ [293, 334], [334, 282], [282, 293], [458, 461], [461, 462], [462, 458],
232
+ [276, 353], [353, 383], [383, 276], [308, 324], [324, 325], [325, 308],
233
+ [276, 300], [300, 293], [293, 276], [372, 345], [345, 447], [447, 372],
234
+ [352, 345], [345, 340], [340, 352], [274, 1], [1, 19], [19, 274],
235
+ [456, 248], [248, 281], [281, 456], [436, 427], [427, 425], [425, 436],
236
+ [381, 256], [256, 252], [252, 381], [269, 391], [391, 393], [393, 269],
237
+ [200, 199], [199, 428], [428, 200], [266, 330], [330, 329], [329, 266],
238
+ [287, 273], [273, 422], [422, 287], [250, 462], [462, 328], [328, 250],
239
+ [258, 286], [286, 384], [384, 258], [265, 353], [353, 342], [342, 265],
240
+ [387, 259], [259, 257], [257, 387], [424, 431], [431, 430], [430, 424],
241
+ [342, 353], [353, 276], [276, 342], [273, 335], [335, 424], [424, 273],
242
+ [292, 325], [325, 307], [307, 292], [366, 447], [447, 345], [345, 366],
243
+ [271, 303], [303, 302], [302, 271], [423, 266], [266, 371], [371, 423],
244
+ [294, 455], [455, 460], [460, 294], [279, 278], [278, 294], [294, 279],
245
+ [271, 272], [272, 304], [304, 271], [432, 434], [434, 427], [427, 432],
246
+ [272, 407], [407, 408], [408, 272], [394, 430], [430, 431], [431, 394],
247
+ [395, 369], [369, 400], [400, 395], [334, 333], [333, 299], [299, 334],
248
+ [351, 417], [417, 168], [168, 351], [352, 280], [280, 411], [411, 352],
249
+ [325, 319], [319, 320], [320, 325], [295, 296], [296, 336], [336, 295],
250
+ [319, 403], [403, 404], [404, 319], [330, 348], [348, 349], [349, 330],
251
+ [293, 298], [298, 333], [333, 293], [323, 454], [454, 447], [447, 323],
252
+ [15, 16], [16, 315], [315, 15], [358, 429], [429, 279], [279, 358],
253
+ [14, 15], [15, 316], [316, 14], [285, 336], [336, 9], [9, 285],
254
+ [329, 349], [349, 350], [350, 329], [374, 380], [380, 252], [252, 374],
255
+ [318, 402], [402, 403], [403, 318], [6, 197], [197, 419], [419, 6],
256
+ [318, 319], [319, 325], [325, 318], [367, 364], [364, 365], [365, 367],
257
+ [435, 367], [367, 397], [397, 435], [344, 438], [438, 439], [439, 344],
258
+ [272, 271], [271, 311], [311, 272], [195, 5], [5, 281], [281, 195],
259
+ [273, 287], [287, 291], [291, 273], [396, 428], [428, 199], [199, 396],
260
+ [311, 271], [271, 268], [268, 311], [283, 444], [444, 445], [445, 283],
261
+ [373, 254], [254, 339], [339, 373], [282, 334], [334, 296], [296, 282],
262
+ [449, 347], [347, 346], [346, 449], [264, 447], [447, 454], [454, 264],
263
+ [336, 296], [296, 299], [299, 336], [338, 10], [10, 151], [151, 338],
264
+ [278, 439], [439, 455], [455, 278], [292, 407], [407, 415], [415, 292],
265
+ [358, 371], [371, 355], [355, 358], [340, 345], [345, 372], [372, 340],
266
+ [346, 347], [347, 280], [280, 346], [442, 443], [443, 282], [282, 442],
267
+ [19, 94], [94, 370], [370, 19], [441, 442], [442, 295], [295, 441],
268
+ [248, 419], [419, 197], [197, 248], [263, 255], [255, 359], [359, 263],
269
+ [440, 275], [275, 274], [274, 440], [300, 383], [383, 368], [368, 300],
270
+ [351, 412], [412, 465], [465, 351], [263, 467], [467, 466], [466, 263],
271
+ [301, 368], [368, 389], [389, 301], [395, 378], [378, 379], [379, 395],
272
+ [412, 351], [351, 419], [419, 412], [436, 426], [426, 322], [322, 436],
273
+ [2, 164], [164, 393], [393, 2], [370, 462], [462, 461], [461, 370],
274
+ [164, 0], [0, 267], [267, 164], [302, 11], [11, 12], [12, 302],
275
+ [268, 12], [12, 13], [13, 268], [293, 300], [300, 301], [301, 293],
276
+ [446, 261], [261, 340], [340, 446], [330, 266], [266, 425], [425, 330],
277
+ [426, 423], [423, 391], [391, 426], [429, 355], [355, 437], [437, 429],
278
+ [391, 327], [327, 326], [326, 391], [440, 457], [457, 438], [438, 440],
279
+ [341, 382], [382, 362], [362, 341], [459, 457], [457, 461], [461, 459],
280
+ [434, 430], [430, 394], [394, 434], [414, 463], [463, 362], [362, 414],
281
+ [396, 369], [369, 262], [262, 396], [354, 461], [461, 457], [457, 354],
282
+ [316, 403], [403, 402], [402, 316], [315, 404], [404, 403], [403, 315],
283
+ [314, 405], [405, 404], [404, 314], [313, 406], [406, 405], [405, 313],
284
+ [421, 418], [418, 406], [406, 421], [366, 401], [401, 361], [361, 366],
285
+ [306, 408], [408, 407], [407, 306], [291, 409], [409, 408], [408, 291],
286
+ [287, 410], [410, 409], [409, 287], [432, 436], [436, 410], [410, 432],
287
+ [434, 416], [416, 411], [411, 434], [264, 368], [368, 383], [383, 264],
288
+ [309, 438], [438, 457], [457, 309], [352, 376], [376, 401], [401, 352],
289
+ [274, 275], [275, 4], [4, 274], [421, 428], [428, 262], [262, 421],
290
+ [294, 327], [327, 358], [358, 294], [433, 416], [416, 367], [367, 433],
291
+ [289, 455], [455, 439], [439, 289], [462, 370], [370, 326], [326, 462],
292
+ [2, 326], [326, 370], [370, 2], [305, 460], [460, 455], [455, 305],
293
+ [254, 449], [449, 448], [448, 254], [255, 261], [261, 446], [446, 255],
294
+ [253, 450], [450, 449], [449, 253], [252, 451], [451, 450], [450, 252],
295
+ [256, 452], [452, 451], [451, 256], [341, 453], [453, 452], [452, 341],
296
+ [413, 464], [464, 463], [463, 413], [441, 413], [413, 414], [414, 441],
297
+ [258, 442], [442, 441], [441, 258], [257, 443], [443, 442], [442, 257],
298
+ [259, 444], [444, 443], [443, 259], [260, 445], [445, 444], [444, 260],
299
+ [467, 342], [342, 445], [445, 467], [459, 458], [458, 250], [250, 459],
300
+ [289, 392], [392, 290], [290, 289], [290, 328], [328, 460], [460, 290],
301
+ [376, 433], [433, 435], [435, 376], [250, 290], [290, 392], [392, 250],
302
+ [411, 416], [416, 433], [433, 411], [341, 463], [463, 464], [464, 341],
303
+ [453, 464], [464, 465], [465, 453], [357, 465], [465, 412], [412, 357],
304
+ [343, 412], [412, 399], [399, 343], [360, 363], [363, 440], [440, 360],
305
+ [437, 399], [399, 456], [456, 437], [420, 456], [456, 363], [363, 420],
306
+ [401, 435], [435, 288], [288, 401], [372, 383], [383, 353], [353, 372],
307
+ [339, 255], [255, 249], [249, 339], [448, 261], [261, 255], [255, 448],
308
+ [133, 243], [243, 190], [190, 133], [133, 155], [155, 112], [112, 133],
309
+ [33, 246], [246, 247], [247, 33], [33, 130], [130, 25], [25, 33],
310
+ [398, 384], [384, 286], [286, 398], [362, 398], [398, 414], [414, 362],
311
+ [362, 463], [463, 341], [341, 362], [263, 359], [359, 467], [467, 263],
312
+ [263, 249], [249, 255], [255, 263], [466, 467], [467, 260], [260, 466],
313
+ [75, 60], [60, 166], [166, 75], [238, 239], [239, 79], [79, 238],
314
+ [162, 127], [127, 139], [139, 162], [72, 11], [11, 37], [37, 72],
315
+ [121, 232], [232, 120], [120, 121], [73, 72], [72, 39], [39, 73],
316
+ [114, 128], [128, 47], [47, 114], [233, 232], [232, 128], [128, 233],
317
+ [103, 104], [104, 67], [67, 103], [152, 175], [175, 148], [148, 152],
318
+ [119, 118], [118, 101], [101, 119], [74, 73], [73, 40], [40, 74],
319
+ [107, 9], [9, 108], [108, 107], [49, 48], [48, 131], [131, 49],
320
+ [32, 194], [194, 211], [211, 32], [184, 74], [74, 185], [185, 184],
321
+ [191, 80], [80, 183], [183, 191], [185, 40], [40, 186], [186, 185],
322
+ [119, 230], [230, 118], [118, 119], [210, 202], [202, 214], [214, 210],
323
+ [84, 83], [83, 17], [17, 84], [77, 76], [76, 146], [146, 77],
324
+ [161, 160], [160, 30], [30, 161], [190, 56], [56, 173], [173, 190],
325
+ [182, 106], [106, 194], [194, 182], [138, 135], [135, 192], [192, 138],
326
+ [129, 203], [203, 98], [98, 129], [54, 21], [21, 68], [68, 54],
327
+ [5, 51], [51, 4], [4, 5], [145, 144], [144, 23], [23, 145],
328
+ [90, 77], [77, 91], [91, 90], [207, 205], [205, 187], [187, 207],
329
+ [83, 201], [201, 18], [18, 83], [181, 91], [91, 182], [182, 181],
330
+ [180, 90], [90, 181], [181, 180], [16, 85], [85, 17], [17, 16],
331
+ [205, 206], [206, 36], [36, 205], [176, 148], [148, 140], [140, 176],
332
+ [165, 92], [92, 39], [39, 165], [245, 193], [193, 244], [244, 245],
333
+ [27, 159], [159, 28], [28, 27], [30, 247], [247, 161], [161, 30],
334
+ [174, 236], [236, 196], [196, 174], [103, 54], [54, 104], [104, 103],
335
+ [55, 193], [193, 8], [8, 55], [111, 117], [117, 31], [31, 111],
336
+ [221, 189], [189, 55], [55, 221], [240, 98], [98, 99], [99, 240],
337
+ [142, 126], [126, 100], [100, 142], [219, 166], [166, 218], [218, 219],
338
+ [112, 155], [155, 26], [26, 112], [198, 209], [209, 131], [131, 198],
339
+ [169, 135], [135, 150], [150, 169], [114, 47], [47, 217], [217, 114],
340
+ [224, 223], [223, 53], [53, 224], [220, 45], [45, 134], [134, 220],
341
+ [32, 211], [211, 140], [140, 32], [109, 67], [67, 108], [108, 109],
342
+ [146, 43], [43, 91], [91, 146], [231, 230], [230, 120], [120, 231],
343
+ [113, 226], [226, 247], [247, 113], [105, 63], [63, 52], [52, 105],
344
+ [241, 238], [238, 242], [242, 241], [124, 46], [46, 156], [156, 124],
345
+ [95, 78], [78, 96], [96, 95], [70, 46], [46, 63], [63, 70],
346
+ [116, 143], [143, 227], [227, 116], [116, 123], [123, 111], [111, 116],
347
+ [1, 44], [44, 19], [19, 1], [3, 236], [236, 51], [51, 3],
348
+ [207, 216], [216, 205], [205, 207], [26, 154], [154, 22], [22, 26],
349
+ [165, 39], [39, 167], [167, 165], [199, 200], [200, 208], [208, 199],
350
+ [101, 36], [36, 100], [100, 101], [43, 57], [57, 202], [202, 43],
351
+ [242, 20], [20, 99], [99, 242], [56, 28], [28, 157], [157, 56],
352
+ [124, 35], [35, 113], [113, 124], [29, 160], [160, 27], [27, 29],
353
+ [211, 204], [204, 210], [210, 211], [124, 113], [113, 46], [46, 124],
354
+ [106, 43], [43, 204], [204, 106], [96, 62], [62, 77], [77, 96],
355
+ [227, 137], [137, 116], [116, 227], [73, 41], [41, 72], [72, 73],
356
+ [36, 203], [203, 142], [142, 36], [235, 64], [64, 240], [240, 235],
357
+ [48, 49], [49, 64], [64, 48], [42, 41], [41, 74], [74, 42],
358
+ [214, 212], [212, 207], [207, 214], [183, 42], [42, 184], [184, 183],
359
+ [210, 169], [169, 211], [211, 210], [140, 170], [170, 176], [176, 140],
360
+ [104, 105], [105, 69], [69, 104], [193, 122], [122, 168], [168, 193],
361
+ [50, 123], [123, 187], [187, 50], [89, 96], [96, 90], [90, 89],
362
+ [66, 65], [65, 107], [107, 66], [179, 89], [89, 180], [180, 179],
363
+ [119, 101], [101, 120], [120, 119], [68, 63], [63, 104], [104, 68],
364
+ [234, 93], [93, 227], [227, 234], [16, 15], [15, 85], [85, 16],
365
+ [209, 129], [129, 49], [49, 209], [15, 14], [14, 86], [86, 15],
366
+ [107, 55], [55, 9], [9, 107], [120, 100], [100, 121], [121, 120],
367
+ [153, 145], [145, 22], [22, 153], [178, 88], [88, 179], [179, 178],
368
+ [197, 6], [6, 196], [196, 197], [89, 88], [88, 96], [96, 89],
369
+ [135, 138], [138, 136], [136, 135], [138, 215], [215, 172], [172, 138],
370
+ [218, 115], [115, 219], [219, 218], [41, 42], [42, 81], [81, 41],
371
+ [5, 195], [195, 51], [51, 5], [57, 43], [43, 61], [61, 57],
372
+ [208, 171], [171, 199], [199, 208], [41, 81], [81, 38], [38, 41],
373
+ [224, 53], [53, 225], [225, 224], [24, 144], [144, 110], [110, 24],
374
+ [105, 52], [52, 66], [66, 105], [118, 229], [229, 117], [117, 118],
375
+ [227, 34], [34, 234], [234, 227], [66, 107], [107, 69], [69, 66],
376
+ [10, 109], [109, 151], [151, 10], [219, 48], [48, 235], [235, 219],
377
+ [183, 62], [62, 191], [191, 183], [142, 129], [129, 126], [126, 142],
378
+ [116, 111], [111, 143], [143, 116], [118, 117], [117, 50], [50, 118],
379
+ [223, 222], [222, 52], [52, 223], [94, 19], [19, 141], [141, 94],
380
+ [222, 221], [221, 65], [65, 222], [196, 3], [3, 197], [197, 196],
381
+ [45, 220], [220, 44], [44, 45], [156, 70], [70, 139], [139, 156],
382
+ [188, 122], [122, 245], [245, 188], [139, 71], [71, 162], [162, 139],
383
+ [149, 170], [170, 150], [150, 149], [122, 188], [188, 196], [196, 122],
384
+ [206, 216], [216, 92], [92, 206], [164, 2], [2, 167], [167, 164],
385
+ [242, 141], [141, 241], [241, 242], [0, 164], [164, 37], [37, 0],
386
+ [11, 72], [72, 12], [12, 11], [12, 38], [38, 13], [13, 12],
387
+ [70, 63], [63, 71], [71, 70], [31, 226], [226, 111], [111, 31],
388
+ [36, 101], [101, 205], [205, 36], [203, 206], [206, 165], [165, 203],
389
+ [126, 209], [209, 217], [217, 126], [98, 165], [165, 97], [97, 98],
390
+ [237, 220], [220, 218], [218, 237], [237, 239], [239, 241], [241, 237],
391
+ [210, 214], [214, 169], [169, 210], [140, 171], [171, 32], [32, 140],
392
+ [241, 125], [125, 237], [237, 241], [179, 86], [86, 178], [178, 179],
393
+ [180, 85], [85, 179], [179, 180], [181, 84], [84, 180], [180, 181],
394
+ [182, 83], [83, 181], [181, 182], [194, 201], [201, 182], [182, 194],
395
+ [177, 137], [137, 132], [132, 177], [184, 76], [76, 183], [183, 184],
396
+ [185, 61], [61, 184], [184, 185], [186, 57], [57, 185], [185, 186],
397
+ [216, 212], [212, 186], [186, 216], [192, 214], [214, 187], [187, 192],
398
+ [139, 34], [34, 156], [156, 139], [218, 79], [79, 237], [237, 218],
399
+ [147, 123], [123, 177], [177, 147], [45, 44], [44, 4], [4, 45],
400
+ [208, 201], [201, 32], [32, 208], [98, 64], [64, 129], [129, 98],
401
+ [192, 213], [213, 138], [138, 192], [235, 59], [59, 219], [219, 235],
402
+ [141, 242], [242, 97], [97, 141], [97, 2], [2, 141], [141, 97],
403
+ [240, 75], [75, 235], [235, 240], [229, 24], [24, 228], [228, 229],
404
+ [31, 25], [25, 226], [226, 31], [230, 23], [23, 229], [229, 230],
405
+ [231, 22], [22, 230], [230, 231], [232, 26], [26, 231], [231, 232],
406
+ [233, 112], [112, 232], [232, 233], [244, 189], [189, 243], [243, 244],
407
+ [189, 221], [221, 190], [190, 189], [222, 28], [28, 221], [221, 222],
408
+ [223, 27], [27, 222], [222, 223], [224, 29], [29, 223], [223, 224],
409
+ [225, 30], [30, 224], [224, 225], [113, 247], [247, 225], [225, 113],
410
+ [99, 60], [60, 240], [240, 99], [213, 147], [147, 215], [215, 213],
411
+ [60, 20], [20, 166], [166, 60], [192, 187], [187, 213], [213, 192],
412
+ [243, 112], [112, 244], [244, 243], [244, 233], [233, 245], [245, 244],
413
+ [245, 128], [128, 188], [188, 245], [188, 114], [114, 174], [174, 188],
414
+ [134, 131], [131, 220], [220, 134], [174, 217], [217, 236], [236, 174],
415
+ [236, 198], [198, 134], [134, 236], [215, 177], [177, 58], [58, 215],
416
+ [156, 143], [143, 124], [124, 156], [25, 110], [110, 7], [7, 25],
417
+ [31, 228], [228, 25], [25, 31], [264, 356], [356, 368], [368, 264],
418
+ [0, 11], [11, 267], [267, 0], [451, 452], [452, 349], [349, 451],
419
+ [267, 302], [302, 269], [269, 267], [350, 357], [357, 277], [277, 350],
420
+ [350, 452], [452, 357], [357, 350], [299, 333], [333, 297], [297, 299],
421
+ [396, 175], [175, 377], [377, 396], [280, 347], [347, 330], [330, 280],
422
+ [269, 303], [303, 270], [270, 269], [151, 9], [9, 337], [337, 151],
423
+ [344, 278], [278, 360], [360, 344], [424, 418], [418, 431], [431, 424],
424
+ [270, 304], [304, 409], [409, 270], [272, 310], [310, 407], [407, 272],
425
+ [322, 270], [270, 410], [410, 322], [449, 450], [450, 347], [347, 449],
426
+ [432, 422], [422, 434], [434, 432], [18, 313], [313, 17], [17, 18],
427
+ [291, 306], [306, 375], [375, 291], [259, 387], [387, 260], [260, 259],
428
+ [424, 335], [335, 418], [418, 424], [434, 364], [364, 416], [416, 434],
429
+ [391, 423], [423, 327], [327, 391], [301, 251], [251, 298], [298, 301],
430
+ [275, 281], [281, 4], [4, 275], [254, 373], [373, 253], [253, 254],
431
+ [375, 307], [307, 321], [321, 375], [280, 425], [425, 411], [411, 280],
432
+ [200, 421], [421, 18], [18, 200], [335, 321], [321, 406], [406, 335],
433
+ [321, 320], [320, 405], [405, 321], [314, 315], [315, 17], [17, 314],
434
+ [423, 426], [426, 266], [266, 423], [396, 377], [377, 369], [369, 396],
435
+ [270, 322], [322, 269], [269, 270], [413, 417], [417, 464], [464, 413],
436
+ [385, 386], [386, 258], [258, 385], [248, 456], [456, 419], [419, 248],
437
+ [298, 284], [284, 333], [333, 298], [168, 417], [417, 8], [8, 168],
438
+ [448, 346], [346, 261], [261, 448], [417, 413], [413, 285], [285, 417],
439
+ [326, 327], [327, 328], [328, 326], [277, 355], [355, 329], [329, 277],
440
+ [309, 392], [392, 438], [438, 309], [381, 382], [382, 256], [256, 381],
441
+ [279, 429], [429, 360], [360, 279], [365, 364], [364, 379], [379, 365],
442
+ [355, 277], [277, 437], [437, 355], [282, 443], [443, 283], [283, 282],
443
+ [281, 275], [275, 363], [363, 281], [395, 431], [431, 369], [369, 395],
444
+ [299, 297], [297, 337], [337, 299], [335, 273], [273, 321], [321, 335],
445
+ [348, 450], [450, 349], [349, 348], [359, 446], [446, 467], [467, 359],
446
+ [283, 293], [293, 282], [282, 283], [250, 458], [458, 462], [462, 250],
447
+ [300, 276], [276, 383], [383, 300], [292, 308], [308, 325], [325, 292],
448
+ [283, 276], [276, 293], [293, 283], [264, 372], [372, 447], [447, 264],
449
+ [346, 352], [352, 340], [340, 346], [354, 274], [274, 19], [19, 354],
450
+ [363, 456], [456, 281], [281, 363], [426, 436], [436, 425], [425, 426],
451
+ [380, 381], [381, 252], [252, 380], [267, 269], [269, 393], [393, 267],
452
+ [421, 200], [200, 428], [428, 421], [371, 266], [266, 329], [329, 371],
453
+ [432, 287], [287, 422], [422, 432], [290, 250], [250, 328], [328, 290],
454
+ [385, 258], [258, 384], [384, 385], [446, 265], [265, 342], [342, 446],
455
+ [386, 387], [387, 257], [257, 386], [422, 424], [424, 430], [430, 422],
456
+ [445, 342], [342, 276], [276, 445], [422, 273], [273, 424], [424, 422],
457
+ [306, 292], [292, 307], [307, 306], [352, 366], [366, 345], [345, 352],
458
+ [268, 271], [271, 302], [302, 268], [358, 423], [423, 371], [371, 358],
459
+ [327, 294], [294, 460], [460, 327], [331, 279], [279, 294], [294, 331],
460
+ [303, 271], [271, 304], [304, 303], [436, 432], [432, 427], [427, 436],
461
+ [304, 272], [272, 408], [408, 304], [395, 394], [394, 431], [431, 395],
462
+ [378, 395], [395, 400], [400, 378], [296, 334], [334, 299], [299, 296],
463
+ [6, 351], [351, 168], [168, 6], [376, 352], [352, 411], [411, 376],
464
+ [307, 325], [325, 320], [320, 307], [285, 295], [295, 336], [336, 285],
465
+ [320, 319], [319, 404], [404, 320], [329, 330], [330, 349], [349, 329],
466
+ [334, 293], [293, 333], [333, 334], [366, 323], [323, 447], [447, 366],
467
+ [316, 15], [15, 315], [315, 316], [331, 358], [358, 279], [279, 331],
468
+ [317, 14], [14, 316], [316, 317], [8, 285], [285, 9], [9, 8],
469
+ [277, 329], [329, 350], [350, 277], [253, 374], [374, 252], [252, 253],
470
+ [319, 318], [318, 403], [403, 319], [351, 6], [6, 419], [419, 351],
471
+ [324, 318], [318, 325], [325, 324], [397, 367], [367, 365], [365, 397],
472
+ [288, 435], [435, 397], [397, 288], [278, 344], [344, 439], [439, 278],
473
+ [310, 272], [272, 311], [311, 310], [248, 195], [195, 281], [281, 248],
474
+ [375, 273], [273, 291], [291, 375], [175, 396], [396, 199], [199, 175],
475
+ [312, 311], [311, 268], [268, 312], [276, 283], [283, 445], [445, 276],
476
+ [390, 373], [373, 339], [339, 390], [295, 282], [282, 296], [296, 295],
477
+ [448, 449], [449, 346], [346, 448], [356, 264], [264, 454], [454, 356],
478
+ [337, 336], [336, 299], [299, 337], [337, 338], [338, 151], [151, 337],
479
+ [294, 278], [278, 455], [455, 294], [308, 292], [292, 415], [415, 308],
480
+ [429, 358], [358, 355], [355, 429], [265, 340], [340, 372], [372, 265],
481
+ [352, 346], [346, 280], [280, 352], [295, 442], [442, 282], [282, 295],
482
+ [354, 19], [19, 370], [370, 354], [285, 441], [441, 295], [295, 285],
483
+ [195, 248], [248, 197], [197, 195], [457, 440], [440, 274], [274, 457],
484
+ [301, 300], [300, 368], [368, 301], [417, 351], [351, 465], [465, 417],
485
+ [251, 301], [301, 389], [389, 251], [394, 395], [395, 379], [379, 394],
486
+ [399, 412], [412, 419], [419, 399], [410, 436], [436, 322], [322, 410],
487
+ [326, 2], [2, 393], [393, 326], [354, 370], [370, 461], [461, 354],
488
+ [393, 164], [164, 267], [267, 393], [268, 302], [302, 12], [12, 268],
489
+ [312, 268], [268, 13], [13, 312], [298, 293], [293, 301], [301, 298],
490
+ [265, 446], [446, 340], [340, 265], [280, 330], [330, 425], [425, 280],
491
+ [322, 426], [426, 391], [391, 322], [420, 429], [429, 437], [437, 420],
492
+ [393, 391], [391, 326], [326, 393], [344, 440], [440, 438], [438, 344],
493
+ [458, 459], [459, 461], [461, 458], [364, 434], [434, 394], [394, 364],
494
+ [428, 396], [396, 262], [262, 428], [274, 354], [354, 457], [457, 274],
495
+ [317, 316], [316, 402], [402, 317], [316, 315], [315, 403], [403, 316],
496
+ [315, 314], [314, 404], [404, 315], [314, 313], [313, 405], [405, 314],
497
+ [313, 421], [421, 406], [406, 313], [323, 366], [366, 361], [361, 323],
498
+ [292, 306], [306, 407], [407, 292], [306, 291], [291, 408], [408, 306],
499
+ [291, 287], [287, 409], [409, 291], [287, 432], [432, 410], [410, 287],
500
+ [427, 434], [434, 411], [411, 427], [372, 264], [264, 383], [383, 372],
501
+ [459, 309], [309, 457], [457, 459], [366, 352], [352, 401], [401, 366],
502
+ [1, 274], [274, 4], [4, 1], [418, 421], [421, 262], [262, 418],
503
+ [331, 294], [294, 358], [358, 331], [435, 433], [433, 367], [367, 435],
504
+ [392, 289], [289, 439], [439, 392], [328, 462], [462, 326], [326, 328],
505
+ [94, 2], [2, 370], [370, 94], [289, 305], [305, 455], [455, 289],
506
+ [339, 254], [254, 448], [448, 339], [359, 255], [255, 446], [446, 359],
507
+ [254, 253], [253, 449], [449, 254], [253, 252], [252, 450], [450, 253],
508
+ [252, 256], [256, 451], [451, 252], [256, 341], [341, 452], [452, 256],
509
+ [414, 413], [413, 463], [463, 414], [286, 441], [441, 414], [414, 286],
510
+ [286, 258], [258, 441], [441, 286], [258, 257], [257, 442], [442, 258],
511
+ [257, 259], [259, 443], [443, 257], [259, 260], [260, 444], [444, 259],
512
+ [260, 467], [467, 445], [445, 260], [309, 459], [459, 250], [250, 309],
513
+ [305, 289], [289, 290], [290, 305], [305, 290], [290, 460], [460, 305],
514
+ [401, 376], [376, 435], [435, 401], [309, 250], [250, 392], [392, 309],
515
+ [376, 411], [411, 433], [433, 376], [453, 341], [341, 464], [464, 453],
516
+ [357, 453], [453, 465], [465, 357], [343, 357], [357, 412], [412, 343],
517
+ [437, 343], [343, 399], [399, 437], [344, 360], [360, 440], [440, 344],
518
+ [420, 437], [437, 456], [456, 420], [360, 420], [420, 363], [363, 360],
519
+ [361, 401], [401, 288], [288, 361], [265, 372], [372, 353], [353, 265],
520
+ [390, 339], [339, 249], [249, 390], [339, 448], [448, 255], [255, 339]]))
client/src/hooks/useFaceLandmarkDetection.tsx ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect, useRef, useCallback } from 'react';
2
+ import * as vision from '@mediapipe/tasks-vision';
3
+
4
+ import { facePoke } from '@/lib/facePoke';
5
+ import { useMainStore } from './useMainStore';
6
+ import useThrottledCallback from 'beautiful-react-hooks/useThrottledCallback';
7
+
8
+ import { landmarkGroups, FACEMESH_LIPS, FACEMESH_LEFT_EYE, FACEMESH_LEFT_EYEBROW, FACEMESH_RIGHT_EYE, FACEMESH_RIGHT_EYEBROW, FACEMESH_FACE_OVAL } from './landmarks';
9
+
10
+ // New types for improved type safety
11
+ export type LandmarkGroup = 'lips' | 'leftEye' | 'leftEyebrow' | 'rightEye' | 'rightEyebrow' | 'faceOval' | 'background';
12
+ export type LandmarkCenter = { x: number; y: number; z: number };
13
+ export type ClosestLandmark = { group: LandmarkGroup; distance: number; vector: { x: number; y: number; z: number } };
14
+
15
+ export type MediaPipeResources = {
16
+ faceLandmarker: vision.FaceLandmarker | null;
17
+ drawingUtils: vision.DrawingUtils | null;
18
+ };
19
+
20
+ export function useFaceLandmarkDetection() {
21
+ const error = useMainStore(s => s.error);
22
+ const setError = useMainStore(s => s.setError);
23
+ const imageFile = useMainStore(s => s.imageFile);
24
+ const setImageFile = useMainStore(s => s.setImageFile);
25
+ const originalImage = useMainStore(s => s.originalImage);
26
+ const originalImageHash = useMainStore(s => s.originalImageHash);
27
+ const setOriginalImageHash = useMainStore(s => s.setOriginalImageHash);
28
+ const previewImage = useMainStore(s => s.previewImage);
29
+ const setPreviewImage = useMainStore(s => s.setPreviewImage);
30
+ const resetImage = useMainStore(s => s.resetImage);
31
+
32
+ ;(window as any).debugJuju = useMainStore;
33
+ ////////////////////////////////////////////////////////////////////////
34
+ // ok so apparently I cannot vary the latency, or else there is a bug
35
+ // const averageLatency = useMainStore(s => s.averageLatency);
36
+ const averageLatency = 220
37
+ ////////////////////////////////////////////////////////////////////////
38
+
39
+ // State for face detection
40
+ const [faceLandmarks, setFaceLandmarks] = useState<vision.NormalizedLandmark[][]>([]);
41
+ const [isMediaPipeReady, setIsMediaPipeReady] = useState(false);
42
+ const [isDrawingUtilsReady, setIsDrawingUtilsReady] = useState(false);
43
+ const [blendShapes, setBlendShapes] = useState<vision.Classifications[]>([]);
44
+
45
+ // State for mouse interaction
46
+ const [dragStart, setDragStart] = useState<{ x: number; y: number } | null>(null);
47
+ const [dragEnd, setDragEnd] = useState<{ x: number; y: number } | null>(null);
48
+
49
+ const [isDragging, setIsDragging] = useState(false);
50
+ const [isWaitingForResponse, setIsWaitingForResponse] = useState(false);
51
+ const dragStartRef = useRef<{ x: number; y: number } | null>(null);
52
+ const currentMousePosRef = useRef<{ x: number; y: number } | null>(null);
53
+ const lastModifiedImageHashRef = useRef<string | null>(null);
54
+
55
+ const [currentLandmark, setCurrentLandmark] = useState<ClosestLandmark | null>(null);
56
+ const [previousLandmark, setPreviousLandmark] = useState<ClosestLandmark | null>(null);
57
+ const [currentOpacity, setCurrentOpacity] = useState(0);
58
+ const [previousOpacity, setPreviousOpacity] = useState(0);
59
+
60
+ const [isHovering, setIsHovering] = useState(false);
61
+
62
+ // Refs
63
+ const canvasRef = useRef<HTMLCanvasElement>(null);
64
+ const mediaPipeRef = useRef<MediaPipeResources>({
65
+ faceLandmarker: null,
66
+ drawingUtils: null,
67
+ });
68
+
69
+ const setActiveLandmark = useCallback((newLandmark: ClosestLandmark | undefined) => {
70
+ //if (newLandmark && (!currentLandmark || newLandmark.group !== currentLandmark.group)) {
71
+ setPreviousLandmark(currentLandmark || null);
72
+ setCurrentLandmark(newLandmark || null);
73
+ setCurrentOpacity(0);
74
+ setPreviousOpacity(1);
75
+ //}
76
+ }, [currentLandmark, setPreviousLandmark, setCurrentLandmark, setCurrentOpacity, setPreviousOpacity]);
77
+
78
+ // Initialize MediaPipe
79
+ useEffect(() => {
80
+ console.log('Initializing MediaPipe...');
81
+ let isMounted = true;
82
+
83
+ const initializeMediaPipe = async () => {
84
+ const { FaceLandmarker, FilesetResolver, DrawingUtils } = vision;
85
+
86
+ try {
87
+ console.log('Initializing FilesetResolver...');
88
+ const filesetResolver = await FilesetResolver.forVisionTasks(
89
+ "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.3/wasm"
90
+ );
91
+
92
+ console.log('Creating FaceLandmarker...');
93
+ const faceLandmarker = await FaceLandmarker.createFromOptions(filesetResolver, {
94
+ baseOptions: {
95
+ modelAssetPath: `https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task`,
96
+ delegate: "GPU"
97
+ },
98
+ outputFaceBlendshapes: true,
99
+ runningMode: "IMAGE",
100
+ numFaces: 1
101
+ });
102
+
103
+ if (isMounted) {
104
+ console.log('FaceLandmarker created successfully.');
105
+ mediaPipeRef.current.faceLandmarker = faceLandmarker;
106
+ setIsMediaPipeReady(true);
107
+ } else {
108
+ faceLandmarker.close();
109
+ }
110
+ } catch (error) {
111
+ console.error('Error during MediaPipe initialization:', error);
112
+ setError('Failed to initialize face detection. Please try refreshing the page.');
113
+ }
114
+ };
115
+
116
+ initializeMediaPipe();
117
+
118
+
119
+ return () => {
120
+ isMounted = false;
121
+ if (mediaPipeRef.current.faceLandmarker) {
122
+ mediaPipeRef.current.faceLandmarker.close();
123
+ }
124
+ };
125
+ }, []);
126
+
127
+ // New state for storing landmark centers
128
+ const [landmarkCenters, setLandmarkCenters] = useState<Record<LandmarkGroup, LandmarkCenter>>({} as Record<LandmarkGroup, LandmarkCenter>);
129
+
130
+ // Function to compute the center of each landmark group
131
+ const computeLandmarkCenters = useCallback((landmarks: vision.NormalizedLandmark[]) => {
132
+ const centers: Record<LandmarkGroup, LandmarkCenter> = {} as Record<LandmarkGroup, LandmarkCenter>;
133
+
134
+ const computeGroupCenter = (group: Readonly<Set<number[]>>): LandmarkCenter => {
135
+ let sumX = 0, sumY = 0, sumZ = 0, count = 0;
136
+ group.forEach(([index]) => {
137
+ if (landmarks[index]) {
138
+ sumX += landmarks[index].x;
139
+ sumY += landmarks[index].y;
140
+ sumZ += landmarks[index].z || 0;
141
+ count++;
142
+ }
143
+ });
144
+ return { x: sumX / count, y: sumY / count, z: sumZ / count };
145
+ };
146
+
147
+ centers.lips = computeGroupCenter(FACEMESH_LIPS);
148
+ centers.leftEye = computeGroupCenter(FACEMESH_LEFT_EYE);
149
+ centers.leftEyebrow = computeGroupCenter(FACEMESH_LEFT_EYEBROW);
150
+ centers.rightEye = computeGroupCenter(FACEMESH_RIGHT_EYE);
151
+ centers.rightEyebrow = computeGroupCenter(FACEMESH_RIGHT_EYEBROW);
152
+ centers.faceOval = computeGroupCenter(FACEMESH_FACE_OVAL);
153
+ centers.background = { x: 0.5, y: 0.5, z: 0 };
154
+
155
+ setLandmarkCenters(centers);
156
+ // console.log('Landmark centers computed:', centers);
157
+ }, []);
158
+
159
+ // Function to find the closest landmark to the mouse position
160
+ const findClosestLandmark = useCallback((mouseX: number, mouseY: number, isGroup?: LandmarkGroup): ClosestLandmark => {
161
+ const defaultLandmark: ClosestLandmark = {
162
+ group: 'background',
163
+ distance: 0,
164
+ vector: {
165
+ x: mouseX,
166
+ y: mouseY,
167
+ z: 0
168
+ }
169
+ }
170
+
171
+ if (Object.keys(landmarkCenters).length === 0) {
172
+ console.warn('Landmark centers not computed yet');
173
+ return defaultLandmark;
174
+ }
175
+
176
+ let closestGroup: LandmarkGroup | null = null;
177
+ let minDistance = Infinity;
178
+ let closestVector = { x: 0, y: 0, z: 0 };
179
+ let faceOvalDistance = Infinity;
180
+ let faceOvalVector = { x: 0, y: 0, z: 0 };
181
+
182
+ Object.entries(landmarkCenters).forEach(([group, center]) => {
183
+ const dx = mouseX - center.x;
184
+ const dy = mouseY - center.y;
185
+ const distance = Math.sqrt(dx * dx + dy * dy);
186
+
187
+ if (group === 'faceOval') {
188
+ faceOvalDistance = distance;
189
+ faceOvalVector = { x: dx, y: dy, z: 0 };
190
+ }
191
+
192
+ // filter to keep the group if it is belonging to `ofGroup`
193
+ if (isGroup) {
194
+ if (group !== isGroup) {
195
+ return
196
+ }
197
+ }
198
+
199
+ if (distance < minDistance) {
200
+ minDistance = distance;
201
+ closestGroup = group as LandmarkGroup;
202
+ closestVector = { x: dx, y: dy, z: 0 }; // Z is 0 as mouse interaction is 2D
203
+ }
204
+ });
205
+
206
+ // Fallback to faceOval if no group found or distance is too large
207
+ if (minDistance > 0.05) {
208
+ // console.log('Distance is too high, so we use the faceOval group');
209
+ closestGroup = 'background';
210
+ minDistance = faceOvalDistance;
211
+ closestVector = faceOvalVector;
212
+ }
213
+
214
+ if (closestGroup) {
215
+ // console.log(`Closest landmark: ${closestGroup}, distance: ${minDistance.toFixed(4)}`);
216
+ return { group: closestGroup, distance: minDistance, vector: closestVector };
217
+ } else {
218
+ // console.log('No group found, returning fallback');
219
+ return defaultLandmark
220
+ }
221
+ }, [landmarkCenters]);
222
+
223
+ // Detect face landmarks
224
+ const detectFaceLandmarks = useCallback(async (imageDataUrl: string) => {
225
+ // console.log('Attempting to detect face landmarks...');
226
+ if (!isMediaPipeReady) {
227
+ console.log('MediaPipe not ready. Skipping detection.');
228
+ return;
229
+ }
230
+
231
+ const faceLandmarker = mediaPipeRef.current.faceLandmarker;
232
+
233
+ if (!faceLandmarker) {
234
+ console.error('FaceLandmarker is not initialized.');
235
+ return;
236
+ }
237
+
238
+ const drawingUtils = mediaPipeRef.current.drawingUtils;
239
+
240
+ const image = new Image();
241
+ image.src = imageDataUrl;
242
+ await new Promise((resolve) => { image.onload = resolve; });
243
+
244
+ const faceLandmarkerResult = faceLandmarker.detect(image);
245
+ // console.log("Face landmarks detected:", faceLandmarkerResult);
246
+
247
+ setFaceLandmarks(faceLandmarkerResult.faceLandmarks);
248
+ setBlendShapes(faceLandmarkerResult.faceBlendshapes || []);
249
+
250
+ if (faceLandmarkerResult.faceLandmarks && faceLandmarkerResult.faceLandmarks[0]) {
251
+ computeLandmarkCenters(faceLandmarkerResult.faceLandmarks[0]);
252
+ }
253
+
254
+ if (canvasRef.current && drawingUtils) {
255
+ drawLandmarks(faceLandmarkerResult.faceLandmarks[0], canvasRef.current, drawingUtils);
256
+ }
257
+ }, [isMediaPipeReady, isDrawingUtilsReady, computeLandmarkCenters]);
258
+
259
+ const drawLandmarks = useCallback((
260
+ landmarks: vision.NormalizedLandmark[],
261
+ canvas: HTMLCanvasElement,
262
+ drawingUtils: vision.DrawingUtils
263
+ ) => {
264
+ const ctx = canvas.getContext('2d');
265
+ if (!ctx) return;
266
+
267
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
268
+
269
+ if (canvasRef.current && previewImage) {
270
+ const img = new Image();
271
+ img.onload = () => {
272
+ canvas.width = img.width;
273
+ canvas.height = img.height;
274
+
275
+ const drawLandmarkGroup = (landmark: ClosestLandmark | null, opacity: number) => {
276
+ if (!landmark) return;
277
+ const connections = landmarkGroups[landmark.group];
278
+ if (connections) {
279
+ ctx.globalAlpha = opacity;
280
+ drawingUtils.drawConnectors(
281
+ landmarks,
282
+ connections,
283
+ { color: 'orange', lineWidth: 4 }
284
+ );
285
+ }
286
+ };
287
+
288
+ drawLandmarkGroup(previousLandmark, previousOpacity);
289
+ drawLandmarkGroup(currentLandmark, currentOpacity);
290
+
291
+ ctx.globalAlpha = 1;
292
+ };
293
+ img.src = previewImage;
294
+ }
295
+ }, [previewImage, currentLandmark, previousLandmark, currentOpacity, previousOpacity]);
296
+
297
+ useEffect(() => {
298
+ if (isMediaPipeReady && isDrawingUtilsReady && faceLandmarks.length > 0 && canvasRef.current && mediaPipeRef.current.drawingUtils) {
299
+ drawLandmarks(faceLandmarks[0], canvasRef.current, mediaPipeRef.current.drawingUtils);
300
+ }
301
+ }, [isMediaPipeReady, isDrawingUtilsReady, faceLandmarks, currentLandmark, previousLandmark, currentOpacity, previousOpacity, drawLandmarks]);
302
+ useEffect(() => {
303
+ let animationFrame: number;
304
+ const animate = () => {
305
+ setCurrentOpacity((prev) => Math.min(prev + 0.2, 1));
306
+ setPreviousOpacity((prev) => Math.max(prev - 0.2, 0));
307
+
308
+ if (currentOpacity < 1 || previousOpacity > 0) {
309
+ animationFrame = requestAnimationFrame(animate);
310
+ }
311
+ };
312
+ animationFrame = requestAnimationFrame(animate);
313
+ return () => cancelAnimationFrame(animationFrame);
314
+ }, [currentLandmark]);
315
+
316
+ // Canvas ref callback
317
+ const canvasRefCallback = useCallback((node: HTMLCanvasElement | null) => {
318
+ if (node !== null) {
319
+ const ctx = node.getContext('2d');
320
+ if (ctx) {
321
+ // Get device pixel ratio
322
+ const pixelRatio = window.devicePixelRatio || 1;
323
+
324
+ // Scale canvas based on the pixel ratio
325
+ node.width = node.clientWidth * pixelRatio;
326
+ node.height = node.clientHeight * pixelRatio;
327
+ ctx.scale(pixelRatio, pixelRatio);
328
+
329
+ mediaPipeRef.current.drawingUtils = new vision.DrawingUtils(ctx);
330
+ setIsDrawingUtilsReady(true);
331
+ } else {
332
+ console.error('Failed to get 2D context from canvas.');
333
+ }
334
+ canvasRef.current = node;
335
+ }
336
+ }, []);
337
+
338
+
339
+ useEffect(() => {
340
+ if (!isMediaPipeReady) {
341
+ console.log('MediaPipe not ready. Skipping landmark detection.');
342
+ return
343
+ }
344
+ if (!previewImage) {
345
+ console.log('Preview image not ready. Skipping landmark detection.');
346
+ return
347
+ }
348
+ if (!isDrawingUtilsReady) {
349
+ console.log('DrawingUtils not ready. Skipping landmark detection.');
350
+ return
351
+ }
352
+ detectFaceLandmarks(previewImage);
353
+ }, [isMediaPipeReady, isDrawingUtilsReady, previewImage])
354
+
355
+
356
+
357
+ const modifyImage = useCallback(({ landmark, vector }: {
358
+ landmark: ClosestLandmark
359
+ vector: { x: number; y: number; z: number }
360
+ }) => {
361
+
362
+ const {
363
+ originalImage,
364
+ originalImageHash,
365
+ params: previousParams,
366
+ setParams,
367
+ setError
368
+ } = useMainStore.getState()
369
+
370
+
371
+ if (!originalImage) {
372
+ console.error('Image file or facePoke not available');
373
+ return;
374
+ }
375
+
376
+ const params = {
377
+ ...previousParams
378
+ }
379
+
380
+ const minX = -0.50;
381
+ const maxX = 0.50;
382
+ const minY = -0.50;
383
+ const maxY = 0.50;
384
+
385
+ // Function to map a value from one range to another
386
+ const mapRange = (value: number, inMin: number, inMax: number, outMin: number, outMax: number): number => {
387
+ return Math.min(outMax, Math.max(outMin, ((value - inMin) * (outMax - outMin)) / (inMax - inMin) + outMin));
388
+ };
389
+
390
+ console.log("modifyImage:", {
391
+ originalImage,
392
+ originalImageHash,
393
+ landmark,
394
+ vector,
395
+ minX,
396
+ maxX,
397
+ minY,
398
+ maxY,
399
+ })
400
+
401
+ // Map landmarks to ImageModificationParams
402
+ switch (landmark.group) {
403
+ case 'leftEye':
404
+ case 'rightEye':
405
+ // eyebrow (min: -20, max: 5, default: 0)
406
+ const eyesMin = 210
407
+ const eyesMax = 5
408
+ params.eyes = mapRange(vector.x, minX, maxX, eyesMin, eyesMax);
409
+
410
+ break;
411
+ case 'leftEyebrow':
412
+ case 'rightEyebrow':
413
+ // moving the mouse vertically for the eyebrow
414
+ // should make them up/down
415
+ // eyebrow (min: -10, max: 15, default: 0)
416
+ const eyebrowMin = -10
417
+ const eyebrowMax = 15
418
+ params.eyebrow = mapRange(vector.y, minY, maxY, eyebrowMin, eyebrowMax);
419
+
420
+ break;
421
+ case 'lips':
422
+ // aaa (min: -30, max: 120, default: 0)
423
+ //const aaaMin = -30
424
+ //const aaaMax = 120
425
+ //params.aaa = mapRange(vector.x, minY, maxY, aaaMin, aaaMax);
426
+
427
+ // eee (min: -20, max: 15, default: 0)
428
+ const eeeMin = -20
429
+ const eeeMax = 15
430
+ params.eee = mapRange(vector.y, minY, maxY, eeeMin, eeeMax);
431
+
432
+
433
+ // woo (min: -20, max: 15, default: 0)
434
+ const wooMin = -20
435
+ const wooMax = 15
436
+ params.woo = mapRange(vector.x, minX, maxX, wooMin, wooMax);
437
+
438
+ break;
439
+ case 'faceOval':
440
+ // displacing the face horizontally by moving the mouse on the X axis
441
+ // should perform a yaw rotation
442
+ // rotate_roll (min: -20, max: 20, default: 0)
443
+ const rollMin = -40
444
+ const rollMax = 40
445
+
446
+ // note: we invert the axis here
447
+ params.rotate_roll = mapRange(vector.x, minX, maxX, rollMin, rollMax);
448
+ break;
449
+
450
+ case 'background':
451
+ // displacing the face horizontally by moving the mouse on the X axis
452
+ // should perform a yaw rotation
453
+ // rotate_yaw (min: -20, max: 20, default: 0)
454
+ const yawMin = -40
455
+ const yawMax = 40
456
+
457
+ // note: we invert the axis here
458
+ params.rotate_yaw = mapRange(-vector.x, minX, maxX, yawMin, yawMax);
459
+
460
+ // displacing the face vertically by moving the mouse on the Y axis
461
+ // should perform a pitch rotation
462
+ // rotate_pitch (min: -20, max: 20, default: 0)
463
+ const pitchMin = -40
464
+ const pitchMax = 40
465
+ params.rotate_pitch = mapRange(vector.y, minY, maxY, pitchMin, pitchMax);
466
+ break;
467
+ default:
468
+ return
469
+ }
470
+
471
+ for (const [key, value] of Object.entries(params)) {
472
+ if (isNaN(value as any) || !isFinite(value as any)) {
473
+ console.log(`${key} is NaN, aborting`)
474
+ return
475
+ }
476
+ }
477
+ console.log(`PITCH=${params.rotate_pitch || 0}, YAW=${params.rotate_yaw || 0}, ROLL=${params.rotate_roll || 0}`);
478
+
479
+ setParams(params)
480
+ try {
481
+ // For the first request or when the image file changes, send the full image
482
+ if (!lastModifiedImageHashRef.current || lastModifiedImageHashRef.current !== originalImageHash) {
483
+ lastModifiedImageHashRef.current = originalImageHash;
484
+ facePoke.modifyImage(originalImage, null, params);
485
+ } else {
486
+ // For subsequent requests, send only the hash
487
+ facePoke.modifyImage(null, lastModifiedImageHashRef.current, params);
488
+ }
489
+ } catch (error) {
490
+ // console.error('Error modifying image:', error);
491
+ setError('Failed to modify image');
492
+ }
493
+ }, []);
494
+
495
+ // this is throttled by our average latency
496
+ const modifyImageWithRateLimit = useThrottledCallback((params: {
497
+ landmark: ClosestLandmark
498
+ vector: { x: number; y: number; z: number }
499
+ }) => {
500
+ modifyImage(params);
501
+ }, [modifyImage], averageLatency);
502
+
503
+ const handleMouseEnter = useCallback(() => {
504
+ setIsHovering(true);
505
+ }, []);
506
+
507
+ const handleMouseLeave = useCallback(() => {
508
+ setIsHovering(false);
509
+ }, []);
510
+
511
+ // Update mouse event handlers
512
+ const handleMouseDown = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => {
513
+ if (!canvasRef.current) return;
514
+
515
+ const rect = canvasRef.current.getBoundingClientRect();
516
+ const x = (event.clientX - rect.left) / rect.width;
517
+ const y = (event.clientY - rect.top) / rect.height;
518
+
519
+ const landmark = findClosestLandmark(x, y);
520
+ console.log(`Mouse down on ${landmark.group}`);
521
+ setActiveLandmark(landmark);
522
+ setDragStart({ x, y });
523
+ dragStartRef.current = { x, y };
524
+ }, [findClosestLandmark, setActiveLandmark, setDragStart]);
525
+
526
+ const handleMouseMove = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => {
527
+ if (!canvasRef.current) return;
528
+
529
+ const rect = canvasRef.current.getBoundingClientRect();
530
+ const x = (event.clientX - rect.left) / rect.width;
531
+ const y = (event.clientY - rect.top) / rect.height;
532
+
533
+ // only send an API request to modify the image if we are actively dragging
534
+ if (dragStart && dragStartRef.current) {
535
+
536
+ const landmark = findClosestLandmark(x, y, currentLandmark?.group);
537
+
538
+ console.log(`Dragging mouse (was over ${currentLandmark?.group || 'nothing'}, now over ${landmark.group})`);
539
+
540
+ // Compute the vector from the landmark center to the current mouse position
541
+ modifyImageWithRateLimit({
542
+ landmark: currentLandmark || landmark, // this will still use the initially selected landmark
543
+ vector: {
544
+ x: x - landmarkCenters[landmark.group].x,
545
+ y: y - landmarkCenters[landmark.group].y,
546
+ z: 0 // Z is 0 as mouse interaction is 2D
547
+ }
548
+ });
549
+ setIsDragging(true);
550
+ } else {
551
+ const landmark = findClosestLandmark(x, y);
552
+
553
+ //console.log(`Moving mouse over ${landmark.group}`);
554
+ // console.log(`Simple mouse move over ${landmark.group}`);
555
+
556
+ // we need to be careful here, we don't want to change the active
557
+ // landmark dynamically if we are busy dragging
558
+
559
+ if (!currentLandmark || (currentLandmark?.group !== landmark?.group)) {
560
+ // console.log("setting activeLandmark to ", landmark);
561
+ setActiveLandmark(landmark);
562
+ }
563
+ setIsHovering(true); // Ensure hovering state is maintained during movement
564
+ }
565
+ }, [currentLandmark, dragStart, setIsHovering, setActiveLandmark, setIsDragging, modifyImageWithRateLimit, landmarkCenters]);
566
+
567
+ const handleMouseUp = useCallback((event: React.MouseEvent<HTMLCanvasElement>) => {
568
+ if (!canvasRef.current) return;
569
+
570
+ const rect = canvasRef.current.getBoundingClientRect();
571
+ const x = (event.clientX - rect.left) / rect.width;
572
+ const y = (event.clientY - rect.top) / rect.height;
573
+
574
+ // only send an API request to modify the image if we are actively dragging
575
+ if (dragStart && dragStartRef.current) {
576
+
577
+ const landmark = findClosestLandmark(x, y, currentLandmark?.group);
578
+
579
+ console.log(`Mouse up (was over ${currentLandmark?.group || 'nothing'}, now over ${landmark.group})`);
580
+
581
+ // Compute the vector from the landmark center to the current mouse position
582
+ modifyImageWithRateLimit({
583
+ landmark: currentLandmark || landmark, // this will still use the initially selected landmark
584
+ vector: {
585
+ x: x - landmarkCenters[landmark.group].x,
586
+ y: y - landmarkCenters[landmark.group].y,
587
+ z: 0 // Z is 0 as mouse interaction is 2D
588
+ }
589
+ });
590
+ }
591
+
592
+ setIsDragging(false);
593
+ dragStartRef.current = null;
594
+ setActiveLandmark(undefined);
595
+ }, [currentLandmark, isDragging, modifyImageWithRateLimit, findClosestLandmark, setActiveLandmark, landmarkCenters, modifyImageWithRateLimit, setIsDragging]);
596
+
597
+ useEffect(() => {
598
+ facePoke.setOnModifiedImage((image: string, image_hash: string) => {
599
+ if (image) {
600
+ setPreviewImage(image);
601
+ }
602
+ setOriginalImageHash(image_hash);
603
+ lastModifiedImageHashRef.current = image_hash;
604
+ });
605
+ }, [setPreviewImage, setOriginalImageHash]);
606
+
607
+ return {
608
+ canvasRef,
609
+ canvasRefCallback,
610
+ mediaPipeRef,
611
+ faceLandmarks,
612
+ isMediaPipeReady,
613
+ isDrawingUtilsReady,
614
+ blendShapes,
615
+
616
+ //dragStart,
617
+ //setDragStart,
618
+ //dragEnd,
619
+ //setDragEnd,
620
+ setFaceLandmarks,
621
+ setBlendShapes,
622
+
623
+ handleMouseDown,
624
+ handleMouseUp,
625
+ handleMouseMove,
626
+ handleMouseEnter,
627
+ handleMouseLeave,
628
+
629
+ currentLandmark,
630
+ currentOpacity,
631
+ }
632
+ }
client/src/hooks/useFacePokeAPI.ts ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState } from "react";
2
+
3
+ import { facePoke } from "../lib/facePoke";
4
+ import { useMainStore } from "./useMainStore";
5
+
6
+ export function useFacePokeAPI() {
7
+
8
+ // State for FacePoke
9
+ const [status, setStatus] = useState('');
10
+ const [isDebugMode, setIsDebugMode] = useState(false);
11
+ const [interruptMessage, setInterruptMessage] = useState<string | null>(null);
12
+
13
+ const [isLoading, setIsLoading] = useState(false);
14
+
15
+ // Initialize FacePoke
16
+ useEffect(() => {
17
+ const urlParams = new URLSearchParams(window.location.search);
18
+ setIsDebugMode(urlParams.get('debug') === 'true');
19
+ }, []);
20
+
21
+ // Handle WebSocket interruptions
22
+ useEffect(() => {
23
+ const handleInterruption = (event: CustomEvent) => {
24
+ setInterruptMessage(event.detail.message);
25
+ };
26
+
27
+ window.addEventListener('websocketInterruption' as any, handleInterruption);
28
+
29
+ return () => {
30
+ window.removeEventListener('websocketInterruption' as any, handleInterruption);
31
+ };
32
+ }, []);
33
+
34
+ return {
35
+ facePoke,
36
+ status,
37
+ setStatus,
38
+ isDebugMode,
39
+ setIsDebugMode,
40
+ interruptMessage,
41
+ isLoading,
42
+ setIsLoading,
43
+ }
44
+ }
client/src/hooks/useMainStore.ts ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { create } from 'zustand'
2
+ import type { ClosestLandmark } from './useFaceLandmarkDetection'
3
+ import type { ImageModificationParams } from '@/lib/facePoke'
4
+
5
+ interface ImageState {
6
+ error: string
7
+ imageFile: File | null
8
+ originalImage: string
9
+ previewImage: string
10
+ originalImageHash: string
11
+ minLatency: number
12
+ averageLatency: number
13
+ maxLatency: number
14
+ activeLandmark?: ClosestLandmark
15
+ params: Partial<ImageModificationParams>
16
+ setError: (error?: string) => void
17
+ setImageFile: (file: File | null) => void
18
+ setOriginalImage: (url: string) => void
19
+ setOriginalImageHash: (hash: string) => void
20
+ setPreviewImage: (url: string) => void
21
+ resetImage: () => void
22
+ setAverageLatency: (averageLatency: number) => void
23
+ setActiveLandmark: (activeLandmark?: ClosestLandmark) => void
24
+ setParams: (params: Partial<ImageModificationParams>) => void
25
+ }
26
+
27
+ export const useMainStore = create<ImageState>((set, get) => ({
28
+ error: '',
29
+ imageFile: null,
30
+ originalImage: '',
31
+ originalImageHash: '',
32
+ previewImage: '',
33
+ minLatency: 20, // min time between requests
34
+ averageLatency: 190, // this should be the average for most people
35
+ maxLatency: 4000, // max time between requests
36
+ activeLandmark: undefined,
37
+ params: {},
38
+ setError: (error: string = '') => set({ error }),
39
+ setImageFile: (file) => set({ imageFile: file }),
40
+ setOriginalImage: (url) => set({ originalImage: url }),
41
+ setOriginalImageHash: (originalImageHash) => set({ originalImageHash }),
42
+ setPreviewImage: (url) => set({ previewImage: url }),
43
+ resetImage: () => {
44
+ const { originalImage } = get()
45
+ if (originalImage) {
46
+ set({ previewImage: originalImage })
47
+ }
48
+ },
49
+ setAverageLatency: (averageLatency: number) => set({ averageLatency }),
50
+ setActiveLandmark: (activeLandmark?: ClosestLandmark) => set({ activeLandmark }),
51
+ setParams: (params: Partial<ImageModificationParams>) => {
52
+ const {params: previousParams } = get()
53
+ set({ params: {
54
+ ...previousParams,
55
+ ...params
56
+ }})
57
+ },
58
+ }))
client/src/index.tsx ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import { createRoot } from 'react-dom/client';
2
+
3
+ import { App } from './app';
4
+
5
+ const root = createRoot(document.getElementById('root')!);
6
+ root.render(<App />);
client/src/layout.tsx ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { type ReactNode } from 'react';
2
+
3
+ export function Layout({ children }: { children: ReactNode }) {
4
+ return (
5
+ <div className="fixed min-h-screen w-full flex items-center justify-center bg-gradient-to-br from-gray-300 to-stone-300"
6
+ style={{ boxShadow: "inset 0 0 10vh 0 rgb(0 0 0 / 30%)" }}>
7
+ <div className="min-h-screen w-full py-8 flex flex-col justify-center">
8
+ <div className="relative p-4 sm:max-w-5xl sm:mx-auto">
9
+ {children}
10
+ </div>
11
+ </div>
12
+ </div>
13
+ );
14
+ }
client/src/lib/circularBuffer.ts ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ /**
4
+ * Circular buffer for storing and managing response times.
5
+ */
6
+ export class CircularBuffer<T> {
7
+ private buffer: T[];
8
+ private pointer: number;
9
+
10
+ constructor(private capacity: number) {
11
+ this.buffer = new Array<T>(capacity);
12
+ this.pointer = 0;
13
+ }
14
+
15
+ /**
16
+ * Adds an item to the buffer, overwriting the oldest item if full.
17
+ * @param item - The item to add to the buffer.
18
+ */
19
+ push(item: T): void {
20
+ this.buffer[this.pointer] = item;
21
+ this.pointer = (this.pointer + 1) % this.capacity;
22
+ }
23
+
24
+ /**
25
+ * Retrieves all items currently in the buffer.
26
+ * @returns An array of all items in the buffer.
27
+ */
28
+ getAll(): T[] {
29
+ return this.buffer.filter(item => item !== undefined);
30
+ }
31
+ }
client/src/lib/convertImageToBase64.ts ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export async function convertImageToBase64(imageFile: File): Promise<string> {
2
+ return new Promise((resolve, reject) => {
3
+ const reader = new FileReader();
4
+
5
+ reader.onload = () => {
6
+ if (typeof reader.result === 'string') {
7
+ resolve(reader.result);
8
+ } else {
9
+ reject(new Error('Failed to convert image to base64'));
10
+ }
11
+ };
12
+
13
+ reader.onerror = () => {
14
+ reject(new Error('Error reading file'));
15
+ };
16
+
17
+ reader.readAsDataURL(imageFile);
18
+ });
19
+ }
client/src/lib/facePoke.ts ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { v4 as uuidv4 } from 'uuid';
2
+ import { CircularBuffer } from './circularBuffer';
3
+ import { useMainStore } from '@/hooks/useMainStore';
4
+
5
+ /**
6
+ * Represents a tracked request with its UUID and timestamp.
7
+ */
8
+ export interface TrackedRequest {
9
+ uuid: string;
10
+ timestamp: number;
11
+ }
12
+
13
+ /**
14
+ * Represents the parameters for image modification.
15
+ */
16
+ export interface ImageModificationParams {
17
+ eyes: number;
18
+ eyebrow: number;
19
+ wink: number;
20
+ pupil_x: number;
21
+ pupil_y: number;
22
+ aaa: number;
23
+ eee: number;
24
+ woo: number;
25
+ smile: number;
26
+ rotate_pitch: number;
27
+ rotate_yaw: number;
28
+ rotate_roll: number;
29
+ }
30
+
31
+ /**
32
+ * Represents a message to modify an image.
33
+ */
34
+ export interface ModifyImageMessage {
35
+ type: 'modify_image';
36
+ image?: string;
37
+ image_hash?: string;
38
+ params: Partial<ImageModificationParams>;
39
+ }
40
+
41
+
42
+ /**
43
+ * Callback type for handling modified images.
44
+ */
45
+ type OnModifiedImage = (image: string, image_hash: string) => void;
46
+
47
+ /**
48
+ * Enum representing the different states of a WebSocket connection.
49
+ */
50
+ enum WebSocketState {
51
+ CONNECTING = 0,
52
+ OPEN = 1,
53
+ CLOSING = 2,
54
+ CLOSED = 3
55
+ }
56
+
57
+ /**
58
+ * FacePoke class manages the WebSocket connection
59
+ */
60
+ export class FacePoke {
61
+ private ws: WebSocket | null = null;
62
+ private readonly connectionId: string = uuidv4();
63
+ private isUnloading: boolean = false;
64
+ private onModifiedImage: OnModifiedImage = () => {};
65
+ private reconnectAttempts: number = 0;
66
+ private readonly maxReconnectAttempts: number = 5;
67
+ private readonly reconnectDelay: number = 5000;
68
+ private readonly eventListeners: Map<string, Set<Function>> = new Map();
69
+
70
+ private requestTracker: Map<string, TrackedRequest> = new Map();
71
+ private responseTimeBuffer: CircularBuffer<number>;
72
+ private readonly MAX_TRACKED_TIMES = 5; // Number of recent response times to track
73
+
74
+ /**
75
+ * Creates an instance of FacePoke.
76
+ * Initializes the WebSocket connection.
77
+ */
78
+ constructor() {
79
+ console.log(`[FacePoke] Initializing FacePoke instance with connection ID: ${this.connectionId}`);
80
+ this.initializeWebSocket();
81
+ this.setupUnloadHandler();
82
+
83
+ this.responseTimeBuffer = new CircularBuffer<number>(this.MAX_TRACKED_TIMES);
84
+ console.log(`[FacePoke] Initialized response time tracker with capacity: ${this.MAX_TRACKED_TIMES}`);
85
+ }
86
+
87
+
88
+ /**
89
+ * Generates a unique UUID for a request and starts tracking it.
90
+ * @returns The generated UUID for the request.
91
+ */
92
+ private trackRequest(): string {
93
+ const uuid = uuidv4();
94
+ this.requestTracker.set(uuid, { uuid, timestamp: Date.now() });
95
+ // console.log(`[FacePoke] Started tracking request with UUID: ${uuid}`);
96
+ return uuid;
97
+ }
98
+
99
+ /**
100
+ * Completes tracking for a request and updates response time statistics.
101
+ * @param uuid - The UUID of the completed request.
102
+ */
103
+ private completeRequest(uuid: string): void {
104
+ const request = this.requestTracker.get(uuid);
105
+ if (request) {
106
+ const responseTime = Date.now() - request.timestamp;
107
+ this.responseTimeBuffer.push(responseTime);
108
+ this.requestTracker.delete(uuid);
109
+ this.updateThrottleTime();
110
+ console.log(`[FacePoke] Completed request ${uuid}. Response time: ${responseTime}ms`);
111
+ } else {
112
+ console.warn(`[FacePoke] Attempted to complete unknown request: ${uuid}`);
113
+ }
114
+ }
115
+
116
+ /**
117
+ * Calculates the average response time from recent requests.
118
+ * @returns The average response time in milliseconds.
119
+ */
120
+ private calculateAverageResponseTime(): number {
121
+ const times = this.responseTimeBuffer.getAll();
122
+
123
+ const averageLatency = useMainStore.getState().averageLatency;
124
+
125
+ if (times.length === 0) return averageLatency;
126
+ const sum = times.reduce((acc, time) => acc + time, 0);
127
+ return sum / times.length;
128
+ }
129
+
130
+ /**
131
+ * Updates the throttle time based on recent response times.
132
+ */
133
+ private updateThrottleTime(): void {
134
+ const { minLatency, maxLatency, averageLatency, setAverageLatency } = useMainStore.getState();
135
+ const avgResponseTime = this.calculateAverageResponseTime();
136
+ const newLatency = Math.min(minLatency, Math.max(minLatency, avgResponseTime));
137
+
138
+ if (newLatency !== averageLatency) {
139
+ setAverageLatency(newLatency)
140
+ console.log(`[FacePoke] Updated throttle time (latency is ${newLatency}ms)`);
141
+ }
142
+ }
143
+
144
+ /**
145
+ * Sets the callback function for handling modified images.
146
+ * @param handler - The function to be called when a modified image is received.
147
+ */
148
+ public setOnModifiedImage(handler: OnModifiedImage): void {
149
+ this.onModifiedImage = handler;
150
+ console.log(`[FacePoke] onModifiedImage handler set`);
151
+ }
152
+
153
+ /**
154
+ * Starts or restarts the WebSocket connection.
155
+ */
156
+ public async startWebSocket(): Promise<void> {
157
+ console.log(`[FacePoke] Starting WebSocket connection.`);
158
+ if (!this.ws || this.ws.readyState !== WebSocketState.OPEN) {
159
+ await this.initializeWebSocket();
160
+ }
161
+ }
162
+
163
+ /**
164
+ * Initializes the WebSocket connection.
165
+ * Implements exponential backoff for reconnection attempts.
166
+ */
167
+ private async initializeWebSocket(): Promise<void> {
168
+ console.log(`[FacePoke][${this.connectionId}] Initializing WebSocket connection`);
169
+
170
+ const connect = () => {
171
+ this.ws = new WebSocket(`wss://${window.location.host}/ws`);
172
+
173
+ this.ws.onopen = this.handleWebSocketOpen.bind(this);
174
+ this.ws.onmessage = this.handleWebSocketMessage.bind(this);
175
+ this.ws.onclose = this.handleWebSocketClose.bind(this);
176
+ this.ws.onerror = this.handleWebSocketError.bind(this);
177
+ };
178
+
179
+ // const debouncedConnect = debounce(connect, this.reconnectDelay, { leading: true, trailing: false });
180
+
181
+ connect(); // Initial connection attempt
182
+ }
183
+
184
+ /**
185
+ * Handles the WebSocket open event.
186
+ */
187
+ private handleWebSocketOpen(): void {
188
+ console.log(`[FacePoke][${this.connectionId}] WebSocket connection opened`);
189
+ this.reconnectAttempts = 0; // Reset reconnect attempts on successful connection
190
+ this.emitEvent('websocketOpen');
191
+ }
192
+
193
+ // Update handleWebSocketMessage to complete request tracking
194
+ private handleWebSocketMessage(event: MessageEvent): void {
195
+ try {
196
+ const data = JSON.parse(event.data);
197
+ // console.log(`[FacePoke][${this.connectionId}] Received JSON data:`, data);
198
+
199
+ if (data.uuid) {
200
+ this.completeRequest(data.uuid);
201
+ }
202
+
203
+ if (data.type === 'modified_image') {
204
+ if (data?.image) {
205
+ this.onModifiedImage(data.image, data.image_hash);
206
+ }
207
+ }
208
+
209
+ this.emitEvent('message', data);
210
+ } catch (error) {
211
+ console.error(`[FacePoke][${this.connectionId}] Error parsing WebSocket message:`, error);
212
+ }
213
+ }
214
+
215
+ /**
216
+ * Handles WebSocket close events.
217
+ * Implements reconnection logic with exponential backoff.
218
+ * @param event - The CloseEvent containing close information.
219
+ */
220
+ private handleWebSocketClose(event: CloseEvent): void {
221
+ if (event.wasClean) {
222
+ console.log(`[FacePoke][${this.connectionId}] WebSocket connection closed cleanly, code=${event.code}, reason=${event.reason}`);
223
+ } else {
224
+ console.warn(`[FacePoke][${this.connectionId}] WebSocket connection abruptly closed`);
225
+ }
226
+
227
+ this.emitEvent('websocketClose', event);
228
+
229
+ // Attempt to reconnect after a delay, unless the page is unloading or max attempts reached
230
+ if (!this.isUnloading && this.reconnectAttempts < this.maxReconnectAttempts) {
231
+ this.reconnectAttempts++;
232
+ const delay = Math.min(1000 * (2 ** this.reconnectAttempts), 30000); // Exponential backoff, max 30 seconds
233
+ console.log(`[FacePoke][${this.connectionId}] Attempting to reconnect in ${delay}ms (Attempt ${this.reconnectAttempts}/${this.maxReconnectAttempts})...`);
234
+ setTimeout(() => this.initializeWebSocket(), delay);
235
+ } else if (this.reconnectAttempts >= this.maxReconnectAttempts) {
236
+ console.error(`[FacePoke][${this.connectionId}] Max reconnect attempts reached. Please refresh the page.`);
237
+ this.emitEvent('maxReconnectAttemptsReached');
238
+ }
239
+ }
240
+
241
+ /**
242
+ * Handles WebSocket errors.
243
+ * @param error - The error event.
244
+ */
245
+ private handleWebSocketError(error: Event): void {
246
+ console.error(`[FacePoke][${this.connectionId}] WebSocket error:`, error);
247
+ this.emitEvent('websocketError', error);
248
+ }
249
+
250
+ /**
251
+ * Handles interruption messages from the server.
252
+ * @param message - The interruption message.
253
+ */
254
+ private handleInterruption(message: string): void {
255
+ console.warn(`[FacePoke] Interruption: ${message}`);
256
+ this.emitEvent('interruption', message);
257
+ }
258
+
259
+ /**
260
+ * Toggles the microphone on or off.
261
+ * @param isOn - Whether to turn the microphone on (true) or off (false).
262
+ */
263
+ public async toggleMicrophone(isOn: boolean): Promise<void> {
264
+ console.log(`[FacePoke] Attempting to ${isOn ? 'start' : 'stop'} microphone`);
265
+ try {
266
+ if (isOn) {
267
+ await this.startMicrophone();
268
+ } else {
269
+ this.stopMicrophone();
270
+ }
271
+ this.emitEvent('microphoneToggled', isOn);
272
+ } catch (error) {
273
+ console.error(`[FacePoke] Error toggling microphone:`, error);
274
+ this.emitEvent('microphoneError', error);
275
+ throw error;
276
+ }
277
+ }
278
+
279
+
280
+ /**
281
+ * Cleans up resources and closes connections.
282
+ */
283
+ public cleanup(): void {
284
+ console.log('[FacePoke] Starting cleanup process');
285
+ if (this.ws) {
286
+ this.ws.close();
287
+ this.ws = null;
288
+ }
289
+ this.eventListeners.clear();
290
+ console.log('[FacePoke] Cleanup completed');
291
+ this.emitEvent('cleanup');
292
+ }
293
+
294
+ /**
295
+ * Modifies an image based on the provided parameters
296
+ * @param image - The data-uri base64 image to modify.
297
+ * @param imageHash - The hash of the image to modify.
298
+ * @param params - The parameters for image modification.
299
+ */
300
+ public modifyImage(image: string | null, imageHash: string | null, params: Partial<ImageModificationParams>): void {
301
+ try {
302
+ const message: ModifyImageMessage = {
303
+ type: 'modify_image',
304
+ params: params
305
+ };
306
+
307
+ if (image) {
308
+ message.image = image;
309
+ } else if (imageHash) {
310
+ message.image_hash = imageHash;
311
+ } else {
312
+ throw new Error('Either image or imageHash must be provided');
313
+ }
314
+
315
+ this.sendJsonMessage(message);
316
+ // console.log(`[FacePoke] Sent modify image request with UUID: ${uuid}`);
317
+ } catch (err) {
318
+ console.error(`[FacePoke] Failed to modify the image:`, err);
319
+ }
320
+ }
321
+
322
+ /**
323
+ * Sends a JSON message through the WebSocket connection with request tracking.
324
+ * @param message - The message to send.
325
+ * @throws Error if the WebSocket is not open.
326
+ */
327
+ private sendJsonMessage<T>(message: T): void {
328
+ if (!this.ws || this.ws.readyState !== WebSocketState.OPEN) {
329
+ const error = new Error('WebSocket connection is not open');
330
+ console.error('[FacePoke] Error sending JSON message:', error);
331
+ this.emitEvent('sendJsonMessageError', error);
332
+ throw error;
333
+ }
334
+
335
+ const uuid = this.trackRequest();
336
+ const messageWithUuid = { ...message, uuid };
337
+ // console.log(`[FacePoke] Sending JSON message with UUID ${uuid}:`, messageWithUuid);
338
+ this.ws.send(JSON.stringify(messageWithUuid));
339
+ }
340
+
341
+ /**
342
+ * Sets up the unload handler to clean up resources when the page is unloading.
343
+ */
344
+ private setupUnloadHandler(): void {
345
+ window.addEventListener('beforeunload', () => {
346
+ console.log('[FacePoke] Page is unloading, cleaning up resources');
347
+ this.isUnloading = true;
348
+ if (this.ws) {
349
+ this.ws.close(1000, 'Page is unloading');
350
+ }
351
+ this.cleanup();
352
+ });
353
+ }
354
+
355
+ /**
356
+ * Adds an event listener for a specific event type.
357
+ * @param eventType - The type of event to listen for.
358
+ * @param listener - The function to be called when the event is emitted.
359
+ */
360
+ public addEventListener(eventType: string, listener: Function): void {
361
+ if (!this.eventListeners.has(eventType)) {
362
+ this.eventListeners.set(eventType, new Set());
363
+ }
364
+ this.eventListeners.get(eventType)!.add(listener);
365
+ console.log(`[FacePoke] Added event listener for '${eventType}'`);
366
+ }
367
+
368
+ /**
369
+ * Removes an event listener for a specific event type.
370
+ * @param eventType - The type of event to remove the listener from.
371
+ * @param listener - The function to be removed from the listeners.
372
+ */
373
+ public removeEventListener(eventType: string, listener: Function): void {
374
+ const listeners = this.eventListeners.get(eventType);
375
+ if (listeners) {
376
+ listeners.delete(listener);
377
+ console.log(`[FacePoke] Removed event listener for '${eventType}'`);
378
+ }
379
+ }
380
+
381
+ /**
382
+ * Emits an event to all registered listeners for that event type.
383
+ * @param eventType - The type of event to emit.
384
+ * @param data - Optional data to pass to the event listeners.
385
+ */
386
+ private emitEvent(eventType: string, data?: any): void {
387
+ const listeners = this.eventListeners.get(eventType);
388
+ if (listeners) {
389
+ console.log(`[FacePoke] Emitting event '${eventType}' with data:`, data);
390
+ listeners.forEach(listener => listener(data));
391
+ }
392
+ }
393
+ }
394
+
395
+ /**
396
+ * Singleton instance of the FacePoke class.
397
+ */
398
+ export const facePoke = new FacePoke();
client/src/lib/throttle.ts ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ /**
3
+ * Custom throttle function that allows the first call to go through immediately
4
+ * and then limits subsequent calls.
5
+ * @param func - The function to throttle.
6
+ * @param limit - The minimum time between function calls in milliseconds.
7
+ * @returns A throttled version of the function.
8
+ */
9
+ export function throttle<T extends (...args: any[]) => any>(func: T, limit: number): T {
10
+ let lastCall = 0;
11
+ let timeoutId: NodeJS.Timer | null = null;
12
+
13
+ return function (this: any, ...args: Parameters<T>) {
14
+ const context = this;
15
+ const now = Date.now();
16
+
17
+ if (now - lastCall >= limit) {
18
+ if (timeoutId !== null) {
19
+ clearTimeout(timeoutId);
20
+ timeoutId = null;
21
+ }
22
+ lastCall = now;
23
+ return func.apply(context, args);
24
+ } else if (!timeoutId) {
25
+ timeoutId = setTimeout(() => {
26
+ lastCall = Date.now();
27
+ timeoutId = null;
28
+ func.apply(context, args);
29
+ }, limit - (now - lastCall));
30
+ }
31
+ } as T;
32
+ }
client/src/lib/utils.ts ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { clsx, type ClassValue } from "clsx"
2
+ import { twMerge } from "tailwind-merge"
3
+
4
+ export function cn(...inputs: ClassValue[]) {
5
+ return twMerge(clsx(inputs))
6
+ }
7
+
8
+ export function truncateFileName(fileName: string, maxLength: number = 16) {
9
+ if (fileName.length <= maxLength) return fileName;
10
+
11
+ const start = fileName.slice(0, maxLength / 2 - 1);
12
+ const end = fileName.slice(-maxLength / 2 + 2);
13
+
14
+ return `${start}...${end}`;
15
+ };
client/src/styles/globals.css ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;
4
+
5
+ @layer base {
6
+ :root {
7
+ --background: 0 0% 100%;
8
+ --foreground: 222.2 47.4% 11.2%;
9
+
10
+ --muted: 210 40% 96.1%;
11
+ --muted-foreground: 215.4 16.3% 46.9%;
12
+
13
+ --popover: 0 0% 100%;
14
+ --popover-foreground: 222.2 47.4% 11.2%;
15
+
16
+ --border: 214.3 31.8% 91.4%;
17
+ --input: 214.3 31.8% 91.4%;
18
+
19
+ --card: 0 0% 100%;
20
+ --card-foreground: 222.2 47.4% 11.2%;
21
+
22
+ --primary: 222.2 47.4% 11.2%;
23
+ --primary-foreground: 210 40% 98%;
24
+
25
+ --secondary: 210 40% 96.1%;
26
+ --secondary-foreground: 222.2 47.4% 11.2%;
27
+
28
+ --accent: 210 40% 96.1%;
29
+ --accent-foreground: 222.2 47.4% 11.2%;
30
+
31
+ --destructive: 0 100% 50%;
32
+ --destructive-foreground: 210 40% 98%;
33
+
34
+ --ring: 215 20.2% 65.1%;
35
+
36
+ --radius: 0.5rem;
37
+ }
38
+
39
+ .dark {
40
+ --background: 224 71% 4%;
41
+ --foreground: 213 31% 91%;
42
+
43
+ --muted: 223 47% 11%;
44
+ --muted-foreground: 215.4 16.3% 56.9%;
45
+
46
+ --accent: 216 34% 17%;
47
+ --accent-foreground: 210 40% 98%;
48
+
49
+ --popover: 224 71% 4%;
50
+ --popover-foreground: 215 20.2% 65.1%;
51
+
52
+ --border: 216 34% 17%;
53
+ --input: 216 34% 17%;
54
+
55
+ --card: 224 71% 4%;
56
+ --card-foreground: 213 31% 91%;
57
+
58
+ --primary: 210 40% 98%;
59
+ --primary-foreground: 222.2 47.4% 1.2%;
60
+
61
+ --secondary: 222.2 47.4% 11.2%;
62
+ --secondary-foreground: 210 40% 98%;
63
+
64
+ --destructive: 0 63% 31%;
65
+ --destructive-foreground: 210 40% 98%;
66
+
67
+ --ring: 216 34% 17%;
68
+
69
+ --radius: 0.5rem;
70
+ }
71
+ }
72
+
73
+ @layer base {
74
+ * {
75
+ @apply border-border;
76
+ }
77
+ body {
78
+ @apply bg-background text-foreground;
79
+ font-feature-settings: "rlig" 1, "calt" 1;
80
+ }
81
+ }
client/tailwind.config.js ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const { fontFamily } = require("tailwindcss/defaultTheme")
2
+
3
+ /** @type {import('tailwindcss').Config} */
4
+ module.exports = {
5
+ darkMode: ["class"],
6
+ content: [
7
+ "app/**/*.{ts,tsx}",
8
+ "components/**/*.{ts,tsx}",
9
+ '../public/index.html'
10
+ ],
11
+ theme: {
12
+ container: {
13
+ center: true,
14
+ padding: "2rem",
15
+ screens: {
16
+ "2xl": "1400px",
17
+ },
18
+ },
19
+ extend: {
20
+ colors: {
21
+ border: "hsl(var(--border))",
22
+ input: "hsl(var(--input))",
23
+ ring: "hsl(var(--ring))",
24
+ background: "hsl(var(--background))",
25
+ foreground: "hsl(var(--foreground))",
26
+ primary: {
27
+ DEFAULT: "hsl(var(--primary))",
28
+ foreground: "hsl(var(--primary-foreground))",
29
+ },
30
+ secondary: {
31
+ DEFAULT: "hsl(var(--secondary))",
32
+ foreground: "hsl(var(--secondary-foreground))",
33
+ },
34
+ destructive: {
35
+ DEFAULT: "hsl(var(--destructive))",
36
+ foreground: "hsl(var(--destructive-foreground))",
37
+ },
38
+ muted: {
39
+ DEFAULT: "hsl(var(--muted))",
40
+ foreground: "hsl(var(--muted-foreground))",
41
+ },
42
+ accent: {
43
+ DEFAULT: "hsl(var(--accent))",
44
+ foreground: "hsl(var(--accent-foreground))",
45
+ },
46
+ popover: {
47
+ DEFAULT: "hsl(var(--popover))",
48
+ foreground: "hsl(var(--popover-foreground))",
49
+ },
50
+ card: {
51
+ DEFAULT: "hsl(var(--card))",
52
+ foreground: "hsl(var(--card-foreground))",
53
+ },
54
+ },
55
+ borderRadius: {
56
+ lg: `var(--radius)`,
57
+ md: `calc(var(--radius) - 2px)`,
58
+ sm: "calc(var(--radius) - 4px)",
59
+ },
60
+ fontFamily: {
61
+ sans: ["var(--font-sans)", ...fontFamily.sans],
62
+ },
63
+ fontSize: {
64
+ "5xs": "8px",
65
+ "4xs": "9px",
66
+ "3xs": "10px",
67
+ "2xs": "11px"
68
+ },
69
+ keyframes: {
70
+ "accordion-down": {
71
+ from: { height: "0" },
72
+ to: { height: "var(--radix-accordion-content-height)" },
73
+ },
74
+ "accordion-up": {
75
+ from: { height: "var(--radix-accordion-content-height)" },
76
+ to: { height: "0" },
77
+ },
78
+ },
79
+ animation: {
80
+ "accordion-down": "accordion-down 0.2s ease-out",
81
+ "accordion-up": "accordion-up 0.2s ease-out",
82
+ },
83
+ },
84
+ },
85
+ plugins: [require("tailwindcss-animate")],
86
+ }
client/tsconfig.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ // Enable latest features
4
+ "lib": ["ESNext", "DOM", "DOM.Iterable"],
5
+ "target": "ESNext",
6
+ "module": "ESNext",
7
+ "moduleDetection": "force",
8
+ "jsx": "react-jsx",
9
+ "allowJs": true,
10
+
11
+ // Bundler mode
12
+ "moduleResolution": "bundler",
13
+ "allowImportingTsExtensions": true,
14
+ "verbatimModuleSyntax": true,
15
+ "noEmit": true,
16
+
17
+ "baseUrl": ".",
18
+ "paths": {
19
+ "@/*": ["./src/*"]
20
+ },
21
+
22
+ // Best practices
23
+ "strict": true,
24
+ "skipLibCheck": true,
25
+ "noFallthroughCasesInSwitch": true,
26
+
27
+ // Some stricter flags (disabled by default)
28
+ "noUnusedLocals": false,
29
+ "noUnusedParameters": false,
30
+ "noPropertyAccessFromIndexSignature": false
31
+ }
32
+ }
engine.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ import hashlib
4
+ import uuid
5
+ import os
6
+ import io
7
+ import shutil
8
+ import asyncio
9
+ import base64
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from queue import Queue
12
+ from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple, Union
13
+ from functools import lru_cache
14
+ import av
15
+ import numpy as np
16
+ import cv2
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from PIL import Image
20
+
21
+ from liveportrait.config.argument_config import ArgumentConfig
22
+ from liveportrait.utils.camera import get_rotation_matrix
23
+ from liveportrait.utils.io import load_image_rgb, load_driving_info, resize_to_limit
24
+ from liveportrait.utils.crop import prepare_paste_back, paste_back
25
+
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Global constants
31
+ DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
32
+ MODELS_DIR = os.path.join(DATA_ROOT, "models")
33
+
34
+ def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image:
35
+ """
36
+ Convert a base64 data URI to a PIL Image.
37
+
38
+ Args:
39
+ base64_string (str): The base64 encoded image data.
40
+
41
+ Returns:
42
+ Image.Image: The decoded PIL Image.
43
+ """
44
+ if ',' in base64_string:
45
+ base64_string = base64_string.split(',')[1]
46
+ img_data = base64.b64decode(base64_string)
47
+ return Image.open(io.BytesIO(img_data))
48
+
49
+ class Engine:
50
+ """
51
+ The main engine class for FacePoke
52
+ """
53
+
54
+ def __init__(self, live_portrait):
55
+ """
56
+ Initialize the FacePoke engine with necessary models and processors.
57
+
58
+ Args:
59
+ live_portrait (LivePortraitPipeline): The LivePortrait model for video generation.
60
+ """
61
+ self.live_portrait = live_portrait
62
+
63
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+
65
+ # cache for the "modify image" workflow
66
+ self.image_cache = {} # Stores the original images
67
+ self.processed_cache = {} # Stores the processed image data
68
+
69
+ logger.info("✅ FacePoke Engine initialized successfully.")
70
+
71
+ def get_image_hash(self, image: Union[Image.Image, str, bytes]) -> str:
72
+ """
73
+ Compute or retrieve the hash for an image.
74
+
75
+ Args:
76
+ image (Union[Image.Image, str, bytes]): The input image, either as a PIL Image,
77
+ base64 string, or bytes.
78
+
79
+ Returns:
80
+ str: The computed hash of the image.
81
+ """
82
+ if isinstance(image, str):
83
+ # Assume it's already a hash if it's a string of the right length
84
+ if len(image) == 32:
85
+ return image
86
+ # Otherwise, assume it's a base64 string
87
+ image = base64_data_uri_to_PIL_Image(image)
88
+
89
+ if isinstance(image, Image.Image):
90
+ return hashlib.md5(image.tobytes()).hexdigest()
91
+ elif isinstance(image, bytes):
92
+ return hashlib.md5(image).hexdigest()
93
+ else:
94
+ raise ValueError("Unsupported image type")
95
+
96
+ @lru_cache(maxsize=128)
97
+ def _process_image(self, image_hash: str) -> Dict[str, Any]:
98
+ """
99
+ Process the input image and cache the results.
100
+
101
+ Args:
102
+ image_hash (str): Hash of the input image.
103
+
104
+ Returns:
105
+ Dict[str, Any]: Processed image data.
106
+ """
107
+ logger.info(f"Processing image with hash: {image_hash}")
108
+ if image_hash not in self.image_cache:
109
+ raise ValueError(f"Image with hash {image_hash} not found in cache")
110
+
111
+ image = self.image_cache[image_hash]
112
+ img_rgb = np.array(image)
113
+
114
+ inference_cfg = self.live_portrait.live_portrait_wrapper.cfg
115
+ img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
116
+ crop_info = self.live_portrait.cropper.crop_single_image(img_rgb)
117
+ img_crop_256x256 = crop_info['img_crop_256x256']
118
+
119
+ I_s = self.live_portrait.live_portrait_wrapper.prepare_source(img_crop_256x256)
120
+ x_s_info = self.live_portrait.live_portrait_wrapper.get_kp_info(I_s)
121
+ f_s = self.live_portrait.live_portrait_wrapper.extract_feature_3d(I_s)
122
+ x_s = self.live_portrait.live_portrait_wrapper.transform_keypoint(x_s_info)
123
+
124
+ processed_data = {
125
+ 'img_rgb': img_rgb,
126
+ 'crop_info': crop_info,
127
+ 'x_s_info': x_s_info,
128
+ 'f_s': f_s,
129
+ 'x_s': x_s,
130
+ 'inference_cfg': inference_cfg
131
+ }
132
+
133
+ self.processed_cache[image_hash] = processed_data
134
+
135
+ return processed_data
136
+
137
+ async def modify_image(self, image_or_hash: Union[Image.Image, str, bytes], params: Dict[str, float]) -> str:
138
+ """
139
+ Modify the input image based on the provided parameters, using caching for efficiency
140
+ and outputting the result as a WebP image.
141
+
142
+ Args:
143
+ image_or_hash (Union[Image.Image, str, bytes]): Input image as a PIL Image, base64-encoded string,
144
+ image bytes, or a hash string.
145
+ params (Dict[str, float]): Parameters for face transformation.
146
+
147
+ Returns:
148
+ str: Modified image as a base64-encoded WebP data URI.
149
+
150
+ Raises:
151
+ ValueError: If there's an error modifying the image or WebP is not supported.
152
+ """
153
+ logger.info("Starting image modification")
154
+ logger.debug(f"Modification parameters: {params}")
155
+
156
+ try:
157
+ image_hash = self.get_image_hash(image_or_hash)
158
+
159
+ # If we don't have the image in cache yet, add it
160
+ if image_hash not in self.image_cache:
161
+ if isinstance(image_or_hash, (Image.Image, bytes)):
162
+ self.image_cache[image_hash] = image_or_hash
163
+ elif isinstance(image_or_hash, str) and len(image_or_hash) != 32:
164
+ # It's a base64 string, not a hash
165
+ self.image_cache[image_hash] = base64_data_uri_to_PIL_Image(image_or_hash)
166
+ else:
167
+ raise ValueError("Image not found in cache and no valid image provided")
168
+
169
+ # Process the image (this will use the cache if available)
170
+ if image_hash not in self.processed_cache:
171
+ processed_data = await asyncio.to_thread(self._process_image, image_hash)
172
+ else:
173
+ processed_data = self.processed_cache[image_hash]
174
+
175
+ # Apply modifications based on params
176
+ x_d_new = processed_data['x_s_info']['kp'].clone()
177
+ await self._apply_facial_modifications(x_d_new, params)
178
+
179
+ # Apply rotation
180
+ R_new = get_rotation_matrix(
181
+ processed_data['x_s_info']['pitch'] + params.get('rotate_pitch', 0),
182
+ processed_data['x_s_info']['yaw'] + params.get('rotate_yaw', 0),
183
+ processed_data['x_s_info']['roll'] + params.get('rotate_roll', 0)
184
+ )
185
+ x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t']
186
+
187
+ # Apply stitching
188
+ x_d_new = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.stitching, processed_data['x_s'], x_d_new)
189
+
190
+ # Generate the output
191
+ out = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.warp_decode, processed_data['f_s'], processed_data['x_s'], x_d_new)
192
+ I_p = self.live_portrait.live_portrait_wrapper.parse_output(out['out'])[0]
193
+
194
+ # Paste back to full size
195
+ mask_ori = await asyncio.to_thread(
196
+ prepare_paste_back,
197
+ processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'],
198
+ dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0])
199
+ )
200
+ I_p_to_ori_blend = await asyncio.to_thread(
201
+ paste_back,
202
+ I_p, processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori
203
+ )
204
+
205
+ # Convert the result to a PIL Image
206
+ result_image = Image.fromarray(I_p_to_ori_blend)
207
+
208
+ # Save as WebP
209
+ buffered = io.BytesIO()
210
+ result_image.save(buffered, format="WebP", quality=85) # Adjust quality as needed
211
+ modified_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
212
+
213
+ logger.info("Image modification completed successfully")
214
+ return f"data:image/webp;base64,{modified_image_base64}"
215
+
216
+ except Exception as e:
217
+ logger.error(f"Error in modify_image: {str(e)}")
218
+ logger.exception("Full traceback:")
219
+ raise ValueError(f"Failed to modify image: {str(e)}")
220
+
221
+ async def _apply_facial_modifications(self, x_d_new: torch.Tensor, params: Dict[str, float]) -> None:
222
+ """
223
+ Apply facial modifications to the keypoints based on the provided parameters.
224
+
225
+ Args:
226
+ x_d_new (torch.Tensor): Tensor of facial keypoints to be modified.
227
+ params (Dict[str, float]): Parameters for face transformation.
228
+ """
229
+ modifications = [
230
+ ('smile', [
231
+ (0, 20, 1, -0.01), (0, 14, 1, -0.02), (0, 17, 1, 0.0065), (0, 17, 2, 0.003),
232
+ (0, 13, 1, -0.00275), (0, 16, 1, -0.00275), (0, 3, 1, -0.0035), (0, 7, 1, -0.0035)
233
+ ]),
234
+ ('aaa', [
235
+ (0, 19, 1, 0.001), (0, 19, 2, 0.0001), (0, 17, 1, -0.0001)
236
+ ]),
237
+ ('eee', [
238
+ (0, 20, 2, -0.001), (0, 20, 1, -0.001), (0, 14, 1, -0.001)
239
+ ]),
240
+ ('woo', [
241
+ (0, 14, 1, 0.001), (0, 3, 1, -0.0005), (0, 7, 1, -0.0005), (0, 17, 2, -0.0005)
242
+ ]),
243
+ ('wink', [
244
+ (0, 11, 1, 0.001), (0, 13, 1, -0.0003), (0, 17, 0, 0.0003),
245
+ (0, 17, 1, 0.0003), (0, 3, 1, -0.0003)
246
+ ]),
247
+ ('pupil_x', [
248
+ (0, 11, 0, 0.0007 if params.get('pupil_x', 0) > 0 else 0.001),
249
+ (0, 15, 0, 0.001 if params.get('pupil_x', 0) > 0 else 0.0007)
250
+ ]),
251
+ ('pupil_y', [
252
+ (0, 11, 1, -0.001), (0, 15, 1, -0.001)
253
+ ]),
254
+ ('eyes', [
255
+ (0, 11, 1, -0.001), (0, 13, 1, 0.0003), (0, 15, 1, -0.001), (0, 16, 1, 0.0003),
256
+ (0, 1, 1, -0.00025), (0, 2, 1, 0.00025)
257
+ ]),
258
+ ('eyebrow', [
259
+ (0, 1, 1, 0.001 if params.get('eyebrow', 0) > 0 else 0.0003),
260
+ (0, 2, 1, -0.001 if params.get('eyebrow', 0) > 0 else -0.0003),
261
+ (0, 1, 0, -0.001 if params.get('eyebrow', 0) <= 0 else 0),
262
+ (0, 2, 0, 0.001 if params.get('eyebrow', 0) <= 0 else 0)
263
+ ])
264
+ ]
265
+
266
+ for param_name, adjustments in modifications:
267
+ param_value = params.get(param_name, 0)
268
+ for i, j, k, factor in adjustments:
269
+ x_d_new[i, j, k] += param_value * factor
270
+
271
+ # Special case for pupil_y affecting eyes
272
+ x_d_new[0, 11, 1] -= params.get('pupil_y', 0) * 0.001
273
+ x_d_new[0, 15, 1] -= params.get('pupil_y', 0) * 0.001
274
+ params['eyes'] = params.get('eyes', 0) - params.get('pupil_y', 0) / 2.
275
+
276
+ async def cleanup(self):
277
+ """
278
+ Perform cleanup operations for the Engine.
279
+ This method should be called when shutting down the application.
280
+ """
281
+ logger.info("Starting Engine cleanup")
282
+ try:
283
+ # TODO: Add any additional cleanup operations here
284
+ logger.info("Engine cleanup completed successfully")
285
+ except Exception as e:
286
+ logger.error(f"Error during Engine cleanup: {str(e)}")
287
+ logger.exception("Full traceback:")
288
+
289
+ def create_engine(models):
290
+ logger.info("Creating Engine instance...")
291
+
292
+ live_portrait = models
293
+
294
+ engine = Engine(
295
+ live_portrait=live_portrait,
296
+ # we might have more in the future
297
+ )
298
+
299
+ logger.info("Engine instance created successfully")
300
+ return engine
liveportrait/config/__init__.py ADDED
File without changes
liveportrait/config/argument_config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ config for user
5
+ """
6
+
7
+ import os.path as osp
8
+ from dataclasses import dataclass
9
+ import tyro
10
+ from typing_extensions import Annotated
11
+ from .base_config import PrintableConfig, make_abs_path
12
+
13
+
14
+ @dataclass(repr=False) # use repr from PrintableConfig
15
+ class ArgumentConfig(PrintableConfig):
16
+ ########## input arguments ##########
17
+ source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
18
+ driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
19
+ output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
20
+ #####################################
21
+
22
+ ########## inference arguments ##########
23
+ device_id: int = 0
24
+ flag_lip_zero : bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
25
+ flag_eye_retargeting: bool = False
26
+ flag_lip_retargeting: bool = False
27
+ flag_stitching: bool = True # we recommend setting it to True!
28
+ flag_relative: bool = True # whether to use relative motion
29
+ flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
30
+ flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
31
+ flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
32
+ #########################################
33
+
34
+ ########## crop arguments ##########
35
+ dsize: int = 512
36
+ scale: float = 2.3
37
+ vx_ratio: float = 0 # vx ratio
38
+ vy_ratio: float = -0.125 # vy ratio +up, -down
39
+ ####################################
40
+
41
+ ########## gradio arguments ##########
42
+ server_port: Annotated[int, tyro.conf.arg(aliases=["-p"])] = 8890
43
+ share: bool = True
44
+ server_name: str = "0.0.0.0"
liveportrait/config/base_config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ pretty printing class
5
+ """
6
+
7
+ from __future__ import annotations
8
+ import os.path as osp
9
+ from typing import Tuple
10
+
11
+
12
+ def make_abs_path(fn):
13
+ return osp.join(osp.dirname(osp.realpath(__file__)), fn)
14
+
15
+
16
+ class PrintableConfig: # pylint: disable=too-few-public-methods
17
+ """Printable Config defining str function"""
18
+
19
+ def __repr__(self):
20
+ lines = [self.__class__.__name__ + ":"]
21
+ for key, val in vars(self).items():
22
+ if isinstance(val, Tuple):
23
+ flattened_val = "["
24
+ for item in val:
25
+ flattened_val += str(item) + "\n"
26
+ flattened_val = flattened_val.rstrip("\n")
27
+ val = flattened_val + "]"
28
+ lines += f"{key}: {str(val)}".split("\n")
29
+ return "\n ".join(lines)
liveportrait/config/crop_config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ parameters used for crop faces
5
+ """
6
+
7
+ import os.path as osp
8
+ from dataclasses import dataclass
9
+ from typing import Union, List
10
+ from .base_config import PrintableConfig
11
+
12
+
13
+ @dataclass(repr=False) # use repr from PrintableConfig
14
+ class CropConfig(PrintableConfig):
15
+ dsize: int = 512 # crop size
16
+ scale: float = 2.3 # scale factor
17
+ vx_ratio: float = 0 # vx ratio
18
+ vy_ratio: float = -0.125 # vy ratio +up, -down
liveportrait/config/inference_config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ config dataclass used for inference
5
+ """
6
+
7
+ import os
8
+ import os.path as osp
9
+ from dataclasses import dataclass
10
+ from typing import Literal, Tuple
11
+ from .base_config import PrintableConfig, make_abs_path
12
+
13
+ # Configuration
14
+ DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
15
+ MODELS_DIR = os.path.join(DATA_ROOT, "models")
16
+
17
+ @dataclass(repr=False) # use repr from PrintableConfig
18
+ class InferenceConfig(PrintableConfig):
19
+ models_config: str = make_abs_path('./models.yaml') # portrait animation config
20
+ checkpoint_F = os.path.join(MODELS_DIR, "liveportrait", "appearance_feature_extractor.pth")
21
+ checkpoint_M = os.path.join(MODELS_DIR, "liveportrait", "motion_extractor.pth")
22
+ checkpoint_W = os.path.join(MODELS_DIR, "liveportrait", "warping_module.pth")
23
+ checkpoint_G = os.path.join(MODELS_DIR, "liveportrait", "spade_generator.pth")
24
+ checkpoint_S = os.path.join(MODELS_DIR, "liveportrait", "stitching_retargeting_module.pth")
25
+
26
+ flag_use_half_precision: bool = True # whether to use half precision
27
+
28
+ flag_lip_zero: bool = True # whether let the lip to close state before animation, only take effect when flag_eye_retargeting and flag_lip_retargeting is False
29
+ lip_zero_threshold: float = 0.03
30
+
31
+ flag_eye_retargeting: bool = False
32
+ flag_lip_retargeting: bool = False
33
+ flag_stitching: bool = True # we recommend setting it to True!
34
+
35
+ flag_relative: bool = True # whether to use relative motion
36
+ anchor_frame: int = 0 # set this value if find_best_frame is True
37
+
38
+ input_shape: Tuple[int, int] = (256, 256) # input shape
39
+ output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
40
+ output_fps: int = 25 # MuseTalk prefers 25 fps, so we use 25 as default fps for output video
41
+ crf: int = 15 # crf for output video
42
+
43
+ flag_write_result: bool = True # whether to write output video
44
+ flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
45
+ mask_crop = None
46
+ flag_write_gif: bool = False
47
+ size_gif: int = 256
48
+ ref_max_shape: int = 1280
49
+ ref_shape_n: int = 2
50
+
51
+ device_id: int = 0
52
+ flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
53
+ flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
liveportrait/config/models.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ appearance_feature_extractor_params: # the F in the paper
3
+ image_channel: 3
4
+ block_expansion: 64
5
+ num_down_blocks: 2
6
+ max_features: 512
7
+ reshape_channel: 32
8
+ reshape_depth: 16
9
+ num_resblocks: 6
10
+ motion_extractor_params: # the M in the paper
11
+ num_kp: 21
12
+ backbone: convnextv2_tiny
13
+ warping_module_params: # the W in the paper
14
+ num_kp: 21
15
+ block_expansion: 64
16
+ max_features: 512
17
+ num_down_blocks: 2
18
+ reshape_channel: 32
19
+ estimate_occlusion_map: True
20
+ dense_motion_params:
21
+ block_expansion: 32
22
+ max_features: 1024
23
+ num_blocks: 5
24
+ reshape_depth: 16
25
+ compress: 4
26
+ spade_generator_params: # the G in the paper
27
+ upscale: 2 # represents upsample factor 256x256 -> 512x512
28
+ block_expansion: 64
29
+ max_features: 512
30
+ num_down_blocks: 2
31
+ stitching_retargeting_module_params: # the S in the paper
32
+ stitching:
33
+ input_size: 126 # (21*3)*2
34
+ hidden_sizes: [128, 128, 64]
35
+ output_size: 65 # (21*3)+2(tx,ty)
36
+ lip:
37
+ input_size: 65 # (21*3)+2
38
+ hidden_sizes: [128, 128, 64]
39
+ output_size: 63 # (21*3)
40
+ eye:
41
+ input_size: 66 # (21*3)+3
42
+ hidden_sizes: [256, 256, 128, 128, 64]
43
+ output_size: 63 # (21*3)
liveportrait/gradio_pipeline.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Pipeline for gradio
5
+ """
6
+ import gradio as gr
7
+ from .config.argument_config import ArgumentConfig
8
+ from .live_portrait_pipeline import LivePortraitPipeline
9
+ from .utils.io import load_img_online
10
+ from .utils.rprint import rlog as log
11
+ from .utils.crop import prepare_paste_back, paste_back
12
+ from .utils.camera import get_rotation_matrix
13
+ from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
14
+
15
+ def update_args(args, user_args):
16
+ """update the args according to user inputs
17
+ """
18
+ for k, v in user_args.items():
19
+ if hasattr(args, k):
20
+ setattr(args, k, v)
21
+ return args
22
+
23
+ class GradioPipeline(LivePortraitPipeline):
24
+
25
+ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
26
+ super().__init__(inference_cfg, crop_cfg)
27
+ # self.live_portrait_wrapper = self.live_portrait_wrapper
28
+ self.args = args
29
+ # for single image retargeting
30
+ self.start_prepare = False
31
+ self.f_s_user = None
32
+ self.x_c_s_info_user = None
33
+ self.x_s_user = None
34
+ self.source_lmk_user = None
35
+ self.mask_ori = None
36
+ self.img_rgb = None
37
+ self.crop_M_c2o = None
38
+
39
+
40
+ def execute_video(
41
+ self,
42
+ input_image_path,
43
+ input_video_path,
44
+ flag_relative_input,
45
+ flag_do_crop_input,
46
+ flag_remap_input,
47
+ ):
48
+ """ for video driven potrait animation
49
+ """
50
+ if input_image_path is not None and input_video_path is not None:
51
+ args_user = {
52
+ 'source_image': input_image_path,
53
+ 'driving_info': input_video_path,
54
+ 'flag_relative': flag_relative_input,
55
+ 'flag_do_crop': flag_do_crop_input,
56
+ 'flag_pasteback': flag_remap_input,
57
+ }
58
+ # update config from user input
59
+ self.args = update_args(self.args, args_user)
60
+ self.live_portrait_wrapper.update_config(self.args.__dict__)
61
+ self.cropper.update_config(self.args.__dict__)
62
+ # video driven animation
63
+ video_path, video_path_concat = self.execute(self.args)
64
+ gr.Info("Run successfully!", duration=2)
65
+ return video_path, video_path_concat,
66
+ else:
67
+ raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)
68
+
69
+ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
70
+ """ for single image retargeting
71
+ """
72
+ if input_eye_ratio is None or input_eye_ratio is None:
73
+ raise gr.Error("Invalid ratio input 💥!", duration=5)
74
+ elif self.f_s_user is None:
75
+ if self.start_prepare:
76
+ raise gr.Error(
77
+ "The source portrait is under processing 💥! Please wait for a second.",
78
+ duration=5
79
+ )
80
+ else:
81
+ raise gr.Error(
82
+ "The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
83
+ duration=5
84
+ )
85
+ else:
86
+ # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
87
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
88
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
89
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
90
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
91
+ lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
92
+ num_kp = self.x_s_user.shape[1]
93
+ # default: use x_s
94
+ x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
95
+ # D(W(f_s; x_s, x′_d))
96
+ out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
97
+ out = self.live_portrait_wrapper.parse_output(out['out'])[0]
98
+ out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
99
+ gr.Info("Run successfully!", duration=2)
100
+ return out, out_to_ori_blend
101
+
102
+
103
+ def prepare_retargeting(self, input_image_path, flag_do_crop = True):
104
+ """ for single image retargeting
105
+ """
106
+ if input_image_path is not None:
107
+ gr.Info("Upload successfully!", duration=2)
108
+ self.start_prepare = True
109
+ inference_cfg = self.live_portrait_wrapper.cfg
110
+ ######## process source portrait ########
111
+ img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
112
+ log(f"Load source image from {input_image_path}.")
113
+ crop_info = self.cropper.crop_single_image(img_rgb)
114
+ if flag_do_crop:
115
+ I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
116
+ else:
117
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
118
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
119
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
120
+ ############################################
121
+
122
+ # record global info for next time use
123
+ self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
124
+ self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
125
+ self.x_s_info_user = x_s_info
126
+ self.source_lmk_user = crop_info['lmk_crop']
127
+ self.img_rgb = img_rgb
128
+ self.crop_M_c2o = crop_info['M_c2o']
129
+ self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
130
+ # update slider
131
+ eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
132
+ eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
133
+ lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
134
+ lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
135
+ # for vis
136
+ self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
137
+ return eye_close_ratio, lip_close_ratio, self.I_s_vis
138
+ else:
139
+ # when press the clear button, go here
140
+ return 0.8, 0.8, self.I_s_vis
liveportrait/live_portrait_pipeline.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Pipeline of LivePortrait
5
+ """
6
+
7
+ # TODO:
8
+ # 1. 当前假定所有的模板都是已经裁好的,需要修改下
9
+ # 2. pick样例图 source + driving
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import pickle
14
+ import os.path as osp
15
+ from rich.progress import track
16
+
17
+ from .config.argument_config import ArgumentConfig
18
+ from .config.inference_config import InferenceConfig
19
+ from .config.crop_config import CropConfig
20
+ from .utils.cropper import Cropper
21
+ from .utils.camera import get_rotation_matrix
22
+ from .utils.video import images2video, concat_frames
23
+ from .utils.crop import _transform_img, prepare_paste_back, paste_back
24
+ from .utils.retargeting_utils import calc_lip_close_ratio
25
+ from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
26
+ from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
27
+ from .utils.rprint import rlog as log
28
+ from .live_portrait_wrapper import LivePortraitWrapper
29
+
30
+
31
+ def make_abs_path(fn):
32
+ return osp.join(osp.dirname(osp.realpath(__file__)), fn)
33
+
34
+
35
+ class LivePortraitPipeline(object):
36
+
37
+ def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):
38
+ self.live_portrait_wrapper: LivePortraitWrapper = LivePortraitWrapper(cfg=inference_cfg)
39
+ self.cropper = Cropper(crop_cfg=crop_cfg)
40
+
41
+ def execute(self, args: ArgumentConfig):
42
+ inference_cfg = self.live_portrait_wrapper.cfg # for convenience
43
+ ######## process source portrait ########
44
+ img_rgb = load_image_rgb(args.source_image)
45
+ img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
46
+ log(f"Load source image from {args.source_image}")
47
+ crop_info = self.cropper.crop_single_image(img_rgb)
48
+ source_lmk = crop_info['lmk_crop']
49
+ img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
50
+ if inference_cfg.flag_do_crop:
51
+ I_s = self.live_portrait_wrapper.prepare_source(img_crop_256x256)
52
+ else:
53
+ I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
54
+ x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
55
+ x_c_s = x_s_info['kp']
56
+ R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
57
+ f_s = self.live_portrait_wrapper.extract_feature_3d(I_s)
58
+ x_s = self.live_portrait_wrapper.transform_keypoint(x_s_info)
59
+
60
+ if inference_cfg.flag_lip_zero:
61
+ # let lip-open scalar to be 0 at first
62
+ c_d_lip_before_animation = [0.]
63
+ combined_lip_ratio_tensor_before_animation = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_before_animation, source_lmk)
64
+ if combined_lip_ratio_tensor_before_animation[0][0] < inference_cfg.lip_zero_threshold:
65
+ inference_cfg.flag_lip_zero = False
66
+ else:
67
+ lip_delta_before_animation = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor_before_animation)
68
+ ############################################
69
+
70
+ ######## process driving info ########
71
+ if is_video(args.driving_info):
72
+ log(f"Load from video file (mp4 mov avi etc...): {args.driving_info}")
73
+ # TODO: 这里track一下驱动视频 -> 构建模板
74
+ driving_rgb_lst = load_driving_info(args.driving_info)
75
+ driving_rgb_lst_256 = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst]
76
+ I_d_lst = self.live_portrait_wrapper.prepare_driving_videos(driving_rgb_lst_256)
77
+ n_frames = I_d_lst.shape[0]
78
+ if inference_cfg.flag_eye_retargeting or inference_cfg.flag_lip_retargeting:
79
+ driving_lmk_lst = self.cropper.get_retargeting_lmk_info(driving_rgb_lst)
80
+ input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
81
+ elif is_template(args.driving_info):
82
+ log(f"Load from video templates {args.driving_info}")
83
+ with open(args.driving_info, 'rb') as f:
84
+ template_lst, driving_lmk_lst = pickle.load(f)
85
+ n_frames = template_lst[0]['n_frames']
86
+ input_eye_ratio_lst, input_lip_ratio_lst = self.live_portrait_wrapper.calc_retargeting_ratio(source_lmk, driving_lmk_lst)
87
+ else:
88
+ raise Exception("Unsupported driving types!")
89
+ #########################################
90
+
91
+ ######## prepare for pasteback ########
92
+ if inference_cfg.flag_pasteback:
93
+ mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
94
+ I_p_paste_lst = []
95
+ #########################################
96
+
97
+ I_p_lst = []
98
+ R_d_0, x_d_0_info = None, None
99
+ for i in track(range(n_frames), description='Animating...', total=n_frames):
100
+ if is_video(args.driving_info):
101
+ # extract kp info by M
102
+ I_d_i = I_d_lst[i]
103
+ x_d_i_info = self.live_portrait_wrapper.get_kp_info(I_d_i)
104
+ R_d_i = get_rotation_matrix(x_d_i_info['pitch'], x_d_i_info['yaw'], x_d_i_info['roll'])
105
+ else:
106
+ # from template
107
+ x_d_i_info = template_lst[i]
108
+ x_d_i_info = dct2cuda(x_d_i_info, inference_cfg.device_id)
109
+ R_d_i = x_d_i_info['R_d']
110
+
111
+ if i == 0:
112
+ R_d_0 = R_d_i
113
+ x_d_0_info = x_d_i_info
114
+
115
+ if inference_cfg.flag_relative:
116
+ R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
117
+ delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
118
+ scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
119
+ t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
120
+ else:
121
+ R_new = R_d_i
122
+ delta_new = x_d_i_info['exp']
123
+ scale_new = x_s_info['scale']
124
+ t_new = x_d_i_info['t']
125
+
126
+ t_new[..., 2].fill_(0) # zero tz
127
+ x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
128
+
129
+ # Algorithm 1:
130
+ if not inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
131
+ # without stitching or retargeting
132
+ if inference_cfg.flag_lip_zero:
133
+ x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
134
+ else:
135
+ pass
136
+ elif inference_cfg.flag_stitching and not inference_cfg.flag_eye_retargeting and not inference_cfg.flag_lip_retargeting:
137
+ # with stitching and without retargeting
138
+ if inference_cfg.flag_lip_zero:
139
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
140
+ else:
141
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
142
+ else:
143
+ eyes_delta, lip_delta = None, None
144
+ if inference_cfg.flag_eye_retargeting:
145
+ c_d_eyes_i = input_eye_ratio_lst[i]
146
+ combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio(c_d_eyes_i, source_lmk)
147
+ # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
148
+ eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s, combined_eye_ratio_tensor)
149
+ if inference_cfg.flag_lip_retargeting:
150
+ c_d_lip_i = input_lip_ratio_lst[i]
151
+ combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
152
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
153
+ lip_delta = self.live_portrait_wrapper.retarget_lip(x_s, combined_lip_ratio_tensor)
154
+
155
+ if inference_cfg.flag_relative: # use x_s
156
+ x_d_i_new = x_s + \
157
+ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
158
+ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
159
+ else: # use x_d,i
160
+ x_d_i_new = x_d_i_new + \
161
+ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
162
+ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
163
+
164
+ if inference_cfg.flag_stitching:
165
+ x_d_i_new = self.live_portrait_wrapper.stitching(x_s, x_d_i_new)
166
+
167
+ out = self.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
168
+ I_p_i = self.live_portrait_wrapper.parse_output(out['out'])[0]
169
+ I_p_lst.append(I_p_i)
170
+
171
+ if inference_cfg.flag_pasteback:
172
+ I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
173
+ I_p_paste_lst.append(I_p_i_to_ori_blend)
174
+
175
+ mkdir(args.output_dir)
176
+ wfp_concat = None
177
+
178
+ # note by @jbilcke-hf:
179
+ # I have disabled this block, since we don't need to debug it
180
+ #if is_video(args.driving_info):
181
+ # frames_concatenated = concat_frames(I_p_lst, driving_rgb_lst, img_crop_256x256)
182
+ # # save (driving frames, source image, drived frames) result
183
+ # wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
184
+ # images2video(frames_concatenated, wfp=wfp_concat)#
185
+
186
+ # save drived result
187
+ wfp = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}.mp4')
188
+ if inference_cfg.flag_pasteback:
189
+ images2video(I_p_paste_lst, wfp=wfp)
190
+ else:
191
+ images2video(I_p_lst, wfp=wfp)
192
+
193
+ return wfp, wfp_concat
liveportrait/live_portrait_wrapper.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Wrapper for LivePortrait core functions
5
+ """
6
+
7
+ import os.path as osp
8
+ import numpy as np
9
+ import cv2
10
+ import torch
11
+ import yaml
12
+
13
+ from .utils.timer import Timer
14
+ from .utils.helper import load_model, concat_feat
15
+ from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
16
+ from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
17
+ from .config.inference_config import InferenceConfig
18
+ from .utils.rprint import rlog as log
19
+
20
+
21
+ class LivePortraitWrapper(object):
22
+
23
+ def __init__(self, cfg: InferenceConfig):
24
+
25
+ model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
26
+
27
+ # init F
28
+ self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
29
+ #log(f'Load appearance_feature_extractor done.')
30
+ # init M
31
+ self.motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
32
+ #log(f'Load motion_extractor done.')
33
+ # init W
34
+ self.warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
35
+ #log(f'Load warping_module done.')
36
+ # init G
37
+ self.spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
38
+ #log(f'Load spade_generator done.')
39
+ # init S and R
40
+ if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
41
+ self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
42
+ #log(f'Load stitching_retargeting_module done.')
43
+ else:
44
+ self.stitching_retargeting_module = None
45
+
46
+ self.cfg = cfg
47
+ self.device_id = cfg.device_id
48
+ self.timer = Timer()
49
+
50
+ def update_config(self, user_args):
51
+ for k, v in user_args.items():
52
+ if hasattr(self.cfg, k):
53
+ setattr(self.cfg, k, v)
54
+
55
+ def prepare_source(self, img: np.ndarray) -> torch.Tensor:
56
+ """ construct the input as standard
57
+ img: HxWx3, uint8, 256x256
58
+ """
59
+ h, w = img.shape[:2]
60
+ if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
61
+ x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
62
+ else:
63
+ x = img.copy()
64
+
65
+ if x.ndim == 3:
66
+ x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
67
+ elif x.ndim == 4:
68
+ x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
69
+ else:
70
+ raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
71
+ x = np.clip(x, 0, 1) # clip to 0~1
72
+ x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
73
+ x = x.cuda(self.device_id)
74
+ return x
75
+
76
+ def prepare_driving_videos(self, imgs) -> torch.Tensor:
77
+ """ construct the input as standard
78
+ imgs: NxBxHxWx3, uint8
79
+ """
80
+ if isinstance(imgs, list):
81
+ _imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
82
+ elif isinstance(imgs, np.ndarray):
83
+ _imgs = imgs
84
+ else:
85
+ raise ValueError(f'imgs type error: {type(imgs)}')
86
+
87
+ y = _imgs.astype(np.float32) / 255.
88
+ y = np.clip(y, 0, 1) # clip to 0~1
89
+ y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
90
+ y = y.cuda(self.device_id)
91
+
92
+ return y
93
+
94
+ def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
95
+ """ get the appearance feature of the image by F
96
+ x: Bx3xHxW, normalized to 0~1
97
+ """
98
+ with torch.no_grad():
99
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
100
+ feature_3d = self.appearance_feature_extractor(x)
101
+
102
+ return feature_3d.float()
103
+
104
+ def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
105
+ """ get the implicit keypoint information
106
+ x: Bx3xHxW, normalized to 0~1
107
+ flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
108
+ return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
109
+ """
110
+ with torch.no_grad():
111
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
112
+ kp_info = self.motion_extractor(x)
113
+
114
+ if self.cfg.flag_use_half_precision:
115
+ # float the dict
116
+ for k, v in kp_info.items():
117
+ if isinstance(v, torch.Tensor):
118
+ kp_info[k] = v.float()
119
+
120
+ flag_refine_info: bool = kwargs.get('flag_refine_info', True)
121
+ if flag_refine_info:
122
+ bs = kp_info['kp'].shape[0]
123
+ kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
124
+ kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
125
+ kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
126
+ kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
127
+ kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
128
+
129
+ return kp_info
130
+
131
+ def get_pose_dct(self, kp_info: dict) -> dict:
132
+ pose_dct = dict(
133
+ pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
134
+ yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
135
+ roll=headpose_pred_to_degree(kp_info['roll']).item(),
136
+ )
137
+ return pose_dct
138
+
139
+ def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
140
+
141
+ # get the canonical keypoints of source image by M
142
+ source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
143
+ source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
144
+
145
+ # get the canonical keypoints of first driving frame by M
146
+ driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
147
+ driving_first_frame_rotation = get_rotation_matrix(
148
+ driving_first_frame_kp_info['pitch'],
149
+ driving_first_frame_kp_info['yaw'],
150
+ driving_first_frame_kp_info['roll']
151
+ )
152
+
153
+ # get feature volume by F
154
+ source_feature_3d = self.extract_feature_3d(source_prepared)
155
+
156
+ return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
157
+
158
+ def transform_keypoint(self, kp_info: dict):
159
+ """
160
+ transform the implicit keypoints with the pose, shift, and expression deformation
161
+ kp: BxNx3
162
+ """
163
+ kp = kp_info['kp'] # (bs, k, 3)
164
+ pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
165
+
166
+ t, exp = kp_info['t'], kp_info['exp']
167
+ scale = kp_info['scale']
168
+
169
+ pitch = headpose_pred_to_degree(pitch)
170
+ yaw = headpose_pred_to_degree(yaw)
171
+ roll = headpose_pred_to_degree(roll)
172
+
173
+ bs = kp.shape[0]
174
+ if kp.ndim == 2:
175
+ num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
176
+ else:
177
+ num_kp = kp.shape[1] # Bxnum_kpx3
178
+
179
+ rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
180
+
181
+ # Eqn.2: s * (R * x_c,s + exp) + t
182
+ kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
183
+ kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
184
+ kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
185
+
186
+ return kp_transformed
187
+
188
+ def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
189
+ """
190
+ kp_source: BxNx3
191
+ eye_close_ratio: Bx3
192
+ Return: Bx(3*num_kp+2)
193
+ """
194
+ feat_eye = concat_feat(kp_source, eye_close_ratio)
195
+
196
+ with torch.no_grad():
197
+ delta = self.stitching_retargeting_module['eye'](feat_eye)
198
+
199
+ return delta
200
+
201
+ def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
202
+ """
203
+ kp_source: BxNx3
204
+ lip_close_ratio: Bx2
205
+ """
206
+ feat_lip = concat_feat(kp_source, lip_close_ratio)
207
+
208
+ with torch.no_grad():
209
+ delta = self.stitching_retargeting_module['lip'](feat_lip)
210
+
211
+ return delta
212
+
213
+ def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
214
+ """
215
+ kp_source: BxNx3
216
+ kp_driving: BxNx3
217
+ Return: Bx(3*num_kp+2)
218
+ """
219
+ feat_stiching = concat_feat(kp_source, kp_driving)
220
+
221
+ with torch.no_grad():
222
+ delta = self.stitching_retargeting_module['stitching'](feat_stiching)
223
+
224
+ return delta
225
+
226
+ def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
227
+ """ conduct the stitching
228
+ kp_source: Bxnum_kpx3
229
+ kp_driving: Bxnum_kpx3
230
+ """
231
+
232
+ if self.stitching_retargeting_module is not None:
233
+
234
+ bs, num_kp = kp_source.shape[:2]
235
+
236
+ kp_driving_new = kp_driving.clone()
237
+ delta = self.stitch(kp_source, kp_driving_new)
238
+
239
+ delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
240
+ delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
241
+
242
+ kp_driving_new += delta_exp
243
+ kp_driving_new[..., :2] += delta_tx_ty
244
+
245
+ return kp_driving_new
246
+
247
+ return kp_driving
248
+
249
+ def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
250
+ """ get the image after the warping of the implicit keypoints
251
+ feature_3d: Bx32x16x64x64, feature volume
252
+ kp_source: BxNx3
253
+ kp_driving: BxNx3
254
+ """
255
+ # The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
256
+ with torch.no_grad():
257
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.cfg.flag_use_half_precision):
258
+ # get decoder input
259
+ ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
260
+ # decode
261
+ ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
262
+
263
+ # float the dict
264
+ if self.cfg.flag_use_half_precision:
265
+ for k, v in ret_dct.items():
266
+ if isinstance(v, torch.Tensor):
267
+ ret_dct[k] = v.float()
268
+
269
+ return ret_dct
270
+
271
+ def parse_output(self, out: torch.Tensor) -> np.ndarray:
272
+ """ construct the output as standard
273
+ return: 1xHxWx3, uint8
274
+ """
275
+ out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
276
+ out = np.clip(out, 0, 1) # clip to 0~1
277
+ out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
278
+
279
+ return out
280
+
281
+ def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
282
+ input_eye_ratio_lst = []
283
+ input_lip_ratio_lst = []
284
+ for lmk in driving_lmk_lst:
285
+ # for eyes retargeting
286
+ input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
287
+ # for lip retargeting
288
+ input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
289
+ return input_eye_ratio_lst, input_lip_ratio_lst
290
+
291
+ def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
292
+ eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
293
+ eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(self.device_id)
294
+ input_eye_ratio_tensor = torch.Tensor([input_eye_ratio[0][0]]).reshape(1, 1).cuda(self.device_id)
295
+ # [c_s,eyes, c_d,eyes,i]
296
+ combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
297
+ return combined_eye_ratio_tensor
298
+
299
+ def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
300
+ lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
301
+ lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(self.device_id)
302
+ # [c_s,lip, c_d,lip,i]
303
+ input_lip_ratio_tensor = torch.Tensor([input_lip_ratio[0]]).cuda(self.device_id)
304
+ if input_lip_ratio_tensor.shape != [1, 1]:
305
+ input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
306
+ combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
307
+ return combined_lip_ratio_tensor
liveportrait/modules/__init__.py ADDED
File without changes
liveportrait/modules/appearance_feature_extractor.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
5
+ """
6
+
7
+ import torch
8
+ from torch import nn
9
+ from .util import SameBlock2d, DownBlock2d, ResBlock3d
10
+
11
+
12
+ class AppearanceFeatureExtractor(nn.Module):
13
+
14
+ def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks):
15
+ super(AppearanceFeatureExtractor, self).__init__()
16
+ self.image_channel = image_channel
17
+ self.block_expansion = block_expansion
18
+ self.num_down_blocks = num_down_blocks
19
+ self.max_features = max_features
20
+ self.reshape_channel = reshape_channel
21
+ self.reshape_depth = reshape_depth
22
+
23
+ self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
24
+
25
+ down_blocks = []
26
+ for i in range(num_down_blocks):
27
+ in_features = min(max_features, block_expansion * (2 ** i))
28
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
29
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
30
+ self.down_blocks = nn.ModuleList(down_blocks)
31
+
32
+ self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
33
+
34
+ self.resblocks_3d = torch.nn.Sequential()
35
+ for i in range(num_resblocks):
36
+ self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
37
+
38
+ def forward(self, source_image):
39
+ out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
40
+
41
+ for i in range(len(self.down_blocks)):
42
+ out = self.down_blocks[i](out)
43
+ out = self.second(out)
44
+ bs, c, h, w = out.shape # ->Bx512x64x64
45
+
46
+ f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64
47
+ f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
48
+ return f_s
liveportrait/modules/convnextv2.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ # from timm.models.layers import trunc_normal_, DropPath
10
+ from .util import LayerNorm, DropPath, trunc_normal_, GRN
11
+
12
+ __all__ = ['convnextv2_tiny']
13
+
14
+
15
+ class Block(nn.Module):
16
+ """ ConvNeXtV2 Block.
17
+
18
+ Args:
19
+ dim (int): Number of input channels.
20
+ drop_path (float): Stochastic depth rate. Default: 0.0
21
+ """
22
+
23
+ def __init__(self, dim, drop_path=0.):
24
+ super().__init__()
25
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
26
+ self.norm = LayerNorm(dim, eps=1e-6)
27
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
28
+ self.act = nn.GELU()
29
+ self.grn = GRN(4 * dim)
30
+ self.pwconv2 = nn.Linear(4 * dim, dim)
31
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
32
+
33
+ def forward(self, x):
34
+ input = x
35
+ x = self.dwconv(x)
36
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
37
+ x = self.norm(x)
38
+ x = self.pwconv1(x)
39
+ x = self.act(x)
40
+ x = self.grn(x)
41
+ x = self.pwconv2(x)
42
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
43
+
44
+ x = input + self.drop_path(x)
45
+ return x
46
+
47
+
48
+ class ConvNeXtV2(nn.Module):
49
+ """ ConvNeXt V2
50
+
51
+ Args:
52
+ in_chans (int): Number of input image channels. Default: 3
53
+ num_classes (int): Number of classes for classification head. Default: 1000
54
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
55
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
56
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
57
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ in_chans=3,
63
+ depths=[3, 3, 9, 3],
64
+ dims=[96, 192, 384, 768],
65
+ drop_path_rate=0.,
66
+ **kwargs
67
+ ):
68
+ super().__init__()
69
+ self.depths = depths
70
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
71
+ stem = nn.Sequential(
72
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
73
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
74
+ )
75
+ self.downsample_layers.append(stem)
76
+ for i in range(3):
77
+ downsample_layer = nn.Sequential(
78
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
79
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
80
+ )
81
+ self.downsample_layers.append(downsample_layer)
82
+
83
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
84
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
85
+ cur = 0
86
+ for i in range(4):
87
+ stage = nn.Sequential(
88
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
89
+ )
90
+ self.stages.append(stage)
91
+ cur += depths[i]
92
+
93
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
94
+
95
+ # NOTE: the output semantic items
96
+ num_bins = kwargs.get('num_bins', 66)
97
+ num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
98
+ self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
99
+
100
+ # print('dims[-1]: ', dims[-1])
101
+ self.fc_scale = nn.Linear(dims[-1], 1) # scale
102
+ self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
103
+ self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
104
+ self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
105
+ self.fc_t = nn.Linear(dims[-1], 3) # translation
106
+ self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
107
+
108
+ def _init_weights(self, m):
109
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
110
+ trunc_normal_(m.weight, std=.02)
111
+ nn.init.constant_(m.bias, 0)
112
+
113
+ def forward_features(self, x):
114
+ for i in range(4):
115
+ x = self.downsample_layers[i](x)
116
+ x = self.stages[i](x)
117
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
118
+
119
+ def forward(self, x):
120
+ x = self.forward_features(x)
121
+
122
+ # implicit keypoints
123
+ kp = self.fc_kp(x)
124
+
125
+ # pose and expression deformation
126
+ pitch = self.fc_pitch(x)
127
+ yaw = self.fc_yaw(x)
128
+ roll = self.fc_roll(x)
129
+ t = self.fc_t(x)
130
+ exp = self.fc_exp(x)
131
+ scale = self.fc_scale(x)
132
+
133
+ ret_dct = {
134
+ 'pitch': pitch,
135
+ 'yaw': yaw,
136
+ 'roll': roll,
137
+ 't': t,
138
+ 'exp': exp,
139
+ 'scale': scale,
140
+
141
+ 'kp': kp, # canonical keypoint
142
+ }
143
+
144
+ return ret_dct
145
+
146
+
147
+ def convnextv2_tiny(**kwargs):
148
+ model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
149
+ return model
liveportrait/modules/dense_motion.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
5
+ """
6
+
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torch
10
+ from .util import Hourglass, make_coordinate_grid, kp2gaussian
11
+
12
+
13
+ class DenseMotionNetwork(nn.Module):
14
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
15
+ super(DenseMotionNetwork, self).__init__()
16
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
17
+
18
+ self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
19
+ self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
20
+ self.norm = nn.BatchNorm3d(compress, affine=True)
21
+ self.num_kp = num_kp
22
+ self.flag_estimate_occlusion_map = estimate_occlusion_map
23
+
24
+ if self.flag_estimate_occlusion_map:
25
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
26
+ else:
27
+ self.occlusion = None
28
+
29
+ def create_sparse_motions(self, feature, kp_driving, kp_source):
30
+ bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
31
+ identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
32
+ identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
33
+ coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
34
+
35
+ k = coordinate_grid.shape[1]
36
+
37
+ # NOTE: there lacks an one-order flow
38
+ driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
39
+
40
+ # adding background feature
41
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
42
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
43
+ return sparse_motions
44
+
45
+ def create_deformed_feature(self, feature, sparse_motions):
46
+ bs, _, d, h, w = feature.shape
47
+ feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
48
+ feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
49
+ sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
50
+ sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
51
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
52
+
53
+ return sparse_deformed
54
+
55
+ def create_heatmap_representations(self, feature, kp_driving, kp_source):
56
+ spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
57
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
58
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
59
+ heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
60
+
61
+ # adding background feature
62
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()).to(heatmap.device)
63
+ heatmap = torch.cat([zeros, heatmap], dim=1)
64
+ heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
65
+ return heatmap
66
+
67
+ def forward(self, feature, kp_driving, kp_source):
68
+ bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
69
+
70
+ feature = self.compress(feature) # (bs, 4, 16, 64, 64)
71
+ feature = self.norm(feature) # (bs, 4, 16, 64, 64)
72
+ feature = F.relu(feature) # (bs, 4, 16, 64, 64)
73
+
74
+ out_dict = dict()
75
+
76
+ # 1. deform 3d feature
77
+ sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
78
+ deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
79
+
80
+ # 2. (bs, 1+num_kp, d, h, w)
81
+ heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
82
+
83
+ input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
84
+ input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
85
+
86
+ prediction = self.hourglass(input)
87
+
88
+ mask = self.mask(prediction)
89
+ mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
90
+ out_dict['mask'] = mask
91
+ mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
92
+ sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
93
+ deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
94
+ deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
95
+
96
+ out_dict['deformation'] = deformation
97
+
98
+ if self.flag_estimate_occlusion_map:
99
+ bs, _, d, h, w = prediction.shape
100
+ prediction_reshape = prediction.view(bs, -1, h, w)
101
+ occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
102
+ out_dict['occlusion_map'] = occlusion_map
103
+
104
+ return out_dict
liveportrait/modules/motion_extractor.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
5
+ """
6
+
7
+ from torch import nn
8
+ import torch
9
+
10
+ from .convnextv2 import convnextv2_tiny
11
+ from .util import filter_state_dict
12
+
13
+ model_dict = {
14
+ 'convnextv2_tiny': convnextv2_tiny,
15
+ }
16
+
17
+
18
+ class MotionExtractor(nn.Module):
19
+ def __init__(self, **kwargs):
20
+ super(MotionExtractor, self).__init__()
21
+
22
+ # default is convnextv2_base
23
+ backbone = kwargs.get('backbone', 'convnextv2_tiny')
24
+ self.detector = model_dict.get(backbone)(**kwargs)
25
+
26
+ def load_pretrained(self, init_path: str):
27
+ if init_path not in (None, ''):
28
+ state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model']
29
+ state_dict = filter_state_dict(state_dict, remove_name='head')
30
+ ret = self.detector.load_state_dict(state_dict, strict=False)
31
+ print(f'Load pretrained model from {init_path}, ret: {ret}')
32
+
33
+ def forward(self, x):
34
+ out = self.detector(x)
35
+ return out
liveportrait/modules/spade_generator.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
5
+ """
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from .util import SPADEResnetBlock
11
+
12
+
13
+ class SPADEDecoder(nn.Module):
14
+ def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2):
15
+ for i in range(num_down_blocks):
16
+ input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
17
+ self.upscale = upscale
18
+ super().__init__()
19
+ norm_G = 'spadespectralinstance'
20
+ label_num_channels = input_channels # 256
21
+
22
+ self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
23
+ self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
24
+ self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
25
+ self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
26
+ self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
27
+ self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
28
+ self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
29
+ self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels)
30
+ self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels)
31
+ self.up = nn.Upsample(scale_factor=2)
32
+
33
+ if self.upscale is None or self.upscale <= 1:
34
+ self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
35
+ else:
36
+ self.conv_img = nn.Sequential(
37
+ nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
38
+ nn.PixelShuffle(upscale_factor=2)
39
+ )
40
+
41
+ def forward(self, feature):
42
+ seg = feature # Bx256x64x64
43
+ x = self.fc(feature) # Bx512x64x64
44
+ x = self.G_middle_0(x, seg)
45
+ x = self.G_middle_1(x, seg)
46
+ x = self.G_middle_2(x, seg)
47
+ x = self.G_middle_3(x, seg)
48
+ x = self.G_middle_4(x, seg)
49
+ x = self.G_middle_5(x, seg)
50
+
51
+ x = self.up(x) # Bx512x64x64 -> Bx512x128x128
52
+ x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
53
+ x = self.up(x) # Bx256x128x128 -> Bx256x256x256
54
+ x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
55
+
56
+ x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
57
+ x = torch.sigmoid(x) # Bx3xHxW
58
+
59
+ return x
liveportrait/modules/stitching_retargeting_network.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Stitching module(S) and two retargeting modules(R) defined in the paper.
5
+
6
+ - The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
7
+ the stitching region.
8
+
9
+ - The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
10
+ when a person with small eyes drives a person with larger eyes.
11
+
12
+ - The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
13
+ the lips are in a closed state, which facilitates better animation driving.
14
+ """
15
+ from torch import nn
16
+
17
+
18
+ class StitchingRetargetingNetwork(nn.Module):
19
+ def __init__(self, input_size, hidden_sizes, output_size):
20
+ super(StitchingRetargetingNetwork, self).__init__()
21
+ layers = []
22
+ for i in range(len(hidden_sizes)):
23
+ if i == 0:
24
+ layers.append(nn.Linear(input_size, hidden_sizes[i]))
25
+ else:
26
+ layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
27
+ layers.append(nn.ReLU(inplace=True))
28
+ layers.append(nn.Linear(hidden_sizes[-1], output_size))
29
+ self.mlp = nn.Sequential(*layers)
30
+
31
+ def initialize_weights_to_zero(self):
32
+ for m in self.modules():
33
+ if isinstance(m, nn.Linear):
34
+ nn.init.zeros_(m.weight)
35
+ nn.init.zeros_(m.bias)
36
+
37
+ def forward(self, x):
38
+ return self.mlp(x)
liveportrait/modules/util.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ This file defines various neural network modules and utility functions, including convolutional and residual blocks,
5
+ normalizations, and functions for spatial transformation and tensor manipulation.
6
+ """
7
+
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torch
11
+ import torch.nn.utils.spectral_norm as spectral_norm
12
+ import math
13
+ import warnings
14
+
15
+
16
+ def kp2gaussian(kp, spatial_size, kp_variance):
17
+ """
18
+ Transform a keypoint into gaussian like representation
19
+ """
20
+ mean = kp
21
+
22
+ coordinate_grid = make_coordinate_grid(spatial_size, mean)
23
+ number_of_leading_dimensions = len(mean.shape) - 1
24
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
25
+ coordinate_grid = coordinate_grid.view(*shape)
26
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
27
+ coordinate_grid = coordinate_grid.repeat(*repeats)
28
+
29
+ # Preprocess kp shape
30
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
31
+ mean = mean.view(*shape)
32
+
33
+ mean_sub = (coordinate_grid - mean)
34
+
35
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
36
+
37
+ return out
38
+
39
+
40
+ def make_coordinate_grid(spatial_size, ref, **kwargs):
41
+ d, h, w = spatial_size
42
+ x = torch.arange(w).type(ref.dtype).to(ref.device)
43
+ y = torch.arange(h).type(ref.dtype).to(ref.device)
44
+ z = torch.arange(d).type(ref.dtype).to(ref.device)
45
+
46
+ # NOTE: must be right-down-in
47
+ x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
48
+ y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
49
+ z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
50
+
51
+ yy = y.view(1, -1, 1).repeat(d, 1, w)
52
+ xx = x.view(1, 1, -1).repeat(d, h, 1)
53
+ zz = z.view(-1, 1, 1).repeat(1, h, w)
54
+
55
+ meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
56
+
57
+ return meshed
58
+
59
+
60
+ class ConvT2d(nn.Module):
61
+ """
62
+ Upsampling block for use in decoder.
63
+ """
64
+
65
+ def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
66
+ super(ConvT2d, self).__init__()
67
+
68
+ self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
69
+ padding=padding, output_padding=output_padding)
70
+ self.norm = nn.InstanceNorm2d(out_features)
71
+
72
+ def forward(self, x):
73
+ out = self.convT(x)
74
+ out = self.norm(out)
75
+ out = F.leaky_relu(out)
76
+ return out
77
+
78
+
79
+ class ResBlock3d(nn.Module):
80
+ """
81
+ Res block, preserve spatial resolution.
82
+ """
83
+
84
+ def __init__(self, in_features, kernel_size, padding):
85
+ super(ResBlock3d, self).__init__()
86
+ self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
87
+ self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
88
+ self.norm1 = nn.BatchNorm3d(in_features, affine=True)
89
+ self.norm2 = nn.BatchNorm3d(in_features, affine=True)
90
+
91
+ def forward(self, x):
92
+ out = self.norm1(x)
93
+ out = F.relu(out)
94
+ out = self.conv1(out)
95
+ out = self.norm2(out)
96
+ out = F.relu(out)
97
+ out = self.conv2(out)
98
+ out += x
99
+ return out
100
+
101
+
102
+ class UpBlock3d(nn.Module):
103
+ """
104
+ Upsampling block for use in decoder.
105
+ """
106
+
107
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
108
+ super(UpBlock3d, self).__init__()
109
+
110
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
111
+ padding=padding, groups=groups)
112
+ self.norm = nn.BatchNorm3d(out_features, affine=True)
113
+
114
+ def forward(self, x):
115
+ out = F.interpolate(x, scale_factor=(1, 2, 2))
116
+ out = self.conv(out)
117
+ out = self.norm(out)
118
+ out = F.relu(out)
119
+ return out
120
+
121
+
122
+ class DownBlock2d(nn.Module):
123
+ """
124
+ Downsampling block for use in encoder.
125
+ """
126
+
127
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
128
+ super(DownBlock2d, self).__init__()
129
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
130
+ self.norm = nn.BatchNorm2d(out_features, affine=True)
131
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
132
+
133
+ def forward(self, x):
134
+ out = self.conv(x)
135
+ out = self.norm(out)
136
+ out = F.relu(out)
137
+ out = self.pool(out)
138
+ return out
139
+
140
+
141
+ class DownBlock3d(nn.Module):
142
+ """
143
+ Downsampling block for use in encoder.
144
+ """
145
+
146
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
147
+ super(DownBlock3d, self).__init__()
148
+ '''
149
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
150
+ padding=padding, groups=groups, stride=(1, 2, 2))
151
+ '''
152
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
153
+ padding=padding, groups=groups)
154
+ self.norm = nn.BatchNorm3d(out_features, affine=True)
155
+ self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
156
+
157
+ def forward(self, x):
158
+ out = self.conv(x)
159
+ out = self.norm(out)
160
+ out = F.relu(out)
161
+ out = self.pool(out)
162
+ return out
163
+
164
+
165
+ class SameBlock2d(nn.Module):
166
+ """
167
+ Simple block, preserve spatial resolution.
168
+ """
169
+
170
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
171
+ super(SameBlock2d, self).__init__()
172
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
173
+ self.norm = nn.BatchNorm2d(out_features, affine=True)
174
+ if lrelu:
175
+ self.ac = nn.LeakyReLU()
176
+ else:
177
+ self.ac = nn.ReLU()
178
+
179
+ def forward(self, x):
180
+ out = self.conv(x)
181
+ out = self.norm(out)
182
+ out = self.ac(out)
183
+ return out
184
+
185
+
186
+ class Encoder(nn.Module):
187
+ """
188
+ Hourglass Encoder
189
+ """
190
+
191
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
192
+ super(Encoder, self).__init__()
193
+
194
+ down_blocks = []
195
+ for i in range(num_blocks):
196
+ down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
197
+ self.down_blocks = nn.ModuleList(down_blocks)
198
+
199
+ def forward(self, x):
200
+ outs = [x]
201
+ for down_block in self.down_blocks:
202
+ outs.append(down_block(outs[-1]))
203
+ return outs
204
+
205
+
206
+ class Decoder(nn.Module):
207
+ """
208
+ Hourglass Decoder
209
+ """
210
+
211
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
212
+ super(Decoder, self).__init__()
213
+
214
+ up_blocks = []
215
+
216
+ for i in range(num_blocks)[::-1]:
217
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
218
+ out_filters = min(max_features, block_expansion * (2 ** i))
219
+ up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
220
+
221
+ self.up_blocks = nn.ModuleList(up_blocks)
222
+ self.out_filters = block_expansion + in_features
223
+
224
+ self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
225
+ self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
226
+
227
+ def forward(self, x):
228
+ out = x.pop()
229
+ for up_block in self.up_blocks:
230
+ out = up_block(out)
231
+ skip = x.pop()
232
+ out = torch.cat([out, skip], dim=1)
233
+ out = self.conv(out)
234
+ out = self.norm(out)
235
+ out = F.relu(out)
236
+ return out
237
+
238
+
239
+ class Hourglass(nn.Module):
240
+ """
241
+ Hourglass architecture.
242
+ """
243
+
244
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
245
+ super(Hourglass, self).__init__()
246
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
247
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
248
+ self.out_filters = self.decoder.out_filters
249
+
250
+ def forward(self, x):
251
+ return self.decoder(self.encoder(x))
252
+
253
+
254
+ class SPADE(nn.Module):
255
+ def __init__(self, norm_nc, label_nc):
256
+ super().__init__()
257
+
258
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
259
+ nhidden = 128
260
+
261
+ self.mlp_shared = nn.Sequential(
262
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
263
+ nn.ReLU())
264
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
265
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
266
+
267
+ def forward(self, x, segmap):
268
+ normalized = self.param_free_norm(x)
269
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
270
+ actv = self.mlp_shared(segmap)
271
+ gamma = self.mlp_gamma(actv)
272
+ beta = self.mlp_beta(actv)
273
+ out = normalized * (1 + gamma) + beta
274
+ return out
275
+
276
+
277
+ class SPADEResnetBlock(nn.Module):
278
+ def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
279
+ super().__init__()
280
+ # Attributes
281
+ self.learned_shortcut = (fin != fout)
282
+ fmiddle = min(fin, fout)
283
+ self.use_se = use_se
284
+ # create conv layers
285
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
286
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
287
+ if self.learned_shortcut:
288
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
289
+ # apply spectral norm if specified
290
+ if 'spectral' in norm_G:
291
+ self.conv_0 = spectral_norm(self.conv_0)
292
+ self.conv_1 = spectral_norm(self.conv_1)
293
+ if self.learned_shortcut:
294
+ self.conv_s = spectral_norm(self.conv_s)
295
+ # define normalization layers
296
+ self.norm_0 = SPADE(fin, label_nc)
297
+ self.norm_1 = SPADE(fmiddle, label_nc)
298
+ if self.learned_shortcut:
299
+ self.norm_s = SPADE(fin, label_nc)
300
+
301
+ def forward(self, x, seg1):
302
+ x_s = self.shortcut(x, seg1)
303
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
304
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
305
+ out = x_s + dx
306
+ return out
307
+
308
+ def shortcut(self, x, seg1):
309
+ if self.learned_shortcut:
310
+ x_s = self.conv_s(self.norm_s(x, seg1))
311
+ else:
312
+ x_s = x
313
+ return x_s
314
+
315
+ def actvn(self, x):
316
+ return F.leaky_relu(x, 2e-1)
317
+
318
+
319
+ def filter_state_dict(state_dict, remove_name='fc'):
320
+ new_state_dict = {}
321
+ for key in state_dict:
322
+ if remove_name in key:
323
+ continue
324
+ new_state_dict[key] = state_dict[key]
325
+ return new_state_dict
326
+
327
+
328
+ class GRN(nn.Module):
329
+ """ GRN (Global Response Normalization) layer
330
+ """
331
+
332
+ def __init__(self, dim):
333
+ super().__init__()
334
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
335
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
336
+
337
+ def forward(self, x):
338
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
339
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
340
+ return self.gamma * (x * Nx) + self.beta + x
341
+
342
+
343
+ class LayerNorm(nn.Module):
344
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
345
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
346
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
347
+ with shape (batch_size, channels, height, width).
348
+ """
349
+
350
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
351
+ super().__init__()
352
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
353
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
354
+ self.eps = eps
355
+ self.data_format = data_format
356
+ if self.data_format not in ["channels_last", "channels_first"]:
357
+ raise NotImplementedError
358
+ self.normalized_shape = (normalized_shape, )
359
+
360
+ def forward(self, x):
361
+ if self.data_format == "channels_last":
362
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
363
+ elif self.data_format == "channels_first":
364
+ u = x.mean(1, keepdim=True)
365
+ s = (x - u).pow(2).mean(1, keepdim=True)
366
+ x = (x - u) / torch.sqrt(s + self.eps)
367
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
368
+ return x
369
+
370
+
371
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
372
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
373
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
374
+ def norm_cdf(x):
375
+ # Computes standard normal cumulative distribution function
376
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
377
+
378
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
379
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
380
+ "The distribution of values may be incorrect.",
381
+ stacklevel=2)
382
+
383
+ with torch.no_grad():
384
+ # Values are generated by using a truncated uniform distribution and
385
+ # then using the inverse CDF for the normal distribution.
386
+ # Get upper and lower cdf values
387
+ l = norm_cdf((a - mean) / std)
388
+ u = norm_cdf((b - mean) / std)
389
+
390
+ # Uniformly fill tensor with values from [l, u], then translate to
391
+ # [2l-1, 2u-1].
392
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
393
+
394
+ # Use inverse cdf transform for normal distribution to get truncated
395
+ # standard normal
396
+ tensor.erfinv_()
397
+
398
+ # Transform to proper mean, std
399
+ tensor.mul_(std * math.sqrt(2.))
400
+ tensor.add_(mean)
401
+
402
+ # Clamp to ensure it's in the proper range
403
+ tensor.clamp_(min=a, max=b)
404
+ return tensor
405
+
406
+
407
+ def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
408
+ """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
409
+
410
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
411
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
412
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
413
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
414
+ 'survival rate' as the argument.
415
+
416
+ """
417
+ if drop_prob == 0. or not training:
418
+ return x
419
+ keep_prob = 1 - drop_prob
420
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
421
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
422
+ if keep_prob > 0.0 and scale_by_keep:
423
+ random_tensor.div_(keep_prob)
424
+ return x * random_tensor
425
+
426
+
427
+ class DropPath(nn.Module):
428
+ """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
429
+ """
430
+
431
+ def __init__(self, drop_prob=None, scale_by_keep=True):
432
+ super(DropPath, self).__init__()
433
+ self.drop_prob = drop_prob
434
+ self.scale_by_keep = scale_by_keep
435
+
436
+ def forward(self, x):
437
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
438
+
439
+
440
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
441
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)