Fazhong Liu
commited on
Commit
•
7ca9b42
1
Parent(s):
a4660a7
fin
Browse files- .gitattributes +35 -35
- .gitignore +1 -0
- README.md +4 -4
- app.py +114 -0
- configs/transmomo.yaml +100 -0
- configs/transmomo_solo_dance.yaml +100 -0
- data/meanpose_with_view.npy +0 -0
- data/mse_description.json +1 -0
- data/stdpose_with_view.npy +0 -0
- lib/__init__.py +0 -0
- lib/__pycache__/__init__.cpython-38.pyc +0 -0
- lib/__pycache__/data.cpython-38.pyc +0 -0
- lib/__pycache__/network.cpython-38.pyc +0 -0
- lib/__pycache__/operation.cpython-38.pyc +0 -0
- lib/data.py +421 -0
- lib/loss.py +57 -0
- lib/network.py +356 -0
- lib/operation.py +219 -0
- lib/trainer.py +298 -0
- lib/util/__init__.py +0 -0
- lib/util/__pycache__/__init__.cpython-37.pyc +0 -0
- lib/util/__pycache__/__init__.cpython-38.pyc +0 -0
- lib/util/__pycache__/general.cpython-37.pyc +0 -0
- lib/util/__pycache__/general.cpython-38.pyc +0 -0
- lib/util/__pycache__/motion.cpython-37.pyc +0 -0
- lib/util/__pycache__/motion.cpython-38.pyc +0 -0
- lib/util/__pycache__/visualization.cpython-37.pyc +0 -0
- lib/util/__pycache__/visualization.cpython-38.pyc +0 -0
- lib/util/general.py +361 -0
- lib/util/global_norm.py +29 -0
- lib/util/motion.py +309 -0
- lib/util/visualization.py +448 -0
- requirements.txt +17 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# *.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
# *.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
# *.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
# *.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
# *.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
# *.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
# *.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
# *.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
# *.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
# *.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
# *.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
# *.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
# *.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
# *.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
# *.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
# *.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
# *.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
# *.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
# *.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
# *.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
# *.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
# *.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
# *.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
# *.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
# *.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
# saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
# *.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
# *.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
# *.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
# *.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
# *.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
# *.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
# *.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
# *.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
# *tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pt
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.24.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: Transmomo
|
3 |
+
emoji: 📈
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.24.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import shutil
|
3 |
+
import gradio as gr
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
from lib.data import get_meanpose
|
10 |
+
from lib.network import get_autoencoder
|
11 |
+
from lib.util.motion import preprocess_mixamo, preprocess_test, postprocess
|
12 |
+
from lib.util.general import get_config
|
13 |
+
from lib.operation import rotate_and_maybe_project_world
|
14 |
+
from itertools import combinations
|
15 |
+
from lib.util.visualization import motion2video
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
def load_and_preprocess(path, config, mean_pose, std_pose):
|
19 |
+
|
20 |
+
motion3d = np.load(path)
|
21 |
+
|
22 |
+
# length must be multiples of 8 due to the size of convolution
|
23 |
+
_, _, T = motion3d.shape
|
24 |
+
T = (T // 8) * 8
|
25 |
+
motion3d = motion3d[:, :, :T]
|
26 |
+
|
27 |
+
# project to 2d
|
28 |
+
motion_proj = motion3d[:, [0, 2], :]
|
29 |
+
|
30 |
+
# reformat for mixamo data
|
31 |
+
motion_proj = preprocess_mixamo(motion_proj, unit=1.0)
|
32 |
+
|
33 |
+
# preprocess for network input
|
34 |
+
motion_proj, start = preprocess_test(motion_proj, mean_pose, std_pose, config.data.unit)
|
35 |
+
motion_proj = motion_proj.reshape((-1, motion_proj.shape[-1]))
|
36 |
+
motion_proj = torch.from_numpy(motion_proj).float()
|
37 |
+
|
38 |
+
return motion_proj, start
|
39 |
+
|
40 |
+
def handle_motion_generation(npy1,npy2):
|
41 |
+
path1 = './data/a.npy'
|
42 |
+
path2 = './data/b.npy'
|
43 |
+
np.save(path1,npy1)
|
44 |
+
np.save(path2,npy2)
|
45 |
+
config_path = './configs/transmomo.yaml' # 替换为您的配置文件路径
|
46 |
+
description_path = "./data/mse_description.json"
|
47 |
+
checkpoint_path = './data/autoencoder_00200000.pt'
|
48 |
+
out_dir_path = './output' # 替换为输出目录的路径
|
49 |
+
|
50 |
+
config = get_config(config_path)
|
51 |
+
ae = get_autoencoder(config)
|
52 |
+
ae.load_state_dict(torch.load(checkpoint_path))
|
53 |
+
ae.cuda()
|
54 |
+
ae.eval()
|
55 |
+
mean_pose, std_pose = get_meanpose("test", config.data)
|
56 |
+
# print("loaded model")
|
57 |
+
|
58 |
+
description = json.load(open(description_path))
|
59 |
+
chars = list(description.keys())
|
60 |
+
|
61 |
+
os.makedirs(out_dir_path, exist_ok=True)
|
62 |
+
|
63 |
+
# path1 = '/home/fazhong/studio/transmomo.pytorch/data/mixamo/36_800_24/test/PUMPKINHULK_L/Back_Squat/motions/2.npy'
|
64 |
+
# path2 = '/home/fazhong/studio/transmomo.pytorch/data/mixamo/36_800_24/test/PUMPKINHULK_L/Golf_Post_Shot/motions/3.npy'
|
65 |
+
out_path1 = os.path.join(out_dir_path, "adv.npy")
|
66 |
+
|
67 |
+
|
68 |
+
x_a, x_a_start = load_and_preprocess(path1, config, mean_pose, std_pose)
|
69 |
+
x_b, x_b_start = load_and_preprocess(path2, config, mean_pose, std_pose)
|
70 |
+
|
71 |
+
x_a_batch = x_a.unsqueeze(0).cuda()
|
72 |
+
x_b_batch = x_b.unsqueeze(0).cuda()
|
73 |
+
|
74 |
+
x_ab = ae.cross2d(x_a_batch, x_b_batch, x_a_batch)
|
75 |
+
x_ab = postprocess(x_ab, mean_pose, std_pose, config.data.unit, start=x_a_start)
|
76 |
+
|
77 |
+
np.save(out_path1, x_ab)
|
78 |
+
motion_data = x_ab
|
79 |
+
height = 512 # 视频的高度
|
80 |
+
width = 512 # 视频的宽度
|
81 |
+
save_path = './an.mp4' # 保存视频的路径
|
82 |
+
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] # 关节颜色
|
83 |
+
bg_color = (255, 255, 255) # 背景颜色
|
84 |
+
fps = 25 # 视频的帧率
|
85 |
+
|
86 |
+
# print(motion_data.shape)
|
87 |
+
# 调用函数生成视频
|
88 |
+
motion2video(motion_data, height, width, save_path, colors, bg_color=bg_color, transparency=False, fps=fps)
|
89 |
+
first_frame_image = Image.open('./an-frames/0000.png')
|
90 |
+
return first_frame_image
|
91 |
+
# print('hi')
|
92 |
+
|
93 |
+
with gr.Blocks() as demo:
|
94 |
+
gr.Markdown("Upload two `.npy` files to generate motion and visualize the first frame of the output animation.")
|
95 |
+
|
96 |
+
with gr.Row():
|
97 |
+
file1 = gr.File(file_types=[".npy"], label="Upload first .npy file")
|
98 |
+
file2 = gr.File(file_types=[".npy"], label="Upload second .npy file")
|
99 |
+
|
100 |
+
with gr.Row():
|
101 |
+
generate_btn = gr.Button("Generate Motion")
|
102 |
+
|
103 |
+
output_image = gr.Image(label="First Frame of the Generated Animation")
|
104 |
+
|
105 |
+
generate_btn.click(
|
106 |
+
fn=handle_motion_generation,
|
107 |
+
inputs=[file1, file2],
|
108 |
+
outputs=output_image
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
# tsp_page.launch(debug = True)
|
114 |
+
demo.launch()
|
configs/transmomo.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trainer: TransmomoTrainer
|
2 |
+
K: 3
|
3 |
+
rotation_axes: &rotation_axes [0, 0, 1] # horizontal, depth, vertical
|
4 |
+
body_reference: &body_reference True # if set True, will use spine as vertical axis
|
5 |
+
|
6 |
+
# model options
|
7 |
+
n_joints: 15 # number of body joints
|
8 |
+
seq_len: 64 # length of motion sequence
|
9 |
+
|
10 |
+
# logger options
|
11 |
+
snapshot_save_iter: 20000
|
12 |
+
log_iter: 40
|
13 |
+
val_iter: 400
|
14 |
+
val_batches: 10
|
15 |
+
|
16 |
+
# optimization options
|
17 |
+
max_iter: 200000 # maximum number of training iterations
|
18 |
+
batch_size: 64 # batch size
|
19 |
+
weight_decay: 0.0001 # weight decay
|
20 |
+
beta1: 0.5 # Adam parameter
|
21 |
+
beta2: 0.999 # Adam parameter
|
22 |
+
init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
|
23 |
+
lr: 0.0002 # initial learning rate
|
24 |
+
lr_policy: step # learning rate scheduler
|
25 |
+
step_size: 20000 # how often to decay learning rate
|
26 |
+
gamma: 0.5 # how much to decay learning rate
|
27 |
+
|
28 |
+
trans_gan_w: 2 # weight of GAN loss
|
29 |
+
trans_gan_ls_w: 0 # if set > 0, will treat limb-scaled data as "real" data
|
30 |
+
recon_x_w: 10 # weight of reconstruction loss
|
31 |
+
cross_x_w: 4 # weight of cross reconstruction loss
|
32 |
+
inv_v_ls_w: 2 # weight of view invariance loss against limb scale
|
33 |
+
inv_m_ls_w: 2 # weight of motion invariance loss against limb scale
|
34 |
+
inv_b_trans_w: 2 # weight of body invariance loss against rotation
|
35 |
+
inv_m_trans_w: 2 # weight of motion invariance loss against rotation
|
36 |
+
|
37 |
+
triplet_b_w: 10 # weight of body triplet loss
|
38 |
+
triplet_v_w: 10 # weight of view triplet loss
|
39 |
+
triplet_margin: 0.2 # triplet loss: margin
|
40 |
+
triplet_neg_range: [0.0, 0.5] # triplet loss: range of negative examples
|
41 |
+
|
42 |
+
# network options
|
43 |
+
autoencoder:
|
44 |
+
cls: Autoencoder3f
|
45 |
+
body_reference: *body_reference
|
46 |
+
motion_encoder:
|
47 |
+
cls: ConvEncoder
|
48 |
+
channels: [30, 64, 128, 128]
|
49 |
+
padding: 3
|
50 |
+
kernel_size: 8
|
51 |
+
conv_stride: 2
|
52 |
+
conv_pool: null
|
53 |
+
body_encoder:
|
54 |
+
cls: ConvEncoder
|
55 |
+
channels: [28, 64, 128, 256]
|
56 |
+
padding: 2
|
57 |
+
kernel_size: 7
|
58 |
+
conv_stride: 1
|
59 |
+
conv_pool: AvgPool1d
|
60 |
+
global_pool: avg_pool1d
|
61 |
+
view_encoder:
|
62 |
+
cls: ConvEncoder
|
63 |
+
channels: [28, 64, 32, 8]
|
64 |
+
padding: 2
|
65 |
+
kernel_size: 7
|
66 |
+
conv_stride: 1
|
67 |
+
conv_pool: MaxPool1d
|
68 |
+
global_pool: max_pool1d
|
69 |
+
decoder:
|
70 |
+
channels: [392, 256, 128, 45]
|
71 |
+
kernel_size: 7
|
72 |
+
|
73 |
+
discriminator:
|
74 |
+
encoder_cls: ConvEncoder
|
75 |
+
gan_type: lsgan
|
76 |
+
channels: [30, 64, 96, 128]
|
77 |
+
padding: 3
|
78 |
+
kernel_size: 8
|
79 |
+
conv_stride: 2
|
80 |
+
conv_pool: null
|
81 |
+
|
82 |
+
body_discriminator:
|
83 |
+
gan_type: lsgan
|
84 |
+
channels: [512, 128, 32]
|
85 |
+
|
86 |
+
# data options
|
87 |
+
data:
|
88 |
+
train_cls: MixamoLimbScaleDataset
|
89 |
+
eval_cls: MixamoDataset
|
90 |
+
global_range: [0.5, 2.0] # limb scale: range of gamma_g
|
91 |
+
local_range: [0.5, 2.0] # limb scale: range of the gammas
|
92 |
+
rotation_axes: *rotation_axes
|
93 |
+
unit: 128
|
94 |
+
# train_dir: ./data/mixamo/36_800_24/train
|
95 |
+
# test_dir: ./data/mixamo/36_800_24/test
|
96 |
+
num_workers: 4
|
97 |
+
train_meanpose_path: ./data/meanpose_with_view.npy
|
98 |
+
train_stdpose_path: ./data/stdpose_with_view.npy
|
99 |
+
test_meanpose_path: ./data/meanpose_with_view.npy
|
100 |
+
test_stdpose_path: ./data/stdpose_with_view.npy
|
configs/transmomo_solo_dance.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
trainer: TransmomoTrainer
|
2 |
+
K: 3
|
3 |
+
rotation_axes: &rotation_axes [0, 0, 1] # horizontal, depth, vertical
|
4 |
+
body_reference: &body_reference True # if set True, will use spine as vertical axis
|
5 |
+
|
6 |
+
# model options
|
7 |
+
n_joints: 15 # number of body joints
|
8 |
+
seq_len: 64 # length of motion sequence
|
9 |
+
|
10 |
+
# logger options
|
11 |
+
snapshot_save_iter: 20000
|
12 |
+
log_iter: 40
|
13 |
+
val_iter: 400
|
14 |
+
val_batches: 10
|
15 |
+
|
16 |
+
# optimization options
|
17 |
+
max_iter: 200000 # maximum number of training iterations
|
18 |
+
batch_size: 64 # batch size
|
19 |
+
weight_decay: 0.0001 # weight decay
|
20 |
+
beta1: 0.5 # Adam parameter
|
21 |
+
beta2: 0.999 # Adam parameter
|
22 |
+
init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
|
23 |
+
lr: 0.0002 # initial learning rate
|
24 |
+
lr_policy: step # learning rate scheduler
|
25 |
+
step_size: 20000 # how often to decay learning rate
|
26 |
+
gamma: 0.5 # how much to decay learning rate
|
27 |
+
|
28 |
+
trans_gan_w: 2 # weight of GAN loss
|
29 |
+
trans_gan_ls_w: 0 # if set > 0, will treat limb-scaled data as "real" data
|
30 |
+
recon_x_w: 10 # weight of reconstruction loss
|
31 |
+
cross_x_w: 4 # weight of cross reconstruction loss
|
32 |
+
inv_v_ls_w: 2 # weight of view invariance loss against limb scale
|
33 |
+
inv_m_ls_w: 2 # weight of motion invariance loss against limb scale
|
34 |
+
inv_b_trans_w: 2 # weight of body invariance loss against rotation
|
35 |
+
inv_m_trans_w: 2 # weight of motion invariance loss against rotation
|
36 |
+
|
37 |
+
triplet_b_w: 10 # weight of body triplet loss
|
38 |
+
triplet_v_w: 10 # weight of view triplet loss
|
39 |
+
triplet_margin: 0.2 # triplet loss: margin
|
40 |
+
triplet_neg_range: [0.0, 0.5] # triplet loss: range of negative examples
|
41 |
+
|
42 |
+
# network options
|
43 |
+
autoencoder:
|
44 |
+
cls: Autoencoder3f
|
45 |
+
body_reference: *body_reference
|
46 |
+
motion_encoder:
|
47 |
+
cls: ConvEncoder
|
48 |
+
channels: [30, 64, 128, 128]
|
49 |
+
padding: 3
|
50 |
+
kernel_size: 8
|
51 |
+
conv_stride: 2
|
52 |
+
conv_pool: null
|
53 |
+
body_encoder:
|
54 |
+
cls: ConvEncoder
|
55 |
+
channels: [28, 64, 128, 256]
|
56 |
+
padding: 2
|
57 |
+
kernel_size: 7
|
58 |
+
conv_stride: 1
|
59 |
+
conv_pool: AvgPool1d
|
60 |
+
global_pool: avg_pool1d
|
61 |
+
view_encoder:
|
62 |
+
cls: ConvEncoder
|
63 |
+
channels: [28, 64, 32, 8]
|
64 |
+
padding: 2
|
65 |
+
kernel_size: 7
|
66 |
+
conv_stride: 1
|
67 |
+
conv_pool: MaxPool1d
|
68 |
+
global_pool: max_pool1d
|
69 |
+
decoder:
|
70 |
+
channels: [392, 256, 128, 45]
|
71 |
+
kernel_size: 7
|
72 |
+
|
73 |
+
discriminator:
|
74 |
+
encoder_cls: ConvEncoder
|
75 |
+
gan_type: lsgan
|
76 |
+
channels: [30, 64, 96, 128]
|
77 |
+
padding: 3
|
78 |
+
kernel_size: 8
|
79 |
+
conv_stride: 2
|
80 |
+
conv_pool: null
|
81 |
+
|
82 |
+
body_discriminator:
|
83 |
+
gan_type: lsgan
|
84 |
+
channels: [512, 128, 32]
|
85 |
+
|
86 |
+
# data options
|
87 |
+
data:
|
88 |
+
train_cls: SoloDanceDataset
|
89 |
+
eval_cls: MixamoDataset
|
90 |
+
global_range: [0.5, 2.0] # limb scale: range of gamma_g
|
91 |
+
local_range: [0.5, 2.0] # limb scale: range of the gammas
|
92 |
+
rotation_axes: *rotation_axes
|
93 |
+
unit: 128
|
94 |
+
train_dir: ./data/solo_dance/train
|
95 |
+
test_dir: ./data/mixamo/36_800_24/test
|
96 |
+
num_workers: 4
|
97 |
+
train_meanpose_path: ./data/mixamo/36_800_24/meanpose_with_view.npy
|
98 |
+
train_stdpose_path: ./data/mixamo/36_800_24/stdpose_with_view.npy
|
99 |
+
test_meanpose_path: ./data/mixamo/36_800_24/meanpose_with_view.npy
|
100 |
+
test_stdpose_path: ./data/mixamo/36_800_24/stdpose_with_view.npy
|
data/meanpose_with_view.npy
ADDED
Binary file (488 Bytes). View file
|
|
data/mse_description.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"ANDROMEDA": ["Goalkeeper_Directing_(1)", "Standing_Aim_Idle_02_Looking", "Pilot_Flips_Switches_(1)", "Running_Tired", "Slide_Hip_Hop_Dance", "Military_Signaling_(3)", "Aim_Pistol", "Standing_Idle_(1)", "Drinking_Fountain", "Golf_Putt_Failure", "Standing_Torch_Burn_Webs", "Back_Squat", "Baseball_Pitching", "Dancing_Twerk", "Superhuman_Choke_Lift", "Zombie_Crawl"], "PUMPKINHULK_L": ["Front_Raises", "Rifle_Idle", "Defender", "Talking_Phone_Pacing", "Sitting_Clap_(2)", "Standing_Clap", "Samba_Dancing_(6)", "Golf_Bad_Shot_(1)", "Golf_Putt_Victory", "Grab_Rifle_And_Put_Back", "Salsa_Dancing_(4)", "Military_Signaling_(2)", "Robot_Hip_Hop_Dance", "Speedbag", "Jog_In_Circle", "Looking_Around"], "SPORTY_GRANY": ["Standing_Torch_Idle_04", "Look_Around_(1)", "Happy", "Standing_Torch_Inspect_Downward", "Quarterback_Pass", "Zombie_Stand_Up_(2)", "Struck_In_Head", "Shooting_Gun", "Zombie_Transition", "Hostage_Situation_Idle_-_Hostage", "Jazz_Dancing_(2)", "Samba_Dancing_(5)", "Knocked_Out_(1)", "Being_Electrocuted", "Falling_From_Losing_Balance", "Pulling_A_Rope"], "TY": ["Golf_Pre-Putt_(1)", "Back_Flip_To_Uppercut", "Standing_Torch_Idle_03", "Tonic_Seizure", "Talking_At_Watercooler", "Sitting_Clap", "Shuffling", "Tender_Placement_(1)", "Helping_Out", "Northern_Soul_Spin_Combo", "Golf_Post_Shot", "Salsa_Dancing_(3)", "Sword_And_Shield_Idle_(1)", "Hip_Hop_Dancing_(1)", "Golf_Tee_Up_(1)"]}
|
data/stdpose_with_view.npy
ADDED
Binary file (488 Bytes). View file
|
|
lib/__init__.py
ADDED
File without changes
|
lib/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (145 Bytes). View file
|
|
lib/__pycache__/data.cpython-38.pyc
ADDED
Binary file (12.6 kB). View file
|
|
lib/__pycache__/network.cpython-38.pyc
ADDED
Binary file (9.88 kB). View file
|
|
lib/__pycache__/operation.cpython-38.pyc
ADDED
Binary file (6.16 kB). View file
|
|
lib/data.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
thismodule = sys.modules[__name__]
|
3 |
+
|
4 |
+
from lib.util.motion import preprocess_mixamo, rotate_motion_3d, limb_scale_motion_2d, normalize_motion, get_change_of_basis, localize_motion, scale_limbs
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import glob
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
from easydict import EasyDict as edict
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
view_angles = np.array([ i * np.pi / 6 for i in range(-3, 4)])
|
15 |
+
|
16 |
+
def get_dataloader(phase, config):
|
17 |
+
|
18 |
+
config.data.batch_size = config.batch_size
|
19 |
+
config.data.seq_len = config.seq_len
|
20 |
+
dataset_cls_name = config.data.train_cls if phase == 'train' else config.data.eval_cls
|
21 |
+
dataset_cls = getattr(thismodule, dataset_cls_name)
|
22 |
+
dataset = dataset_cls(phase, config.data)
|
23 |
+
|
24 |
+
dataloader = DataLoader(dataset, shuffle=(phase=='train'),
|
25 |
+
batch_size=config.batch_size,
|
26 |
+
num_workers=(config.data.num_workers if phase == 'train' else 1),
|
27 |
+
worker_init_fn=lambda _: np.random.seed(),
|
28 |
+
drop_last=True)
|
29 |
+
|
30 |
+
return dataloader
|
31 |
+
|
32 |
+
|
33 |
+
class _MixamoDatasetBase(Dataset):
|
34 |
+
def __init__(self, phase, config):
|
35 |
+
super(_MixamoDatasetBase, self).__init__()
|
36 |
+
|
37 |
+
assert phase in ['train', 'test']
|
38 |
+
self.phase = phase
|
39 |
+
self.data_root = config.train_dir if phase=='train' else config.test_dir
|
40 |
+
self.meanpose_path = config.train_meanpose_path if phase=='train' else config.test_meanpose_path
|
41 |
+
self.stdpose_path = config.train_stdpose_path if phase=='train' else config.test_stdpose_path
|
42 |
+
self.unit = config.unit
|
43 |
+
self.aug = (phase == 'train')
|
44 |
+
self.character_names = sorted(os.listdir(self.data_root))
|
45 |
+
|
46 |
+
items = glob.glob(os.path.join(self.data_root, self.character_names[0], '*/motions/*.npy'))
|
47 |
+
self.motion_names = ['/'.join(x.split('/')[-3:]) for x in items]
|
48 |
+
|
49 |
+
self.meanpose, self.stdpose = get_meanpose(phase, config)
|
50 |
+
self.meanpose = self.meanpose.astype(np.float32)
|
51 |
+
self.stdpose = self.stdpose.astype(np.float32)
|
52 |
+
|
53 |
+
if 'preload' in config and config.preload:
|
54 |
+
self.preload()
|
55 |
+
self.cached = True
|
56 |
+
else:
|
57 |
+
self.cached = False
|
58 |
+
|
59 |
+
def build_item(self, mot_name, char_name):
|
60 |
+
"""
|
61 |
+
:param mot_name: animation_name/motions/xxx.npy
|
62 |
+
:param char_name: character_name
|
63 |
+
:return:
|
64 |
+
"""
|
65 |
+
return os.path.join(self.data_root, char_name, mot_name)
|
66 |
+
|
67 |
+
def load_item(self, item):
|
68 |
+
if self.cached:
|
69 |
+
data = self.cache[item]
|
70 |
+
else:
|
71 |
+
data = np.load(item)
|
72 |
+
return data
|
73 |
+
|
74 |
+
def preload(self):
|
75 |
+
print("pre-loading into memory")
|
76 |
+
pbar = tqdm(total=len(self))
|
77 |
+
self.cache = {}
|
78 |
+
for motion_name in self.motion_names:
|
79 |
+
for character_name in self.character_names:
|
80 |
+
item = self.build_item(motion_name, character_name)
|
81 |
+
motion3d = np.load(item)
|
82 |
+
self.cache[item] = motion3d
|
83 |
+
pbar.update(1)
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def gen_aug_params(rotate=False):
|
87 |
+
if rotate:
|
88 |
+
params = {'ratio': np.random.uniform(0.8, 1.2),
|
89 |
+
'roll': np.random.uniform((-np.pi / 9, -np.pi / 9, -np.pi / 6), (np.pi / 9, np.pi / 9, np.pi / 6))}
|
90 |
+
else:
|
91 |
+
params = {'ratio': np.random.uniform(0.5, 1.5)}
|
92 |
+
return edict(params)
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def augmentation(data, params=None):
|
96 |
+
"""
|
97 |
+
:param data: numpy array of size (joints, 3, len_frames)
|
98 |
+
:return:
|
99 |
+
"""
|
100 |
+
if params is None:
|
101 |
+
return data, params
|
102 |
+
|
103 |
+
# rotate
|
104 |
+
if 'roll' in params.keys():
|
105 |
+
cx, cy, cz = np.cos(params.roll)
|
106 |
+
sx, sy, sz = np.sin(params.roll)
|
107 |
+
mat33_x = np.array([
|
108 |
+
[1, 0, 0],
|
109 |
+
[0, cx, -sx],
|
110 |
+
[0, sx, cx]
|
111 |
+
], dtype='float')
|
112 |
+
mat33_y = np.array([
|
113 |
+
[cy, 0, sy],
|
114 |
+
[0, 1, 0],
|
115 |
+
[-sy, 0, cy]
|
116 |
+
], dtype='float')
|
117 |
+
mat33_z = np.array([
|
118 |
+
[cz, -sz, 0],
|
119 |
+
[sz, cz, 0],
|
120 |
+
[0, 0, 1]
|
121 |
+
], dtype='float')
|
122 |
+
data = mat33_x @ mat33_y @ mat33_z @ data
|
123 |
+
|
124 |
+
# scale
|
125 |
+
if 'ratio' in params.keys():
|
126 |
+
data = data * params.ratio
|
127 |
+
|
128 |
+
return data, params
|
129 |
+
|
130 |
+
def __getitem__(self, index):
|
131 |
+
raise NotImplementedError
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
return len(self.motion_names) * len(self.character_names)
|
135 |
+
|
136 |
+
|
137 |
+
def get_meanpose(phase, config):
|
138 |
+
|
139 |
+
meanpose_path = config.train_meanpose_path if phase == "train" else config.test_meanpose_path
|
140 |
+
stdpose_path = config.train_stdpose_path if phase == "train" else config.test_stdpose_path
|
141 |
+
|
142 |
+
if os.path.exists(meanpose_path) and os.path.exists(stdpose_path):
|
143 |
+
meanpose = np.load(meanpose_path)
|
144 |
+
stdpose = np.load(stdpose_path)
|
145 |
+
else:
|
146 |
+
meanpose, stdpose = gen_meanpose(phase, config)
|
147 |
+
np.save(meanpose_path, meanpose)
|
148 |
+
np.save(stdpose_path, stdpose)
|
149 |
+
print("meanpose saved at {}".format(meanpose_path))
|
150 |
+
print("stdpose saved at {}".format(stdpose_path))
|
151 |
+
|
152 |
+
if meanpose.shape[-1] == 2:
|
153 |
+
mean_x, mean_y = meanpose[:, 0], meanpose[:, 1]
|
154 |
+
meanpose = np.stack([mean_x, mean_x, mean_y], axis=1)
|
155 |
+
|
156 |
+
if stdpose.shape[-1] == 2:
|
157 |
+
std_x, std_y = stdpose[:, 0], stdpose[:, 1]
|
158 |
+
stdpose = np.stack([std_x, std_x, std_y], axis=1)
|
159 |
+
|
160 |
+
return meanpose, stdpose
|
161 |
+
|
162 |
+
|
163 |
+
def gen_meanpose(phase, config, n_samp=20000):
|
164 |
+
|
165 |
+
data_dir = config.train_dir if phase == "train" else config.test_dir
|
166 |
+
all_paths = glob.glob(os.path.join(data_dir, '*/*/motions/*.npy'))
|
167 |
+
random.shuffle(all_paths)
|
168 |
+
all_paths = all_paths[:n_samp]
|
169 |
+
all_joints = []
|
170 |
+
|
171 |
+
print("computing meanpose and stdpose")
|
172 |
+
|
173 |
+
for path in tqdm(all_paths):
|
174 |
+
motion = np.load(path)
|
175 |
+
if motion.shape[1] == 3:
|
176 |
+
basis = None
|
177 |
+
if sum(config.rotation_axes) > 0:
|
178 |
+
x_angles = view_angles if config.rotation_axes[0] else np.array([0])
|
179 |
+
z_angles = view_angles if config.rotation_axes[1] else np.array([0])
|
180 |
+
y_angles = view_angles if config.rotation_axes[2] else np.array([0])
|
181 |
+
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
|
182 |
+
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
|
183 |
+
i = np.random.choice(len(angles))
|
184 |
+
basis = get_change_of_basis(motion, angles[i])
|
185 |
+
motion = preprocess_mixamo(motion)
|
186 |
+
motion = rotate_motion_3d(motion, basis)
|
187 |
+
motion = localize_motion(motion)
|
188 |
+
all_joints.append(motion)
|
189 |
+
else:
|
190 |
+
motion = preprocess_mixamo(motion)
|
191 |
+
motion = rotate_motion_3d(motion, basis)
|
192 |
+
motion = localize_motion(motion)
|
193 |
+
all_joints.append(motion)
|
194 |
+
else:
|
195 |
+
motion = motion * 128
|
196 |
+
motion_proj = localize_motion(motion)
|
197 |
+
all_joints.append(motion_proj)
|
198 |
+
|
199 |
+
all_joints = np.concatenate(all_joints, axis=2)
|
200 |
+
|
201 |
+
meanpose = np.mean(all_joints, axis=2)
|
202 |
+
stdpose = np.std(all_joints, axis=2)
|
203 |
+
stdpose[np.where(stdpose == 0)] = 1e-9
|
204 |
+
|
205 |
+
return meanpose, stdpose
|
206 |
+
|
207 |
+
|
208 |
+
class MixamoDataset(_MixamoDatasetBase):
|
209 |
+
|
210 |
+
def __init__(self, phase, config):
|
211 |
+
super(MixamoDataset, self).__init__(phase, config)
|
212 |
+
x_angles = view_angles if config.rotation_axes[0] else np.array([0])
|
213 |
+
z_angles = view_angles if config.rotation_axes[1] else np.array([0])
|
214 |
+
y_angles = view_angles if config.rotation_axes[2] else np.array([0])
|
215 |
+
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
|
216 |
+
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
|
217 |
+
self.view_angles = angles
|
218 |
+
|
219 |
+
def preprocessing(self, motion3d, view_angle=None, params=None):
|
220 |
+
"""
|
221 |
+
:param item: filename built from self.build_tiem
|
222 |
+
:return:
|
223 |
+
"""
|
224 |
+
|
225 |
+
if self.aug: motion3d, params = self.augmentation(motion3d, params)
|
226 |
+
|
227 |
+
basis = None
|
228 |
+
if view_angle is not None: basis = get_change_of_basis(motion3d, view_angle)
|
229 |
+
|
230 |
+
motion3d = preprocess_mixamo(motion3d)
|
231 |
+
motion3d = rotate_motion_3d(motion3d, basis)
|
232 |
+
motion3d = localize_motion(motion3d)
|
233 |
+
motion3d = normalize_motion(motion3d, self.meanpose, self.stdpose)
|
234 |
+
|
235 |
+
motion2d = motion3d[:, [0, 2], :]
|
236 |
+
|
237 |
+
motion3d = motion3d.reshape([-1, motion3d.shape[-1]])
|
238 |
+
motion2d = motion2d.reshape([-1, motion2d.shape[-1]])
|
239 |
+
|
240 |
+
motion3d = torch.from_numpy(motion3d).float()
|
241 |
+
motion2d = torch.from_numpy(motion2d).float()
|
242 |
+
|
243 |
+
return motion3d, motion2d
|
244 |
+
|
245 |
+
def __getitem__(self, index):
|
246 |
+
# select two motions
|
247 |
+
idx_a, idx_b = np.random.choice(len(self.motion_names), size=2, replace=False)
|
248 |
+
mot_a, mot_b = self.motion_names[idx_a], self.motion_names[idx_b]
|
249 |
+
# select two characters
|
250 |
+
idx_a, idx_b = np.random.choice(len(self.character_names), size=2, replace=False)
|
251 |
+
char_a, char_b = self.character_names[idx_a], self.character_names[idx_b]
|
252 |
+
idx_a, idx_b = np.random.choice(len(self.view_angles), size=2, replace=False)
|
253 |
+
view_a, view_b = self.view_angles[idx_a], self.view_angles[idx_b]
|
254 |
+
|
255 |
+
if self.aug:
|
256 |
+
param_a = self.gen_aug_params(rotate=False)
|
257 |
+
param_b = self.gen_aug_params(rotate=False)
|
258 |
+
else:
|
259 |
+
param_a = param_b = None
|
260 |
+
|
261 |
+
item_a = self.load_item(self.build_item(mot_a, char_a))
|
262 |
+
item_b = self.load_item(self.build_item(mot_b, char_b))
|
263 |
+
item_ab = self.load_item(self.build_item(mot_a, char_b))
|
264 |
+
item_ba = self.load_item(self.build_item(mot_b, char_a))
|
265 |
+
|
266 |
+
X_a, x_a = self.preprocessing(item_a, view_a, param_a)
|
267 |
+
X_b, x_b = self.preprocessing(item_b, view_b, param_b)
|
268 |
+
|
269 |
+
X_aab, x_aab = self.preprocessing(item_a, view_b, param_a)
|
270 |
+
X_bba, x_bba = self.preprocessing(item_b, view_a, param_b)
|
271 |
+
X_aba, x_aba = self.preprocessing(item_ab, view_a, param_b)
|
272 |
+
X_bab, x_bab = self.preprocessing(item_ba, view_b, param_a)
|
273 |
+
X_abb, x_abb = self.preprocessing(item_ab, view_b, param_b)
|
274 |
+
X_baa, x_baa = self.preprocessing(item_ba, view_a, param_a)
|
275 |
+
|
276 |
+
return {"X_a": X_a, "X_b": X_b,
|
277 |
+
"X_aab": X_aab, "X_bba": X_bba,
|
278 |
+
"X_aba": X_aba, "X_bab": X_bab,
|
279 |
+
"X_abb": X_abb, "X_baa": X_baa,
|
280 |
+
"x_a": x_a, "x_b": x_b,
|
281 |
+
"x_aab": x_aab, "x_bba": x_bba,
|
282 |
+
"x_aba": x_aba, "x_bab": x_bab,
|
283 |
+
"x_abb": x_abb, "x_baa": x_baa,
|
284 |
+
"mot_a": mot_a, "mot_b": mot_b,
|
285 |
+
"char_a": char_a, "char_b": char_b,
|
286 |
+
"view_a": view_a, "view_b": view_b,
|
287 |
+
"meanpose": self.meanpose, "stdpose": self.stdpose}
|
288 |
+
|
289 |
+
|
290 |
+
class MixamoLimbScaleDataset(_MixamoDatasetBase):
|
291 |
+
|
292 |
+
def __init__(self, phase, config):
|
293 |
+
super(MixamoLimbScaleDataset, self).__init__(phase, config)
|
294 |
+
self.global_range = config.global_range
|
295 |
+
self.local_range = config.local_range
|
296 |
+
|
297 |
+
x_angles = view_angles if config.rotation_axes[0] else np.array([0])
|
298 |
+
z_angles = view_angles if config.rotation_axes[1] else np.array([0])
|
299 |
+
y_angles = view_angles if config.rotation_axes[2] else np.array([0])
|
300 |
+
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
|
301 |
+
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
|
302 |
+
self.view_angles = angles
|
303 |
+
|
304 |
+
def preprocessing(self, motion3d, view_angle=None, params=None):
|
305 |
+
if self.aug: motion3d, params = self.augmentation(motion3d, params)
|
306 |
+
|
307 |
+
basis = None
|
308 |
+
if view_angle is not None: basis = get_change_of_basis(motion3d, view_angle)
|
309 |
+
|
310 |
+
motion3d = preprocess_mixamo(motion3d)
|
311 |
+
motion3d = rotate_motion_3d(motion3d, basis)
|
312 |
+
motion2d = motion3d[:, [0, 2], :]
|
313 |
+
motion2d_scale = limb_scale_motion_2d(motion2d, self.global_range, self.local_range)
|
314 |
+
|
315 |
+
motion2d = localize_motion(motion2d)
|
316 |
+
motion2d_scale = localize_motion(motion2d_scale)
|
317 |
+
|
318 |
+
motion2d = normalize_motion(motion2d, self.meanpose, self.stdpose)
|
319 |
+
motion2d_scale = normalize_motion(motion2d_scale, self.meanpose, self.stdpose)
|
320 |
+
|
321 |
+
motion2d = motion2d.reshape([-1, motion2d.shape[-1]])
|
322 |
+
motion2d_scale = motion2d_scale.reshape((-1, motion2d_scale.shape[-1]))
|
323 |
+
motion2d = torch.from_numpy(motion2d).float()
|
324 |
+
motion2d_scale = torch.from_numpy(motion2d_scale).float()
|
325 |
+
|
326 |
+
return motion2d, motion2d_scale
|
327 |
+
|
328 |
+
def __getitem__(self, index):
|
329 |
+
# select two motions
|
330 |
+
motion_idx = np.random.choice(len(self.motion_names))
|
331 |
+
motion = self.motion_names[motion_idx]
|
332 |
+
# select two characters
|
333 |
+
char_idx = np.random.choice(len(self.character_names))
|
334 |
+
character = self.character_names[char_idx]
|
335 |
+
view_idx = np.random.choice(len(self.view_angles))
|
336 |
+
view = self.view_angles[view_idx]
|
337 |
+
|
338 |
+
if self.aug:
|
339 |
+
param = self.gen_aug_params(rotate=True)
|
340 |
+
else:
|
341 |
+
param = None
|
342 |
+
|
343 |
+
item = self.build_item(motion, character)
|
344 |
+
|
345 |
+
x, x_s = self.preprocessing(self.load_item(item), view, param)
|
346 |
+
|
347 |
+
return {"x": x, "x_s": x_s, "mot": motion, "char": character, "view": view,
|
348 |
+
"meanpose": self.meanpose, "stdpose": self.stdpose}
|
349 |
+
|
350 |
+
|
351 |
+
class SoloDanceDataset(Dataset):
|
352 |
+
|
353 |
+
def __init__(self, phase, config):
|
354 |
+
super(SoloDanceDataset, self).__init__()
|
355 |
+
self.global_range = config.global_range
|
356 |
+
self.local_range = config.local_range
|
357 |
+
|
358 |
+
assert phase in ['train', 'test']
|
359 |
+
self.data_root = config.train_dir if phase=='train' else config.test_dir
|
360 |
+
self.phase = phase
|
361 |
+
self.unit = config.unit
|
362 |
+
self.meanpose_path = config.train_meanpose_path if phase == 'train' else config.test_meanpose_path
|
363 |
+
self.stdpose_path = config.train_stdpose_path if phase == 'train' else config.test_stdpose_path
|
364 |
+
self.character_names = sorted(os.listdir(self.data_root))
|
365 |
+
|
366 |
+
self.items = glob.glob(os.path.join(self.data_root, '*/*/motions/*.npy'))
|
367 |
+
self.meanpose, self.stdpose = get_meanpose(phase, config)
|
368 |
+
self.meanpose = self.meanpose.astype(np.float32)
|
369 |
+
self.stdpose = self.stdpose.astype(np.float32)
|
370 |
+
|
371 |
+
if 'preload' in config and config.preload:
|
372 |
+
self.preload()
|
373 |
+
self.cached = True
|
374 |
+
else:
|
375 |
+
self.cached = False
|
376 |
+
|
377 |
+
def load_item(self, item):
|
378 |
+
if self.cached:
|
379 |
+
data = self.cache[item]
|
380 |
+
else:
|
381 |
+
data = np.load(item)
|
382 |
+
return data
|
383 |
+
|
384 |
+
def preload(self):
|
385 |
+
print("pre-loading into memory")
|
386 |
+
pbar = tqdm(total=len(self))
|
387 |
+
self.cache = {}
|
388 |
+
for item in self.items:
|
389 |
+
motion = np.load(item)
|
390 |
+
self.cache[item] = motion
|
391 |
+
pbar.update(1)
|
392 |
+
|
393 |
+
def preprocessing(self, motion):
|
394 |
+
|
395 |
+
motion = motion * self.unit
|
396 |
+
|
397 |
+
motion[1, :, :] = (motion[2, :, :] + motion[5, :, :]) / 2
|
398 |
+
motion[8, :, :] = (motion[9, :, :] + motion[12, :, :]) / 2
|
399 |
+
|
400 |
+
global_scale = self.global_range[0] + np.random.random() * (self.global_range[1] - self.global_range[0])
|
401 |
+
local_scales = self.local_range[0] + np.random.random([8]) * (self.local_range[1] - self.local_range[0])
|
402 |
+
motion_scale = scale_limbs(motion, global_scale, local_scales)
|
403 |
+
|
404 |
+
motion = localize_motion(motion)
|
405 |
+
motion_scale = localize_motion(motion_scale)
|
406 |
+
motion = normalize_motion(motion, self.meanpose, self.stdpose)
|
407 |
+
motion_scale = normalize_motion(motion_scale, self.meanpose, self.stdpose)
|
408 |
+
motion = motion.reshape((-1, motion.shape[-1]))
|
409 |
+
motion_scale = motion_scale.reshape((-1, motion_scale.shape[-1]))
|
410 |
+
motion = torch.from_numpy(motion).float()
|
411 |
+
motion_scale = torch.from_numpy(motion_scale).float()
|
412 |
+
return motion, motion_scale
|
413 |
+
|
414 |
+
def __len__(self):
|
415 |
+
return len(self.items)
|
416 |
+
|
417 |
+
def __getitem__(self, index):
|
418 |
+
item = self.items[index]
|
419 |
+
motion = self.load_item(item)
|
420 |
+
x, x_s = self.preprocessing(motion)
|
421 |
+
return {"x": x, "x_s": x_s, "meanpose": self.meanpose, "stdpose": self.stdpose}
|
lib/loss.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def kl_loss(code):
|
6 |
+
return torch.mean(torch.pow(code, 2))
|
7 |
+
|
8 |
+
|
9 |
+
def pairwise_cosine_similarity(seqs_i, seqs_j):
|
10 |
+
# seqs_i, seqs_j: [batch, statics, channel]
|
11 |
+
n_statics = seqs_i.size(1)
|
12 |
+
seqs_i_exp = seqs_i.unsqueeze(2).repeat(1, 1, n_statics, 1)
|
13 |
+
seqs_j_exp = seqs_j.unsqueeze(1).repeat(1, n_statics, 1, 1)
|
14 |
+
return F.cosine_similarity(seqs_i_exp, seqs_j_exp, dim=3)
|
15 |
+
|
16 |
+
|
17 |
+
def temporal_pairwise_cosine_similarity(seqs_i, seqs_j):
|
18 |
+
# seqs_i, seqs_j: [batch, channel, time]
|
19 |
+
seq_len = seqs_i.size(2)
|
20 |
+
seqs_i_exp = seqs_i.unsqueeze(3).repeat(1, 1, 1, seq_len)
|
21 |
+
seqs_j_exp = seqs_j.unsqueeze(2).repeat(1, 1, seq_len, 1)
|
22 |
+
return F.cosine_similarity(seqs_i_exp, seqs_j_exp, dim=1)
|
23 |
+
|
24 |
+
|
25 |
+
def consecutive_cosine_similarity(seqs):
|
26 |
+
# seqs: [batch, channel, time]
|
27 |
+
seqs_roll = seqs.roll(shifts=1, dim=2)[1:]
|
28 |
+
seqs = seqs[:-1]
|
29 |
+
return F.cosine_similarity(seqs, seqs_roll)
|
30 |
+
|
31 |
+
|
32 |
+
def triplet_margin_loss(seqs_a, seqs_b, neg_range=(0.0, 0.5), margin=0.2):
|
33 |
+
# seqs_a, seqs_b: [batch, channel, time]
|
34 |
+
|
35 |
+
neg_start, neg_end = neg_range
|
36 |
+
batch_size, _, seq_len = seqs_a.size()
|
37 |
+
n_neg_all = seq_len ** 2
|
38 |
+
n_neg = int(round(neg_end * n_neg_all))
|
39 |
+
n_neg_discard = int(round(neg_start * n_neg_all))
|
40 |
+
|
41 |
+
batch_size, _, seq_len = seqs_a.size()
|
42 |
+
sim_aa = temporal_pairwise_cosine_similarity(seqs_a, seqs_a)
|
43 |
+
sim_bb = temporal_pairwise_cosine_similarity(seqs_b, seqs_b)
|
44 |
+
sim_ab = temporal_pairwise_cosine_similarity(seqs_a, seqs_b)
|
45 |
+
sim_ba = sim_ab.transpose(1, 2)
|
46 |
+
|
47 |
+
diff_ab = (sim_ab - sim_aa).reshape(batch_size, -1)
|
48 |
+
diff_ba = (sim_ba - sim_bb).reshape(batch_size, -1)
|
49 |
+
diff = torch.cat([diff_ab, diff_ba], dim=0)
|
50 |
+
diff, _ = diff.topk(n_neg, dim=-1, sorted=True)
|
51 |
+
diff = diff[:, n_neg_discard:]
|
52 |
+
|
53 |
+
loss = diff + margin
|
54 |
+
loss = loss.clamp(min=0.)
|
55 |
+
loss = loss.mean()
|
56 |
+
|
57 |
+
return loss
|
lib/network.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
thismodule = sys.modules[__name__]
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import random
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
torch.manual_seed(123)
|
11 |
+
|
12 |
+
|
13 |
+
def get_autoencoder(config):
|
14 |
+
ae_cls = getattr(thismodule, config.autoencoder.cls)
|
15 |
+
return ae_cls(config.autoencoder)
|
16 |
+
|
17 |
+
|
18 |
+
class ConvEncoder(nn.Module):
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def build_from_config(cls, config):
|
22 |
+
conv_pool = None if config.conv_pool is None else getattr(nn, config.conv_pool)
|
23 |
+
encoder = cls(config.channels, config.padding, config.kernel_size, config.conv_stride, conv_pool)
|
24 |
+
return encoder
|
25 |
+
|
26 |
+
def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None):
|
27 |
+
super(ConvEncoder, self).__init__()
|
28 |
+
|
29 |
+
self.in_channels = channels[0]
|
30 |
+
|
31 |
+
model = []
|
32 |
+
acti = nn.LeakyReLU(0.2)
|
33 |
+
|
34 |
+
nr_layer = len(channels) - 1
|
35 |
+
|
36 |
+
for i in range(nr_layer):
|
37 |
+
if conv_pool is None:
|
38 |
+
model.append(nn.ReflectionPad1d(padding))
|
39 |
+
model.append(nn.Conv1d(channels[i], channels[i+1], kernel_size=kernel_size, stride=conv_stride))
|
40 |
+
model.append(acti)
|
41 |
+
else:
|
42 |
+
model.append(nn.ReflectionPad1d(padding))
|
43 |
+
model.append(nn.Conv1d(channels[i], channels[i+1], kernel_size=kernel_size, stride=conv_stride))
|
44 |
+
model.append(acti)
|
45 |
+
model.append(conv_pool(kernel_size=2, stride=2))
|
46 |
+
|
47 |
+
self.model = nn.Sequential(*model)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = x[:, :self.in_channels, :]
|
51 |
+
x = self.model(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class ConvDecoder(nn.Module):
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def build_from_config(cls, config):
|
59 |
+
decoder = cls(config.channels, config.kernel_size)
|
60 |
+
return decoder
|
61 |
+
|
62 |
+
def __init__(self, channels, kernel_size=7):
|
63 |
+
super(ConvDecoder, self).__init__()
|
64 |
+
|
65 |
+
model = []
|
66 |
+
pad = (kernel_size - 1) // 2
|
67 |
+
acti = nn.LeakyReLU(0.2)
|
68 |
+
|
69 |
+
for i in range(len(channels) - 1):
|
70 |
+
model.append(nn.Upsample(scale_factor=2, mode='nearest'))
|
71 |
+
model.append(nn.ReflectionPad1d(pad))
|
72 |
+
model.append(nn.Conv1d(channels[i], channels[i + 1],
|
73 |
+
kernel_size=kernel_size, stride=1))
|
74 |
+
if i == 0 or i == 1:
|
75 |
+
model.append(nn.Dropout(p=0.2))
|
76 |
+
if not i == len(channels) - 2:
|
77 |
+
model.append(acti) # whether to add tanh a last?
|
78 |
+
#model.append(nn.Dropout(p=0.2))
|
79 |
+
|
80 |
+
self.model = nn.Sequential(*model)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
return self.model(x)
|
84 |
+
|
85 |
+
|
86 |
+
class Discriminator(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self, config):
|
89 |
+
super(Discriminator, self).__init__()
|
90 |
+
self.gan_type = config.gan_type
|
91 |
+
encoder_cls = getattr(thismodule, config.encoder_cls)
|
92 |
+
self.encoder = encoder_cls.build_from_config(config)
|
93 |
+
self.linear = nn.Linear(config.channels[-1], 1)
|
94 |
+
|
95 |
+
def forward(self, seqs):
|
96 |
+
|
97 |
+
code_seq = self.encoder(seqs)
|
98 |
+
logits = self.linear(code_seq.permute(0, 2, 1))
|
99 |
+
return logits
|
100 |
+
|
101 |
+
def calc_dis_loss(self, x_gen, x_real):
|
102 |
+
|
103 |
+
fake_logits = self.forward(x_gen)
|
104 |
+
real_logits = self.forward(x_real)
|
105 |
+
|
106 |
+
if self.gan_type == 'lsgan':
|
107 |
+
loss = torch.mean((fake_logits - 0) ** 2) + torch.mean((real_logits - 1) ** 2)
|
108 |
+
elif self.gan_type == 'nsgan':
|
109 |
+
all0 = torch.zeros_like(fake_logits, requires_grad=False)
|
110 |
+
all1 = torch.ones_like(real_logits, requires_grad=False)
|
111 |
+
loss = torch.mean(F.binary_cross_entropy(F.sigmoid(fake_logits), all0) +
|
112 |
+
F.binary_cross_entropy(F.sigmoid(real_logits), all1))
|
113 |
+
else:
|
114 |
+
raise NotImplementedError
|
115 |
+
|
116 |
+
return loss
|
117 |
+
|
118 |
+
def calc_gen_loss(self, x_gen):
|
119 |
+
|
120 |
+
logits = self.forward(x_gen)
|
121 |
+
if self.gan_type == 'lsgan':
|
122 |
+
loss = torch.mean((logits - 1) ** 2)
|
123 |
+
elif self.gan_type == 'nsgan':
|
124 |
+
all1 = torch.ones_like(logits, requires_grad=False)
|
125 |
+
loss = torch.mean(F.binary_cross_entropy(F.sigmoid(logits), all1))
|
126 |
+
else:
|
127 |
+
raise NotImplementedError
|
128 |
+
|
129 |
+
return loss
|
130 |
+
|
131 |
+
|
132 |
+
class Autoencoder3f(nn.Module):
|
133 |
+
|
134 |
+
def __init__(self, config):
|
135 |
+
super(Autoencoder3f, self).__init__()
|
136 |
+
|
137 |
+
assert config.motion_encoder.channels[-1] + config.body_encoder.channels[-1] + \
|
138 |
+
config.view_encoder.channels[-1] == config.decoder.channels[0]
|
139 |
+
|
140 |
+
self.n_joints = config.decoder.channels[-1] // 3
|
141 |
+
self.body_reference = config.body_reference
|
142 |
+
|
143 |
+
motion_cls = getattr(thismodule, config.motion_encoder.cls)
|
144 |
+
body_cls = getattr(thismodule, config.body_encoder.cls)
|
145 |
+
view_cls = getattr(thismodule, config.view_encoder.cls)
|
146 |
+
|
147 |
+
self.motion_encoder = motion_cls.build_from_config(config.motion_encoder)
|
148 |
+
self.body_encoder = body_cls.build_from_config(config.body_encoder)
|
149 |
+
self.view_encoder = view_cls.build_from_config(config.view_encoder)
|
150 |
+
self.decoder = ConvDecoder.build_from_config(config.decoder)
|
151 |
+
|
152 |
+
self.body_pool = getattr(F, config.body_encoder.global_pool) if config.body_encoder.global_pool is not None else None
|
153 |
+
self.view_pool = getattr(F, config.view_encoder.global_pool) if config.view_encoder.global_pool is not None else None
|
154 |
+
|
155 |
+
def forward(self, seqs):
|
156 |
+
return self.reconstruct(seqs)
|
157 |
+
|
158 |
+
def encode_motion(self, seqs):
|
159 |
+
motion_code_seq = self.motion_encoder(seqs)
|
160 |
+
return motion_code_seq
|
161 |
+
|
162 |
+
def encode_body(self, seqs):
|
163 |
+
body_code_seq = self.body_encoder(seqs)
|
164 |
+
kernel_size = body_code_seq.size(-1)
|
165 |
+
body_code = self.body_pool(body_code_seq, kernel_size) if self.body_pool is not None else body_code_seq
|
166 |
+
return body_code, body_code_seq
|
167 |
+
|
168 |
+
def encode_view(self, seqs):
|
169 |
+
view_code_seq = self.view_encoder(seqs)
|
170 |
+
kernel_size = view_code_seq.size(-1)
|
171 |
+
view_code = self.view_pool(view_code_seq, kernel_size) if self.view_pool is not None else view_code_seq
|
172 |
+
return view_code, view_code_seq
|
173 |
+
|
174 |
+
def decode(self, motion_code, body_code, view_code):
|
175 |
+
if body_code.size(-1) == 1:
|
176 |
+
body_code = body_code.repeat(1, 1, motion_code.shape[-1])
|
177 |
+
if view_code.size(-1) == 1:
|
178 |
+
view_code = view_code.repeat(1, 1, motion_code.shape[-1])
|
179 |
+
complete_code = torch.cat([motion_code, body_code, view_code], dim=1)
|
180 |
+
out = self.decoder(complete_code)
|
181 |
+
return out
|
182 |
+
|
183 |
+
def cross3d(self, x_a, x_b, x_c):
|
184 |
+
motion_a = self.encode_motion(x_a)
|
185 |
+
body_b, _ = self.encode_body(x_b)
|
186 |
+
view_c, _ = self.encode_view(x_c)
|
187 |
+
out = self.decode(motion_a, body_b, view_c)
|
188 |
+
return out
|
189 |
+
|
190 |
+
def cross2d(self, x_a, x_b, x_c):
|
191 |
+
motion_a = self.encode_motion(x_a)
|
192 |
+
body_b, _ = self.encode_body(x_b)
|
193 |
+
view_c, _ = self.encode_view(x_c)
|
194 |
+
out = self.decode(motion_a, body_b, view_c)
|
195 |
+
batch_size, channels, seq_len = out.size()
|
196 |
+
n_joints = channels // 3
|
197 |
+
out = out.view(batch_size, n_joints, 3, seq_len)
|
198 |
+
out = out[:, :, [0, 2], :]
|
199 |
+
out = out.view(batch_size, n_joints * 2, seq_len)
|
200 |
+
return out
|
201 |
+
|
202 |
+
def cross2d_adv(self, x_a, x_b, x_c):
|
203 |
+
x_a.cpu()
|
204 |
+
x_a_shape = x_a.shape
|
205 |
+
print(x_a.shape)
|
206 |
+
#motion_a_org = self.encode_motion(x_a)
|
207 |
+
print(x_a)
|
208 |
+
|
209 |
+
|
210 |
+
# The heatmap image is saved as 'tensor_heatmap.png' in the current directory
|
211 |
+
|
212 |
+
# for i in range(0,119):
|
213 |
+
# x_a[0][11][i]+=1
|
214 |
+
|
215 |
+
#x_a[0][7][60]+=0.01
|
216 |
+
|
217 |
+
#motion_a = self.encode_motion(x_a)
|
218 |
+
# print(motion_a.shape)
|
219 |
+
# print(motion_a[0][0]-motion_a_org[0][0])
|
220 |
+
# res = motion_a[0] - motion_a_org[0]
|
221 |
+
# res = res.cpu().detach().numpy()
|
222 |
+
# # Code for plotting the heatmap
|
223 |
+
# plt.figure(figsize=(15, 10))
|
224 |
+
# plt.imshow(res, cmap='hot', interpolation='nearest')
|
225 |
+
# plt.colorbar()
|
226 |
+
# plt.title('Heatmap of the Tensor')
|
227 |
+
|
228 |
+
# # Save the heatmap to a local file
|
229 |
+
# plt.savefig('/home/fazhong/studio/transmomo.pytorch/tensor_heatmap2.png')
|
230 |
+
# plt.close()
|
231 |
+
|
232 |
+
initial_motion_a = self.encode_motion(x_a) # 计算初始的motion_a
|
233 |
+
|
234 |
+
# 定义一个函数来计算motion的变化量
|
235 |
+
def motion_change(motion_a, initial_motion_a):
|
236 |
+
return (motion_a - initial_motion_a).norm()
|
237 |
+
|
238 |
+
# 设置初始的最大变化量为0
|
239 |
+
max_change = 0
|
240 |
+
|
241 |
+
# 扰动次数,可以根据需要更改
|
242 |
+
num_perturbations = 10000
|
243 |
+
init_a = x_a.clone()
|
244 |
+
for _ in range(num_perturbations):
|
245 |
+
# 复制x_a以避免在原始数据上修改
|
246 |
+
x_a_perturbed = x_a.clone().cpu()
|
247 |
+
|
248 |
+
# 选择要扰动的随机点
|
249 |
+
batch_idx, seq_idx, feature_idx = (torch.randint(0, x_a.size(0), (1,)),
|
250 |
+
torch.randint(0, x_a.size(1), (1,)),
|
251 |
+
torch.randint(0, x_a.size(2), (1,)))
|
252 |
+
|
253 |
+
# 在选定点上加上扰动
|
254 |
+
x_a_perturbed[batch_idx, seq_idx, feature_idx] += 10 * torch.randn(1)
|
255 |
+
|
256 |
+
# 计算扰动后的motion_a
|
257 |
+
perturbed_motion_a = self.encode_motion(x_a_perturbed.to('cuda:0'))
|
258 |
+
|
259 |
+
# 计算变化量
|
260 |
+
change = motion_change(perturbed_motion_a, initial_motion_a)
|
261 |
+
|
262 |
+
# 如果变化量大于之前保存的最大变化量,则更新x_a和最大变化量
|
263 |
+
if change > max_change:
|
264 |
+
x_a = x_a_perturbed
|
265 |
+
max_change = change
|
266 |
+
|
267 |
+
# 最后,x_a将是导致最大motion_a变化的扰动版本
|
268 |
+
# max_change是这个变化量
|
269 |
+
# print(max_change)
|
270 |
+
# print(max_change.shape)
|
271 |
+
|
272 |
+
print(x_a_perturbed - init_a.cpu())
|
273 |
+
motion_a = self.encode_motion(x_a_perturbed.to('cuda:0'))
|
274 |
+
# motion_a = self.encode_motion(x_a.to('cuda:0'))
|
275 |
+
body_b, _ = self.encode_body(x_b)
|
276 |
+
view_c, _ = self.encode_view(x_c)
|
277 |
+
|
278 |
+
out = self.decode(motion_a, body_b, view_c)
|
279 |
+
batch_size, channels, seq_len = out.size()
|
280 |
+
n_joints = channels // 3
|
281 |
+
out = out.view(batch_size, n_joints, 3, seq_len)
|
282 |
+
out = out[:, :, [0, 2], :]
|
283 |
+
out = out.view(batch_size, n_joints * 2, seq_len)
|
284 |
+
return out
|
285 |
+
|
286 |
+
def cross2d_one(self, x_a):
|
287 |
+
motion_a = self.encode_motion(x_a)
|
288 |
+
body_b, _ = self.encode_body(x_a)
|
289 |
+
view_c, _ = self.encode_view(x_a)
|
290 |
+
|
291 |
+
out = self.decode(motion_a, body_b, view_c)
|
292 |
+
batch_size, channels, seq_len = out.size()
|
293 |
+
n_joints = channels // 3
|
294 |
+
out = out.view(batch_size, n_joints, 3, seq_len)
|
295 |
+
out = out[:, :, [0, 2], :]
|
296 |
+
out = out.view(batch_size, n_joints * 2, seq_len)
|
297 |
+
return out
|
298 |
+
|
299 |
+
def adv_cross(self,x_a):
|
300 |
+
motion_a = self.encode_motion(x_a)
|
301 |
+
body_b, _ = self.encode_body(x_a)
|
302 |
+
view_c, _ = self.encode_view(x_a)
|
303 |
+
return motion_a
|
304 |
+
|
305 |
+
def reconstruct3d(self, x):
|
306 |
+
motion_code = self.encode_motion(x)
|
307 |
+
body_code, _ = self.encode_body(x)
|
308 |
+
view_code, _ = self.encode_view(x)
|
309 |
+
out = self.decode(motion_code, body_code, view_code)
|
310 |
+
return out
|
311 |
+
|
312 |
+
def reconstruct2d(self, x):
|
313 |
+
motion_code = self.encode_motion(x)
|
314 |
+
body_code, _ = self.encode_body(x)
|
315 |
+
view_code, _ = self.encode_view(x)
|
316 |
+
out = self.decode(motion_code, body_code, view_code)
|
317 |
+
batch_size, channels, seq_len = out.size()
|
318 |
+
n_joints = channels // 3
|
319 |
+
out = out.view(batch_size, n_joints, 3, seq_len)
|
320 |
+
out = out[:, :, [0, 2], :]
|
321 |
+
out = out.view(batch_size, n_joints * 2, seq_len)
|
322 |
+
return out
|
323 |
+
|
324 |
+
def interpolate(self, x_a, x_b, N):
|
325 |
+
|
326 |
+
step_size = 1. / (N-1)
|
327 |
+
batch_size, _, seq_len = x_a.size()
|
328 |
+
|
329 |
+
motion_a = self.encode_motion(x_a)
|
330 |
+
body_a, body_a_seq = self.encode_body(x_a)
|
331 |
+
view_a, view_a_seq = self.encode_view(x_a)
|
332 |
+
|
333 |
+
motion_b = self.encode_motion(x_b)
|
334 |
+
body_b, body_b_seq = self.encode_body(x_b)
|
335 |
+
view_b, view_b_seq = self.encode_view(x_b)
|
336 |
+
|
337 |
+
batch_out = torch.zeros([batch_size, N, N, 2 * self.n_joints, seq_len])
|
338 |
+
|
339 |
+
for i in range(N):
|
340 |
+
motion_weight = i * step_size
|
341 |
+
for j in range(N):
|
342 |
+
body_weight = j * step_size
|
343 |
+
motion = (1. - motion_weight) * motion_a + motion_weight * motion_b
|
344 |
+
body = (1. - body_weight) * body_a + body_weight * body_b
|
345 |
+
view = (1. - body_weight) * view_a + body_weight * view_b
|
346 |
+
out = self.decode(motion, body, view)
|
347 |
+
batch_size, channels, seq_len = out.size()
|
348 |
+
n_joints = channels // 3
|
349 |
+
out = out.view(batch_size, n_joints, 3, seq_len)
|
350 |
+
out = out[:, :, [0, 2], :]
|
351 |
+
out = out.view(batch_size, n_joints * 2, seq_len)
|
352 |
+
batch_out[:, i, j, :, :] = out
|
353 |
+
|
354 |
+
return batch_out
|
355 |
+
|
356 |
+
|
lib/operation.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
import imageio
|
5 |
+
from math import pi
|
6 |
+
from tqdm import tqdm
|
7 |
+
from lib.data import get_dataloader, get_meanpose
|
8 |
+
from lib.util.general import get_config
|
9 |
+
from lib.util.visualization import motion2video_np, hex2rgb
|
10 |
+
import os
|
11 |
+
|
12 |
+
eps = 1e-16
|
13 |
+
|
14 |
+
|
15 |
+
def localize_motion_torch(motion):
|
16 |
+
"""
|
17 |
+
:param motion: B x J x D x T
|
18 |
+
:return:
|
19 |
+
"""
|
20 |
+
B, J, D, T = motion.size()
|
21 |
+
|
22 |
+
# subtract centers to local coordinates
|
23 |
+
centers = motion[:, 8:9, :, :] # B x 1 x D x (T-1)
|
24 |
+
motion = motion - centers
|
25 |
+
|
26 |
+
# adding velocity
|
27 |
+
translation = centers[:, :, :, 1:] - centers[:, :, :, :-1] # B x 1 x D x (T-1)
|
28 |
+
velocity = F.pad(translation, [1, 0], "constant", 0.) # B x 1 x D x T
|
29 |
+
motion = torch.cat([motion[:, :8], motion[:, 9:], velocity], dim=1)
|
30 |
+
|
31 |
+
return motion
|
32 |
+
|
33 |
+
|
34 |
+
def normalize_motion_torch(motion, meanpose, stdpose):
|
35 |
+
"""
|
36 |
+
:param motion: (B, J, D, T)
|
37 |
+
:param meanpose: (J, D)
|
38 |
+
:param stdpose: (J, D)
|
39 |
+
:return:
|
40 |
+
"""
|
41 |
+
B, J, D, T = motion.size()
|
42 |
+
if D == 2 and meanpose.size(1) == 3:
|
43 |
+
meanpose = meanpose[:, [0, 2]]
|
44 |
+
if D == 2 and stdpose.size(1) == 3:
|
45 |
+
stdpose = stdpose[:, [0, 2]]
|
46 |
+
return (motion - meanpose.view(1, J, D, 1)) / stdpose.view(1, J, D, 1)
|
47 |
+
|
48 |
+
|
49 |
+
def normalize_motion_inv_torch(motion, meanpose, stdpose):
|
50 |
+
"""
|
51 |
+
:param motion: (B, J, D, T)
|
52 |
+
:param meanpose: (J, D)
|
53 |
+
:param stdpose: (J, D)
|
54 |
+
:return:
|
55 |
+
"""
|
56 |
+
B, J, D, T = motion.size()
|
57 |
+
if D == 2 and meanpose.size(1) == 3:
|
58 |
+
meanpose = meanpose[:, [0, 2]]
|
59 |
+
if D == 2 and stdpose.size(1) == 3:
|
60 |
+
stdpose = stdpose[:, [0, 2]]
|
61 |
+
return motion * stdpose.view(1, J, D, 1) + meanpose.view(1, J, D, 1)
|
62 |
+
|
63 |
+
|
64 |
+
def globalize_motion_torch(motion):
|
65 |
+
"""
|
66 |
+
:param motion: B x J x D x T
|
67 |
+
:return:
|
68 |
+
"""
|
69 |
+
B, J, D, T = motion.size()
|
70 |
+
|
71 |
+
motion_inv = torch.zeros_like(motion)
|
72 |
+
motion_inv[:, :8] = motion[:, :8]
|
73 |
+
motion_inv[:, 9:] = motion[:, 8:-1]
|
74 |
+
|
75 |
+
velocity = motion[:, -1:, :, :]
|
76 |
+
centers = torch.zeros_like(velocity)
|
77 |
+
displacement = torch.zeros_like(velocity[:, :, :, 0])
|
78 |
+
|
79 |
+
for t in range(T):
|
80 |
+
displacement += velocity[:, :, :, t]
|
81 |
+
centers[:, :, :, t] = displacement
|
82 |
+
|
83 |
+
motion_inv = motion_inv + centers
|
84 |
+
|
85 |
+
return motion_inv
|
86 |
+
|
87 |
+
|
88 |
+
def restore_world_space(motion, meanpose, stdpose, n_joints=15):
|
89 |
+
B, C, T = motion.size()
|
90 |
+
motion = motion.view(B, n_joints, C // n_joints, T)
|
91 |
+
motion = normalize_motion_inv_torch(motion, meanpose, stdpose)
|
92 |
+
motion = globalize_motion_torch(motion)
|
93 |
+
return motion
|
94 |
+
|
95 |
+
|
96 |
+
def convert_to_learning_space(motion, meanpose, stdpose):
|
97 |
+
B, J, D, T = motion.size()
|
98 |
+
motion = localize_motion_torch(motion)
|
99 |
+
motion = normalize_motion_torch(motion, meanpose, stdpose)
|
100 |
+
motion = motion.view(B, J*D, T)
|
101 |
+
return motion
|
102 |
+
|
103 |
+
|
104 |
+
# tensor operations for rotating and projecting 3d skeleton sequence
|
105 |
+
|
106 |
+
def get_body_basis(motion_3d):
|
107 |
+
"""
|
108 |
+
Get the unit vectors for vector rectangular coordinates for given 3D motion
|
109 |
+
:param motion_3d: 3D motion from 3D joints positions, shape (B, n_joints, 3, seq_len).
|
110 |
+
:param angles: (K, 3), Rotation angles around each axis.
|
111 |
+
:return: unit vectors for vector rectangular coordinates's , shape (B, 3, 3).
|
112 |
+
"""
|
113 |
+
B = motion_3d.size(0)
|
114 |
+
|
115 |
+
# 2 RightArm 5 LeftArm 9 RightUpLeg 12 LeftUpLeg
|
116 |
+
horizontal = (motion_3d[:, 2] - motion_3d[:, 5] + motion_3d[:, 9] - motion_3d[:, 12]) / 2 # [B, 3, seq_len]
|
117 |
+
horizontal = horizontal.mean(dim=-1) # [B, 3]
|
118 |
+
horizontal = horizontal / horizontal.norm(dim=-1).unsqueeze(-1) # [B, 3]
|
119 |
+
|
120 |
+
vector_z = torch.tensor([0., 0., 1.], device=motion_3d.device, dtype=motion_3d.dtype).unsqueeze(0).repeat(B, 1) # [B, 3]
|
121 |
+
vector_y = torch.cross(horizontal, vector_z) # [B, 3]
|
122 |
+
vector_y = vector_y / vector_y.norm(dim=-1).unsqueeze(-1)
|
123 |
+
vector_x = torch.cross(vector_y, vector_z)
|
124 |
+
vectors = torch.stack([vector_x, vector_y, vector_z], dim=2) # [B, 3, 3]
|
125 |
+
|
126 |
+
vectors = vectors.detach()
|
127 |
+
|
128 |
+
return vectors
|
129 |
+
|
130 |
+
|
131 |
+
def rotate_basis_euler(basis_vectors, angles):
|
132 |
+
"""
|
133 |
+
Rotate vector rectangular coordinates from given angles.
|
134 |
+
|
135 |
+
:param basis_vectors: [B, 3, 3]
|
136 |
+
:param angles: [B, K, T, 3] Rotation angles around each axis.
|
137 |
+
:return: [B, K, T, 3, 3]
|
138 |
+
"""
|
139 |
+
B, K, T, _ = angles.size()
|
140 |
+
|
141 |
+
cos, sin = torch.cos(angles), torch.sin(angles)
|
142 |
+
cx, cy, cz = cos[:, :, :, 0], cos[:, :, :, 1], cos[:, :, :, 2] # [B, K, T]
|
143 |
+
sx, sy, sz = sin[:, :, :, 0], sin[:, :, :, 1], sin[:, :, :, 2] # [B, K, T]
|
144 |
+
|
145 |
+
x = basis_vectors[:, 0, :] # [B, 3]
|
146 |
+
o = torch.zeros_like(x[:, 0]) # [B]
|
147 |
+
|
148 |
+
x_cpm_0 = torch.stack([o, -x[:, 2], x[:, 1]], dim=1) # [B, 3]
|
149 |
+
x_cpm_1 = torch.stack([x[:, 2], o, -x[:, 0]], dim=1) # [B, 3]
|
150 |
+
x_cpm_2 = torch.stack([-x[:, 1], x[:, 0], o], dim=1) # [B, 3]
|
151 |
+
x_cpm = torch.stack([x_cpm_0, x_cpm_1, x_cpm_2], dim=1) # [B, 3, 3]
|
152 |
+
x_cpm = x_cpm.unsqueeze(1).unsqueeze(2) # [B, 1, 1, 3, 3]
|
153 |
+
|
154 |
+
x = x.unsqueeze(-1) # [B, 3, 1]
|
155 |
+
xx = torch.matmul(x, x.transpose(-1, -2)).unsqueeze(1).unsqueeze(2) # [B, 1, 1, 3, 3]
|
156 |
+
eye = torch.eye(n=3, dtype=basis_vectors.dtype, device=basis_vectors.device)
|
157 |
+
eye = eye.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, 3, 3]
|
158 |
+
mat33_x = cx.unsqueeze(-1).unsqueeze(-1) * eye \
|
159 |
+
+ sx.unsqueeze(-1).unsqueeze(-1) * x_cpm \
|
160 |
+
+ (1. - cx).unsqueeze(-1).unsqueeze(-1) * xx # [B, K, T, 3, 3]
|
161 |
+
|
162 |
+
o = torch.zeros_like(cz)
|
163 |
+
i = torch.ones_like(cz)
|
164 |
+
mat33_z_0 = torch.stack([cz, sz, o], dim=3) # [B, K, T, 3]
|
165 |
+
mat33_z_1 = torch.stack([-sz, cz, o], dim=3) # [B, K, T, 3]
|
166 |
+
mat33_z_2 = torch.stack([o, o, i], dim=3) # [B, K, T, 3]
|
167 |
+
mat33_z = torch.stack([mat33_z_0, mat33_z_1, mat33_z_2], dim=3) # [B, K, T, 3, 3]
|
168 |
+
|
169 |
+
basis_vectors = basis_vectors.unsqueeze(1).unsqueeze(2)
|
170 |
+
basis_vectors = basis_vectors @ mat33_x.transpose(-1, -2) @ mat33_z
|
171 |
+
|
172 |
+
|
173 |
+
return basis_vectors
|
174 |
+
|
175 |
+
|
176 |
+
def change_of_basis(motion_3d, basis_vectors=None, project_2d=False):
|
177 |
+
# motion_3d: (B, n_joints, 3, seq_len)
|
178 |
+
# basis_vectors: (B, K, T, 3, 3)
|
179 |
+
|
180 |
+
if basis_vectors is None:
|
181 |
+
motion_proj = motion_3d[:, :, [0, 2], :] # [B, n_joints, 2, seq_len]
|
182 |
+
else:
|
183 |
+
if project_2d: basis_vectors = basis_vectors[:, :, :, [0, 2], :]
|
184 |
+
_, K, seq_len, _, _ = basis_vectors.size()
|
185 |
+
motion_3d = motion_3d.unsqueeze(1).repeat(1, K, 1, 1, 1)
|
186 |
+
motion_3d = motion_3d.permute([0, 1, 4, 3, 2]) # [B, K, J, 3, T] -> [B, K, T, 3, J]
|
187 |
+
motion_proj = basis_vectors @ motion_3d # [B, K, T, 2, 3] @ [B, K, T, 3, J] -> [B, K, T, 2, J]
|
188 |
+
motion_proj = motion_proj.permute([0, 1, 4, 3, 2]) # [B, K, T, 3, J] -> [B, K, J, 3, T]
|
189 |
+
|
190 |
+
return motion_proj
|
191 |
+
|
192 |
+
|
193 |
+
def rotate_and_maybe_project_world(X, angles=None, body_reference=True, project_2d=False):
|
194 |
+
|
195 |
+
out_dim = 2 if project_2d else 3
|
196 |
+
batch_size, n_joints, _, seq_len = X.size()
|
197 |
+
|
198 |
+
if angles is not None:
|
199 |
+
K = angles.size(1)
|
200 |
+
basis_vectors = get_body_basis(X) if body_reference else \
|
201 |
+
torch.eye(3, device=X.device).unsqueeze(0).repeat(batch_size, 1, 1)
|
202 |
+
basis_vectors = rotate_basis_euler(basis_vectors, angles)
|
203 |
+
X_trans = change_of_basis(X, basis_vectors, project_2d=project_2d)
|
204 |
+
X_trans = X_trans.reshape(batch_size * K, n_joints, out_dim, seq_len)
|
205 |
+
else:
|
206 |
+
X_trans = change_of_basis(X, project_2d=project_2d)
|
207 |
+
X_trans = X_trans.reshape(batch_size, n_joints, out_dim, seq_len)
|
208 |
+
|
209 |
+
return X_trans
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
def rotate_and_maybe_project_learning(X, meanpose, stdpose, angles=None, body_reference=True, project_2d=False):
|
214 |
+
batch_size, channels, seq_len = X.size()
|
215 |
+
n_joints = channels // 3
|
216 |
+
X = restore_world_space(X, meanpose, stdpose, n_joints)
|
217 |
+
X = rotate_and_maybe_project_world(X, angles, body_reference, project_2d)
|
218 |
+
X = convert_to_learning_space(X, meanpose, stdpose)
|
219 |
+
return X
|
lib/trainer.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import lib.network
|
7 |
+
from lib.loss import *
|
8 |
+
from lib.util.general import weights_init, get_model_list, get_scheduler
|
9 |
+
from lib.network import Discriminator
|
10 |
+
from lib.operation import rotate_and_maybe_project_learning
|
11 |
+
|
12 |
+
class BaseTrainer(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, config):
|
15 |
+
super(BaseTrainer, self).__init__()
|
16 |
+
|
17 |
+
lr = config.lr
|
18 |
+
autoencoder_cls = getattr(lib.network, config.autoencoder.cls)
|
19 |
+
self.autoencoder = autoencoder_cls(config.autoencoder)
|
20 |
+
self.discriminator = Discriminator(config.discriminator)
|
21 |
+
|
22 |
+
# Setup the optimizers
|
23 |
+
beta1 = config.beta1
|
24 |
+
beta2 = config.beta2
|
25 |
+
dis_params = list(self.discriminator.parameters())
|
26 |
+
ae_params = list(self.autoencoder.parameters())
|
27 |
+
self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
|
28 |
+
lr=lr, betas=(beta1, beta2), weight_decay=config.weight_decay)
|
29 |
+
self.ae_opt = torch.optim.Adam([p for p in ae_params if p.requires_grad],
|
30 |
+
lr=lr, betas=(beta1, beta2), weight_decay=config.weight_decay)
|
31 |
+
self.dis_scheduler = get_scheduler(self.dis_opt, config)
|
32 |
+
self.ae_scheduler = get_scheduler(self.ae_opt, config)
|
33 |
+
|
34 |
+
# Network weight initialization
|
35 |
+
self.apply(weights_init(config.init))
|
36 |
+
self.discriminator.apply(weights_init('gaussian'))
|
37 |
+
|
38 |
+
def forward(self, data):
|
39 |
+
x_a, x_b = data["x_a"], data["x_b"]
|
40 |
+
batch_size = x_a.size(0)
|
41 |
+
self.eval()
|
42 |
+
body_a, body_b = self.sample_body_code(batch_size)
|
43 |
+
motion_a = self.autoencoder.encode_motion(x_a)
|
44 |
+
body_a_enc, _ = self.autoencoder.encode_body(x_a)
|
45 |
+
motion_b = self.autoencoder.encode_motion(x_b)
|
46 |
+
body_b_enc, _ = self.autoencoder.encode_body(x_b)
|
47 |
+
x_ab = self.autoencoder.decode(motion_a, body_b)
|
48 |
+
x_ba = self.autoencoder.decode(motion_b, body_a)
|
49 |
+
self.train()
|
50 |
+
return x_ab, x_ba
|
51 |
+
|
52 |
+
def dis_update(self, data, config):
|
53 |
+
raise NotImplemented
|
54 |
+
|
55 |
+
def ae_update(self, data, config):
|
56 |
+
raise NotImplemented
|
57 |
+
|
58 |
+
def recon_criterion(self, input, target):
|
59 |
+
raise NotImplemented
|
60 |
+
|
61 |
+
def update_learning_rate(self):
|
62 |
+
if self.dis_scheduler is not None:
|
63 |
+
self.dis_scheduler.step()
|
64 |
+
if self.ae_scheduler is not None:
|
65 |
+
self.ae_scheduler.step()
|
66 |
+
|
67 |
+
def resume(self, checkpoint_dir, config):
|
68 |
+
# Load generators
|
69 |
+
last_model_name = get_model_list(checkpoint_dir, "autoencoder")
|
70 |
+
state_dict = torch.load(last_model_name)
|
71 |
+
self.autoencoder.load_state_dict(state_dict)
|
72 |
+
iterations = int(last_model_name[-11:-3])
|
73 |
+
# Load discriminators
|
74 |
+
last_model_name = get_model_list(checkpoint_dir, "discriminator")
|
75 |
+
state_dict = torch.load(last_model_name)
|
76 |
+
self.discriminator.load_state_dict(state_dict)
|
77 |
+
# Load optimizers
|
78 |
+
state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
|
79 |
+
self.dis_opt.load_state_dict(state_dict['discriminator'])
|
80 |
+
self.ae_opt.load_state_dict(state_dict['autoencoder'])
|
81 |
+
# Reinitilize schedulers
|
82 |
+
self.dis_scheduler = get_scheduler(self.dis_opt, config, iterations)
|
83 |
+
self.ae_scheduler = get_scheduler(self.ae_opt, config, iterations)
|
84 |
+
print('Resume from iteration %d' % iterations)
|
85 |
+
return iterations
|
86 |
+
|
87 |
+
def save(self, snapshot_dir, iterations):
|
88 |
+
# Save generators, discriminators, and optimizers
|
89 |
+
ae_name = os.path.join(snapshot_dir, 'autoencoder_%08d.pt' % (iterations + 1))
|
90 |
+
dis_name = os.path.join(snapshot_dir, 'discriminator_%08d.pt' % (iterations + 1))
|
91 |
+
opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
|
92 |
+
torch.save(self.autoencoder.state_dict(), ae_name)
|
93 |
+
torch.save(self.discriminator.state_dict(), dis_name)
|
94 |
+
torch.save({'autoencoder': self.ae_opt.state_dict(), 'discriminator': self.dis_opt.state_dict()}, opt_name)
|
95 |
+
|
96 |
+
def validate(self, data, config):
|
97 |
+
re_dict = self.evaluate(self.autoencoder, data, config)
|
98 |
+
for key, val in re_dict.items():
|
99 |
+
setattr(self, key, val)
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def recon_criterion(input, target):
|
103 |
+
return torch.mean(torch.abs(input - target))
|
104 |
+
|
105 |
+
@classmethod
|
106 |
+
def evaluate(cls, autoencoder, data, config):
|
107 |
+
autoencoder.eval()
|
108 |
+
x_a, x_b = data["x_a"], data["x_b"]
|
109 |
+
x_aba, x_bab = data["x_aba"], data["x_bab"]
|
110 |
+
batch_size, _, seq_len = x_a.size()
|
111 |
+
|
112 |
+
re_dict = {}
|
113 |
+
|
114 |
+
with torch.no_grad(): # 2D eval
|
115 |
+
|
116 |
+
x_a_recon = autoencoder.reconstruct2d(x_a)
|
117 |
+
x_b_recon = autoencoder.reconstruct2d(x_b)
|
118 |
+
x_aba_recon = autoencoder.cross2d(x_a, x_b, x_a)
|
119 |
+
x_bab_recon = autoencoder.cross2d(x_b, x_a, x_b)
|
120 |
+
|
121 |
+
re_dict['loss_val_recon_x'] = cls.recon_criterion(x_a_recon, x_a) + cls.recon_criterion(x_b_recon, x_b)
|
122 |
+
re_dict['loss_val_cross_body'] = cls.recon_criterion(x_aba_recon, x_aba) + cls.recon_criterion(
|
123 |
+
x_bab_recon, x_bab)
|
124 |
+
re_dict['loss_val_total'] = 0.5 * re_dict['loss_val_recon_x'] + 0.5 * re_dict['loss_val_cross_body']
|
125 |
+
|
126 |
+
autoencoder.train()
|
127 |
+
return re_dict
|
128 |
+
|
129 |
+
|
130 |
+
class TransmomoTrainer(BaseTrainer):
|
131 |
+
|
132 |
+
def __init__(self, config):
|
133 |
+
super(TransmomoTrainer, self).__init__(config)
|
134 |
+
|
135 |
+
self.angle_unit = np.pi / (config.K + 1)
|
136 |
+
view_angles = np.array([i * self.angle_unit for i in range(1, config.K + 1)])
|
137 |
+
x_angles = view_angles if config.rotation_axes[0] else np.array([0])
|
138 |
+
z_angles = view_angles if config.rotation_axes[1] else np.array([0])
|
139 |
+
y_angles = view_angles if config.rotation_axes[2] else np.array([0])
|
140 |
+
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
|
141 |
+
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
|
142 |
+
self.angles = torch.tensor(angles).float().cuda()
|
143 |
+
self.rotation_axes = torch.tensor(config.rotation_axes).float().cuda()
|
144 |
+
self.rotation_axes_mask = [(_ > 0) for _ in config.rotation_axes]
|
145 |
+
|
146 |
+
def dis_update(self, data, config):
|
147 |
+
|
148 |
+
x_a = data["x"]
|
149 |
+
x_s = data["x_s"] # the limb-scaled version of x_a
|
150 |
+
meanpose = data["meanpose"][0]
|
151 |
+
stdpose = data["stdpose"][0]
|
152 |
+
|
153 |
+
self.dis_opt.zero_grad()
|
154 |
+
|
155 |
+
# encode
|
156 |
+
motion_a = self.autoencoder.encode_motion(x_a)
|
157 |
+
body_a, body_a_seq = self.autoencoder.encode_body(x_a)
|
158 |
+
view_a, view_a_seq = self.autoencoder.encode_view(x_a)
|
159 |
+
|
160 |
+
motion_s = self.autoencoder.encode_motion(x_s)
|
161 |
+
body_s, body_s_seq = self.autoencoder.encode_body(x_s)
|
162 |
+
view_s, view_s_seq = self.autoencoder.encode_view(x_s)
|
163 |
+
|
164 |
+
# decode (reconstruct, transform)
|
165 |
+
inds = random.sample(list(range(self.angles.size(0))), config.K)
|
166 |
+
angles = self.angles[inds].clone().detach() # [K, 3]
|
167 |
+
angles += self.angle_unit * self.rotation_axes * torch.randn([3], device=x_a.device)
|
168 |
+
angles = angles.unsqueeze(0).unsqueeze(2) # [B=1, K, T=1, 3]
|
169 |
+
|
170 |
+
X_a_recon = self.autoencoder.decode(motion_a, body_a, view_a)
|
171 |
+
x_a_trans = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=angles,
|
172 |
+
body_reference=config.autoencoder.body_reference, project_2d=True)
|
173 |
+
|
174 |
+
x_a_exp = x_a.repeat_interleave(config.K, dim=0)
|
175 |
+
|
176 |
+
self.loss_dis_trans = self.discriminator.calc_dis_loss(x_a_trans.detach(), x_a_exp)
|
177 |
+
|
178 |
+
if config.trans_gan_ls_w > 0:
|
179 |
+
X_s_recon = self.autoencoder.decode(motion_s, body_s, view_s)
|
180 |
+
x_s_trans = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=angles,
|
181 |
+
body_reference=config.autoencoder.body_reference, project_2d=True)
|
182 |
+
x_s_exp = x_s.repeat_interleave(config.K, dim=0)
|
183 |
+
self.loss_dis_trans_ls = self.discriminator.calc_dis_loss(x_s_trans.detach(), x_s_exp)
|
184 |
+
else:
|
185 |
+
self.loss_dis_trans_ls = 0
|
186 |
+
|
187 |
+
self.loss_dis_total = config.trans_gan_w * self.loss_dis_trans + \
|
188 |
+
config.trans_gan_ls_w * self.loss_dis_trans_ls
|
189 |
+
|
190 |
+
self.loss_dis_total.backward()
|
191 |
+
self.dis_opt.step()
|
192 |
+
|
193 |
+
def ae_update(self, data, config):
|
194 |
+
|
195 |
+
x_a = data["x"]
|
196 |
+
x_s = data["x_s"]
|
197 |
+
meanpose = data["meanpose"][0]
|
198 |
+
stdpose = data["stdpose"][0]
|
199 |
+
self.ae_opt.zero_grad()
|
200 |
+
|
201 |
+
# encode
|
202 |
+
motion_a = self.autoencoder.encode_motion(x_a)
|
203 |
+
body_a, body_a_seq = self.autoencoder.encode_body(x_a)
|
204 |
+
view_a, view_a_seq = self.autoencoder.encode_view(x_a)
|
205 |
+
|
206 |
+
motion_s = self.autoencoder.encode_motion(x_s)
|
207 |
+
body_s, body_s_seq = self.autoencoder.encode_body(x_s)
|
208 |
+
view_s, view_s_seq = self.autoencoder.encode_view(x_s)
|
209 |
+
|
210 |
+
# invariance loss
|
211 |
+
self.loss_inv_v_ls = self.recon_criterion(view_a, view_s) if config.inv_v_ls_w > 0 else 0
|
212 |
+
self.loss_inv_m_ls = self.recon_criterion(motion_a, motion_s) if config.inv_m_ls_w > 0 else 0
|
213 |
+
|
214 |
+
# body triplet loss
|
215 |
+
if config.triplet_b_w > 0:
|
216 |
+
self.loss_triplet_b = triplet_margin_loss(
|
217 |
+
body_a_seq, body_s_seq,
|
218 |
+
neg_range=config.triplet_neg_range,
|
219 |
+
margin=config.triplet_margin)
|
220 |
+
else:
|
221 |
+
self.loss_triplet_b = 0
|
222 |
+
|
223 |
+
# reconstruction
|
224 |
+
X_a_recon = self.autoencoder.decode(motion_a, body_a, view_a)
|
225 |
+
x_a_recon = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=None,
|
226 |
+
body_reference=config.autoencoder.body_reference, project_2d=True)
|
227 |
+
|
228 |
+
X_s_recon = self.autoencoder.decode(motion_s, body_s, view_s)
|
229 |
+
x_s_recon = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=None,
|
230 |
+
body_reference=config.autoencoder.body_reference, project_2d=True)
|
231 |
+
|
232 |
+
self.loss_recon_x = 0.5 * self.recon_criterion(x_a_recon, x_a) +\
|
233 |
+
0.5 * self.recon_criterion(x_s_recon, x_s)
|
234 |
+
|
235 |
+
# cross reconstruction
|
236 |
+
X_as_recon = self.autoencoder.decode(motion_a, body_s, view_s)
|
237 |
+
x_as_recon = rotate_and_maybe_project_learning(X_as_recon, meanpose, stdpose, angles=None,
|
238 |
+
body_reference=config.autoencoder.body_reference, project_2d=True)
|
239 |
+
|
240 |
+
X_sa_recon = self.autoencoder.decode(motion_s, body_a, view_a)
|
241 |
+
x_sa_recon = rotate_and_maybe_project_learning(X_sa_recon, meanpose, stdpose, angles=None,
|
242 |
+
body_reference=config.autoencoder.body_reference, project_2d=True)
|
243 |
+
|
244 |
+
self.loss_cross_x = 0.5 * self.recon_criterion(x_as_recon, x_s) + 0.5 * self.recon_criterion(x_sa_recon, x_a)
|
245 |
+
|
246 |
+
# apply transformation
|
247 |
+
inds = random.sample(list(range(self.angles.size(0))), config.K)
|
248 |
+
angles = self.angles[inds].clone().detach()
|
249 |
+
angles += self.angle_unit * self.rotation_axes * torch.randn([3], device=x_a.device)
|
250 |
+
angles = angles.unsqueeze(0).unsqueeze(2)
|
251 |
+
|
252 |
+
x_a_trans = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=angles,
|
253 |
+
body_reference=config.autoencoder.body_reference, project_2d=True)
|
254 |
+
x_s_trans = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=angles,
|
255 |
+
body_reference=config.autoencoder.body_reference, project_2d=True)
|
256 |
+
|
257 |
+
# GAN loss
|
258 |
+
self.loss_gan_trans = self.discriminator.calc_gen_loss(x_a_trans)
|
259 |
+
self.loss_gan_trans_ls = self.discriminator.calc_gen_loss(x_s_trans) if config.trans_gan_ls_w > 0 else 0
|
260 |
+
|
261 |
+
# encode again
|
262 |
+
motion_a_trans = self.autoencoder.encode_motion(x_a_trans)
|
263 |
+
body_a_trans, _ = self.autoencoder.encode_body(x_a_trans)
|
264 |
+
view_a_trans, view_a_trans_seq = self.autoencoder.encode_view(x_a_trans)
|
265 |
+
|
266 |
+
motion_s_trans = self.autoencoder.encode_motion(x_s_trans)
|
267 |
+
body_s_trans, _ = self.autoencoder.encode_body(x_s_trans)
|
268 |
+
|
269 |
+
self.loss_inv_m_trans = 0.5 * self.recon_criterion(motion_a_trans, motion_a.repeat_interleave(config.K, dim=0)) + \
|
270 |
+
0.5 * self.recon_criterion(motion_s_trans, motion_s.repeat_interleave(config.K, dim=0))
|
271 |
+
self.loss_inv_b_trans = 0.5 * self.recon_criterion(body_a_trans, body_a.repeat_interleave(config.K, dim=0)) + \
|
272 |
+
0.5 * self.recon_criterion(body_s_trans, body_s.repeat_interleave(config.K, dim=0))
|
273 |
+
|
274 |
+
# view triplet loss
|
275 |
+
if config.triplet_v_w > 0:
|
276 |
+
view_a_seq_exp = view_a_seq.repeat_interleave(config.K, dim=0)
|
277 |
+
self.loss_triplet_v = triplet_margin_loss(
|
278 |
+
view_a_seq_exp, view_a_trans_seq,
|
279 |
+
neg_range=config.triplet_neg_range, margin=config.triplet_margin)
|
280 |
+
else:
|
281 |
+
self.loss_triplet_v = 0
|
282 |
+
|
283 |
+
# add all losses
|
284 |
+
self.loss_total = torch.tensor(0.).float().cuda()
|
285 |
+
self.loss_total += config.recon_x_w * self.loss_recon_x
|
286 |
+
self.loss_total += config.cross_x_w * self.loss_cross_x
|
287 |
+
self.loss_total += config.inv_v_ls_w * self.loss_inv_v_ls
|
288 |
+
self.loss_total += config.inv_m_ls_w * self.loss_inv_m_ls
|
289 |
+
self.loss_total += config.inv_b_trans_w * self.loss_inv_b_trans
|
290 |
+
self.loss_total += config.inv_m_trans_w * self.loss_inv_m_trans
|
291 |
+
self.loss_total += config.trans_gan_w * self.loss_gan_trans
|
292 |
+
self.loss_total += config.trans_gan_ls_w * self.loss_gan_trans_ls
|
293 |
+
self.loss_total += config.triplet_b_w * self.loss_triplet_b
|
294 |
+
self.loss_total += config.triplet_v_w * self.loss_triplet_v
|
295 |
+
|
296 |
+
self.loss_total.backward()
|
297 |
+
self.ae_opt.step()
|
298 |
+
|
lib/util/__init__.py
ADDED
File without changes
|
lib/util/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (146 Bytes). View file
|
|
lib/util/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (150 Bytes). View file
|
|
lib/util/__pycache__/general.cpython-37.pyc
ADDED
Binary file (13.3 kB). View file
|
|
lib/util/__pycache__/general.cpython-38.pyc
ADDED
Binary file (13.4 kB). View file
|
|
lib/util/__pycache__/motion.cpython-37.pyc
ADDED
Binary file (8.05 kB). View file
|
|
lib/util/__pycache__/motion.cpython-38.pyc
ADDED
Binary file (8.07 kB). View file
|
|
lib/util/__pycache__/visualization.cpython-37.pyc
ADDED
Binary file (12.6 kB). View file
|
|
lib/util/__pycache__/visualization.cpython-38.pyc
ADDED
Binary file (12.7 kB). View file
|
|
lib/util/general.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import shutil
|
6 |
+
import csv
|
7 |
+
# from lib.network.munit import Vgg16
|
8 |
+
from torch.autograd import Variable
|
9 |
+
from torch.optim import lr_scheduler
|
10 |
+
from easydict import EasyDict as edict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import os
|
15 |
+
import math
|
16 |
+
import torchvision.utils as vutils
|
17 |
+
import yaml
|
18 |
+
import numpy as np
|
19 |
+
import torch.nn.init as init
|
20 |
+
import time
|
21 |
+
|
22 |
+
|
23 |
+
def get_config(config_path):
|
24 |
+
with open(config_path, 'r') as stream:
|
25 |
+
config = yaml.load(stream, Loader=yaml.SafeLoader)
|
26 |
+
config = edict(config)
|
27 |
+
_, config_filename = os.path.split(config_path)
|
28 |
+
config_name, _ = os.path.splitext(config_filename)
|
29 |
+
config.name = config_name
|
30 |
+
return config
|
31 |
+
|
32 |
+
class TextLogger:
|
33 |
+
|
34 |
+
def __init__(self, log_path):
|
35 |
+
self.log_path = log_path
|
36 |
+
with open(self.log_path, "w") as f:
|
37 |
+
f.write("")
|
38 |
+
def log(self, log):
|
39 |
+
with open(self.log_path, "a+") as f:
|
40 |
+
f.write(log + "\n")
|
41 |
+
|
42 |
+
def eformat(f, prec):
|
43 |
+
s = "%.*e"%(prec, f)
|
44 |
+
mantissa, exp = s.split('e')
|
45 |
+
# add 1 to digits as 1 is taken by sign +/-
|
46 |
+
return "%se%d"%(mantissa, int(exp))
|
47 |
+
|
48 |
+
|
49 |
+
def __write_images(image_outputs, display_image_num, file_name):
|
50 |
+
image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels
|
51 |
+
image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0)
|
52 |
+
image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True)
|
53 |
+
vutils.save_image(image_grid, file_name, nrow=1)
|
54 |
+
|
55 |
+
|
56 |
+
def write_2images(image_outputs, display_image_num, image_directory, postfix):
|
57 |
+
n = len(image_outputs)
|
58 |
+
__write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix))
|
59 |
+
__write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix))
|
60 |
+
|
61 |
+
|
62 |
+
def write_one_row_html(html_file, iterations, img_filename, all_size):
|
63 |
+
html_file.write("<h3>iteration [%d] (%s)</h3>" % (iterations,img_filename.split('/')[-1]))
|
64 |
+
html_file.write("""
|
65 |
+
<p><a href="%s">
|
66 |
+
<img src="%s" style="width:%dpx">
|
67 |
+
</a><br>
|
68 |
+
<p>
|
69 |
+
""" % (img_filename, img_filename, all_size))
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536):
|
74 |
+
html_file = open(filename, "w")
|
75 |
+
html_file.write('''
|
76 |
+
<!DOCTYPE html>
|
77 |
+
<html>
|
78 |
+
<head>
|
79 |
+
<title>Experiment name = %s</title>
|
80 |
+
<meta http-equiv="refresh" content="30">
|
81 |
+
</head>
|
82 |
+
<body>
|
83 |
+
''' % os.path.basename(filename))
|
84 |
+
html_file.write("<h3>current</h3>")
|
85 |
+
write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size)
|
86 |
+
write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size)
|
87 |
+
for j in range(iterations, image_save_iterations-1, -1):
|
88 |
+
if j % image_save_iterations == 0:
|
89 |
+
write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size)
|
90 |
+
write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size)
|
91 |
+
write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size)
|
92 |
+
write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size)
|
93 |
+
html_file.write("</body></html>")
|
94 |
+
html_file.close()
|
95 |
+
|
96 |
+
|
97 |
+
def write_loss(iterations, trainer, train_writer):
|
98 |
+
members = [attr for attr in dir(trainer) \
|
99 |
+
if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'grad' in attr or 'nwd' in attr)]
|
100 |
+
for m in members:
|
101 |
+
train_writer.add_scalar(m, getattr(trainer, m), iterations + 1)
|
102 |
+
|
103 |
+
|
104 |
+
def slerp(val, low, high):
|
105 |
+
"""
|
106 |
+
original: Animating Rotation with Quaternion Curves, Ken Shoemake
|
107 |
+
https://arxiv.org/abs/1609.04468
|
108 |
+
Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White
|
109 |
+
"""
|
110 |
+
omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high)))
|
111 |
+
so = np.sin(omega)
|
112 |
+
return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high
|
113 |
+
|
114 |
+
|
115 |
+
def get_slerp_interp(nb_latents, nb_interp, z_dim):
|
116 |
+
"""
|
117 |
+
modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot
|
118 |
+
https://github.com/ptrblck/prog_gans_pytorch_inference
|
119 |
+
"""
|
120 |
+
|
121 |
+
latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32)
|
122 |
+
for _ in range(nb_latents):
|
123 |
+
low = np.random.randn(z_dim)
|
124 |
+
high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7
|
125 |
+
interp_vals = np.linspace(0, 1, num=nb_interp)
|
126 |
+
latent_interp = np.array([slerp(v, low, high) for v in interp_vals],
|
127 |
+
dtype=np.float32)
|
128 |
+
latent_interps = np.vstack((latent_interps, latent_interp))
|
129 |
+
|
130 |
+
return latent_interps[:, :, np.newaxis, np.newaxis]
|
131 |
+
|
132 |
+
|
133 |
+
# Get model list for resume
|
134 |
+
def get_model_list(dirname, key):
|
135 |
+
if os.path.exists(dirname) is False:
|
136 |
+
return None
|
137 |
+
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
|
138 |
+
os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
|
139 |
+
if gen_models is None:
|
140 |
+
return None
|
141 |
+
gen_models.sort()
|
142 |
+
last_model_name = gen_models[-1]
|
143 |
+
return last_model_name
|
144 |
+
|
145 |
+
|
146 |
+
def get_scheduler(optimizer, hyperparameters, iterations=-1):
|
147 |
+
if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant':
|
148 |
+
scheduler = None # constant scheduler
|
149 |
+
elif hyperparameters['lr_policy'] == 'step':
|
150 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'],
|
151 |
+
gamma=hyperparameters['gamma'], last_epoch=iterations)
|
152 |
+
else:
|
153 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy'])
|
154 |
+
return scheduler
|
155 |
+
|
156 |
+
|
157 |
+
def weights_init(init_type='gaussian'):
|
158 |
+
def init_fun(m):
|
159 |
+
classname = m.__class__.__name__
|
160 |
+
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
|
161 |
+
# print m.__class__.__name__
|
162 |
+
if init_type == 'gaussian':
|
163 |
+
init.normal_(m.weight.data, 0.0, 0.02)
|
164 |
+
elif init_type == 'xavier':
|
165 |
+
init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
|
166 |
+
elif init_type == 'kaiming':
|
167 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
168 |
+
elif init_type == 'orthogonal':
|
169 |
+
init.orthogonal_(m.weight.data, gain=math.sqrt(2))
|
170 |
+
elif init_type == 'default':
|
171 |
+
pass
|
172 |
+
else:
|
173 |
+
assert 0, "Unsupported initialization: {}".format(init_type)
|
174 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
175 |
+
init.constant_(m.bias.data, 0.0)
|
176 |
+
|
177 |
+
return init_fun
|
178 |
+
|
179 |
+
|
180 |
+
class Timer:
|
181 |
+
def __init__(self, msg):
|
182 |
+
self.msg = msg
|
183 |
+
self.start_time = None
|
184 |
+
|
185 |
+
def __enter__(self):
|
186 |
+
self.start_time = time.time()
|
187 |
+
|
188 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
189 |
+
print(self.msg % (time.time() - self.start_time))
|
190 |
+
|
191 |
+
|
192 |
+
class TrainClock(object):
|
193 |
+
def __init__(self):
|
194 |
+
self.epoch = 1
|
195 |
+
self.minibatch = 0
|
196 |
+
self.step = 0
|
197 |
+
|
198 |
+
def tick(self):
|
199 |
+
self.minibatch += 1
|
200 |
+
self.step += 1
|
201 |
+
|
202 |
+
def tock(self):
|
203 |
+
self.epoch += 1
|
204 |
+
self.minibatch = 0
|
205 |
+
|
206 |
+
def make_checkpoint(self):
|
207 |
+
return {
|
208 |
+
'epoch': self.epoch,
|
209 |
+
'minibatch': self.minibatch,
|
210 |
+
'step': self.step
|
211 |
+
}
|
212 |
+
|
213 |
+
def restore_checkpoint(self, clock_dict):
|
214 |
+
self.epoch = clock_dict['epoch']
|
215 |
+
self.minibatch = clock_dict['minibatch']
|
216 |
+
self.step = clock_dict['step']
|
217 |
+
|
218 |
+
|
219 |
+
class Table(object):
|
220 |
+
def __init__(self, filename):
|
221 |
+
'''
|
222 |
+
create a table to record experiment results that can be opened by excel
|
223 |
+
:param filename: using '.csv' as postfix
|
224 |
+
'''
|
225 |
+
assert '.csv' in filename
|
226 |
+
self.filename = filename
|
227 |
+
|
228 |
+
@staticmethod
|
229 |
+
def merge_headers(header1, header2):
|
230 |
+
#return list(set(header1 + header2))
|
231 |
+
if len(header1) > len(header2):
|
232 |
+
return header1
|
233 |
+
else:
|
234 |
+
return header2
|
235 |
+
|
236 |
+
def write(self, ordered_dict):
|
237 |
+
'''
|
238 |
+
write an entry
|
239 |
+
:param ordered_dict: something like {'name':'exp1', 'acc':90.5, 'epoch':50}
|
240 |
+
:return:
|
241 |
+
'''
|
242 |
+
if os.path.exists(self.filename) == False:
|
243 |
+
headers = list(ordered_dict.keys())
|
244 |
+
prev_rec = None
|
245 |
+
else:
|
246 |
+
with open(self.filename) as f:
|
247 |
+
reader = csv.DictReader(f)
|
248 |
+
headers = reader.fieldnames
|
249 |
+
prev_rec = [row for row in reader]
|
250 |
+
headers = self.merge_headers(headers, list(ordered_dict.keys()))
|
251 |
+
|
252 |
+
with open(self.filename, 'w', newline='') as f:
|
253 |
+
writer = csv.DictWriter(f, headers)
|
254 |
+
writer.writeheader()
|
255 |
+
if not prev_rec == None:
|
256 |
+
writer.writerows(prev_rec)
|
257 |
+
writer.writerow(ordered_dict)
|
258 |
+
|
259 |
+
|
260 |
+
class WorklogLogger:
|
261 |
+
def __init__(self, log_file):
|
262 |
+
logging.basicConfig(filename=log_file,
|
263 |
+
level=logging.DEBUG,
|
264 |
+
format='%(asctime)s - %(threadName)s - %(levelname)s - %(message)s')
|
265 |
+
|
266 |
+
self.logger = logging.getLogger()
|
267 |
+
|
268 |
+
def put_line(self, line):
|
269 |
+
self.logger.info(line)
|
270 |
+
|
271 |
+
|
272 |
+
class AverageMeter(object):
|
273 |
+
"""Computes and stores the average and current value"""
|
274 |
+
|
275 |
+
def __init__(self, name):
|
276 |
+
self.name = name
|
277 |
+
self.reset()
|
278 |
+
|
279 |
+
def reset(self):
|
280 |
+
self.val = 0
|
281 |
+
self.avg = 0
|
282 |
+
self.sum = 0
|
283 |
+
self.count = 0
|
284 |
+
|
285 |
+
def update(self, val, n=1):
|
286 |
+
self.val = val
|
287 |
+
self.sum += val * n
|
288 |
+
self.count += n
|
289 |
+
self.avg = self.sum / self.count
|
290 |
+
|
291 |
+
|
292 |
+
def save_args(args, save_dir):
|
293 |
+
param_path = os.path.join(save_dir, 'params.json')
|
294 |
+
|
295 |
+
with open(param_path, 'w') as fp:
|
296 |
+
json.dump(args.__dict__, fp, indent=4, sort_keys=True)
|
297 |
+
|
298 |
+
|
299 |
+
def ensure_dir(path):
|
300 |
+
"""
|
301 |
+
create path by first checking its existence,
|
302 |
+
:param paths: path
|
303 |
+
:return:
|
304 |
+
"""
|
305 |
+
if not os.path.exists(path):
|
306 |
+
os.makedirs(path)
|
307 |
+
|
308 |
+
|
309 |
+
def ensure_dirs(paths):
|
310 |
+
"""
|
311 |
+
create paths by first checking their existence
|
312 |
+
:param paths: list of path
|
313 |
+
:return:
|
314 |
+
"""
|
315 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
316 |
+
for path in paths:
|
317 |
+
ensure_dir(path)
|
318 |
+
else:
|
319 |
+
ensure_dir(paths)
|
320 |
+
|
321 |
+
|
322 |
+
def remkdir(path):
|
323 |
+
"""
|
324 |
+
if dir exists, remove it and create a new one
|
325 |
+
:param path:
|
326 |
+
:return:
|
327 |
+
"""
|
328 |
+
if os.path.exists(path):
|
329 |
+
shutil.rmtree(path)
|
330 |
+
os.makedirs(path)
|
331 |
+
|
332 |
+
|
333 |
+
def cycle(iterable):
|
334 |
+
while True:
|
335 |
+
for x in iterable:
|
336 |
+
yield x
|
337 |
+
|
338 |
+
|
339 |
+
def save_image(image_numpy, image_path):
|
340 |
+
image_pil = Image.fromarray(image_numpy)
|
341 |
+
image_pil.save(image_path)
|
342 |
+
|
343 |
+
|
344 |
+
def pad_to_16x(x):
|
345 |
+
if x % 16 > 0:
|
346 |
+
return x - x % 16 + 16
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
def pad_to_height(tar_height, img_height, img_width):
|
351 |
+
scale = tar_height / img_height
|
352 |
+
h = pad_to_16x(tar_height)
|
353 |
+
w = pad_to_16x(int(img_width * scale))
|
354 |
+
return h, w, scale
|
355 |
+
|
356 |
+
|
357 |
+
def to_gpu(data):
|
358 |
+
for key, item in data.items():
|
359 |
+
if torch.is_tensor(item):
|
360 |
+
data[key] = item.cuda()
|
361 |
+
return data
|
lib/util/global_norm.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
def get_box(pose):
|
4 |
+
#input: pose([15,2])
|
5 |
+
return(np.min(pose[:,0]), np.max(pose[:,0]), np.min(pose[:,1]), np.max(pose[:,1]))
|
6 |
+
|
7 |
+
def get_height(pose):
|
8 |
+
#input: pose([15,2])
|
9 |
+
mean_ankle = (pose[14]+pose[11])/2
|
10 |
+
nose = pose[0]
|
11 |
+
return np.linalg.norm(mean_ankle-nose)
|
12 |
+
|
13 |
+
def get_base_mean(pose):
|
14 |
+
#input: pose([15,2])
|
15 |
+
x1, x2, y1, y2 = get_box(pose)
|
16 |
+
return np.array([(x1+x2)/2, y2])
|
17 |
+
|
18 |
+
def global_norm(driving_npy, target_npy):
|
19 |
+
#input: pose([15,2,frame1]), pose([15,2,frame2])
|
20 |
+
target_mean = np.mean(target_npy, axis=2)
|
21 |
+
driving_mean = np.mean(driving_npy, axis=2)
|
22 |
+
k2 = get_height(target_mean)/get_height(driving_mean)
|
23 |
+
target_mean_base = get_base_mean(target_mean)
|
24 |
+
driving_mean_base = get_base_mean(driving_mean)
|
25 |
+
driving_npy_permuted = np.transpose(driving_npy, axes=[2, 0, 1])
|
26 |
+
k = [1, k2]
|
27 |
+
normalized_permuted = (driving_npy_permuted-driving_mean_base)*k+target_mean_base
|
28 |
+
normalized = np.transpose(normalized_permuted, axes=[1,2,0])
|
29 |
+
return normalized # pose([15,2,frame1])
|
lib/util/motion.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.ndimage import gaussian_filter1d
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def preprocess_test(motion, meanpose, stdpose, unit=128):
|
9 |
+
|
10 |
+
motion = motion * unit
|
11 |
+
|
12 |
+
motion[1, :, :] = (motion[2, :, :] + motion[5, :, :]) / 2
|
13 |
+
motion[8, :, :] = (motion[9, :, :] + motion[12, :, :]) / 2
|
14 |
+
|
15 |
+
start = motion[8, :, 0]
|
16 |
+
|
17 |
+
motion = localize_motion(motion)
|
18 |
+
motion = normalize_motion(motion, meanpose, stdpose)
|
19 |
+
|
20 |
+
return motion, start
|
21 |
+
|
22 |
+
|
23 |
+
def postprocess(motion, meanpose, stdpose, unit=128, start=None):
|
24 |
+
|
25 |
+
motion = motion.detach().cpu().numpy()[0].reshape(-1, 2, motion.shape[-1])
|
26 |
+
motion = normalize_motion_inv(motion, meanpose, stdpose)
|
27 |
+
motion = globalize_motion(motion, start=start)
|
28 |
+
# motion = motion / unit
|
29 |
+
|
30 |
+
return motion
|
31 |
+
|
32 |
+
|
33 |
+
def preprocess_mixamo(motion, unit=128):
|
34 |
+
|
35 |
+
_, D, _ = motion.shape
|
36 |
+
horizontal_dim = 0
|
37 |
+
vertical_dim = D - 1
|
38 |
+
|
39 |
+
motion[1, :, :] = (motion[2, :, :] + motion[5, :, :]) / 2
|
40 |
+
motion[8, :, :] = (motion[9, :, :] + motion[12, :, :]) / 2
|
41 |
+
|
42 |
+
# rotate 180
|
43 |
+
motion[:, horizontal_dim, :] = - motion[:, horizontal_dim, :]
|
44 |
+
motion[:, vertical_dim, :] = - motion[:, vertical_dim, :]
|
45 |
+
|
46 |
+
motion = motion * unit
|
47 |
+
|
48 |
+
return motion
|
49 |
+
|
50 |
+
|
51 |
+
def rotate_motion_3d(motion3d, change_of_basis):
|
52 |
+
|
53 |
+
if change_of_basis is not None: motion3d = change_of_basis @ motion3d
|
54 |
+
|
55 |
+
return motion3d
|
56 |
+
|
57 |
+
|
58 |
+
def limb_scale_motion_2d(motion2d, global_range, local_range):
|
59 |
+
|
60 |
+
global_scale = global_range[0] + np.random.random() * (global_range[1] - global_range[0])
|
61 |
+
local_scales = local_range[0] + np.random.random([8]) * (local_range[1] - local_range[0])
|
62 |
+
motion_scale = scale_limbs(motion2d, global_scale, local_scales)
|
63 |
+
|
64 |
+
return motion_scale
|
65 |
+
|
66 |
+
|
67 |
+
def localize_motion(motion):
|
68 |
+
"""
|
69 |
+
Motion fed into our network is the local motion, i.e. coordinates relative to the hip joint.
|
70 |
+
This function removes global motion of the hip joint, and instead represents global motion with velocity
|
71 |
+
"""
|
72 |
+
|
73 |
+
D = motion.shape[1]
|
74 |
+
|
75 |
+
# subtract centers to local coordinates
|
76 |
+
centers = motion[8, :, :] # N_dim x T
|
77 |
+
motion = motion - centers
|
78 |
+
|
79 |
+
# adding velocity
|
80 |
+
translation = centers[:, 1:] - centers[:, :-1]
|
81 |
+
velocity = np.c_[np.zeros((D, 1)), translation]
|
82 |
+
velocity = velocity.reshape(1, D, -1)
|
83 |
+
motion = np.r_[motion[:8], motion[9:], velocity]
|
84 |
+
# motion_proj = np.r_[motion_proj[:8], motion_proj[9:]]
|
85 |
+
|
86 |
+
return motion
|
87 |
+
|
88 |
+
|
89 |
+
def globalize_motion(motion, start=None, velocity=None):
|
90 |
+
"""
|
91 |
+
inverse process of localize_motion
|
92 |
+
"""
|
93 |
+
|
94 |
+
if velocity is None: velocity = motion[-1].copy()
|
95 |
+
motion_inv = np.r_[motion[:8], np.zeros((1, 2, motion.shape[-1])), motion[8:-1]]
|
96 |
+
|
97 |
+
# restore centre position
|
98 |
+
centers = np.zeros_like(velocity)
|
99 |
+
sum = 0
|
100 |
+
for i in range(motion.shape[-1]):
|
101 |
+
sum += velocity[:, i]
|
102 |
+
centers[:, i] = sum
|
103 |
+
centers += start.reshape([2, 1])
|
104 |
+
|
105 |
+
return motion_inv + centers.reshape((1, 2, -1))
|
106 |
+
|
107 |
+
|
108 |
+
def normalize_motion(motion, meanpose, stdpose):
|
109 |
+
"""
|
110 |
+
:param motion: (J, 2, T)
|
111 |
+
:param meanpose: (J, 2)
|
112 |
+
:param stdpose: (J, 2)
|
113 |
+
:return:
|
114 |
+
"""
|
115 |
+
if motion.shape[1] == 2 and meanpose.shape[1] == 3:
|
116 |
+
meanpose = meanpose[:, [0, 2]]
|
117 |
+
if motion.shape[1] == 2 and stdpose.shape[1] == 3:
|
118 |
+
stdpose = stdpose[:, [0, 2]]
|
119 |
+
return (motion - meanpose[:, :, np.newaxis]) / stdpose[:, :, np.newaxis]
|
120 |
+
|
121 |
+
|
122 |
+
def normalize_motion_inv(motion, meanpose, stdpose):
|
123 |
+
if motion.shape[1] == 2 and meanpose.shape[1] == 3:
|
124 |
+
meanpose = meanpose[:, [0, 2]]
|
125 |
+
if motion.shape[1] == 2 and stdpose.shape[1] == 3:
|
126 |
+
stdpose = stdpose[:, [0, 2]]
|
127 |
+
return motion * stdpose[:, :, np.newaxis] + meanpose[:, :, np.newaxis]
|
128 |
+
|
129 |
+
|
130 |
+
def get_change_of_basis(motion3d, angles=None):
|
131 |
+
"""
|
132 |
+
Get the unit vectors for local rectangular coordinates for given 3D motion
|
133 |
+
:param motion3d: numpy array. 3D motion from 3D joints positions, shape (nr_joints, 3, nr_frames).
|
134 |
+
:param angles: tuple of length 3. Rotation angles around each axis.
|
135 |
+
:return: numpy array. unit vectors for local rectangular coordinates's , shape (3, 3).
|
136 |
+
"""
|
137 |
+
# 2 RightArm 5 LeftArm 9 RightUpLeg 12 LeftUpLeg
|
138 |
+
horizontal = (motion3d[2] - motion3d[5] + motion3d[9] - motion3d[12]) / 2
|
139 |
+
horizontal = np.mean(horizontal, axis=1)
|
140 |
+
horizontal = horizontal / np.linalg.norm(horizontal)
|
141 |
+
local_z = np.array([0, 0, 1])
|
142 |
+
local_y = np.cross(horizontal, local_z) # bugs!!!, horizontal and local_Z may not be perpendicular
|
143 |
+
local_y = local_y / np.linalg.norm(local_y)
|
144 |
+
local_x = np.cross(local_y, local_z)
|
145 |
+
local = np.stack([local_x, local_y, local_z], axis=0)
|
146 |
+
|
147 |
+
if angles is not None:
|
148 |
+
local = rotate_basis(local, angles)
|
149 |
+
|
150 |
+
return local
|
151 |
+
|
152 |
+
|
153 |
+
def rotate_basis(local3d, angles):
|
154 |
+
"""
|
155 |
+
Rotate local rectangular coordinates from given view_angles.
|
156 |
+
|
157 |
+
:param local3d: numpy array. Unit vectors for local rectangular coordinates's , shape (3, 3).
|
158 |
+
:param angles: tuple of length 3. Rotation angles around each axis.
|
159 |
+
:return:
|
160 |
+
"""
|
161 |
+
cx, cy, cz = np.cos(angles)
|
162 |
+
sx, sy, sz = np.sin(angles)
|
163 |
+
|
164 |
+
x = local3d[0]
|
165 |
+
x_cpm = np.array([
|
166 |
+
[0, -x[2], x[1]],
|
167 |
+
[x[2], 0, -x[0]],
|
168 |
+
[-x[1], x[0], 0]
|
169 |
+
], dtype='float')
|
170 |
+
x = x.reshape(-1, 1)
|
171 |
+
mat33_x = cx * np.eye(3) + sx * x_cpm + (1.0 - cx) * np.matmul(x, x.T)
|
172 |
+
|
173 |
+
mat33_z = np.array([
|
174 |
+
[cz, sz, 0],
|
175 |
+
[-sz, cz, 0],
|
176 |
+
[0, 0, 1]
|
177 |
+
], dtype='float')
|
178 |
+
|
179 |
+
local3d = local3d @ mat33_x.T @ mat33_z
|
180 |
+
return local3d
|
181 |
+
|
182 |
+
|
183 |
+
def get_foot_vel(batch_motion, foot_idx):
|
184 |
+
return batch_motion[:, foot_idx, 1:] - batch_motion[:, foot_idx, :-1] + batch_motion[:, -2:, 1:].repeat(1, 2, 1)
|
185 |
+
|
186 |
+
|
187 |
+
def get_limbs(motion):
|
188 |
+
J, D, T = motion.shape
|
189 |
+
limbs = np.zeros([14, D, T])
|
190 |
+
limbs[0] = motion[0] - motion[1] # neck
|
191 |
+
limbs[1] = motion[2] - motion[1] # r_shoulder
|
192 |
+
limbs[2] = motion[3] - motion[2] # r_arm
|
193 |
+
limbs[3] = motion[4] - motion[3] # r_forearm
|
194 |
+
limbs[4] = motion[5] - motion[1] # l_shoulder
|
195 |
+
limbs[5] = motion[6] - motion[5] # l_arm
|
196 |
+
limbs[6] = motion[7] - motion[6] # l_forearm
|
197 |
+
limbs[7] = motion[1] - motion[8] # spine
|
198 |
+
limbs[8] = motion[9] - motion[8] # r_pelvis
|
199 |
+
limbs[9] = motion[10] - motion[9] # r_thigh
|
200 |
+
limbs[10] = motion[11] - motion[10] # r_shin
|
201 |
+
limbs[11] = motion[12] - motion[8] # l_pelvis
|
202 |
+
limbs[12] = motion[13] - motion[12] # l_thigh
|
203 |
+
limbs[13] = motion[14] - motion[13] # l_shin
|
204 |
+
return limbs
|
205 |
+
|
206 |
+
|
207 |
+
def scale_limbs(motion, global_scale, local_scales):
|
208 |
+
"""
|
209 |
+
:param motion: joint sequence [J, 2, T]
|
210 |
+
:param local_scales: 8 numbers of scales
|
211 |
+
:return: scaled joint sequence
|
212 |
+
"""
|
213 |
+
|
214 |
+
limb_dependents = [
|
215 |
+
[0],
|
216 |
+
[2, 3, 4],
|
217 |
+
[3, 4],
|
218 |
+
[4],
|
219 |
+
[5, 6, 7],
|
220 |
+
[6, 7],
|
221 |
+
[7],
|
222 |
+
[0, 1, 2, 3, 4, 5, 6, 7],
|
223 |
+
[9, 10, 11],
|
224 |
+
[10, 11],
|
225 |
+
[11],
|
226 |
+
[12, 13, 14],
|
227 |
+
[13, 14],
|
228 |
+
[14]
|
229 |
+
]
|
230 |
+
|
231 |
+
limbs = get_limbs(motion)
|
232 |
+
scaled_limbs = limbs.copy() * global_scale
|
233 |
+
scaled_limbs[0] *= local_scales[0]
|
234 |
+
scaled_limbs[1] *= local_scales[1]
|
235 |
+
scaled_limbs[2] *= local_scales[2]
|
236 |
+
scaled_limbs[3] *= local_scales[3]
|
237 |
+
scaled_limbs[4] *= local_scales[1]
|
238 |
+
scaled_limbs[5] *= local_scales[2]
|
239 |
+
scaled_limbs[6] *= local_scales[3]
|
240 |
+
scaled_limbs[7] *= local_scales[4]
|
241 |
+
scaled_limbs[8] *= local_scales[5]
|
242 |
+
scaled_limbs[9] *= local_scales[6]
|
243 |
+
scaled_limbs[10] *= local_scales[7]
|
244 |
+
scaled_limbs[11] *= local_scales[5]
|
245 |
+
scaled_limbs[12] *= local_scales[6]
|
246 |
+
scaled_limbs[13] *= local_scales[7]
|
247 |
+
|
248 |
+
delta = scaled_limbs - limbs
|
249 |
+
|
250 |
+
scaled_motion = motion.copy()
|
251 |
+
scaled_motion[limb_dependents[7]] += delta[7] # spine
|
252 |
+
scaled_motion[limb_dependents[1]] += delta[1] # r_shoulder
|
253 |
+
scaled_motion[limb_dependents[4]] += delta[4] # l_shoulder
|
254 |
+
scaled_motion[limb_dependents[2]] += delta[2] # r_arm
|
255 |
+
scaled_motion[limb_dependents[5]] += delta[5] # l_arm
|
256 |
+
scaled_motion[limb_dependents[3]] += delta[3] # r_forearm
|
257 |
+
scaled_motion[limb_dependents[6]] += delta[6] # l_forearm
|
258 |
+
scaled_motion[limb_dependents[0]] += delta[0] # neck
|
259 |
+
scaled_motion[limb_dependents[8]] += delta[8] # r_pelvis
|
260 |
+
scaled_motion[limb_dependents[11]] += delta[11] # l_pelvis
|
261 |
+
scaled_motion[limb_dependents[9]] += delta[9] # r_thigh
|
262 |
+
scaled_motion[limb_dependents[12]] += delta[12] # l_thigh
|
263 |
+
scaled_motion[limb_dependents[10]] += delta[10] # r_shin
|
264 |
+
scaled_motion[limb_dependents[13]] += delta[13] # l_shin
|
265 |
+
|
266 |
+
|
267 |
+
return scaled_motion
|
268 |
+
|
269 |
+
|
270 |
+
def get_limb_lengths(x):
|
271 |
+
_, dims, _ = x.shape
|
272 |
+
if dims == 2:
|
273 |
+
limbs = np.max(np.linalg.norm(get_limbs(x), axis=1), axis=-1)
|
274 |
+
limb_lengths = np.array([
|
275 |
+
limbs[0], # neck
|
276 |
+
max(limbs[1], limbs[4]), # shoulders
|
277 |
+
max(limbs[2], limbs[5]), # arms
|
278 |
+
max(limbs[3], limbs[6]), # forearms
|
279 |
+
limbs[7], # spine
|
280 |
+
max(limbs[8], limbs[11]), # pelvis
|
281 |
+
max(limbs[9], limbs[12]), # thighs
|
282 |
+
max(limbs[10], limbs[13]) # shins
|
283 |
+
])
|
284 |
+
else:
|
285 |
+
limbs = np.mean(np.linalg.norm(get_limbs(x), axis=1), axis=-1)
|
286 |
+
limb_lengths = np.array([
|
287 |
+
limbs[0], # neck
|
288 |
+
(limbs[1] + limbs[4]) / 2., # shoulders
|
289 |
+
(limbs[2] + limbs[5]) / 2., # arms
|
290 |
+
(limbs[3] + limbs[6]) / 2., # forearms
|
291 |
+
limbs[7], # spine
|
292 |
+
(limbs[8] + limbs[11]) / 2., # pelvis
|
293 |
+
(limbs[9] + limbs[12]) / 2., # thighs
|
294 |
+
(limbs[10] + limbs[13]) / 2. # shins
|
295 |
+
])
|
296 |
+
return limb_lengths
|
297 |
+
|
298 |
+
|
299 |
+
def limb_norm(x_a, x_b):
|
300 |
+
|
301 |
+
limb_lengths_a = get_limb_lengths(x_a)
|
302 |
+
limb_lengths_b = get_limb_lengths(x_b)
|
303 |
+
|
304 |
+
limb_lengths_a[limb_lengths_a < 1e-3] = 1e-3
|
305 |
+
local_scales = limb_lengths_b / limb_lengths_a
|
306 |
+
|
307 |
+
x_ab = scale_limbs(x_a, global_scale=1.0, local_scales=local_scales)
|
308 |
+
|
309 |
+
return x_ab
|
lib/util/visualization.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import math
|
5 |
+
import imageio
|
6 |
+
from tqdm import tqdm
|
7 |
+
from PIL import Image
|
8 |
+
from lib.util.motion import normalize_motion_inv, globalize_motion
|
9 |
+
from lib.util.general import ensure_dir
|
10 |
+
from threading import Thread, Lock
|
11 |
+
|
12 |
+
|
13 |
+
def interpolate_color(color1, color2, alpha):
|
14 |
+
color_i = alpha * np.array(color1) + (1 - alpha) * np.array(color2)
|
15 |
+
return color_i.tolist()
|
16 |
+
|
17 |
+
|
18 |
+
def two_pts_to_rectangle(point1, point2):
|
19 |
+
X = [point1[1], point2[1]]
|
20 |
+
Y = [point1[0], point2[0]]
|
21 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
22 |
+
length = 5
|
23 |
+
alpha = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
24 |
+
beta = alpha - 90
|
25 |
+
if beta <= -180:
|
26 |
+
beta += 360
|
27 |
+
p1 = ( int(point1[0] - length*math.cos(math.radians(beta))) , int(point1[1] - length*math.sin(math.radians(beta))) )
|
28 |
+
p2 = ( int(point1[0] + length*math.cos(math.radians(beta))) , int(point1[1] + length*math.sin(math.radians(beta))) )
|
29 |
+
p3 = ( int(point2[0] + length*math.cos(math.radians(beta))) , int(point2[1] + length*math.sin(math.radians(beta))) )
|
30 |
+
p4 = ( int(point2[0] - length*math.cos(math.radians(beta))) , int(point2[1] - length*math.sin(math.radians(beta))) )
|
31 |
+
return [p1,p2,p3,p4]
|
32 |
+
|
33 |
+
|
34 |
+
def rgb2rgba(color):
|
35 |
+
return (color[0], color[1], color[2], 255)
|
36 |
+
|
37 |
+
|
38 |
+
def hex2rgb(hex, number_of_colors=3):
|
39 |
+
h = hex
|
40 |
+
rgb = []
|
41 |
+
for i in range(number_of_colors):
|
42 |
+
h = h.lstrip('#')
|
43 |
+
hex_color = h[0:6]
|
44 |
+
rgb_color = [int(hex_color[i:i+2], 16) for i in (0, 2 ,4)]
|
45 |
+
rgb.append(rgb_color)
|
46 |
+
h = h[6:]
|
47 |
+
|
48 |
+
return rgb
|
49 |
+
|
50 |
+
def normalize_joints(joints_position, H=512, W=512):
|
51 |
+
# 找出关节坐标的最大值和最小值
|
52 |
+
min_x, min_y = np.min(joints_position, axis=0)
|
53 |
+
max_x, max_y = np.max(joints_position, axis=0)
|
54 |
+
|
55 |
+
# 计算关节坐标的范围
|
56 |
+
range_x, range_y = max_x - min_x, max_y - min_y
|
57 |
+
|
58 |
+
# 设定一个缩放的边界保护值,防止关节坐标在缩放后超出画布
|
59 |
+
buffer = 0.05 # 例如 5% 的边界保护
|
60 |
+
scale_x, scale_y = (1 - buffer) * W / range_x, (1 - buffer) * H / range_y
|
61 |
+
|
62 |
+
# 使用较小的缩放比例来保证所有关节都能适合画布
|
63 |
+
scale = min(scale_x, scale_y)
|
64 |
+
|
65 |
+
# 缩放关节坐标
|
66 |
+
joints_position_scaled = (joints_position - np.array([min_x, min_y])) * scale
|
67 |
+
|
68 |
+
# 计算缩放后关节坐标的新边界
|
69 |
+
new_min_x, new_min_y = np.min(joints_position_scaled, axis=0)
|
70 |
+
new_max_x, new_max_y = np.max(joints_position_scaled, axis=0)
|
71 |
+
|
72 |
+
# 计算平移量,将关节移到画布中心
|
73 |
+
translate_x = (W - (new_max_x - new_min_x)) / 2 - new_min_x
|
74 |
+
translate_y = (H - (new_max_y - new_min_y)) / 2 - new_min_y
|
75 |
+
|
76 |
+
# 平移关节坐标
|
77 |
+
joints_position_normalized = joints_position_scaled + np.array([translate_x, translate_y])
|
78 |
+
|
79 |
+
return joints_position_normalized
|
80 |
+
|
81 |
+
def joints2image(joints_position, colors, transparency=False, H=512, W=512, nr_joints=15, imtype=np.uint8, grayscale=False, bg_color=(255, 255, 255)):
|
82 |
+
nr_joints = joints_position.shape[0]
|
83 |
+
joints_position=normalize_joints(joints_position)
|
84 |
+
if nr_joints == 49: # full joints(49): basic(15) + eyes(2) + toes(2) + hands(30)
|
85 |
+
limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7], \
|
86 |
+
[8, 9], [8, 13], [9, 10], [10, 11], [11, 12], [13, 14], [14, 15], [15, 16],
|
87 |
+
]#[0, 17], [0, 18]] #ignore eyes
|
88 |
+
|
89 |
+
L = rgb2rgba(colors[0]) if transparency else colors[0]
|
90 |
+
M = rgb2rgba(colors[1]) if transparency else colors[1]
|
91 |
+
R = rgb2rgba(colors[2]) if transparency else colors[2]
|
92 |
+
|
93 |
+
colors_joints = [M, M, L, L, L, R, R,
|
94 |
+
R, M, L, L, L, L, R, R, R,
|
95 |
+
R, R, L] + [L] * 15 + [R] * 15
|
96 |
+
|
97 |
+
colors_limbs = [M, L, R, M, L, L, R,
|
98 |
+
R, L, R, L, L, L, R, R, R,
|
99 |
+
R, R]
|
100 |
+
elif nr_joints == 15 or nr_joints == 17: # basic joints(15) + (eyes(2))
|
101 |
+
limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7],
|
102 |
+
[8, 9], [8, 12], [9, 10], [10, 11], [12, 13], [13, 14]]
|
103 |
+
# [0, 15], [0, 16] two eyes are not drawn
|
104 |
+
|
105 |
+
L = rgb2rgba(colors[0]) if transparency else colors[0]
|
106 |
+
M = rgb2rgba(colors[1]) if transparency else colors[1]
|
107 |
+
R = rgb2rgba(colors[2]) if transparency else colors[2]
|
108 |
+
|
109 |
+
colors_joints = [M, M, L, L, L, R, R,
|
110 |
+
R, M, L, L, L, R, R, R]
|
111 |
+
|
112 |
+
colors_limbs = [M, L, R, M, L, L, R,
|
113 |
+
R, L, R, L, L, R, R]
|
114 |
+
else:
|
115 |
+
raise ValueError("Only support number of joints be 49 or 17 or 15")
|
116 |
+
|
117 |
+
if transparency:
|
118 |
+
canvas = np.zeros(shape=(H, W, 4))
|
119 |
+
else:
|
120 |
+
canvas = np.ones(shape=(H, W, 3)) * np.array(bg_color).reshape([1, 1, 3])
|
121 |
+
hips = joints_position[8]
|
122 |
+
neck = joints_position[1]
|
123 |
+
torso_length = ((hips[1] - neck[1]) ** 2 + (hips[0] - neck[0]) ** 2) ** 0.5
|
124 |
+
|
125 |
+
head_radius = int(torso_length/4.5)
|
126 |
+
end_effectors_radius = int(torso_length/15)
|
127 |
+
end_effectors_radius = 7
|
128 |
+
joints_radius = 7
|
129 |
+
# joints_position[0][0]*=200
|
130 |
+
# joints_position[0][1]*=200
|
131 |
+
cv2.circle(canvas, (int(joints_position[0][0]),int(joints_position[0][1])), head_radius, colors_joints[0], thickness=-1)
|
132 |
+
|
133 |
+
for i in range(1, len(colors_joints)):
|
134 |
+
# print(joints_position[i][0])
|
135 |
+
# joints_position[i][0]*=200
|
136 |
+
# joints_position[i][1]*=200
|
137 |
+
# print(joints_position[i][1])
|
138 |
+
if i in (17, 18):
|
139 |
+
continue
|
140 |
+
elif i > 18:
|
141 |
+
radius = 2
|
142 |
+
else:
|
143 |
+
radius = joints_radius
|
144 |
+
cv2.circle(canvas, (int(joints_position[i][0]),int(joints_position[i][1])), radius, colors_joints[i], thickness=-1)
|
145 |
+
|
146 |
+
stickwidth = 2
|
147 |
+
|
148 |
+
for i in range(len(limbSeq)):
|
149 |
+
limb = limbSeq[i]
|
150 |
+
cur_canvas = canvas.copy()
|
151 |
+
point1_index = limb[0]
|
152 |
+
point2_index = limb[1]
|
153 |
+
|
154 |
+
#if len(all_peaks[point1_index]) > 0 and len(all_peaks[point2_index]) > 0:
|
155 |
+
point1 = joints_position[point1_index]
|
156 |
+
point2 = joints_position[point2_index]
|
157 |
+
X = [point1[1], point2[1]]
|
158 |
+
Y = [point1[0], point2[0]]
|
159 |
+
mX = np.mean(X)
|
160 |
+
mY = np.mean(Y)
|
161 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
162 |
+
alpha = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
163 |
+
|
164 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(alpha), 0, 360, 1)
|
165 |
+
cv2.fillConvexPoly(cur_canvas, polygon, colors_limbs[i])
|
166 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
167 |
+
bb = bounding_box(canvas)
|
168 |
+
canvas_cropped = canvas[:,bb[2]:bb[3], :]
|
169 |
+
|
170 |
+
canvas = canvas.astype(imtype)
|
171 |
+
canvas_cropped = canvas_cropped.astype(imtype)
|
172 |
+
|
173 |
+
if grayscale:
|
174 |
+
if transparency:
|
175 |
+
canvas = cv2.cvtColor(canvas, cv2.COLOR_RGBA2GRAY)
|
176 |
+
canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGBA2GRAY)
|
177 |
+
else:
|
178 |
+
canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2GRAY)
|
179 |
+
canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGB2GRAY)
|
180 |
+
|
181 |
+
return [canvas, canvas_cropped]
|
182 |
+
|
183 |
+
|
184 |
+
def joints2image_highlight(joints_position, colors, highlights, transparency=False, H=512, W=512, nr_joints=15, imtype=np.uint8, grayscale=False):
|
185 |
+
nr_joints = joints_position.shape[0]
|
186 |
+
|
187 |
+
limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7],
|
188 |
+
[8, 9], [8, 12], [9, 10], [10, 11], [12, 13], [13, 14]]
|
189 |
+
# [0, 15], [0, 16] two eyes are not drawn
|
190 |
+
|
191 |
+
L = rgb2rgba(colors[0]) if transparency else colors[0]
|
192 |
+
M = rgb2rgba(colors[1]) if transparency else colors[1]
|
193 |
+
R = rgb2rgba(colors[2]) if transparency else colors[2]
|
194 |
+
Hi = rgb2rgba(colors[3]) if transparency else colors[3]
|
195 |
+
|
196 |
+
colors_joints = [M, M, L, L, L, R, R,
|
197 |
+
R, M, L, L, L, R, R, R]
|
198 |
+
|
199 |
+
colors_limbs = [M, L, R, M, L, L, R,
|
200 |
+
R, L, R, L, L, R, R]
|
201 |
+
|
202 |
+
for hi in highlights: colors_limbs[hi] = Hi
|
203 |
+
|
204 |
+
if transparency:
|
205 |
+
canvas = np.zeros(shape=(H, W, 4))
|
206 |
+
else:
|
207 |
+
canvas = np.ones(shape=(H, W, 3)) * 255
|
208 |
+
hips = joints_position[8]
|
209 |
+
neck = joints_position[1]
|
210 |
+
torso_length = ((hips[1] - neck[1]) ** 2 + (hips[0] - neck[0]) ** 2) ** 0.5
|
211 |
+
|
212 |
+
head_radius = int(torso_length/4.5)
|
213 |
+
end_effectors_radius = int(torso_length/15)
|
214 |
+
end_effectors_radius = 7
|
215 |
+
joints_radius = 7
|
216 |
+
|
217 |
+
cv2.circle(canvas, (int(joints_position[0][0]*500),int(joints_position[0][1]*500)), head_radius, colors_joints[0], thickness=-1)
|
218 |
+
|
219 |
+
for i in range(1, len(colors_joints)):
|
220 |
+
if i in (17, 18):
|
221 |
+
continue
|
222 |
+
elif i > 18:
|
223 |
+
radius = 2
|
224 |
+
else:
|
225 |
+
radius = joints_radius
|
226 |
+
cv2.circle(canvas, (int(joints_position[i][0]*500),int(joints_position[i][1]*500)), radius, colors_joints[i], thickness=-1)
|
227 |
+
|
228 |
+
stickwidth = 2
|
229 |
+
|
230 |
+
for i in range(len(limbSeq)):
|
231 |
+
limb = limbSeq[i]
|
232 |
+
cur_canvas = canvas.copy()
|
233 |
+
point1_index = limb[0]
|
234 |
+
point2_index = limb[1]
|
235 |
+
|
236 |
+
#if len(all_peaks[point1_index]) > 0 and len(all_peaks[point2_index]) > 0:
|
237 |
+
point1 = joints_position[point1_index]
|
238 |
+
point2 = joints_position[point2_index]
|
239 |
+
X = [point1[1], point2[1]]
|
240 |
+
Y = [point1[0], point2[0]]
|
241 |
+
mX = np.mean(X)
|
242 |
+
mY = np.mean(Y)
|
243 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
244 |
+
alpha = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
245 |
+
|
246 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(alpha), 0, 360, 1)
|
247 |
+
cv2.fillConvexPoly(cur_canvas, polygon, colors_limbs[i])
|
248 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
249 |
+
bb = bounding_box(canvas)
|
250 |
+
canvas_cropped = canvas[:,bb[2]:bb[3], :]
|
251 |
+
|
252 |
+
canvas = canvas.astype(imtype)
|
253 |
+
canvas_cropped = canvas_cropped.astype(imtype)
|
254 |
+
|
255 |
+
if grayscale:
|
256 |
+
if transparency:
|
257 |
+
canvas = cv2.cvtColor(canvas, cv2.COLOR_RGBA2GRAY)
|
258 |
+
canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGBA2GRAY)
|
259 |
+
else:
|
260 |
+
canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2GRAY)
|
261 |
+
canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGB2GRAY)
|
262 |
+
|
263 |
+
return [canvas, canvas_cropped]
|
264 |
+
|
265 |
+
|
266 |
+
def motion2video(motion, h, w, save_path, colors, bg_color=(255, 255, 255), transparency=False, motion_tgt=None, fps=25, save_frame=True, grayscale=False, show_progress=True):
|
267 |
+
nr_joints = motion.shape[0]
|
268 |
+
as_array = save_path.endswith(".npy")
|
269 |
+
vlen = motion.shape[-1]
|
270 |
+
|
271 |
+
out_array = np.zeros([h, w, vlen]) if as_array else None
|
272 |
+
videowriter = None if as_array else imageio.get_writer(save_path, fps=fps, codec='libx264')
|
273 |
+
|
274 |
+
if save_frame:
|
275 |
+
frames_dir = save_path[:-4] + '-frames'
|
276 |
+
ensure_dir(frames_dir)
|
277 |
+
|
278 |
+
iterator = range(vlen)
|
279 |
+
if show_progress: iterator = tqdm(iterator)
|
280 |
+
for i in iterator:
|
281 |
+
[img, img_cropped] = joints2image(motion[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=grayscale)
|
282 |
+
if motion_tgt is not None:
|
283 |
+
[img_tgt, img_tgt_cropped] = joints2image(motion_tgt[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=grayscale)
|
284 |
+
img_ori = img.copy()
|
285 |
+
img = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
|
286 |
+
img_cropped = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
|
287 |
+
bb = bounding_box(img_cropped)
|
288 |
+
img_cropped = img_cropped[:, bb[2]:bb[3], :]
|
289 |
+
if save_frame:
|
290 |
+
save_image(img_cropped, os.path.join(frames_dir, "%04d.png" % i))
|
291 |
+
if as_array: out_array[:, :, i] = img
|
292 |
+
#else: videowriter.append_data(img)
|
293 |
+
|
294 |
+
if as_array: np.save(save_path, out_array)
|
295 |
+
else: videowriter.close()
|
296 |
+
|
297 |
+
return out_array
|
298 |
+
|
299 |
+
|
300 |
+
def motion2video_np(motion, h, w, colors, bg_color=(255, 255, 255), transparency=False, motion_tgt=None, show_progress=True, workers=6):
|
301 |
+
|
302 |
+
nr_joints = motion.shape[0]
|
303 |
+
vlen = motion.shape[-1]
|
304 |
+
out_array = np.zeros([vlen, h, w , 3])
|
305 |
+
|
306 |
+
queue = [i for i in range(vlen)]
|
307 |
+
lock = Lock()
|
308 |
+
pbar = tqdm(total=vlen) if show_progress else None
|
309 |
+
|
310 |
+
class Worker(Thread):
|
311 |
+
|
312 |
+
def __init__(self):
|
313 |
+
super(Worker, self).__init__()
|
314 |
+
|
315 |
+
def run(self):
|
316 |
+
while True:
|
317 |
+
lock.acquire()
|
318 |
+
if len(queue) == 0:
|
319 |
+
lock.release()
|
320 |
+
break
|
321 |
+
else:
|
322 |
+
i = queue.pop(0)
|
323 |
+
lock.release()
|
324 |
+
[img, img_cropped] = joints2image(motion[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=False)
|
325 |
+
if motion_tgt is not None:
|
326 |
+
[img_tgt, img_tgt_cropped] = joints2image(motion_tgt[:, :, i], colors, transparency=transparency, H=h, W=w, nr_joints=nr_joints, grayscale=False)
|
327 |
+
img_ori = img.copy()
|
328 |
+
img = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
|
329 |
+
# img_cropped = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
|
330 |
+
# bb = bounding_box(img_cropped)
|
331 |
+
# img_cropped = img_cropped[:, bb[2]:bb[3], :]
|
332 |
+
out_array[i, :, :] = img
|
333 |
+
if show_progress: pbar.update(1)
|
334 |
+
|
335 |
+
pool = [Worker() for _ in range(workers)]
|
336 |
+
for worker in pool: worker.start()
|
337 |
+
for worker in pool: worker.join()
|
338 |
+
for worker in pool: del worker
|
339 |
+
|
340 |
+
return out_array
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
def save_image(image_numpy, image_path):
|
345 |
+
image_pil = Image.fromarray(image_numpy)
|
346 |
+
image_pil.save(image_path)
|
347 |
+
|
348 |
+
|
349 |
+
def bounding_box(img):
|
350 |
+
a = np.where(img != 0)
|
351 |
+
bbox = np.min(a[0]), np.max(a[0]), np.min(a[1]), np.max(a[1])
|
352 |
+
return bbox
|
353 |
+
|
354 |
+
|
355 |
+
def pose2im_all(all_peaks, H=512, W=512):
|
356 |
+
limbSeq = [[1, 2], [2, 3], [3, 4], # right arm
|
357 |
+
[1, 5], [5, 6], [6, 7], # left arm
|
358 |
+
[8, 9], [9, 10], [10, 11], # right leg
|
359 |
+
[8, 12], [12, 13], [13, 14], # left leg
|
360 |
+
[1, 0], # head/neck
|
361 |
+
[1, 8], # body,
|
362 |
+
]
|
363 |
+
|
364 |
+
limb_colors = [[0, 60, 255], [0, 120, 255], [0, 180, 255],
|
365 |
+
[180, 255, 0], [120, 255, 0], [60, 255, 0],
|
366 |
+
[170, 255, 0], [85, 255, 0], [0, 255, 0],
|
367 |
+
[255, 170, 0], [255, 85, 0], [255, 0, 0],
|
368 |
+
[0, 85, 255],
|
369 |
+
[0, 0, 255],
|
370 |
+
]
|
371 |
+
|
372 |
+
joint_colors = [[85, 0, 255], [0, 0, 255], [0, 60, 255], [0, 120, 255], [0, 180, 255],
|
373 |
+
[180, 255, 0], [120, 255, 0], [60, 255, 0], [0, 0, 255],
|
374 |
+
[170, 255, 0], [85, 255, 0], [0, 255, 0],
|
375 |
+
[255, 170, 0], [255, 85, 0], [255, 0, 0],
|
376 |
+
]
|
377 |
+
|
378 |
+
image = pose2im(all_peaks, limbSeq, limb_colors, joint_colors, H, W)
|
379 |
+
return image
|
380 |
+
|
381 |
+
|
382 |
+
def pose2im(all_peaks, limbSeq, limb_colors, joint_colors, H, W, _circle=True, _limb=True, imtype=np.uint8):
|
383 |
+
canvas = np.zeros(shape=(H, W, 3))
|
384 |
+
canvas.fill(255)
|
385 |
+
|
386 |
+
if _circle:
|
387 |
+
for i in range(len(joint_colors)):
|
388 |
+
cv2.circle(canvas, (int(all_peaks[i][0]), int(all_peaks[i][1])), 2, joint_colors[i], thickness=2)
|
389 |
+
|
390 |
+
if _limb:
|
391 |
+
stickwidth = 2
|
392 |
+
|
393 |
+
for i in range(len(limbSeq)):
|
394 |
+
limb = limbSeq[i]
|
395 |
+
cur_canvas = canvas.copy()
|
396 |
+
point1_index = limb[0]
|
397 |
+
point2_index = limb[1]
|
398 |
+
|
399 |
+
if len(all_peaks[point1_index]) > 0 and len(all_peaks[point2_index]) > 0:
|
400 |
+
point1 = all_peaks[point1_index][0:2]
|
401 |
+
point2 = all_peaks[point2_index][0:2]
|
402 |
+
X = [point1[1], point2[1]]
|
403 |
+
Y = [point1[0], point2[0]]
|
404 |
+
mX = np.mean(X)
|
405 |
+
mY = np.mean(Y)
|
406 |
+
# cv2.line()
|
407 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
408 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
409 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
410 |
+
cv2.fillConvexPoly(cur_canvas, polygon, limb_colors[i])
|
411 |
+
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
412 |
+
|
413 |
+
return canvas.astype(imtype)
|
414 |
+
|
415 |
+
|
416 |
+
def visualize_motion_in_training(outputs, mean_pose, std_pose, nr_visual=4, H=512, W=512):
|
417 |
+
ret = {}
|
418 |
+
for k, out in outputs.items():
|
419 |
+
motion = out[0].detach().cpu().numpy()
|
420 |
+
inds = np.linspace(0, motion.shape[1] - 1, nr_visual, dtype=int)
|
421 |
+
motion = motion[:, inds]
|
422 |
+
motion = motion.reshape(-1, 2, motion.shape[-1])
|
423 |
+
motion = normalize_motion_inv(motion, mean_pose, std_pose)
|
424 |
+
peaks = globalize_motion(motion)
|
425 |
+
|
426 |
+
heatmaps = []
|
427 |
+
for i in range(peaks.shape[2]):
|
428 |
+
skeleton = pose2im_all(peaks[:, :, i], H, W)
|
429 |
+
heatmaps.append(skeleton)
|
430 |
+
heatmaps = np.stack(heatmaps).transpose((0, 3, 1, 2)) / 255.0
|
431 |
+
ret[k] = heatmaps
|
432 |
+
|
433 |
+
return ret
|
434 |
+
|
435 |
+
if __name__ == '__main__':
|
436 |
+
# 加载.npy文件
|
437 |
+
motion_data = np.load('/home/fazhong/studio/transmomo.pytorch/out/retarget_1_121.npy')
|
438 |
+
|
439 |
+
# 设置视频参数
|
440 |
+
height = 512 # 视频的高度
|
441 |
+
width = 512 # 视频的宽度
|
442 |
+
save_path = 'Angry.mp4' # 保存视频的路径
|
443 |
+
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] # 关节颜色
|
444 |
+
bg_color = (255, 255, 255) # 背景颜色
|
445 |
+
fps = 25 # 视频的帧率
|
446 |
+
|
447 |
+
# 调用函数生成视频
|
448 |
+
motion2video(motion_data, height, width, save_path, colors, bg_color=bg_color, transparency=False, fps=fps)
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
fastapi
|
3 |
+
aiohttp
|
4 |
+
easydict
|
5 |
+
imageio-ffmpeg
|
6 |
+
matplotlib
|
7 |
+
numpy
|
8 |
+
Pillow
|
9 |
+
protobuf
|
10 |
+
PyYAML
|
11 |
+
scikit-image
|
12 |
+
scikit-learn
|
13 |
+
scipy
|
14 |
+
tensorboardX
|
15 |
+
torch>=1.2.0
|
16 |
+
torchvision
|
17 |
+
tqdm
|