Spaces:
Runtime error
Runtime error
# coding: utf-8 | |
""" | |
functions for processing and transforming 3D facial keypoints | |
""" | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
PI = np.pi | |
def headpose_pred_to_degree(pred): | |
""" | |
pred: (bs, 66) or (bs, 1) or others | |
""" | |
if pred.ndim > 1 and pred.shape[1] == 66: | |
# NOTE: note that the average is modified to 97.5 | |
device = pred.device | |
idx_tensor = [idx for idx in range(0, 66)] | |
idx_tensor = torch.FloatTensor(idx_tensor).to(device) | |
pred = F.softmax(pred, dim=1) | |
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5 | |
return degree | |
return pred | |
def get_rotation_matrix(pitch_, yaw_, roll_): | |
""" the input is in degree | |
""" | |
# calculate the rotation matrix: vps @ rot | |
# transform to radian | |
pitch = pitch_ / 180 * PI | |
yaw = yaw_ / 180 * PI | |
roll = roll_ / 180 * PI | |
device = pitch.device | |
if pitch.ndim == 1: | |
pitch = pitch.unsqueeze(1) | |
if yaw.ndim == 1: | |
yaw = yaw.unsqueeze(1) | |
if roll.ndim == 1: | |
roll = roll.unsqueeze(1) | |
# calculate the euler matrix | |
bs = pitch.shape[0] | |
ones = torch.ones([bs, 1]).to(device) | |
zeros = torch.zeros([bs, 1]).to(device) | |
x, y, z = pitch, yaw, roll | |
rot_x = torch.cat([ | |
ones, zeros, zeros, | |
zeros, torch.cos(x), -torch.sin(x), | |
zeros, torch.sin(x), torch.cos(x) | |
], dim=1).reshape([bs, 3, 3]) | |
rot_y = torch.cat([ | |
torch.cos(y), zeros, torch.sin(y), | |
zeros, ones, zeros, | |
-torch.sin(y), zeros, torch.cos(y) | |
], dim=1).reshape([bs, 3, 3]) | |
rot_z = torch.cat([ | |
torch.cos(z), -torch.sin(z), zeros, | |
torch.sin(z), torch.cos(z), zeros, | |
zeros, zeros, ones | |
], dim=1).reshape([bs, 3, 3]) | |
rot = rot_z @ rot_y @ rot_x | |
return rot.permute(0, 2, 1) # transpose | |