TMM / lib /data.py
Fazhong Liu
fin
7ca9b42
raw
history blame
16.4 kB
import sys, os
thismodule = sys.modules[__name__]
from lib.util.motion import preprocess_mixamo, rotate_motion_3d, limb_scale_motion_2d, normalize_motion, get_change_of_basis, localize_motion, scale_limbs
import torch
import glob
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict as edict
from tqdm import tqdm
view_angles = np.array([ i * np.pi / 6 for i in range(-3, 4)])
def get_dataloader(phase, config):
config.data.batch_size = config.batch_size
config.data.seq_len = config.seq_len
dataset_cls_name = config.data.train_cls if phase == 'train' else config.data.eval_cls
dataset_cls = getattr(thismodule, dataset_cls_name)
dataset = dataset_cls(phase, config.data)
dataloader = DataLoader(dataset, shuffle=(phase=='train'),
batch_size=config.batch_size,
num_workers=(config.data.num_workers if phase == 'train' else 1),
worker_init_fn=lambda _: np.random.seed(),
drop_last=True)
return dataloader
class _MixamoDatasetBase(Dataset):
def __init__(self, phase, config):
super(_MixamoDatasetBase, self).__init__()
assert phase in ['train', 'test']
self.phase = phase
self.data_root = config.train_dir if phase=='train' else config.test_dir
self.meanpose_path = config.train_meanpose_path if phase=='train' else config.test_meanpose_path
self.stdpose_path = config.train_stdpose_path if phase=='train' else config.test_stdpose_path
self.unit = config.unit
self.aug = (phase == 'train')
self.character_names = sorted(os.listdir(self.data_root))
items = glob.glob(os.path.join(self.data_root, self.character_names[0], '*/motions/*.npy'))
self.motion_names = ['/'.join(x.split('/')[-3:]) for x in items]
self.meanpose, self.stdpose = get_meanpose(phase, config)
self.meanpose = self.meanpose.astype(np.float32)
self.stdpose = self.stdpose.astype(np.float32)
if 'preload' in config and config.preload:
self.preload()
self.cached = True
else:
self.cached = False
def build_item(self, mot_name, char_name):
"""
:param mot_name: animation_name/motions/xxx.npy
:param char_name: character_name
:return:
"""
return os.path.join(self.data_root, char_name, mot_name)
def load_item(self, item):
if self.cached:
data = self.cache[item]
else:
data = np.load(item)
return data
def preload(self):
print("pre-loading into memory")
pbar = tqdm(total=len(self))
self.cache = {}
for motion_name in self.motion_names:
for character_name in self.character_names:
item = self.build_item(motion_name, character_name)
motion3d = np.load(item)
self.cache[item] = motion3d
pbar.update(1)
@staticmethod
def gen_aug_params(rotate=False):
if rotate:
params = {'ratio': np.random.uniform(0.8, 1.2),
'roll': np.random.uniform((-np.pi / 9, -np.pi / 9, -np.pi / 6), (np.pi / 9, np.pi / 9, np.pi / 6))}
else:
params = {'ratio': np.random.uniform(0.5, 1.5)}
return edict(params)
@staticmethod
def augmentation(data, params=None):
"""
:param data: numpy array of size (joints, 3, len_frames)
:return:
"""
if params is None:
return data, params
# rotate
if 'roll' in params.keys():
cx, cy, cz = np.cos(params.roll)
sx, sy, sz = np.sin(params.roll)
mat33_x = np.array([
[1, 0, 0],
[0, cx, -sx],
[0, sx, cx]
], dtype='float')
mat33_y = np.array([
[cy, 0, sy],
[0, 1, 0],
[-sy, 0, cy]
], dtype='float')
mat33_z = np.array([
[cz, -sz, 0],
[sz, cz, 0],
[0, 0, 1]
], dtype='float')
data = mat33_x @ mat33_y @ mat33_z @ data
# scale
if 'ratio' in params.keys():
data = data * params.ratio
return data, params
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
return len(self.motion_names) * len(self.character_names)
def get_meanpose(phase, config):
meanpose_path = config.train_meanpose_path if phase == "train" else config.test_meanpose_path
stdpose_path = config.train_stdpose_path if phase == "train" else config.test_stdpose_path
if os.path.exists(meanpose_path) and os.path.exists(stdpose_path):
meanpose = np.load(meanpose_path)
stdpose = np.load(stdpose_path)
else:
meanpose, stdpose = gen_meanpose(phase, config)
np.save(meanpose_path, meanpose)
np.save(stdpose_path, stdpose)
print("meanpose saved at {}".format(meanpose_path))
print("stdpose saved at {}".format(stdpose_path))
if meanpose.shape[-1] == 2:
mean_x, mean_y = meanpose[:, 0], meanpose[:, 1]
meanpose = np.stack([mean_x, mean_x, mean_y], axis=1)
if stdpose.shape[-1] == 2:
std_x, std_y = stdpose[:, 0], stdpose[:, 1]
stdpose = np.stack([std_x, std_x, std_y], axis=1)
return meanpose, stdpose
def gen_meanpose(phase, config, n_samp=20000):
data_dir = config.train_dir if phase == "train" else config.test_dir
all_paths = glob.glob(os.path.join(data_dir, '*/*/motions/*.npy'))
random.shuffle(all_paths)
all_paths = all_paths[:n_samp]
all_joints = []
print("computing meanpose and stdpose")
for path in tqdm(all_paths):
motion = np.load(path)
if motion.shape[1] == 3:
basis = None
if sum(config.rotation_axes) > 0:
x_angles = view_angles if config.rotation_axes[0] else np.array([0])
z_angles = view_angles if config.rotation_axes[1] else np.array([0])
y_angles = view_angles if config.rotation_axes[2] else np.array([0])
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
i = np.random.choice(len(angles))
basis = get_change_of_basis(motion, angles[i])
motion = preprocess_mixamo(motion)
motion = rotate_motion_3d(motion, basis)
motion = localize_motion(motion)
all_joints.append(motion)
else:
motion = preprocess_mixamo(motion)
motion = rotate_motion_3d(motion, basis)
motion = localize_motion(motion)
all_joints.append(motion)
else:
motion = motion * 128
motion_proj = localize_motion(motion)
all_joints.append(motion_proj)
all_joints = np.concatenate(all_joints, axis=2)
meanpose = np.mean(all_joints, axis=2)
stdpose = np.std(all_joints, axis=2)
stdpose[np.where(stdpose == 0)] = 1e-9
return meanpose, stdpose
class MixamoDataset(_MixamoDatasetBase):
def __init__(self, phase, config):
super(MixamoDataset, self).__init__(phase, config)
x_angles = view_angles if config.rotation_axes[0] else np.array([0])
z_angles = view_angles if config.rotation_axes[1] else np.array([0])
y_angles = view_angles if config.rotation_axes[2] else np.array([0])
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
self.view_angles = angles
def preprocessing(self, motion3d, view_angle=None, params=None):
"""
:param item: filename built from self.build_tiem
:return:
"""
if self.aug: motion3d, params = self.augmentation(motion3d, params)
basis = None
if view_angle is not None: basis = get_change_of_basis(motion3d, view_angle)
motion3d = preprocess_mixamo(motion3d)
motion3d = rotate_motion_3d(motion3d, basis)
motion3d = localize_motion(motion3d)
motion3d = normalize_motion(motion3d, self.meanpose, self.stdpose)
motion2d = motion3d[:, [0, 2], :]
motion3d = motion3d.reshape([-1, motion3d.shape[-1]])
motion2d = motion2d.reshape([-1, motion2d.shape[-1]])
motion3d = torch.from_numpy(motion3d).float()
motion2d = torch.from_numpy(motion2d).float()
return motion3d, motion2d
def __getitem__(self, index):
# select two motions
idx_a, idx_b = np.random.choice(len(self.motion_names), size=2, replace=False)
mot_a, mot_b = self.motion_names[idx_a], self.motion_names[idx_b]
# select two characters
idx_a, idx_b = np.random.choice(len(self.character_names), size=2, replace=False)
char_a, char_b = self.character_names[idx_a], self.character_names[idx_b]
idx_a, idx_b = np.random.choice(len(self.view_angles), size=2, replace=False)
view_a, view_b = self.view_angles[idx_a], self.view_angles[idx_b]
if self.aug:
param_a = self.gen_aug_params(rotate=False)
param_b = self.gen_aug_params(rotate=False)
else:
param_a = param_b = None
item_a = self.load_item(self.build_item(mot_a, char_a))
item_b = self.load_item(self.build_item(mot_b, char_b))
item_ab = self.load_item(self.build_item(mot_a, char_b))
item_ba = self.load_item(self.build_item(mot_b, char_a))
X_a, x_a = self.preprocessing(item_a, view_a, param_a)
X_b, x_b = self.preprocessing(item_b, view_b, param_b)
X_aab, x_aab = self.preprocessing(item_a, view_b, param_a)
X_bba, x_bba = self.preprocessing(item_b, view_a, param_b)
X_aba, x_aba = self.preprocessing(item_ab, view_a, param_b)
X_bab, x_bab = self.preprocessing(item_ba, view_b, param_a)
X_abb, x_abb = self.preprocessing(item_ab, view_b, param_b)
X_baa, x_baa = self.preprocessing(item_ba, view_a, param_a)
return {"X_a": X_a, "X_b": X_b,
"X_aab": X_aab, "X_bba": X_bba,
"X_aba": X_aba, "X_bab": X_bab,
"X_abb": X_abb, "X_baa": X_baa,
"x_a": x_a, "x_b": x_b,
"x_aab": x_aab, "x_bba": x_bba,
"x_aba": x_aba, "x_bab": x_bab,
"x_abb": x_abb, "x_baa": x_baa,
"mot_a": mot_a, "mot_b": mot_b,
"char_a": char_a, "char_b": char_b,
"view_a": view_a, "view_b": view_b,
"meanpose": self.meanpose, "stdpose": self.stdpose}
class MixamoLimbScaleDataset(_MixamoDatasetBase):
def __init__(self, phase, config):
super(MixamoLimbScaleDataset, self).__init__(phase, config)
self.global_range = config.global_range
self.local_range = config.local_range
x_angles = view_angles if config.rotation_axes[0] else np.array([0])
z_angles = view_angles if config.rotation_axes[1] else np.array([0])
y_angles = view_angles if config.rotation_axes[2] else np.array([0])
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
self.view_angles = angles
def preprocessing(self, motion3d, view_angle=None, params=None):
if self.aug: motion3d, params = self.augmentation(motion3d, params)
basis = None
if view_angle is not None: basis = get_change_of_basis(motion3d, view_angle)
motion3d = preprocess_mixamo(motion3d)
motion3d = rotate_motion_3d(motion3d, basis)
motion2d = motion3d[:, [0, 2], :]
motion2d_scale = limb_scale_motion_2d(motion2d, self.global_range, self.local_range)
motion2d = localize_motion(motion2d)
motion2d_scale = localize_motion(motion2d_scale)
motion2d = normalize_motion(motion2d, self.meanpose, self.stdpose)
motion2d_scale = normalize_motion(motion2d_scale, self.meanpose, self.stdpose)
motion2d = motion2d.reshape([-1, motion2d.shape[-1]])
motion2d_scale = motion2d_scale.reshape((-1, motion2d_scale.shape[-1]))
motion2d = torch.from_numpy(motion2d).float()
motion2d_scale = torch.from_numpy(motion2d_scale).float()
return motion2d, motion2d_scale
def __getitem__(self, index):
# select two motions
motion_idx = np.random.choice(len(self.motion_names))
motion = self.motion_names[motion_idx]
# select two characters
char_idx = np.random.choice(len(self.character_names))
character = self.character_names[char_idx]
view_idx = np.random.choice(len(self.view_angles))
view = self.view_angles[view_idx]
if self.aug:
param = self.gen_aug_params(rotate=True)
else:
param = None
item = self.build_item(motion, character)
x, x_s = self.preprocessing(self.load_item(item), view, param)
return {"x": x, "x_s": x_s, "mot": motion, "char": character, "view": view,
"meanpose": self.meanpose, "stdpose": self.stdpose}
class SoloDanceDataset(Dataset):
def __init__(self, phase, config):
super(SoloDanceDataset, self).__init__()
self.global_range = config.global_range
self.local_range = config.local_range
assert phase in ['train', 'test']
self.data_root = config.train_dir if phase=='train' else config.test_dir
self.phase = phase
self.unit = config.unit
self.meanpose_path = config.train_meanpose_path if phase == 'train' else config.test_meanpose_path
self.stdpose_path = config.train_stdpose_path if phase == 'train' else config.test_stdpose_path
self.character_names = sorted(os.listdir(self.data_root))
self.items = glob.glob(os.path.join(self.data_root, '*/*/motions/*.npy'))
self.meanpose, self.stdpose = get_meanpose(phase, config)
self.meanpose = self.meanpose.astype(np.float32)
self.stdpose = self.stdpose.astype(np.float32)
if 'preload' in config and config.preload:
self.preload()
self.cached = True
else:
self.cached = False
def load_item(self, item):
if self.cached:
data = self.cache[item]
else:
data = np.load(item)
return data
def preload(self):
print("pre-loading into memory")
pbar = tqdm(total=len(self))
self.cache = {}
for item in self.items:
motion = np.load(item)
self.cache[item] = motion
pbar.update(1)
def preprocessing(self, motion):
motion = motion * self.unit
motion[1, :, :] = (motion[2, :, :] + motion[5, :, :]) / 2
motion[8, :, :] = (motion[9, :, :] + motion[12, :, :]) / 2
global_scale = self.global_range[0] + np.random.random() * (self.global_range[1] - self.global_range[0])
local_scales = self.local_range[0] + np.random.random([8]) * (self.local_range[1] - self.local_range[0])
motion_scale = scale_limbs(motion, global_scale, local_scales)
motion = localize_motion(motion)
motion_scale = localize_motion(motion_scale)
motion = normalize_motion(motion, self.meanpose, self.stdpose)
motion_scale = normalize_motion(motion_scale, self.meanpose, self.stdpose)
motion = motion.reshape((-1, motion.shape[-1]))
motion_scale = motion_scale.reshape((-1, motion_scale.shape[-1]))
motion = torch.from_numpy(motion).float()
motion_scale = torch.from_numpy(motion_scale).float()
return motion, motion_scale
def __len__(self):
return len(self.items)
def __getitem__(self, index):
item = self.items[index]
motion = self.load_item(item)
x, x_s = self.preprocessing(motion)
return {"x": x, "x_s": x_s, "meanpose": self.meanpose, "stdpose": self.stdpose}