|
import sys, os |
|
thismodule = sys.modules[__name__] |
|
|
|
from lib.util.motion import preprocess_mixamo, rotate_motion_3d, limb_scale_motion_2d, normalize_motion, get_change_of_basis, localize_motion, scale_limbs |
|
|
|
import torch |
|
import glob |
|
import numpy as np |
|
import random |
|
from torch.utils.data import Dataset, DataLoader |
|
from easydict import EasyDict as edict |
|
from tqdm import tqdm |
|
|
|
view_angles = np.array([ i * np.pi / 6 for i in range(-3, 4)]) |
|
|
|
def get_dataloader(phase, config): |
|
|
|
config.data.batch_size = config.batch_size |
|
config.data.seq_len = config.seq_len |
|
dataset_cls_name = config.data.train_cls if phase == 'train' else config.data.eval_cls |
|
dataset_cls = getattr(thismodule, dataset_cls_name) |
|
dataset = dataset_cls(phase, config.data) |
|
|
|
dataloader = DataLoader(dataset, shuffle=(phase=='train'), |
|
batch_size=config.batch_size, |
|
num_workers=(config.data.num_workers if phase == 'train' else 1), |
|
worker_init_fn=lambda _: np.random.seed(), |
|
drop_last=True) |
|
|
|
return dataloader |
|
|
|
|
|
class _MixamoDatasetBase(Dataset): |
|
def __init__(self, phase, config): |
|
super(_MixamoDatasetBase, self).__init__() |
|
|
|
assert phase in ['train', 'test'] |
|
self.phase = phase |
|
self.data_root = config.train_dir if phase=='train' else config.test_dir |
|
self.meanpose_path = config.train_meanpose_path if phase=='train' else config.test_meanpose_path |
|
self.stdpose_path = config.train_stdpose_path if phase=='train' else config.test_stdpose_path |
|
self.unit = config.unit |
|
self.aug = (phase == 'train') |
|
self.character_names = sorted(os.listdir(self.data_root)) |
|
|
|
items = glob.glob(os.path.join(self.data_root, self.character_names[0], '*/motions/*.npy')) |
|
self.motion_names = ['/'.join(x.split('/')[-3:]) for x in items] |
|
|
|
self.meanpose, self.stdpose = get_meanpose(phase, config) |
|
self.meanpose = self.meanpose.astype(np.float32) |
|
self.stdpose = self.stdpose.astype(np.float32) |
|
|
|
if 'preload' in config and config.preload: |
|
self.preload() |
|
self.cached = True |
|
else: |
|
self.cached = False |
|
|
|
def build_item(self, mot_name, char_name): |
|
""" |
|
:param mot_name: animation_name/motions/xxx.npy |
|
:param char_name: character_name |
|
:return: |
|
""" |
|
return os.path.join(self.data_root, char_name, mot_name) |
|
|
|
def load_item(self, item): |
|
if self.cached: |
|
data = self.cache[item] |
|
else: |
|
data = np.load(item) |
|
return data |
|
|
|
def preload(self): |
|
print("pre-loading into memory") |
|
pbar = tqdm(total=len(self)) |
|
self.cache = {} |
|
for motion_name in self.motion_names: |
|
for character_name in self.character_names: |
|
item = self.build_item(motion_name, character_name) |
|
motion3d = np.load(item) |
|
self.cache[item] = motion3d |
|
pbar.update(1) |
|
|
|
@staticmethod |
|
def gen_aug_params(rotate=False): |
|
if rotate: |
|
params = {'ratio': np.random.uniform(0.8, 1.2), |
|
'roll': np.random.uniform((-np.pi / 9, -np.pi / 9, -np.pi / 6), (np.pi / 9, np.pi / 9, np.pi / 6))} |
|
else: |
|
params = {'ratio': np.random.uniform(0.5, 1.5)} |
|
return edict(params) |
|
|
|
@staticmethod |
|
def augmentation(data, params=None): |
|
""" |
|
:param data: numpy array of size (joints, 3, len_frames) |
|
:return: |
|
""" |
|
if params is None: |
|
return data, params |
|
|
|
|
|
if 'roll' in params.keys(): |
|
cx, cy, cz = np.cos(params.roll) |
|
sx, sy, sz = np.sin(params.roll) |
|
mat33_x = np.array([ |
|
[1, 0, 0], |
|
[0, cx, -sx], |
|
[0, sx, cx] |
|
], dtype='float') |
|
mat33_y = np.array([ |
|
[cy, 0, sy], |
|
[0, 1, 0], |
|
[-sy, 0, cy] |
|
], dtype='float') |
|
mat33_z = np.array([ |
|
[cz, -sz, 0], |
|
[sz, cz, 0], |
|
[0, 0, 1] |
|
], dtype='float') |
|
data = mat33_x @ mat33_y @ mat33_z @ data |
|
|
|
|
|
if 'ratio' in params.keys(): |
|
data = data * params.ratio |
|
|
|
return data, params |
|
|
|
def __getitem__(self, index): |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
return len(self.motion_names) * len(self.character_names) |
|
|
|
|
|
def get_meanpose(phase, config): |
|
|
|
meanpose_path = config.train_meanpose_path if phase == "train" else config.test_meanpose_path |
|
stdpose_path = config.train_stdpose_path if phase == "train" else config.test_stdpose_path |
|
|
|
if os.path.exists(meanpose_path) and os.path.exists(stdpose_path): |
|
meanpose = np.load(meanpose_path) |
|
stdpose = np.load(stdpose_path) |
|
else: |
|
meanpose, stdpose = gen_meanpose(phase, config) |
|
np.save(meanpose_path, meanpose) |
|
np.save(stdpose_path, stdpose) |
|
print("meanpose saved at {}".format(meanpose_path)) |
|
print("stdpose saved at {}".format(stdpose_path)) |
|
|
|
if meanpose.shape[-1] == 2: |
|
mean_x, mean_y = meanpose[:, 0], meanpose[:, 1] |
|
meanpose = np.stack([mean_x, mean_x, mean_y], axis=1) |
|
|
|
if stdpose.shape[-1] == 2: |
|
std_x, std_y = stdpose[:, 0], stdpose[:, 1] |
|
stdpose = np.stack([std_x, std_x, std_y], axis=1) |
|
|
|
return meanpose, stdpose |
|
|
|
|
|
def gen_meanpose(phase, config, n_samp=20000): |
|
|
|
data_dir = config.train_dir if phase == "train" else config.test_dir |
|
all_paths = glob.glob(os.path.join(data_dir, '*/*/motions/*.npy')) |
|
random.shuffle(all_paths) |
|
all_paths = all_paths[:n_samp] |
|
all_joints = [] |
|
|
|
print("computing meanpose and stdpose") |
|
|
|
for path in tqdm(all_paths): |
|
motion = np.load(path) |
|
if motion.shape[1] == 3: |
|
basis = None |
|
if sum(config.rotation_axes) > 0: |
|
x_angles = view_angles if config.rotation_axes[0] else np.array([0]) |
|
z_angles = view_angles if config.rotation_axes[1] else np.array([0]) |
|
y_angles = view_angles if config.rotation_axes[2] else np.array([0]) |
|
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles) |
|
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1) |
|
i = np.random.choice(len(angles)) |
|
basis = get_change_of_basis(motion, angles[i]) |
|
motion = preprocess_mixamo(motion) |
|
motion = rotate_motion_3d(motion, basis) |
|
motion = localize_motion(motion) |
|
all_joints.append(motion) |
|
else: |
|
motion = preprocess_mixamo(motion) |
|
motion = rotate_motion_3d(motion, basis) |
|
motion = localize_motion(motion) |
|
all_joints.append(motion) |
|
else: |
|
motion = motion * 128 |
|
motion_proj = localize_motion(motion) |
|
all_joints.append(motion_proj) |
|
|
|
all_joints = np.concatenate(all_joints, axis=2) |
|
|
|
meanpose = np.mean(all_joints, axis=2) |
|
stdpose = np.std(all_joints, axis=2) |
|
stdpose[np.where(stdpose == 0)] = 1e-9 |
|
|
|
return meanpose, stdpose |
|
|
|
|
|
class MixamoDataset(_MixamoDatasetBase): |
|
|
|
def __init__(self, phase, config): |
|
super(MixamoDataset, self).__init__(phase, config) |
|
x_angles = view_angles if config.rotation_axes[0] else np.array([0]) |
|
z_angles = view_angles if config.rotation_axes[1] else np.array([0]) |
|
y_angles = view_angles if config.rotation_axes[2] else np.array([0]) |
|
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles) |
|
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1) |
|
self.view_angles = angles |
|
|
|
def preprocessing(self, motion3d, view_angle=None, params=None): |
|
""" |
|
:param item: filename built from self.build_tiem |
|
:return: |
|
""" |
|
|
|
if self.aug: motion3d, params = self.augmentation(motion3d, params) |
|
|
|
basis = None |
|
if view_angle is not None: basis = get_change_of_basis(motion3d, view_angle) |
|
|
|
motion3d = preprocess_mixamo(motion3d) |
|
motion3d = rotate_motion_3d(motion3d, basis) |
|
motion3d = localize_motion(motion3d) |
|
motion3d = normalize_motion(motion3d, self.meanpose, self.stdpose) |
|
|
|
motion2d = motion3d[:, [0, 2], :] |
|
|
|
motion3d = motion3d.reshape([-1, motion3d.shape[-1]]) |
|
motion2d = motion2d.reshape([-1, motion2d.shape[-1]]) |
|
|
|
motion3d = torch.from_numpy(motion3d).float() |
|
motion2d = torch.from_numpy(motion2d).float() |
|
|
|
return motion3d, motion2d |
|
|
|
def __getitem__(self, index): |
|
|
|
idx_a, idx_b = np.random.choice(len(self.motion_names), size=2, replace=False) |
|
mot_a, mot_b = self.motion_names[idx_a], self.motion_names[idx_b] |
|
|
|
idx_a, idx_b = np.random.choice(len(self.character_names), size=2, replace=False) |
|
char_a, char_b = self.character_names[idx_a], self.character_names[idx_b] |
|
idx_a, idx_b = np.random.choice(len(self.view_angles), size=2, replace=False) |
|
view_a, view_b = self.view_angles[idx_a], self.view_angles[idx_b] |
|
|
|
if self.aug: |
|
param_a = self.gen_aug_params(rotate=False) |
|
param_b = self.gen_aug_params(rotate=False) |
|
else: |
|
param_a = param_b = None |
|
|
|
item_a = self.load_item(self.build_item(mot_a, char_a)) |
|
item_b = self.load_item(self.build_item(mot_b, char_b)) |
|
item_ab = self.load_item(self.build_item(mot_a, char_b)) |
|
item_ba = self.load_item(self.build_item(mot_b, char_a)) |
|
|
|
X_a, x_a = self.preprocessing(item_a, view_a, param_a) |
|
X_b, x_b = self.preprocessing(item_b, view_b, param_b) |
|
|
|
X_aab, x_aab = self.preprocessing(item_a, view_b, param_a) |
|
X_bba, x_bba = self.preprocessing(item_b, view_a, param_b) |
|
X_aba, x_aba = self.preprocessing(item_ab, view_a, param_b) |
|
X_bab, x_bab = self.preprocessing(item_ba, view_b, param_a) |
|
X_abb, x_abb = self.preprocessing(item_ab, view_b, param_b) |
|
X_baa, x_baa = self.preprocessing(item_ba, view_a, param_a) |
|
|
|
return {"X_a": X_a, "X_b": X_b, |
|
"X_aab": X_aab, "X_bba": X_bba, |
|
"X_aba": X_aba, "X_bab": X_bab, |
|
"X_abb": X_abb, "X_baa": X_baa, |
|
"x_a": x_a, "x_b": x_b, |
|
"x_aab": x_aab, "x_bba": x_bba, |
|
"x_aba": x_aba, "x_bab": x_bab, |
|
"x_abb": x_abb, "x_baa": x_baa, |
|
"mot_a": mot_a, "mot_b": mot_b, |
|
"char_a": char_a, "char_b": char_b, |
|
"view_a": view_a, "view_b": view_b, |
|
"meanpose": self.meanpose, "stdpose": self.stdpose} |
|
|
|
|
|
class MixamoLimbScaleDataset(_MixamoDatasetBase): |
|
|
|
def __init__(self, phase, config): |
|
super(MixamoLimbScaleDataset, self).__init__(phase, config) |
|
self.global_range = config.global_range |
|
self.local_range = config.local_range |
|
|
|
x_angles = view_angles if config.rotation_axes[0] else np.array([0]) |
|
z_angles = view_angles if config.rotation_axes[1] else np.array([0]) |
|
y_angles = view_angles if config.rotation_axes[2] else np.array([0]) |
|
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles) |
|
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1) |
|
self.view_angles = angles |
|
|
|
def preprocessing(self, motion3d, view_angle=None, params=None): |
|
if self.aug: motion3d, params = self.augmentation(motion3d, params) |
|
|
|
basis = None |
|
if view_angle is not None: basis = get_change_of_basis(motion3d, view_angle) |
|
|
|
motion3d = preprocess_mixamo(motion3d) |
|
motion3d = rotate_motion_3d(motion3d, basis) |
|
motion2d = motion3d[:, [0, 2], :] |
|
motion2d_scale = limb_scale_motion_2d(motion2d, self.global_range, self.local_range) |
|
|
|
motion2d = localize_motion(motion2d) |
|
motion2d_scale = localize_motion(motion2d_scale) |
|
|
|
motion2d = normalize_motion(motion2d, self.meanpose, self.stdpose) |
|
motion2d_scale = normalize_motion(motion2d_scale, self.meanpose, self.stdpose) |
|
|
|
motion2d = motion2d.reshape([-1, motion2d.shape[-1]]) |
|
motion2d_scale = motion2d_scale.reshape((-1, motion2d_scale.shape[-1])) |
|
motion2d = torch.from_numpy(motion2d).float() |
|
motion2d_scale = torch.from_numpy(motion2d_scale).float() |
|
|
|
return motion2d, motion2d_scale |
|
|
|
def __getitem__(self, index): |
|
|
|
motion_idx = np.random.choice(len(self.motion_names)) |
|
motion = self.motion_names[motion_idx] |
|
|
|
char_idx = np.random.choice(len(self.character_names)) |
|
character = self.character_names[char_idx] |
|
view_idx = np.random.choice(len(self.view_angles)) |
|
view = self.view_angles[view_idx] |
|
|
|
if self.aug: |
|
param = self.gen_aug_params(rotate=True) |
|
else: |
|
param = None |
|
|
|
item = self.build_item(motion, character) |
|
|
|
x, x_s = self.preprocessing(self.load_item(item), view, param) |
|
|
|
return {"x": x, "x_s": x_s, "mot": motion, "char": character, "view": view, |
|
"meanpose": self.meanpose, "stdpose": self.stdpose} |
|
|
|
|
|
class SoloDanceDataset(Dataset): |
|
|
|
def __init__(self, phase, config): |
|
super(SoloDanceDataset, self).__init__() |
|
self.global_range = config.global_range |
|
self.local_range = config.local_range |
|
|
|
assert phase in ['train', 'test'] |
|
self.data_root = config.train_dir if phase=='train' else config.test_dir |
|
self.phase = phase |
|
self.unit = config.unit |
|
self.meanpose_path = config.train_meanpose_path if phase == 'train' else config.test_meanpose_path |
|
self.stdpose_path = config.train_stdpose_path if phase == 'train' else config.test_stdpose_path |
|
self.character_names = sorted(os.listdir(self.data_root)) |
|
|
|
self.items = glob.glob(os.path.join(self.data_root, '*/*/motions/*.npy')) |
|
self.meanpose, self.stdpose = get_meanpose(phase, config) |
|
self.meanpose = self.meanpose.astype(np.float32) |
|
self.stdpose = self.stdpose.astype(np.float32) |
|
|
|
if 'preload' in config and config.preload: |
|
self.preload() |
|
self.cached = True |
|
else: |
|
self.cached = False |
|
|
|
def load_item(self, item): |
|
if self.cached: |
|
data = self.cache[item] |
|
else: |
|
data = np.load(item) |
|
return data |
|
|
|
def preload(self): |
|
print("pre-loading into memory") |
|
pbar = tqdm(total=len(self)) |
|
self.cache = {} |
|
for item in self.items: |
|
motion = np.load(item) |
|
self.cache[item] = motion |
|
pbar.update(1) |
|
|
|
def preprocessing(self, motion): |
|
|
|
motion = motion * self.unit |
|
|
|
motion[1, :, :] = (motion[2, :, :] + motion[5, :, :]) / 2 |
|
motion[8, :, :] = (motion[9, :, :] + motion[12, :, :]) / 2 |
|
|
|
global_scale = self.global_range[0] + np.random.random() * (self.global_range[1] - self.global_range[0]) |
|
local_scales = self.local_range[0] + np.random.random([8]) * (self.local_range[1] - self.local_range[0]) |
|
motion_scale = scale_limbs(motion, global_scale, local_scales) |
|
|
|
motion = localize_motion(motion) |
|
motion_scale = localize_motion(motion_scale) |
|
motion = normalize_motion(motion, self.meanpose, self.stdpose) |
|
motion_scale = normalize_motion(motion_scale, self.meanpose, self.stdpose) |
|
motion = motion.reshape((-1, motion.shape[-1])) |
|
motion_scale = motion_scale.reshape((-1, motion_scale.shape[-1])) |
|
motion = torch.from_numpy(motion).float() |
|
motion_scale = torch.from_numpy(motion_scale).float() |
|
return motion, motion_scale |
|
|
|
def __len__(self): |
|
return len(self.items) |
|
|
|
def __getitem__(self, index): |
|
item = self.items[index] |
|
motion = self.load_item(item) |
|
x, x_s = self.preprocessing(motion) |
|
return {"x": x, "x_s": x_s, "meanpose": self.meanpose, "stdpose": self.stdpose} |
|
|