jbilcke-hf's picture
jbilcke-hf HF staff
initial commit
d69879c
raw
history blame
1.86 kB
# coding: utf-8
"""
functions for processing and transforming 3D facial keypoints
"""
import numpy as np
import torch
import torch.nn.functional as F
PI = np.pi
def headpose_pred_to_degree(pred):
"""
pred: (bs, 66) or (bs, 1) or others
"""
if pred.ndim > 1 and pred.shape[1] == 66:
# NOTE: note that the average is modified to 97.5
device = pred.device
idx_tensor = [idx for idx in range(0, 66)]
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
pred = F.softmax(pred, dim=1)
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5
return degree
return pred
def get_rotation_matrix(pitch_, yaw_, roll_):
""" the input is in degree
"""
# calculate the rotation matrix: vps @ rot
# transform to radian
pitch = pitch_ / 180 * PI
yaw = yaw_ / 180 * PI
roll = roll_ / 180 * PI
device = pitch.device
if pitch.ndim == 1:
pitch = pitch.unsqueeze(1)
if yaw.ndim == 1:
yaw = yaw.unsqueeze(1)
if roll.ndim == 1:
roll = roll.unsqueeze(1)
# calculate the euler matrix
bs = pitch.shape[0]
ones = torch.ones([bs, 1]).to(device)
zeros = torch.zeros([bs, 1]).to(device)
x, y, z = pitch, yaw, roll
rot_x = torch.cat([
ones, zeros, zeros,
zeros, torch.cos(x), -torch.sin(x),
zeros, torch.sin(x), torch.cos(x)
], dim=1).reshape([bs, 3, 3])
rot_y = torch.cat([
torch.cos(y), zeros, torch.sin(y),
zeros, ones, zeros,
-torch.sin(y), zeros, torch.cos(y)
], dim=1).reshape([bs, 3, 3])
rot_z = torch.cat([
torch.cos(z), -torch.sin(z), zeros,
torch.sin(z), torch.cos(z), zeros,
zeros, zeros, ones
], dim=1).reshape([bs, 3, 3])
rot = rot_z @ rot_y @ rot_x
return rot.permute(0, 2, 1) # transpose