TMM / lib /operation.py
Fazhong Liu
fin
7ca9b42
import torch
import torch.nn.functional as F
import numpy as np
import imageio
from math import pi
from tqdm import tqdm
from lib.data import get_dataloader, get_meanpose
from lib.util.general import get_config
from lib.util.visualization import motion2video_np, hex2rgb
import os
eps = 1e-16
def localize_motion_torch(motion):
"""
:param motion: B x J x D x T
:return:
"""
B, J, D, T = motion.size()
# subtract centers to local coordinates
centers = motion[:, 8:9, :, :] # B x 1 x D x (T-1)
motion = motion - centers
# adding velocity
translation = centers[:, :, :, 1:] - centers[:, :, :, :-1] # B x 1 x D x (T-1)
velocity = F.pad(translation, [1, 0], "constant", 0.) # B x 1 x D x T
motion = torch.cat([motion[:, :8], motion[:, 9:], velocity], dim=1)
return motion
def normalize_motion_torch(motion, meanpose, stdpose):
"""
:param motion: (B, J, D, T)
:param meanpose: (J, D)
:param stdpose: (J, D)
:return:
"""
B, J, D, T = motion.size()
if D == 2 and meanpose.size(1) == 3:
meanpose = meanpose[:, [0, 2]]
if D == 2 and stdpose.size(1) == 3:
stdpose = stdpose[:, [0, 2]]
return (motion - meanpose.view(1, J, D, 1)) / stdpose.view(1, J, D, 1)
def normalize_motion_inv_torch(motion, meanpose, stdpose):
"""
:param motion: (B, J, D, T)
:param meanpose: (J, D)
:param stdpose: (J, D)
:return:
"""
B, J, D, T = motion.size()
if D == 2 and meanpose.size(1) == 3:
meanpose = meanpose[:, [0, 2]]
if D == 2 and stdpose.size(1) == 3:
stdpose = stdpose[:, [0, 2]]
return motion * stdpose.view(1, J, D, 1) + meanpose.view(1, J, D, 1)
def globalize_motion_torch(motion):
"""
:param motion: B x J x D x T
:return:
"""
B, J, D, T = motion.size()
motion_inv = torch.zeros_like(motion)
motion_inv[:, :8] = motion[:, :8]
motion_inv[:, 9:] = motion[:, 8:-1]
velocity = motion[:, -1:, :, :]
centers = torch.zeros_like(velocity)
displacement = torch.zeros_like(velocity[:, :, :, 0])
for t in range(T):
displacement += velocity[:, :, :, t]
centers[:, :, :, t] = displacement
motion_inv = motion_inv + centers
return motion_inv
def restore_world_space(motion, meanpose, stdpose, n_joints=15):
B, C, T = motion.size()
motion = motion.view(B, n_joints, C // n_joints, T)
motion = normalize_motion_inv_torch(motion, meanpose, stdpose)
motion = globalize_motion_torch(motion)
return motion
def convert_to_learning_space(motion, meanpose, stdpose):
B, J, D, T = motion.size()
motion = localize_motion_torch(motion)
motion = normalize_motion_torch(motion, meanpose, stdpose)
motion = motion.view(B, J*D, T)
return motion
# tensor operations for rotating and projecting 3d skeleton sequence
def get_body_basis(motion_3d):
"""
Get the unit vectors for vector rectangular coordinates for given 3D motion
:param motion_3d: 3D motion from 3D joints positions, shape (B, n_joints, 3, seq_len).
:param angles: (K, 3), Rotation angles around each axis.
:return: unit vectors for vector rectangular coordinates's , shape (B, 3, 3).
"""
B = motion_3d.size(0)
# 2 RightArm 5 LeftArm 9 RightUpLeg 12 LeftUpLeg
horizontal = (motion_3d[:, 2] - motion_3d[:, 5] + motion_3d[:, 9] - motion_3d[:, 12]) / 2 # [B, 3, seq_len]
horizontal = horizontal.mean(dim=-1) # [B, 3]
horizontal = horizontal / horizontal.norm(dim=-1).unsqueeze(-1) # [B, 3]
vector_z = torch.tensor([0., 0., 1.], device=motion_3d.device, dtype=motion_3d.dtype).unsqueeze(0).repeat(B, 1) # [B, 3]
vector_y = torch.cross(horizontal, vector_z) # [B, 3]
vector_y = vector_y / vector_y.norm(dim=-1).unsqueeze(-1)
vector_x = torch.cross(vector_y, vector_z)
vectors = torch.stack([vector_x, vector_y, vector_z], dim=2) # [B, 3, 3]
vectors = vectors.detach()
return vectors
def rotate_basis_euler(basis_vectors, angles):
"""
Rotate vector rectangular coordinates from given angles.
:param basis_vectors: [B, 3, 3]
:param angles: [B, K, T, 3] Rotation angles around each axis.
:return: [B, K, T, 3, 3]
"""
B, K, T, _ = angles.size()
cos, sin = torch.cos(angles), torch.sin(angles)
cx, cy, cz = cos[:, :, :, 0], cos[:, :, :, 1], cos[:, :, :, 2] # [B, K, T]
sx, sy, sz = sin[:, :, :, 0], sin[:, :, :, 1], sin[:, :, :, 2] # [B, K, T]
x = basis_vectors[:, 0, :] # [B, 3]
o = torch.zeros_like(x[:, 0]) # [B]
x_cpm_0 = torch.stack([o, -x[:, 2], x[:, 1]], dim=1) # [B, 3]
x_cpm_1 = torch.stack([x[:, 2], o, -x[:, 0]], dim=1) # [B, 3]
x_cpm_2 = torch.stack([-x[:, 1], x[:, 0], o], dim=1) # [B, 3]
x_cpm = torch.stack([x_cpm_0, x_cpm_1, x_cpm_2], dim=1) # [B, 3, 3]
x_cpm = x_cpm.unsqueeze(1).unsqueeze(2) # [B, 1, 1, 3, 3]
x = x.unsqueeze(-1) # [B, 3, 1]
xx = torch.matmul(x, x.transpose(-1, -2)).unsqueeze(1).unsqueeze(2) # [B, 1, 1, 3, 3]
eye = torch.eye(n=3, dtype=basis_vectors.dtype, device=basis_vectors.device)
eye = eye.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, 3, 3]
mat33_x = cx.unsqueeze(-1).unsqueeze(-1) * eye \
+ sx.unsqueeze(-1).unsqueeze(-1) * x_cpm \
+ (1. - cx).unsqueeze(-1).unsqueeze(-1) * xx # [B, K, T, 3, 3]
o = torch.zeros_like(cz)
i = torch.ones_like(cz)
mat33_z_0 = torch.stack([cz, sz, o], dim=3) # [B, K, T, 3]
mat33_z_1 = torch.stack([-sz, cz, o], dim=3) # [B, K, T, 3]
mat33_z_2 = torch.stack([o, o, i], dim=3) # [B, K, T, 3]
mat33_z = torch.stack([mat33_z_0, mat33_z_1, mat33_z_2], dim=3) # [B, K, T, 3, 3]
basis_vectors = basis_vectors.unsqueeze(1).unsqueeze(2)
basis_vectors = basis_vectors @ mat33_x.transpose(-1, -2) @ mat33_z
return basis_vectors
def change_of_basis(motion_3d, basis_vectors=None, project_2d=False):
# motion_3d: (B, n_joints, 3, seq_len)
# basis_vectors: (B, K, T, 3, 3)
if basis_vectors is None:
motion_proj = motion_3d[:, :, [0, 2], :] # [B, n_joints, 2, seq_len]
else:
if project_2d: basis_vectors = basis_vectors[:, :, :, [0, 2], :]
_, K, seq_len, _, _ = basis_vectors.size()
motion_3d = motion_3d.unsqueeze(1).repeat(1, K, 1, 1, 1)
motion_3d = motion_3d.permute([0, 1, 4, 3, 2]) # [B, K, J, 3, T] -> [B, K, T, 3, J]
motion_proj = basis_vectors @ motion_3d # [B, K, T, 2, 3] @ [B, K, T, 3, J] -> [B, K, T, 2, J]
motion_proj = motion_proj.permute([0, 1, 4, 3, 2]) # [B, K, T, 3, J] -> [B, K, J, 3, T]
return motion_proj
def rotate_and_maybe_project_world(X, angles=None, body_reference=True, project_2d=False):
out_dim = 2 if project_2d else 3
batch_size, n_joints, _, seq_len = X.size()
if angles is not None:
K = angles.size(1)
basis_vectors = get_body_basis(X) if body_reference else \
torch.eye(3, device=X.device).unsqueeze(0).repeat(batch_size, 1, 1)
basis_vectors = rotate_basis_euler(basis_vectors, angles)
X_trans = change_of_basis(X, basis_vectors, project_2d=project_2d)
X_trans = X_trans.reshape(batch_size * K, n_joints, out_dim, seq_len)
else:
X_trans = change_of_basis(X, project_2d=project_2d)
X_trans = X_trans.reshape(batch_size, n_joints, out_dim, seq_len)
return X_trans
def rotate_and_maybe_project_learning(X, meanpose, stdpose, angles=None, body_reference=True, project_2d=False):
batch_size, channels, seq_len = X.size()
n_joints = channels // 3
X = restore_world_space(X, meanpose, stdpose, n_joints)
X = rotate_and_maybe_project_world(X, angles, body_reference, project_2d)
X = convert_to_learning_space(X, meanpose, stdpose)
return X