|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import imageio |
|
from math import pi |
|
from tqdm import tqdm |
|
from lib.data import get_dataloader, get_meanpose |
|
from lib.util.general import get_config |
|
from lib.util.visualization import motion2video_np, hex2rgb |
|
import os |
|
|
|
eps = 1e-16 |
|
|
|
|
|
def localize_motion_torch(motion): |
|
""" |
|
:param motion: B x J x D x T |
|
:return: |
|
""" |
|
B, J, D, T = motion.size() |
|
|
|
|
|
centers = motion[:, 8:9, :, :] |
|
motion = motion - centers |
|
|
|
|
|
translation = centers[:, :, :, 1:] - centers[:, :, :, :-1] |
|
velocity = F.pad(translation, [1, 0], "constant", 0.) |
|
motion = torch.cat([motion[:, :8], motion[:, 9:], velocity], dim=1) |
|
|
|
return motion |
|
|
|
|
|
def normalize_motion_torch(motion, meanpose, stdpose): |
|
""" |
|
:param motion: (B, J, D, T) |
|
:param meanpose: (J, D) |
|
:param stdpose: (J, D) |
|
:return: |
|
""" |
|
B, J, D, T = motion.size() |
|
if D == 2 and meanpose.size(1) == 3: |
|
meanpose = meanpose[:, [0, 2]] |
|
if D == 2 and stdpose.size(1) == 3: |
|
stdpose = stdpose[:, [0, 2]] |
|
return (motion - meanpose.view(1, J, D, 1)) / stdpose.view(1, J, D, 1) |
|
|
|
|
|
def normalize_motion_inv_torch(motion, meanpose, stdpose): |
|
""" |
|
:param motion: (B, J, D, T) |
|
:param meanpose: (J, D) |
|
:param stdpose: (J, D) |
|
:return: |
|
""" |
|
B, J, D, T = motion.size() |
|
if D == 2 and meanpose.size(1) == 3: |
|
meanpose = meanpose[:, [0, 2]] |
|
if D == 2 and stdpose.size(1) == 3: |
|
stdpose = stdpose[:, [0, 2]] |
|
return motion * stdpose.view(1, J, D, 1) + meanpose.view(1, J, D, 1) |
|
|
|
|
|
def globalize_motion_torch(motion): |
|
""" |
|
:param motion: B x J x D x T |
|
:return: |
|
""" |
|
B, J, D, T = motion.size() |
|
|
|
motion_inv = torch.zeros_like(motion) |
|
motion_inv[:, :8] = motion[:, :8] |
|
motion_inv[:, 9:] = motion[:, 8:-1] |
|
|
|
velocity = motion[:, -1:, :, :] |
|
centers = torch.zeros_like(velocity) |
|
displacement = torch.zeros_like(velocity[:, :, :, 0]) |
|
|
|
for t in range(T): |
|
displacement += velocity[:, :, :, t] |
|
centers[:, :, :, t] = displacement |
|
|
|
motion_inv = motion_inv + centers |
|
|
|
return motion_inv |
|
|
|
|
|
def restore_world_space(motion, meanpose, stdpose, n_joints=15): |
|
B, C, T = motion.size() |
|
motion = motion.view(B, n_joints, C // n_joints, T) |
|
motion = normalize_motion_inv_torch(motion, meanpose, stdpose) |
|
motion = globalize_motion_torch(motion) |
|
return motion |
|
|
|
|
|
def convert_to_learning_space(motion, meanpose, stdpose): |
|
B, J, D, T = motion.size() |
|
motion = localize_motion_torch(motion) |
|
motion = normalize_motion_torch(motion, meanpose, stdpose) |
|
motion = motion.view(B, J*D, T) |
|
return motion |
|
|
|
|
|
|
|
|
|
def get_body_basis(motion_3d): |
|
""" |
|
Get the unit vectors for vector rectangular coordinates for given 3D motion |
|
:param motion_3d: 3D motion from 3D joints positions, shape (B, n_joints, 3, seq_len). |
|
:param angles: (K, 3), Rotation angles around each axis. |
|
:return: unit vectors for vector rectangular coordinates's , shape (B, 3, 3). |
|
""" |
|
B = motion_3d.size(0) |
|
|
|
|
|
horizontal = (motion_3d[:, 2] - motion_3d[:, 5] + motion_3d[:, 9] - motion_3d[:, 12]) / 2 |
|
horizontal = horizontal.mean(dim=-1) |
|
horizontal = horizontal / horizontal.norm(dim=-1).unsqueeze(-1) |
|
|
|
vector_z = torch.tensor([0., 0., 1.], device=motion_3d.device, dtype=motion_3d.dtype).unsqueeze(0).repeat(B, 1) |
|
vector_y = torch.cross(horizontal, vector_z) |
|
vector_y = vector_y / vector_y.norm(dim=-1).unsqueeze(-1) |
|
vector_x = torch.cross(vector_y, vector_z) |
|
vectors = torch.stack([vector_x, vector_y, vector_z], dim=2) |
|
|
|
vectors = vectors.detach() |
|
|
|
return vectors |
|
|
|
|
|
def rotate_basis_euler(basis_vectors, angles): |
|
""" |
|
Rotate vector rectangular coordinates from given angles. |
|
|
|
:param basis_vectors: [B, 3, 3] |
|
:param angles: [B, K, T, 3] Rotation angles around each axis. |
|
:return: [B, K, T, 3, 3] |
|
""" |
|
B, K, T, _ = angles.size() |
|
|
|
cos, sin = torch.cos(angles), torch.sin(angles) |
|
cx, cy, cz = cos[:, :, :, 0], cos[:, :, :, 1], cos[:, :, :, 2] |
|
sx, sy, sz = sin[:, :, :, 0], sin[:, :, :, 1], sin[:, :, :, 2] |
|
|
|
x = basis_vectors[:, 0, :] |
|
o = torch.zeros_like(x[:, 0]) |
|
|
|
x_cpm_0 = torch.stack([o, -x[:, 2], x[:, 1]], dim=1) |
|
x_cpm_1 = torch.stack([x[:, 2], o, -x[:, 0]], dim=1) |
|
x_cpm_2 = torch.stack([-x[:, 1], x[:, 0], o], dim=1) |
|
x_cpm = torch.stack([x_cpm_0, x_cpm_1, x_cpm_2], dim=1) |
|
x_cpm = x_cpm.unsqueeze(1).unsqueeze(2) |
|
|
|
x = x.unsqueeze(-1) |
|
xx = torch.matmul(x, x.transpose(-1, -2)).unsqueeze(1).unsqueeze(2) |
|
eye = torch.eye(n=3, dtype=basis_vectors.dtype, device=basis_vectors.device) |
|
eye = eye.unsqueeze(0).unsqueeze(0).unsqueeze(0) |
|
mat33_x = cx.unsqueeze(-1).unsqueeze(-1) * eye \ |
|
+ sx.unsqueeze(-1).unsqueeze(-1) * x_cpm \ |
|
+ (1. - cx).unsqueeze(-1).unsqueeze(-1) * xx |
|
|
|
o = torch.zeros_like(cz) |
|
i = torch.ones_like(cz) |
|
mat33_z_0 = torch.stack([cz, sz, o], dim=3) |
|
mat33_z_1 = torch.stack([-sz, cz, o], dim=3) |
|
mat33_z_2 = torch.stack([o, o, i], dim=3) |
|
mat33_z = torch.stack([mat33_z_0, mat33_z_1, mat33_z_2], dim=3) |
|
|
|
basis_vectors = basis_vectors.unsqueeze(1).unsqueeze(2) |
|
basis_vectors = basis_vectors @ mat33_x.transpose(-1, -2) @ mat33_z |
|
|
|
|
|
return basis_vectors |
|
|
|
|
|
def change_of_basis(motion_3d, basis_vectors=None, project_2d=False): |
|
|
|
|
|
|
|
if basis_vectors is None: |
|
motion_proj = motion_3d[:, :, [0, 2], :] |
|
else: |
|
if project_2d: basis_vectors = basis_vectors[:, :, :, [0, 2], :] |
|
_, K, seq_len, _, _ = basis_vectors.size() |
|
motion_3d = motion_3d.unsqueeze(1).repeat(1, K, 1, 1, 1) |
|
motion_3d = motion_3d.permute([0, 1, 4, 3, 2]) |
|
motion_proj = basis_vectors @ motion_3d |
|
motion_proj = motion_proj.permute([0, 1, 4, 3, 2]) |
|
|
|
return motion_proj |
|
|
|
|
|
def rotate_and_maybe_project_world(X, angles=None, body_reference=True, project_2d=False): |
|
|
|
out_dim = 2 if project_2d else 3 |
|
batch_size, n_joints, _, seq_len = X.size() |
|
|
|
if angles is not None: |
|
K = angles.size(1) |
|
basis_vectors = get_body_basis(X) if body_reference else \ |
|
torch.eye(3, device=X.device).unsqueeze(0).repeat(batch_size, 1, 1) |
|
basis_vectors = rotate_basis_euler(basis_vectors, angles) |
|
X_trans = change_of_basis(X, basis_vectors, project_2d=project_2d) |
|
X_trans = X_trans.reshape(batch_size * K, n_joints, out_dim, seq_len) |
|
else: |
|
X_trans = change_of_basis(X, project_2d=project_2d) |
|
X_trans = X_trans.reshape(batch_size, n_joints, out_dim, seq_len) |
|
|
|
return X_trans |
|
|
|
|
|
|
|
def rotate_and_maybe_project_learning(X, meanpose, stdpose, angles=None, body_reference=True, project_2d=False): |
|
batch_size, channels, seq_len = X.size() |
|
n_joints = channels // 3 |
|
X = restore_world_space(X, meanpose, stdpose, n_joints) |
|
X = rotate_and_maybe_project_world(X, angles, body_reference, project_2d) |
|
X = convert_to_learning_space(X, meanpose, stdpose) |
|
return X |
|
|