Spaces:
Running
on
L40S
Running
on
L40S
import logging | |
import hashlib | |
import os | |
import io | |
import asyncio | |
from async_lru import alru_cache | |
import base64 | |
from queue import Queue | |
from typing import Dict, Any, List, Optional, Union | |
from functools import lru_cache | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from liveportrait.config.argument_config import ArgumentConfig | |
from liveportrait.utils.camera import get_rotation_matrix | |
from liveportrait.utils.io import resize_to_limit | |
from liveportrait.utils.crop import prepare_paste_back, paste_back, parse_bbox_from_landmark | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Global constants | |
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') | |
MODELS_DIR = os.path.join(DATA_ROOT, "models") | |
def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image: | |
""" | |
Convert a base64 data URI to a PIL Image. | |
Args: | |
base64_string (str): The base64 encoded image data. | |
Returns: | |
Image.Image: The decoded PIL Image. | |
""" | |
if ',' in base64_string: | |
base64_string = base64_string.split(',')[1] | |
img_data = base64.b64decode(base64_string) | |
return Image.open(io.BytesIO(img_data)) | |
class Engine: | |
""" | |
The main engine class for FacePoke | |
""" | |
def __init__(self, live_portrait): | |
""" | |
Initialize the FacePoke engine with necessary models and processors. | |
Args: | |
live_portrait (LivePortraitPipeline): The LivePortrait model for video generation. | |
""" | |
self.live_portrait = live_portrait | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.processed_cache = {} # Stores the processed image data | |
logger.info("β FacePoke Engine initialized successfully.") | |
async def get_image_hash(self, image: Union[Image.Image, str, bytes]) -> str: | |
""" | |
Compute or retrieve the hash for an image. | |
Args: | |
image (Union[Image.Image, str, bytes]): The input image, either as a PIL Image, | |
base64 string, or bytes. | |
Returns: | |
str: The computed hash of the image. | |
""" | |
if isinstance(image, str): | |
# Assume it's already a hash if it's a string of the right length | |
if len(image) == 32: | |
return image | |
# Otherwise, assume it's a base64 string | |
image = base64_data_uri_to_PIL_Image(image) | |
if isinstance(image, Image.Image): | |
return hashlib.md5(image.tobytes()).hexdigest() | |
elif isinstance(image, bytes): | |
return hashlib.md5(image).hexdigest() | |
else: | |
raise ValueError("Unsupported image type") | |
async def load_image(self, data): | |
image = Image.open(io.BytesIO(data)) | |
image_hash = await self.get_image_hash(image) | |
img_rgb = np.array(image) | |
inference_cfg = self.live_portrait.live_portrait_wrapper.cfg | |
img_rgb = await asyncio.to_thread(resize_to_limit, img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) | |
crop_info = await asyncio.to_thread(self.live_portrait.cropper.crop_single_image, img_rgb) | |
img_crop_256x256 = crop_info['img_crop_256x256'] | |
I_s = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.prepare_source, img_crop_256x256) | |
x_s_info = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.get_kp_info, I_s) | |
f_s = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.extract_feature_3d, I_s) | |
x_s = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.transform_keypoint, x_s_info) | |
processed_data = { | |
'img_rgb': img_rgb, | |
'crop_info': crop_info, | |
'x_s_info': x_s_info, | |
'f_s': f_s, | |
'x_s': x_s, | |
'inference_cfg': inference_cfg | |
} | |
self.processed_cache[image_hash] = processed_data | |
# Calculate the bounding box | |
bbox_info = parse_bbox_from_landmark(processed_data['crop_info']['lmk_crop'], scale=1.0) | |
return { | |
'h': image_hash, | |
# those aren't easy to serialize | |
'c': bbox_info['center'], # 2x1 | |
's': bbox_info['size'], # scalar | |
'b': bbox_info['bbox'], # 4x2 | |
'a': bbox_info['angle'], # rad, counterclockwise | |
# 'bbox_rot': bbox_info['bbox_rot'].toList(), # 4x2 | |
} | |
async def transform_image(self, image_hash: str, params: Dict[str, float]) -> bytes: | |
# If we don't have the image in cache yet, add it | |
if image_hash not in self.processed_cache: | |
raise ValueError("cache miss") | |
processed_data = self.processed_cache[image_hash] | |
try: | |
# Apply modifications based on params | |
x_d_new = processed_data['x_s_info']['kp'].clone() | |
modifications = [ | |
('smile', [ | |
(0, 20, 1, -0.01), (0, 14, 1, -0.02), (0, 17, 1, 0.0065), (0, 17, 2, 0.003), | |
(0, 13, 1, -0.00275), (0, 16, 1, -0.00275), (0, 3, 1, -0.0035), (0, 7, 1, -0.0035) | |
]), | |
('aaa', [ | |
(0, 19, 1, 0.001), (0, 19, 2, 0.0001), (0, 17, 1, -0.0001) | |
]), | |
('eee', [ | |
(0, 20, 2, -0.001), (0, 20, 1, -0.001), (0, 14, 1, -0.001) | |
]), | |
('woo', [ | |
(0, 14, 1, 0.001), (0, 3, 1, -0.0005), (0, 7, 1, -0.0005), (0, 17, 2, -0.0005) | |
]), | |
('wink', [ | |
(0, 11, 1, 0.001), (0, 13, 1, -0.0003), (0, 17, 0, 0.0003), | |
(0, 17, 1, 0.0003), (0, 3, 1, -0.0003) | |
]), | |
('pupil_x', [ | |
(0, 11, 0, 0.0007 if params.get('pupil_x', 0) > 0 else 0.001), | |
(0, 15, 0, 0.001 if params.get('pupil_x', 0) > 0 else 0.0007) | |
]), | |
('pupil_y', [ | |
(0, 11, 1, -0.001), (0, 15, 1, -0.001) | |
]), | |
('eyes', [ | |
(0, 11, 1, -0.001), (0, 13, 1, 0.0003), (0, 15, 1, -0.001), (0, 16, 1, 0.0003), | |
(0, 1, 1, -0.00025), (0, 2, 1, 0.00025) | |
]), | |
('eyebrow', [ | |
(0, 1, 1, 0.001 if params.get('eyebrow', 0) > 0 else 0.0003), | |
(0, 2, 1, -0.001 if params.get('eyebrow', 0) > 0 else -0.0003), | |
(0, 1, 0, -0.001 if params.get('eyebrow', 0) <= 0 else 0), | |
(0, 2, 0, 0.001 if params.get('eyebrow', 0) <= 0 else 0) | |
]) | |
] | |
for param_name, adjustments in modifications: | |
param_value = params.get(param_name, 0) | |
for i, j, k, factor in adjustments: | |
x_d_new[i, j, k] += param_value * factor | |
# Special case for pupil_y affecting eyes | |
x_d_new[0, 11, 1] -= params.get('pupil_y', 0) * 0.001 | |
x_d_new[0, 15, 1] -= params.get('pupil_y', 0) * 0.001 | |
params['eyes'] = params.get('eyes', 0) - params.get('pupil_y', 0) / 2. | |
# Apply rotation | |
R_new = get_rotation_matrix( | |
processed_data['x_s_info']['pitch'] + params.get('rotate_pitch', 0), | |
processed_data['x_s_info']['yaw'] + params.get('rotate_yaw', 0), | |
processed_data['x_s_info']['roll'] + params.get('rotate_roll', 0) | |
) | |
x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t'] | |
# Apply stitching | |
x_d_new = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.stitching, processed_data['x_s'], x_d_new) | |
# Generate the output | |
out = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.warp_decode, processed_data['f_s'], processed_data['x_s'], x_d_new) | |
I_p = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.parse_output, out['out']) | |
buffered = io.BytesIO() | |
#################################################### | |
# this part is about stitching the image back into the original. | |
# | |
# this is an expensive operation, not just because of the compute | |
# but because the payload will also be bigger (we send back the whole pic) | |
# | |
# I'm currently running some experiments to do it in the frontend | |
# | |
# --- old way: we do it in the server-side: --- | |
mask_ori = await asyncio.to_thread(prepare_paste_back, | |
processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'], | |
dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0]) | |
) | |
I_p_to_ori_blend = await asyncio.to_thread(paste_back, | |
I_p[0], processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori | |
) | |
result_image = Image.fromarray(I_p_to_ori_blend) | |
# --- maybe future way: do it in the frontend: --- | |
#result_image = Image.fromarray(I_p[0]) | |
#################################################### | |
# write it into a webp | |
result_image.save(buffered, format="WebP", quality=82, lossless=False, method=6) | |
return buffered.getvalue() | |
except Exception as e: | |
raise ValueError(f"Failed to modify image: {str(e)}") | |