|
import os |
|
import sys |
|
import cv2 |
|
import yaml |
|
import imageio |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import subprocess, platform |
|
from mutagen.wave import WAVE |
|
from datetime import timedelta |
|
|
|
from face_vid2vid.sync_batchnorm.replicate import DataParallelWithCallback |
|
from face_vid2vid.modules.generator import OcclusionAwareSPADEGenerator |
|
from face_vid2vid.modules.keypoint_detector import KPDetector, HEEstimator |
|
from face_vid2vid.animate import normalize_kp |
|
from batch_face import RetinaFace |
|
|
|
|
|
if sys.version_info[0] < 3: |
|
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") |
|
|
|
|
|
def load_checkpoints(config_path, checkpoint_path): |
|
with open(config_path) as f: |
|
config = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) |
|
|
|
generator.cuda().half() |
|
|
|
kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) |
|
|
|
kp_detector.cuda() |
|
|
|
he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"]) |
|
|
|
he_estimator.cuda() |
|
|
|
print("Loading checkpoints") |
|
checkpoint = torch.load(checkpoint_path) |
|
|
|
generator.load_state_dict(checkpoint["generator"]) |
|
kp_detector.load_state_dict(checkpoint["kp_detector"]) |
|
he_estimator.load_state_dict(checkpoint["he_estimator"]) |
|
|
|
generator = DataParallelWithCallback(generator) |
|
kp_detector = DataParallelWithCallback(kp_detector) |
|
he_estimator = DataParallelWithCallback(he_estimator) |
|
|
|
generator.eval() |
|
kp_detector.eval() |
|
he_estimator.eval() |
|
print("Model successfully loaded!") |
|
|
|
return generator, kp_detector, he_estimator |
|
|
|
|
|
def headpose_pred_to_degree(pred): |
|
device = pred.device |
|
idx_tensor = [idx for idx in range(66)] |
|
idx_tensor = torch.FloatTensor(idx_tensor).to(device) |
|
pred = F.softmax(pred, dim=1) |
|
degree = torch.sum(pred * idx_tensor, axis=1) * 3 - 99 |
|
|
|
return degree |
|
|
|
|
|
def get_rotation_matrix(yaw, pitch, roll): |
|
yaw = yaw / 180 * 3.14 |
|
pitch = pitch / 180 * 3.14 |
|
roll = roll / 180 * 3.14 |
|
|
|
roll = roll.unsqueeze(1) |
|
pitch = pitch.unsqueeze(1) |
|
yaw = yaw.unsqueeze(1) |
|
|
|
pitch_mat = torch.cat( |
|
[ |
|
torch.ones_like(pitch), |
|
torch.zeros_like(pitch), |
|
torch.zeros_like(pitch), |
|
torch.zeros_like(pitch), |
|
torch.cos(pitch), |
|
-torch.sin(pitch), |
|
torch.zeros_like(pitch), |
|
torch.sin(pitch), |
|
torch.cos(pitch), |
|
], |
|
dim=1, |
|
) |
|
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) |
|
|
|
yaw_mat = torch.cat( |
|
[ |
|
torch.cos(yaw), |
|
torch.zeros_like(yaw), |
|
torch.sin(yaw), |
|
torch.zeros_like(yaw), |
|
torch.ones_like(yaw), |
|
torch.zeros_like(yaw), |
|
-torch.sin(yaw), |
|
torch.zeros_like(yaw), |
|
torch.cos(yaw), |
|
], |
|
dim=1, |
|
) |
|
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) |
|
|
|
roll_mat = torch.cat( |
|
[ |
|
torch.cos(roll), |
|
-torch.sin(roll), |
|
torch.zeros_like(roll), |
|
torch.sin(roll), |
|
torch.cos(roll), |
|
torch.zeros_like(roll), |
|
torch.zeros_like(roll), |
|
torch.zeros_like(roll), |
|
torch.ones_like(roll), |
|
], |
|
dim=1, |
|
) |
|
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) |
|
|
|
rot_mat = torch.einsum("bij,bjk,bkm->bim", pitch_mat, yaw_mat, roll_mat) |
|
|
|
return rot_mat |
|
|
|
|
|
def keypoint_transformation(kp_canonical, he, estimate_jacobian=False, free_view=False, yaw=0, pitch=0, roll=0, output_coord=False): |
|
kp = kp_canonical["value"] |
|
if not free_view: |
|
yaw, pitch, roll = he["yaw"], he["pitch"], he["roll"] |
|
yaw = headpose_pred_to_degree(yaw) |
|
pitch = headpose_pred_to_degree(pitch) |
|
roll = headpose_pred_to_degree(roll) |
|
else: |
|
if yaw is not None: |
|
yaw = torch.tensor([yaw]).cuda() |
|
else: |
|
yaw = he["yaw"] |
|
yaw = headpose_pred_to_degree(yaw) |
|
if pitch is not None: |
|
pitch = torch.tensor([pitch]).cuda() |
|
else: |
|
pitch = he["pitch"] |
|
pitch = headpose_pred_to_degree(pitch) |
|
if roll is not None: |
|
roll = torch.tensor([roll]).cuda() |
|
else: |
|
roll = he["roll"] |
|
roll = headpose_pred_to_degree(roll) |
|
|
|
t, exp = he["t"], he["exp"] |
|
|
|
rot_mat = get_rotation_matrix(yaw, pitch, roll) |
|
|
|
|
|
kp_rotated = torch.einsum("bmp,bkp->bkm", rot_mat, kp) |
|
|
|
|
|
t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) |
|
kp_t = kp_rotated + t |
|
|
|
|
|
exp = exp.view(exp.shape[0], -1, 3) |
|
kp_transformed = kp_t + exp |
|
|
|
if estimate_jacobian: |
|
jacobian = kp_canonical["jacobian"] |
|
jacobian_transformed = torch.einsum("bmp,bkps->bkms", rot_mat, jacobian) |
|
else: |
|
jacobian_transformed = None |
|
|
|
if output_coord: |
|
return {"value": kp_transformed, "jacobian": jacobian_transformed}, { |
|
"yaw": float(yaw.cpu().numpy()), |
|
"pitch": float(pitch.cpu().numpy()), |
|
"roll": float(roll.cpu().numpy()), |
|
} |
|
|
|
return {"value": kp_transformed, "jacobian": jacobian_transformed} |
|
|
|
|
|
def get_square_face(coords, image): |
|
x1, y1, x2, y2 = coords |
|
|
|
length = max(x2 - x1, y2 - y1) // 2 |
|
x1 = x1 - length * 0.5 |
|
x2 = x2 + length * 0.5 |
|
y1 = y1 - length * 0.5 |
|
y2 = y2 + length * 0.5 |
|
|
|
|
|
center = (x1 + x2) // 2, (y1 + y2) // 2 |
|
length = max(x2 - x1, y2 - y1) // 2 |
|
x1 = max(int(round(center[0] - length)), 0) |
|
x2 = min(int(round(center[0] + length)), image.shape[1]) |
|
y1 = max(int(round(center[1] - length)), 0) |
|
y2 = min(int(round(center[1] + length)), image.shape[0]) |
|
return image[y1:y2, x1:x2] |
|
|
|
|
|
def smooth_coord(last_coord, current_coord, smooth_factor=0.2): |
|
change = np.array(current_coord) - np.array(last_coord) |
|
|
|
change = change * smooth_factor |
|
return (np.array(last_coord) + np.array(change)).astype(int).tolist() |
|
|
|
|
|
class FaceAnimationClass: |
|
def __init__(self, source_image_path=None, use_sr=False): |
|
assert source_image_path is not None, "source_image_path is None, please set source_image_path" |
|
config_path = os.path.join(os.path.dirname(__file__), "face_vid2vid/config/vox-256-spade.yaml") |
|
|
|
checkpoint_path = os.path.join(os.path.expanduser("~"), ".cache/torch/hub/checkpoints/FaceMapping.pth.tar") |
|
if not os.path.exists(checkpoint_path): |
|
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) |
|
from gdown import download |
|
file_id = "11ZgyjKI5OcB7klcsIdPpCCX38AIX8Soc" |
|
download(id=file_id, output=checkpoint_path, quiet=False) |
|
if use_sr: |
|
from face_vid2vid.GPEN.face_enhancement import FaceEnhancement |
|
|
|
self.faceenhancer = FaceEnhancement( |
|
size=256, model="GPEN-BFR-256", use_sr=False, sr_model="realesrnet_x2", channel_multiplier=1, narrow=0.5, use_facegan=True |
|
) |
|
|
|
|
|
self.generator, self.kp_detector, self.he_estimator = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path) |
|
source_image = cv2.cvtColor(cv2.imread(source_image_path), cv2.COLOR_RGB2BGR).astype(np.float32) / 255. |
|
source_image = cv2.resize(source_image, (256, 256), interpolation=cv2.INTER_AREA) |
|
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) |
|
self.source = source.cuda() |
|
|
|
|
|
self.face_detector = RetinaFace() |
|
self.detect_interval = 8 |
|
self.smooth_factor = 0.2 |
|
|
|
|
|
self.base_frame = cv2.imread(source_image_path) if not use_sr else self.faceenhancer.process(cv2.imread(source_image_path))[0] |
|
self.base_frame = cv2.resize(self.base_frame, (256, 256)) |
|
self.blank_frame = np.ones(self.base_frame.shape, dtype=np.uint8) * 255 |
|
cv2.putText(self.blank_frame, "Face not", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) |
|
cv2.putText(self.blank_frame, "detected!", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) |
|
|
|
|
|
self.n_frame = 0 |
|
|
|
|
|
self.first_frame = True |
|
self.last_coords = None |
|
self.coords = None |
|
self.use_sr = use_sr |
|
self.kp_source = None |
|
self.kp_driving_initial = None |
|
|
|
|
|
def _conver_input_frame(self, frame): |
|
frame = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_NEAREST).astype(np.float32) / 255.0 |
|
return torch.tensor(frame[np.newaxis]).permute(0, 3, 1, 2).cuda() |
|
|
|
def _process_first_frame(self, frame): |
|
print("Processing first frame") |
|
|
|
faces = self.face_detector(frame, cv=True) |
|
if len(faces) == 0: |
|
raise ValueError("Face is not detected") |
|
else: |
|
self.coords = faces[0][0] |
|
face = get_square_face(self.coords, frame) |
|
self.last_coords = self.coords |
|
|
|
|
|
with torch.no_grad(): |
|
self.kp_canonical = self.kp_detector(self.source) |
|
self.he_source = self.he_estimator(self.source) |
|
|
|
face_input = self._conver_input_frame(face) |
|
he_driving_initial = self.he_estimator(face_input) |
|
self.kp_driving_initial, coordinates = keypoint_transformation(self.kp_canonical, he_driving_initial, output_coord=True) |
|
self.kp_source = keypoint_transformation( |
|
self.kp_canonical, self.he_source, free_view=True, yaw=coordinates["yaw"], pitch=coordinates["pitch"], roll=coordinates["roll"] |
|
) |
|
|
|
def _inference(self, frame): |
|
|
|
with torch.no_grad(): |
|
self.n_frame += 1 |
|
if self.first_frame: |
|
self._process_first_frame(frame) |
|
self.first_frame = False |
|
else: |
|
pass |
|
if self.n_frame % self.detect_interval == 0: |
|
faces = self.face_detector(frame, cv=True) |
|
if len(faces) == 0: |
|
raise ValueError("Face is not detected") |
|
else: |
|
self.coords = faces[0][0] |
|
self.coords = smooth_coord(self.last_coords, self.coords, self.smooth_factor) |
|
face = get_square_face(self.coords, frame) |
|
self.last_coords = self.coords |
|
face_input = self._conver_input_frame(face) |
|
|
|
he_driving = self.he_estimator(face_input) |
|
kp_driving = keypoint_transformation(self.kp_canonical, he_driving) |
|
kp_norm = normalize_kp( |
|
kp_source=self.kp_source, |
|
kp_driving=kp_driving, |
|
kp_driving_initial=self.kp_driving_initial, |
|
use_relative_movement=True, |
|
adapt_movement_scale=True, |
|
) |
|
|
|
out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm, fp16=True) |
|
image = np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0] |
|
image = (np.array(image).astype(np.float32) * 255).astype(np.uint8) |
|
result = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
return face, result |
|
|
|
def inference(self, frame): |
|
|
|
try: |
|
if frame is not None: |
|
face, result = self._inference(frame) |
|
if self.use_sr: |
|
result, _, _ = self.faceenhancer.process(result) |
|
result = cv2.resize(result, (256, 256)) |
|
return face, result |
|
except Exception as e: |
|
print(e) |
|
self.first_frame = True |
|
self.n_frame = 0 |
|
return self.blank_frame, self.base_frame |
|
|
|
|
|
def get_audio_duration(audioPath): |
|
audio = WAVE(audioPath) |
|
duration = audio.info.length |
|
return duration |
|
|
|
def seconds_to_hms(seconds): |
|
seconds = int(seconds) + 1 |
|
hms = str(timedelta(seconds=seconds)) |
|
hms = hms.split(":") |
|
hms = [f"0{h}" if len(h) == 1 else h for h in hms] |
|
return ":".join(hms) |
|
|
|
def animate_face(path_id, audiofile, driverfile, imgfile, animatedfile): |
|
from tqdm import tqdm |
|
import time |
|
faceanimation = FaceAnimationClass(source_image_path=os.path.join("temp", path_id, imgfile), use_sr=False) |
|
|
|
tmpfile = f"temp/{path_id}/tmp.mp4" |
|
duration = get_audio_duration(os.path.join("temp", path_id, audiofile)) |
|
print("duration of audio:", duration) |
|
hms = seconds_to_hms(duration) |
|
print("converted into hms:", hms) |
|
command = f"ffmpeg -ss 00:00:00 -i {driverfile} -to {hms} -c copy {tmpfile}" |
|
subprocess.call(command, shell=platform.system() != 'Windows') |
|
|
|
capture = cv2.VideoCapture(tmpfile) |
|
fps = capture.get(cv2.CAP_PROP_FPS) |
|
frames = [] |
|
_, frame = capture.read() |
|
while frame is not None: |
|
frames.append(frame) |
|
_, frame = capture.read() |
|
capture.release() |
|
|
|
output_frames = [] |
|
time_start = time.time() |
|
for frame in tqdm(frames): |
|
face, result = faceanimation.inference(frame) |
|
|
|
output_frames.append(result) |
|
time_end = time.time() |
|
print("Time cost: %.2f" % (time_end - time_start), "FPS: %.2f" % (len(frames) / (time_end - time_start))) |
|
writer = imageio.get_writer(os.path.join("temp", path_id, animatedfile), fps=fps, quality=9, macro_block_size=1, |
|
codec="libx264", pixelformat="yuv420p") |
|
for frame in output_frames: |
|
writer.append_data(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
|
|
writer.close() |
|
|
|
|
|
|