Fazhong Liu commited on
Commit
7ca9b42
1 Parent(s): a4660a7
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # *.7z filter=lfs diff=lfs merge=lfs -text
2
+ # *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ # *.bin filter=lfs diff=lfs merge=lfs -text
4
+ # *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ # *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ # *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ # *.gz filter=lfs diff=lfs merge=lfs -text
8
+ # *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ # *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ # *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ # *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ # *.model filter=lfs diff=lfs merge=lfs -text
13
+ # *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ # *.npy filter=lfs diff=lfs merge=lfs -text
15
+ # *.npz filter=lfs diff=lfs merge=lfs -text
16
+ # *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ # *.ot filter=lfs diff=lfs merge=lfs -text
18
+ # *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ # *.pb filter=lfs diff=lfs merge=lfs -text
20
+ # *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ # *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ # *.pt filter=lfs diff=lfs merge=lfs -text
23
+ # *.pth filter=lfs diff=lfs merge=lfs -text
24
+ # *.rar filter=lfs diff=lfs merge=lfs -text
25
+ # *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ # saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ # *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ # *.tar filter=lfs diff=lfs merge=lfs -text
29
+ # *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ # *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ # *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ # *.xz filter=lfs diff=lfs merge=lfs -text
33
+ # *.zip filter=lfs diff=lfs merge=lfs -text
34
+ # *.zst filter=lfs diff=lfs merge=lfs -text
35
+ # *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pt
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Momo
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.24.0
8
  app_file: app.py
 
1
  ---
2
+ title: Transmomo
3
+ emoji: 📈
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.24.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import shutil
3
+ import gradio as gr
4
+ import os
5
+ import json
6
+ import torch
7
+ import argparse
8
+ import numpy as np
9
+ from lib.data import get_meanpose
10
+ from lib.network import get_autoencoder
11
+ from lib.util.motion import preprocess_mixamo, preprocess_test, postprocess
12
+ from lib.util.general import get_config
13
+ from lib.operation import rotate_and_maybe_project_world
14
+ from itertools import combinations
15
+ from lib.util.visualization import motion2video
16
+ from PIL import Image
17
+
18
+ def load_and_preprocess(path, config, mean_pose, std_pose):
19
+
20
+ motion3d = np.load(path)
21
+
22
+ # length must be multiples of 8 due to the size of convolution
23
+ _, _, T = motion3d.shape
24
+ T = (T // 8) * 8
25
+ motion3d = motion3d[:, :, :T]
26
+
27
+ # project to 2d
28
+ motion_proj = motion3d[:, [0, 2], :]
29
+
30
+ # reformat for mixamo data
31
+ motion_proj = preprocess_mixamo(motion_proj, unit=1.0)
32
+
33
+ # preprocess for network input
34
+ motion_proj, start = preprocess_test(motion_proj, mean_pose, std_pose, config.data.unit)
35
+ motion_proj = motion_proj.reshape((-1, motion_proj.shape[-1]))
36
+ motion_proj = torch.from_numpy(motion_proj).float()
37
+
38
+ return motion_proj, start
39
+
40
+ def handle_motion_generation(npy1,npy2):
41
+ path1 = './data/a.npy'
42
+ path2 = './data/b.npy'
43
+ np.save(path1,npy1)
44
+ np.save(path2,npy2)
45
+ config_path = './configs/transmomo.yaml' # 替换为您的配置文件路径
46
+ description_path = "./data/mse_description.json"
47
+ checkpoint_path = './data/autoencoder_00200000.pt'
48
+ out_dir_path = './output' # 替换为输出目录的路径
49
+
50
+ config = get_config(config_path)
51
+ ae = get_autoencoder(config)
52
+ ae.load_state_dict(torch.load(checkpoint_path))
53
+ ae.cuda()
54
+ ae.eval()
55
+ mean_pose, std_pose = get_meanpose("test", config.data)
56
+ # print("loaded model")
57
+
58
+ description = json.load(open(description_path))
59
+ chars = list(description.keys())
60
+
61
+ os.makedirs(out_dir_path, exist_ok=True)
62
+
63
+ # path1 = '/home/fazhong/studio/transmomo.pytorch/data/mixamo/36_800_24/test/PUMPKINHULK_L/Back_Squat/motions/2.npy'
64
+ # path2 = '/home/fazhong/studio/transmomo.pytorch/data/mixamo/36_800_24/test/PUMPKINHULK_L/Golf_Post_Shot/motions/3.npy'
65
+ out_path1 = os.path.join(out_dir_path, "adv.npy")
66
+
67
+
68
+ x_a, x_a_start = load_and_preprocess(path1, config, mean_pose, std_pose)
69
+ x_b, x_b_start = load_and_preprocess(path2, config, mean_pose, std_pose)
70
+
71
+ x_a_batch = x_a.unsqueeze(0).cuda()
72
+ x_b_batch = x_b.unsqueeze(0).cuda()
73
+
74
+ x_ab = ae.cross2d(x_a_batch, x_b_batch, x_a_batch)
75
+ x_ab = postprocess(x_ab, mean_pose, std_pose, config.data.unit, start=x_a_start)
76
+
77
+ np.save(out_path1, x_ab)
78
+ motion_data = x_ab
79
+ height = 512 # 视频的高度
80
+ width = 512 # 视频的宽度
81
+ save_path = './an.mp4' # 保存视频的路径
82
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] # 关节颜色
83
+ bg_color = (255, 255, 255) # 背景颜色
84
+ fps = 25 # 视频的帧率
85
+
86
+ # print(motion_data.shape)
87
+ # 调用函数生成视频
88
+ motion2video(motion_data, height, width, save_path, colors, bg_color=bg_color, transparency=False, fps=fps)
89
+ first_frame_image = Image.open('./an-frames/0000.png')
90
+ return first_frame_image
91
+ # print('hi')
92
+
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown("Upload two `.npy` files to generate motion and visualize the first frame of the output animation.")
95
+
96
+ with gr.Row():
97
+ file1 = gr.File(file_types=[".npy"], label="Upload first .npy file")
98
+ file2 = gr.File(file_types=[".npy"], label="Upload second .npy file")
99
+
100
+ with gr.Row():
101
+ generate_btn = gr.Button("Generate Motion")
102
+
103
+ output_image = gr.Image(label="First Frame of the Generated Animation")
104
+
105
+ generate_btn.click(
106
+ fn=handle_motion_generation,
107
+ inputs=[file1, file2],
108
+ outputs=output_image
109
+ )
110
+
111
+
112
+ if __name__ == "__main__":
113
+ # tsp_page.launch(debug = True)
114
+ demo.launch()
configs/transmomo.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer: TransmomoTrainer
2
+ K: 3
3
+ rotation_axes: &rotation_axes [0, 0, 1] # horizontal, depth, vertical
4
+ body_reference: &body_reference True # if set True, will use spine as vertical axis
5
+
6
+ # model options
7
+ n_joints: 15 # number of body joints
8
+ seq_len: 64 # length of motion sequence
9
+
10
+ # logger options
11
+ snapshot_save_iter: 20000
12
+ log_iter: 40
13
+ val_iter: 400
14
+ val_batches: 10
15
+
16
+ # optimization options
17
+ max_iter: 200000 # maximum number of training iterations
18
+ batch_size: 64 # batch size
19
+ weight_decay: 0.0001 # weight decay
20
+ beta1: 0.5 # Adam parameter
21
+ beta2: 0.999 # Adam parameter
22
+ init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
23
+ lr: 0.0002 # initial learning rate
24
+ lr_policy: step # learning rate scheduler
25
+ step_size: 20000 # how often to decay learning rate
26
+ gamma: 0.5 # how much to decay learning rate
27
+
28
+ trans_gan_w: 2 # weight of GAN loss
29
+ trans_gan_ls_w: 0 # if set > 0, will treat limb-scaled data as "real" data
30
+ recon_x_w: 10 # weight of reconstruction loss
31
+ cross_x_w: 4 # weight of cross reconstruction loss
32
+ inv_v_ls_w: 2 # weight of view invariance loss against limb scale
33
+ inv_m_ls_w: 2 # weight of motion invariance loss against limb scale
34
+ inv_b_trans_w: 2 # weight of body invariance loss against rotation
35
+ inv_m_trans_w: 2 # weight of motion invariance loss against rotation
36
+
37
+ triplet_b_w: 10 # weight of body triplet loss
38
+ triplet_v_w: 10 # weight of view triplet loss
39
+ triplet_margin: 0.2 # triplet loss: margin
40
+ triplet_neg_range: [0.0, 0.5] # triplet loss: range of negative examples
41
+
42
+ # network options
43
+ autoencoder:
44
+ cls: Autoencoder3f
45
+ body_reference: *body_reference
46
+ motion_encoder:
47
+ cls: ConvEncoder
48
+ channels: [30, 64, 128, 128]
49
+ padding: 3
50
+ kernel_size: 8
51
+ conv_stride: 2
52
+ conv_pool: null
53
+ body_encoder:
54
+ cls: ConvEncoder
55
+ channels: [28, 64, 128, 256]
56
+ padding: 2
57
+ kernel_size: 7
58
+ conv_stride: 1
59
+ conv_pool: AvgPool1d
60
+ global_pool: avg_pool1d
61
+ view_encoder:
62
+ cls: ConvEncoder
63
+ channels: [28, 64, 32, 8]
64
+ padding: 2
65
+ kernel_size: 7
66
+ conv_stride: 1
67
+ conv_pool: MaxPool1d
68
+ global_pool: max_pool1d
69
+ decoder:
70
+ channels: [392, 256, 128, 45]
71
+ kernel_size: 7
72
+
73
+ discriminator:
74
+ encoder_cls: ConvEncoder
75
+ gan_type: lsgan
76
+ channels: [30, 64, 96, 128]
77
+ padding: 3
78
+ kernel_size: 8
79
+ conv_stride: 2
80
+ conv_pool: null
81
+
82
+ body_discriminator:
83
+ gan_type: lsgan
84
+ channels: [512, 128, 32]
85
+
86
+ # data options
87
+ data:
88
+ train_cls: MixamoLimbScaleDataset
89
+ eval_cls: MixamoDataset
90
+ global_range: [0.5, 2.0] # limb scale: range of gamma_g
91
+ local_range: [0.5, 2.0] # limb scale: range of the gammas
92
+ rotation_axes: *rotation_axes
93
+ unit: 128
94
+ # train_dir: ./data/mixamo/36_800_24/train
95
+ # test_dir: ./data/mixamo/36_800_24/test
96
+ num_workers: 4
97
+ train_meanpose_path: ./data/meanpose_with_view.npy
98
+ train_stdpose_path: ./data/stdpose_with_view.npy
99
+ test_meanpose_path: ./data/meanpose_with_view.npy
100
+ test_stdpose_path: ./data/stdpose_with_view.npy
configs/transmomo_solo_dance.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer: TransmomoTrainer
2
+ K: 3
3
+ rotation_axes: &rotation_axes [0, 0, 1] # horizontal, depth, vertical
4
+ body_reference: &body_reference True # if set True, will use spine as vertical axis
5
+
6
+ # model options
7
+ n_joints: 15 # number of body joints
8
+ seq_len: 64 # length of motion sequence
9
+
10
+ # logger options
11
+ snapshot_save_iter: 20000
12
+ log_iter: 40
13
+ val_iter: 400
14
+ val_batches: 10
15
+
16
+ # optimization options
17
+ max_iter: 200000 # maximum number of training iterations
18
+ batch_size: 64 # batch size
19
+ weight_decay: 0.0001 # weight decay
20
+ beta1: 0.5 # Adam parameter
21
+ beta2: 0.999 # Adam parameter
22
+ init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
23
+ lr: 0.0002 # initial learning rate
24
+ lr_policy: step # learning rate scheduler
25
+ step_size: 20000 # how often to decay learning rate
26
+ gamma: 0.5 # how much to decay learning rate
27
+
28
+ trans_gan_w: 2 # weight of GAN loss
29
+ trans_gan_ls_w: 0 # if set > 0, will treat limb-scaled data as "real" data
30
+ recon_x_w: 10 # weight of reconstruction loss
31
+ cross_x_w: 4 # weight of cross reconstruction loss
32
+ inv_v_ls_w: 2 # weight of view invariance loss against limb scale
33
+ inv_m_ls_w: 2 # weight of motion invariance loss against limb scale
34
+ inv_b_trans_w: 2 # weight of body invariance loss against rotation
35
+ inv_m_trans_w: 2 # weight of motion invariance loss against rotation
36
+
37
+ triplet_b_w: 10 # weight of body triplet loss
38
+ triplet_v_w: 10 # weight of view triplet loss
39
+ triplet_margin: 0.2 # triplet loss: margin
40
+ triplet_neg_range: [0.0, 0.5] # triplet loss: range of negative examples
41
+
42
+ # network options
43
+ autoencoder:
44
+ cls: Autoencoder3f
45
+ body_reference: *body_reference
46
+ motion_encoder:
47
+ cls: ConvEncoder
48
+ channels: [30, 64, 128, 128]
49
+ padding: 3
50
+ kernel_size: 8
51
+ conv_stride: 2
52
+ conv_pool: null
53
+ body_encoder:
54
+ cls: ConvEncoder
55
+ channels: [28, 64, 128, 256]
56
+ padding: 2
57
+ kernel_size: 7
58
+ conv_stride: 1
59
+ conv_pool: AvgPool1d
60
+ global_pool: avg_pool1d
61
+ view_encoder:
62
+ cls: ConvEncoder
63
+ channels: [28, 64, 32, 8]
64
+ padding: 2
65
+ kernel_size: 7
66
+ conv_stride: 1
67
+ conv_pool: MaxPool1d
68
+ global_pool: max_pool1d
69
+ decoder:
70
+ channels: [392, 256, 128, 45]
71
+ kernel_size: 7
72
+
73
+ discriminator:
74
+ encoder_cls: ConvEncoder
75
+ gan_type: lsgan
76
+ channels: [30, 64, 96, 128]
77
+ padding: 3
78
+ kernel_size: 8
79
+ conv_stride: 2
80
+ conv_pool: null
81
+
82
+ body_discriminator:
83
+ gan_type: lsgan
84
+ channels: [512, 128, 32]
85
+
86
+ # data options
87
+ data:
88
+ train_cls: SoloDanceDataset
89
+ eval_cls: MixamoDataset
90
+ global_range: [0.5, 2.0] # limb scale: range of gamma_g
91
+ local_range: [0.5, 2.0] # limb scale: range of the gammas
92
+ rotation_axes: *rotation_axes
93
+ unit: 128
94
+ train_dir: ./data/solo_dance/train
95
+ test_dir: ./data/mixamo/36_800_24/test
96
+ num_workers: 4
97
+ train_meanpose_path: ./data/mixamo/36_800_24/meanpose_with_view.npy
98
+ train_stdpose_path: ./data/mixamo/36_800_24/stdpose_with_view.npy
99
+ test_meanpose_path: ./data/mixamo/36_800_24/meanpose_with_view.npy
100
+ test_stdpose_path: ./data/mixamo/36_800_24/stdpose_with_view.npy
data/meanpose_with_view.npy ADDED
Binary file (488 Bytes). View file
 
data/mse_description.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"ANDROMEDA": ["Goalkeeper_Directing_(1)", "Standing_Aim_Idle_02_Looking", "Pilot_Flips_Switches_(1)", "Running_Tired", "Slide_Hip_Hop_Dance", "Military_Signaling_(3)", "Aim_Pistol", "Standing_Idle_(1)", "Drinking_Fountain", "Golf_Putt_Failure", "Standing_Torch_Burn_Webs", "Back_Squat", "Baseball_Pitching", "Dancing_Twerk", "Superhuman_Choke_Lift", "Zombie_Crawl"], "PUMPKINHULK_L": ["Front_Raises", "Rifle_Idle", "Defender", "Talking_Phone_Pacing", "Sitting_Clap_(2)", "Standing_Clap", "Samba_Dancing_(6)", "Golf_Bad_Shot_(1)", "Golf_Putt_Victory", "Grab_Rifle_And_Put_Back", "Salsa_Dancing_(4)", "Military_Signaling_(2)", "Robot_Hip_Hop_Dance", "Speedbag", "Jog_In_Circle", "Looking_Around"], "SPORTY_GRANY": ["Standing_Torch_Idle_04", "Look_Around_(1)", "Happy", "Standing_Torch_Inspect_Downward", "Quarterback_Pass", "Zombie_Stand_Up_(2)", "Struck_In_Head", "Shooting_Gun", "Zombie_Transition", "Hostage_Situation_Idle_-_Hostage", "Jazz_Dancing_(2)", "Samba_Dancing_(5)", "Knocked_Out_(1)", "Being_Electrocuted", "Falling_From_Losing_Balance", "Pulling_A_Rope"], "TY": ["Golf_Pre-Putt_(1)", "Back_Flip_To_Uppercut", "Standing_Torch_Idle_03", "Tonic_Seizure", "Talking_At_Watercooler", "Sitting_Clap", "Shuffling", "Tender_Placement_(1)", "Helping_Out", "Northern_Soul_Spin_Combo", "Golf_Post_Shot", "Salsa_Dancing_(3)", "Sword_And_Shield_Idle_(1)", "Hip_Hop_Dancing_(1)", "Golf_Tee_Up_(1)"]}
data/stdpose_with_view.npy ADDED
Binary file (488 Bytes). View file
 
lib/__init__.py ADDED
File without changes
lib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (145 Bytes). View file
 
lib/__pycache__/data.cpython-38.pyc ADDED
Binary file (12.6 kB). View file
 
lib/__pycache__/network.cpython-38.pyc ADDED
Binary file (9.88 kB). View file
 
lib/__pycache__/operation.cpython-38.pyc ADDED
Binary file (6.16 kB). View file
 
lib/data.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ thismodule = sys.modules[__name__]
3
+
4
+ from lib.util.motion import preprocess_mixamo, rotate_motion_3d, limb_scale_motion_2d, normalize_motion, get_change_of_basis, localize_motion, scale_limbs
5
+
6
+ import torch
7
+ import glob
8
+ import numpy as np
9
+ import random
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from easydict import EasyDict as edict
12
+ from tqdm import tqdm
13
+
14
+ view_angles = np.array([ i * np.pi / 6 for i in range(-3, 4)])
15
+
16
+ def get_dataloader(phase, config):
17
+
18
+ config.data.batch_size = config.batch_size
19
+ config.data.seq_len = config.seq_len
20
+ dataset_cls_name = config.data.train_cls if phase == 'train' else config.data.eval_cls
21
+ dataset_cls = getattr(thismodule, dataset_cls_name)
22
+ dataset = dataset_cls(phase, config.data)
23
+
24
+ dataloader = DataLoader(dataset, shuffle=(phase=='train'),
25
+ batch_size=config.batch_size,
26
+ num_workers=(config.data.num_workers if phase == 'train' else 1),
27
+ worker_init_fn=lambda _: np.random.seed(),
28
+ drop_last=True)
29
+
30
+ return dataloader
31
+
32
+
33
+ class _MixamoDatasetBase(Dataset):
34
+ def __init__(self, phase, config):
35
+ super(_MixamoDatasetBase, self).__init__()
36
+
37
+ assert phase in ['train', 'test']
38
+ self.phase = phase
39
+ self.data_root = config.train_dir if phase=='train' else config.test_dir
40
+ self.meanpose_path = config.train_meanpose_path if phase=='train' else config.test_meanpose_path
41
+ self.stdpose_path = config.train_stdpose_path if phase=='train' else config.test_stdpose_path
42
+ self.unit = config.unit
43
+ self.aug = (phase == 'train')
44
+ self.character_names = sorted(os.listdir(self.data_root))
45
+
46
+ items = glob.glob(os.path.join(self.data_root, self.character_names[0], '*/motions/*.npy'))
47
+ self.motion_names = ['/'.join(x.split('/')[-3:]) for x in items]
48
+
49
+ self.meanpose, self.stdpose = get_meanpose(phase, config)
50
+ self.meanpose = self.meanpose.astype(np.float32)
51
+ self.stdpose = self.stdpose.astype(np.float32)
52
+
53
+ if 'preload' in config and config.preload:
54
+ self.preload()
55
+ self.cached = True
56
+ else:
57
+ self.cached = False
58
+
59
+ def build_item(self, mot_name, char_name):
60
+ """
61
+ :param mot_name: animation_name/motions/xxx.npy
62
+ :param char_name: character_name
63
+ :return:
64
+ """
65
+ return os.path.join(self.data_root, char_name, mot_name)
66
+
67
+ def load_item(self, item):
68
+ if self.cached:
69
+ data = self.cache[item]
70
+ else:
71
+ data = np.load(item)
72
+ return data
73
+
74
+ def preload(self):
75
+ print("pre-loading into memory")
76
+ pbar = tqdm(total=len(self))
77
+ self.cache = {}
78
+ for motion_name in self.motion_names:
79
+ for character_name in self.character_names:
80
+ item = self.build_item(motion_name, character_name)
81
+ motion3d = np.load(item)
82
+ self.cache[item] = motion3d
83
+ pbar.update(1)
84
+
85
+ @staticmethod
86
+ def gen_aug_params(rotate=False):
87
+ if rotate:
88
+ params = {'ratio': np.random.uniform(0.8, 1.2),
89
+ 'roll': np.random.uniform((-np.pi / 9, -np.pi / 9, -np.pi / 6), (np.pi / 9, np.pi / 9, np.pi / 6))}
90
+ else:
91
+ params = {'ratio': np.random.uniform(0.5, 1.5)}
92
+ return edict(params)
93
+
94
+ @staticmethod
95
+ def augmentation(data, params=None):
96
+ """
97
+ :param data: numpy array of size (joints, 3, len_frames)
98
+ :return:
99
+ """
100
+ if params is None:
101
+ return data, params
102
+
103
+ # rotate
104
+ if 'roll' in params.keys():
105
+ cx, cy, cz = np.cos(params.roll)
106
+ sx, sy, sz = np.sin(params.roll)
107
+ mat33_x = np.array([
108
+ [1, 0, 0],
109
+ [0, cx, -sx],
110
+ [0, sx, cx]
111
+ ], dtype='float')
112
+ mat33_y = np.array([
113
+ [cy, 0, sy],
114
+ [0, 1, 0],
115
+ [-sy, 0, cy]
116
+ ], dtype='float')
117
+ mat33_z = np.array([
118
+ [cz, -sz, 0],
119
+ [sz, cz, 0],
120
+ [0, 0, 1]
121
+ ], dtype='float')
122
+ data = mat33_x @ mat33_y @ mat33_z @ data
123
+
124
+ # scale
125
+ if 'ratio' in params.keys():
126
+ data = data * params.ratio
127
+
128
+ return data, params
129
+
130
+ def __getitem__(self, index):
131
+ raise NotImplementedError
132
+
133
+ def __len__(self):
134
+ return len(self.motion_names) * len(self.character_names)
135
+
136
+
137
+ def get_meanpose(phase, config):
138
+
139
+ meanpose_path = config.train_meanpose_path if phase == "train" else config.test_meanpose_path
140
+ stdpose_path = config.train_stdpose_path if phase == "train" else config.test_stdpose_path
141
+
142
+ if os.path.exists(meanpose_path) and os.path.exists(stdpose_path):
143
+ meanpose = np.load(meanpose_path)
144
+ stdpose = np.load(stdpose_path)
145
+ else:
146
+ meanpose, stdpose = gen_meanpose(phase, config)
147
+ np.save(meanpose_path, meanpose)
148
+ np.save(stdpose_path, stdpose)
149
+ print("meanpose saved at {}".format(meanpose_path))
150
+ print("stdpose saved at {}".format(stdpose_path))
151
+
152
+ if meanpose.shape[-1] == 2:
153
+ mean_x, mean_y = meanpose[:, 0], meanpose[:, 1]
154
+ meanpose = np.stack([mean_x, mean_x, mean_y], axis=1)
155
+
156
+ if stdpose.shape[-1] == 2:
157
+ std_x, std_y = stdpose[:, 0], stdpose[:, 1]
158
+ stdpose = np.stack([std_x, std_x, std_y], axis=1)
159
+
160
+ return meanpose, stdpose
161
+
162
+
163
+ def gen_meanpose(phase, config, n_samp=20000):
164
+
165
+ data_dir = config.train_dir if phase == "train" else config.test_dir
166
+ all_paths = glob.glob(os.path.join(data_dir, '*/*/motions/*.npy'))
167
+ random.shuffle(all_paths)
168
+ all_paths = all_paths[:n_samp]
169
+ all_joints = []
170
+
171
+ print("computing meanpose and stdpose")
172
+
173
+ for path in tqdm(all_paths):
174
+ motion = np.load(path)
175
+ if motion.shape[1] == 3:
176
+ basis = None
177
+ if sum(config.rotation_axes) > 0:
178
+ x_angles = view_angles if config.rotation_axes[0] else np.array([0])
179
+ z_angles = view_angles if config.rotation_axes[1] else np.array([0])
180
+ y_angles = view_angles if config.rotation_axes[2] else np.array([0])
181
+ x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
182
+ angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
183
+ i = np.random.choice(len(angles))
184
+ basis = get_change_of_basis(motion, angles[i])
185
+ motion = preprocess_mixamo(motion)
186
+ motion = rotate_motion_3d(motion, basis)
187
+ motion = localize_motion(motion)
188
+ all_joints.append(motion)
189
+ else:
190
+ motion = preprocess_mixamo(motion)
191
+ motion = rotate_motion_3d(motion, basis)
192
+ motion = localize_motion(motion)
193
+ all_joints.append(motion)
194
+ else:
195
+ motion = motion * 128
196
+ motion_proj = localize_motion(motion)
197
+ all_joints.append(motion_proj)
198
+
199
+ all_joints = np.concatenate(all_joints, axis=2)
200
+
201
+ meanpose = np.mean(all_joints, axis=2)
202
+ stdpose = np.std(all_joints, axis=2)
203
+ stdpose[np.where(stdpose == 0)] = 1e-9
204
+
205
+ return meanpose, stdpose
206
+
207
+
208
+ class MixamoDataset(_MixamoDatasetBase):
209
+
210
+ def __init__(self, phase, config):
211
+ super(MixamoDataset, self).__init__(phase, config)
212
+ x_angles = view_angles if config.rotation_axes[0] else np.array([0])
213
+ z_angles = view_angles if config.rotation_axes[1] else np.array([0])
214
+ y_angles = view_angles if config.rotation_axes[2] else np.array([0])
215
+ x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
216
+ angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
217
+ self.view_angles = angles
218
+
219
+ def preprocessing(self, motion3d, view_angle=None, params=None):
220
+ """
221
+ :param item: filename built from self.build_tiem
222
+ :return:
223
+ """
224
+
225
+ if self.aug: motion3d, params = self.augmentation(motion3d, params)
226
+
227
+ basis = None
228
+ if view_angle is not None: basis = get_change_of_basis(motion3d, view_angle)
229
+
230
+ motion3d = preprocess_mixamo(motion3d)
231
+ motion3d = rotate_motion_3d(motion3d, basis)
232
+ motion3d = localize_motion(motion3d)
233
+ motion3d = normalize_motion(motion3d, self.meanpose, self.stdpose)
234
+
235
+ motion2d = motion3d[:, [0, 2], :]
236
+
237
+ motion3d = motion3d.reshape([-1, motion3d.shape[-1]])
238
+ motion2d = motion2d.reshape([-1, motion2d.shape[-1]])
239
+
240
+ motion3d = torch.from_numpy(motion3d).float()
241
+ motion2d = torch.from_numpy(motion2d).float()
242
+
243
+ return motion3d, motion2d
244
+
245
+ def __getitem__(self, index):
246
+ # select two motions
247
+ idx_a, idx_b = np.random.choice(len(self.motion_names), size=2, replace=False)
248
+ mot_a, mot_b = self.motion_names[idx_a], self.motion_names[idx_b]
249
+ # select two characters
250
+ idx_a, idx_b = np.random.choice(len(self.character_names), size=2, replace=False)
251
+ char_a, char_b = self.character_names[idx_a], self.character_names[idx_b]
252
+ idx_a, idx_b = np.random.choice(len(self.view_angles), size=2, replace=False)
253
+ view_a, view_b = self.view_angles[idx_a], self.view_angles[idx_b]
254
+
255
+ if self.aug:
256
+ param_a = self.gen_aug_params(rotate=False)
257
+ param_b = self.gen_aug_params(rotate=False)
258
+ else:
259
+ param_a = param_b = None
260
+
261
+ item_a = self.load_item(self.build_item(mot_a, char_a))
262
+ item_b = self.load_item(self.build_item(mot_b, char_b))
263
+ item_ab = self.load_item(self.build_item(mot_a, char_b))
264
+ item_ba = self.load_item(self.build_item(mot_b, char_a))
265
+
266
+ X_a, x_a = self.preprocessing(item_a, view_a, param_a)
267
+ X_b, x_b = self.preprocessing(item_b, view_b, param_b)
268
+
269
+ X_aab, x_aab = self.preprocessing(item_a, view_b, param_a)
270
+ X_bba, x_bba = self.preprocessing(item_b, view_a, param_b)
271
+ X_aba, x_aba = self.preprocessing(item_ab, view_a, param_b)
272
+ X_bab, x_bab = self.preprocessing(item_ba, view_b, param_a)
273
+ X_abb, x_abb = self.preprocessing(item_ab, view_b, param_b)
274
+ X_baa, x_baa = self.preprocessing(item_ba, view_a, param_a)
275
+
276
+ return {"X_a": X_a, "X_b": X_b,
277
+ "X_aab": X_aab, "X_bba": X_bba,
278
+ "X_aba": X_aba, "X_bab": X_bab,
279
+ "X_abb": X_abb, "X_baa": X_baa,
280
+ "x_a": x_a, "x_b": x_b,
281
+ "x_aab": x_aab, "x_bba": x_bba,
282
+ "x_aba": x_aba, "x_bab": x_bab,
283
+ "x_abb": x_abb, "x_baa": x_baa,
284
+ "mot_a": mot_a, "mot_b": mot_b,
285
+ "char_a": char_a, "char_b": char_b,
286
+ "view_a": view_a, "view_b": view_b,
287
+ "meanpose": self.meanpose, "stdpose": self.stdpose}
288
+
289
+
290
+ class MixamoLimbScaleDataset(_MixamoDatasetBase):
291
+
292
+ def __init__(self, phase, config):
293
+ super(MixamoLimbScaleDataset, self).__init__(phase, config)
294
+ self.global_range = config.global_range
295
+ self.local_range = config.local_range
296
+
297
+ x_angles = view_angles if config.rotation_axes[0] else np.array([0])
298
+ z_angles = view_angles if config.rotation_axes[1] else np.array([0])
299
+ y_angles = view_angles if config.rotation_axes[2] else np.array([0])
300
+ x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
301
+ angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
302
+ self.view_angles = angles
303
+
304
+ def preprocessing(self, motion3d, view_angle=None, params=None):
305
+ if self.aug: motion3d, params = self.augmentation(motion3d, params)
306
+
307
+ basis = None
308
+ if view_angle is not None: basis = get_change_of_basis(motion3d, view_angle)
309
+
310
+ motion3d = preprocess_mixamo(motion3d)
311
+ motion3d = rotate_motion_3d(motion3d, basis)
312
+ motion2d = motion3d[:, [0, 2], :]
313
+ motion2d_scale = limb_scale_motion_2d(motion2d, self.global_range, self.local_range)
314
+
315
+ motion2d = localize_motion(motion2d)
316
+ motion2d_scale = localize_motion(motion2d_scale)
317
+
318
+ motion2d = normalize_motion(motion2d, self.meanpose, self.stdpose)
319
+ motion2d_scale = normalize_motion(motion2d_scale, self.meanpose, self.stdpose)
320
+
321
+ motion2d = motion2d.reshape([-1, motion2d.shape[-1]])
322
+ motion2d_scale = motion2d_scale.reshape((-1, motion2d_scale.shape[-1]))
323
+ motion2d = torch.from_numpy(motion2d).float()
324
+ motion2d_scale = torch.from_numpy(motion2d_scale).float()
325
+
326
+ return motion2d, motion2d_scale
327
+
328
+ def __getitem__(self, index):
329
+ # select two motions
330
+ motion_idx = np.random.choice(len(self.motion_names))
331
+ motion = self.motion_names[motion_idx]
332
+ # select two characters
333
+ char_idx = np.random.choice(len(self.character_names))
334
+ character = self.character_names[char_idx]
335
+ view_idx = np.random.choice(len(self.view_angles))
336
+ view = self.view_angles[view_idx]
337
+
338
+ if self.aug:
339
+ param = self.gen_aug_params(rotate=True)
340
+ else:
341
+ param = None
342
+
343
+ item = self.build_item(motion, character)
344
+
345
+ x, x_s = self.preprocessing(self.load_item(item), view, param)
346
+
347
+ return {"x": x, "x_s": x_s, "mot": motion, "char": character, "view": view,
348
+ "meanpose": self.meanpose, "stdpose": self.stdpose}
349
+
350
+
351
+ class SoloDanceDataset(Dataset):
352
+
353
+ def __init__(self, phase, config):
354
+ super(SoloDanceDataset, self).__init__()
355
+ self.global_range = config.global_range
356
+ self.local_range = config.local_range
357
+
358
+ assert phase in ['train', 'test']
359
+ self.data_root = config.train_dir if phase=='train' else config.test_dir
360
+ self.phase = phase
361
+ self.unit = config.unit
362
+ self.meanpose_path = config.train_meanpose_path if phase == 'train' else config.test_meanpose_path
363
+ self.stdpose_path = config.train_stdpose_path if phase == 'train' else config.test_stdpose_path
364
+ self.character_names = sorted(os.listdir(self.data_root))
365
+
366
+ self.items = glob.glob(os.path.join(self.data_root, '*/*/motions/*.npy'))
367
+ self.meanpose, self.stdpose = get_meanpose(phase, config)
368
+ self.meanpose = self.meanpose.astype(np.float32)
369
+ self.stdpose = self.stdpose.astype(np.float32)
370
+
371
+ if 'preload' in config and config.preload:
372
+ self.preload()
373
+ self.cached = True
374
+ else:
375
+ self.cached = False
376
+
377
+ def load_item(self, item):
378
+ if self.cached:
379
+ data = self.cache[item]
380
+ else:
381
+ data = np.load(item)
382
+ return data
383
+
384
+ def preload(self):
385
+ print("pre-loading into memory")
386
+ pbar = tqdm(total=len(self))
387
+ self.cache = {}
388
+ for item in self.items:
389
+ motion = np.load(item)
390
+ self.cache[item] = motion
391
+ pbar.update(1)
392
+
393
+ def preprocessing(self, motion):
394
+
395
+ motion = motion * self.unit
396
+
397
+ motion[1, :, :] = (motion[2, :, :] + motion[5, :, :]) / 2
398
+ motion[8, :, :] = (motion[9, :, :] + motion[12, :, :]) / 2
399
+
400
+ global_scale = self.global_range[0] + np.random.random() * (self.global_range[1] - self.global_range[0])
401
+ local_scales = self.local_range[0] + np.random.random([8]) * (self.local_range[1] - self.local_range[0])
402
+ motion_scale = scale_limbs(motion, global_scale, local_scales)
403
+
404
+ motion = localize_motion(motion)
405
+ motion_scale = localize_motion(motion_scale)
406
+ motion = normalize_motion(motion, self.meanpose, self.stdpose)
407
+ motion_scale = normalize_motion(motion_scale, self.meanpose, self.stdpose)
408
+ motion = motion.reshape((-1, motion.shape[-1]))
409
+ motion_scale = motion_scale.reshape((-1, motion_scale.shape[-1]))
410
+ motion = torch.from_numpy(motion).float()
411
+ motion_scale = torch.from_numpy(motion_scale).float()
412
+ return motion, motion_scale
413
+
414
+ def __len__(self):
415
+ return len(self.items)
416
+
417
+ def __getitem__(self, index):
418
+ item = self.items[index]
419
+ motion = self.load_item(item)
420
+ x, x_s = self.preprocessing(motion)
421
+ return {"x": x, "x_s": x_s, "meanpose": self.meanpose, "stdpose": self.stdpose}
lib/loss.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def kl_loss(code):
6
+ return torch.mean(torch.pow(code, 2))
7
+
8
+
9
+ def pairwise_cosine_similarity(seqs_i, seqs_j):
10
+ # seqs_i, seqs_j: [batch, statics, channel]
11
+ n_statics = seqs_i.size(1)
12
+ seqs_i_exp = seqs_i.unsqueeze(2).repeat(1, 1, n_statics, 1)
13
+ seqs_j_exp = seqs_j.unsqueeze(1).repeat(1, n_statics, 1, 1)
14
+ return F.cosine_similarity(seqs_i_exp, seqs_j_exp, dim=3)
15
+
16
+
17
+ def temporal_pairwise_cosine_similarity(seqs_i, seqs_j):
18
+ # seqs_i, seqs_j: [batch, channel, time]
19
+ seq_len = seqs_i.size(2)
20
+ seqs_i_exp = seqs_i.unsqueeze(3).repeat(1, 1, 1, seq_len)
21
+ seqs_j_exp = seqs_j.unsqueeze(2).repeat(1, 1, seq_len, 1)
22
+ return F.cosine_similarity(seqs_i_exp, seqs_j_exp, dim=1)
23
+
24
+
25
+ def consecutive_cosine_similarity(seqs):
26
+ # seqs: [batch, channel, time]
27
+ seqs_roll = seqs.roll(shifts=1, dim=2)[1:]
28
+ seqs = seqs[:-1]
29
+ return F.cosine_similarity(seqs, seqs_roll)
30
+
31
+
32
+ def triplet_margin_loss(seqs_a, seqs_b, neg_range=(0.0, 0.5), margin=0.2):
33
+ # seqs_a, seqs_b: [batch, channel, time]
34
+
35
+ neg_start, neg_end = neg_range
36
+ batch_size, _, seq_len = seqs_a.size()
37
+ n_neg_all = seq_len ** 2
38
+ n_neg = int(round(neg_end * n_neg_all))
39
+ n_neg_discard = int(round(neg_start * n_neg_all))
40
+
41
+ batch_size, _, seq_len = seqs_a.size()
42
+ sim_aa = temporal_pairwise_cosine_similarity(seqs_a, seqs_a)
43
+ sim_bb = temporal_pairwise_cosine_similarity(seqs_b, seqs_b)
44
+ sim_ab = temporal_pairwise_cosine_similarity(seqs_a, seqs_b)
45
+ sim_ba = sim_ab.transpose(1, 2)
46
+
47
+ diff_ab = (sim_ab - sim_aa).reshape(batch_size, -1)
48
+ diff_ba = (sim_ba - sim_bb).reshape(batch_size, -1)
49
+ diff = torch.cat([diff_ab, diff_ba], dim=0)
50
+ diff, _ = diff.topk(n_neg, dim=-1, sorted=True)
51
+ diff = diff[:, n_neg_discard:]
52
+
53
+ loss = diff + margin
54
+ loss = loss.clamp(min=0.)
55
+ loss = loss.mean()
56
+
57
+ return loss
lib/network.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ thismodule = sys.modules[__name__]
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import random
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ torch.manual_seed(123)
11
+
12
+
13
+ def get_autoencoder(config):
14
+ ae_cls = getattr(thismodule, config.autoencoder.cls)
15
+ return ae_cls(config.autoencoder)
16
+
17
+
18
+ class ConvEncoder(nn.Module):
19
+
20
+ @classmethod
21
+ def build_from_config(cls, config):
22
+ conv_pool = None if config.conv_pool is None else getattr(nn, config.conv_pool)
23
+ encoder = cls(config.channels, config.padding, config.kernel_size, config.conv_stride, conv_pool)
24
+ return encoder
25
+
26
+ def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None):
27
+ super(ConvEncoder, self).__init__()
28
+
29
+ self.in_channels = channels[0]
30
+
31
+ model = []
32
+ acti = nn.LeakyReLU(0.2)
33
+
34
+ nr_layer = len(channels) - 1
35
+
36
+ for i in range(nr_layer):
37
+ if conv_pool is None:
38
+ model.append(nn.ReflectionPad1d(padding))
39
+ model.append(nn.Conv1d(channels[i], channels[i+1], kernel_size=kernel_size, stride=conv_stride))
40
+ model.append(acti)
41
+ else:
42
+ model.append(nn.ReflectionPad1d(padding))
43
+ model.append(nn.Conv1d(channels[i], channels[i+1], kernel_size=kernel_size, stride=conv_stride))
44
+ model.append(acti)
45
+ model.append(conv_pool(kernel_size=2, stride=2))
46
+
47
+ self.model = nn.Sequential(*model)
48
+
49
+ def forward(self, x):
50
+ x = x[:, :self.in_channels, :]
51
+ x = self.model(x)
52
+ return x
53
+
54
+
55
+ class ConvDecoder(nn.Module):
56
+
57
+ @classmethod
58
+ def build_from_config(cls, config):
59
+ decoder = cls(config.channels, config.kernel_size)
60
+ return decoder
61
+
62
+ def __init__(self, channels, kernel_size=7):
63
+ super(ConvDecoder, self).__init__()
64
+
65
+ model = []
66
+ pad = (kernel_size - 1) // 2
67
+ acti = nn.LeakyReLU(0.2)
68
+
69
+ for i in range(len(channels) - 1):
70
+ model.append(nn.Upsample(scale_factor=2, mode='nearest'))
71
+ model.append(nn.ReflectionPad1d(pad))
72
+ model.append(nn.Conv1d(channels[i], channels[i + 1],
73
+ kernel_size=kernel_size, stride=1))
74
+ if i == 0 or i == 1:
75
+ model.append(nn.Dropout(p=0.2))
76
+ if not i == len(channels) - 2:
77
+ model.append(acti) # whether to add tanh a last?
78
+ #model.append(nn.Dropout(p=0.2))
79
+
80
+ self.model = nn.Sequential(*model)
81
+
82
+ def forward(self, x):
83
+ return self.model(x)
84
+
85
+
86
+ class Discriminator(nn.Module):
87
+
88
+ def __init__(self, config):
89
+ super(Discriminator, self).__init__()
90
+ self.gan_type = config.gan_type
91
+ encoder_cls = getattr(thismodule, config.encoder_cls)
92
+ self.encoder = encoder_cls.build_from_config(config)
93
+ self.linear = nn.Linear(config.channels[-1], 1)
94
+
95
+ def forward(self, seqs):
96
+
97
+ code_seq = self.encoder(seqs)
98
+ logits = self.linear(code_seq.permute(0, 2, 1))
99
+ return logits
100
+
101
+ def calc_dis_loss(self, x_gen, x_real):
102
+
103
+ fake_logits = self.forward(x_gen)
104
+ real_logits = self.forward(x_real)
105
+
106
+ if self.gan_type == 'lsgan':
107
+ loss = torch.mean((fake_logits - 0) ** 2) + torch.mean((real_logits - 1) ** 2)
108
+ elif self.gan_type == 'nsgan':
109
+ all0 = torch.zeros_like(fake_logits, requires_grad=False)
110
+ all1 = torch.ones_like(real_logits, requires_grad=False)
111
+ loss = torch.mean(F.binary_cross_entropy(F.sigmoid(fake_logits), all0) +
112
+ F.binary_cross_entropy(F.sigmoid(real_logits), all1))
113
+ else:
114
+ raise NotImplementedError
115
+
116
+ return loss
117
+
118
+ def calc_gen_loss(self, x_gen):
119
+
120
+ logits = self.forward(x_gen)
121
+ if self.gan_type == 'lsgan':
122
+ loss = torch.mean((logits - 1) ** 2)
123
+ elif self.gan_type == 'nsgan':
124
+ all1 = torch.ones_like(logits, requires_grad=False)
125
+ loss = torch.mean(F.binary_cross_entropy(F.sigmoid(logits), all1))
126
+ else:
127
+ raise NotImplementedError
128
+
129
+ return loss
130
+
131
+
132
+ class Autoencoder3f(nn.Module):
133
+
134
+ def __init__(self, config):
135
+ super(Autoencoder3f, self).__init__()
136
+
137
+ assert config.motion_encoder.channels[-1] + config.body_encoder.channels[-1] + \
138
+ config.view_encoder.channels[-1] == config.decoder.channels[0]
139
+
140
+ self.n_joints = config.decoder.channels[-1] // 3
141
+ self.body_reference = config.body_reference
142
+
143
+ motion_cls = getattr(thismodule, config.motion_encoder.cls)
144
+ body_cls = getattr(thismodule, config.body_encoder.cls)
145
+ view_cls = getattr(thismodule, config.view_encoder.cls)
146
+
147
+ self.motion_encoder = motion_cls.build_from_config(config.motion_encoder)
148
+ self.body_encoder = body_cls.build_from_config(config.body_encoder)
149
+ self.view_encoder = view_cls.build_from_config(config.view_encoder)
150
+ self.decoder = ConvDecoder.build_from_config(config.decoder)
151
+
152
+ self.body_pool = getattr(F, config.body_encoder.global_pool) if config.body_encoder.global_pool is not None else None
153
+ self.view_pool = getattr(F, config.view_encoder.global_pool) if config.view_encoder.global_pool is not None else None
154
+
155
+ def forward(self, seqs):
156
+ return self.reconstruct(seqs)
157
+
158
+ def encode_motion(self, seqs):
159
+ motion_code_seq = self.motion_encoder(seqs)
160
+ return motion_code_seq
161
+
162
+ def encode_body(self, seqs):
163
+ body_code_seq = self.body_encoder(seqs)
164
+ kernel_size = body_code_seq.size(-1)
165
+ body_code = self.body_pool(body_code_seq, kernel_size) if self.body_pool is not None else body_code_seq
166
+ return body_code, body_code_seq
167
+
168
+ def encode_view(self, seqs):
169
+ view_code_seq = self.view_encoder(seqs)
170
+ kernel_size = view_code_seq.size(-1)
171
+ view_code = self.view_pool(view_code_seq, kernel_size) if self.view_pool is not None else view_code_seq
172
+ return view_code, view_code_seq
173
+
174
+ def decode(self, motion_code, body_code, view_code):
175
+ if body_code.size(-1) == 1:
176
+ body_code = body_code.repeat(1, 1, motion_code.shape[-1])
177
+ if view_code.size(-1) == 1:
178
+ view_code = view_code.repeat(1, 1, motion_code.shape[-1])
179
+ complete_code = torch.cat([motion_code, body_code, view_code], dim=1)
180
+ out = self.decoder(complete_code)
181
+ return out
182
+
183
+ def cross3d(self, x_a, x_b, x_c):
184
+ motion_a = self.encode_motion(x_a)
185
+ body_b, _ = self.encode_body(x_b)
186
+ view_c, _ = self.encode_view(x_c)
187
+ out = self.decode(motion_a, body_b, view_c)
188
+ return out
189
+
190
+ def cross2d(self, x_a, x_b, x_c):
191
+ motion_a = self.encode_motion(x_a)
192
+ body_b, _ = self.encode_body(x_b)
193
+ view_c, _ = self.encode_view(x_c)
194
+ out = self.decode(motion_a, body_b, view_c)
195
+ batch_size, channels, seq_len = out.size()
196
+ n_joints = channels // 3
197
+ out = out.view(batch_size, n_joints, 3, seq_len)
198
+ out = out[:, :, [0, 2], :]
199
+ out = out.view(batch_size, n_joints * 2, seq_len)
200
+ return out
201
+
202
+ def cross2d_adv(self, x_a, x_b, x_c):
203
+ x_a.cpu()
204
+ x_a_shape = x_a.shape
205
+ print(x_a.shape)
206
+ #motion_a_org = self.encode_motion(x_a)
207
+ print(x_a)
208
+
209
+
210
+ # The heatmap image is saved as 'tensor_heatmap.png' in the current directory
211
+
212
+ # for i in range(0,119):
213
+ # x_a[0][11][i]+=1
214
+
215
+ #x_a[0][7][60]+=0.01
216
+
217
+ #motion_a = self.encode_motion(x_a)
218
+ # print(motion_a.shape)
219
+ # print(motion_a[0][0]-motion_a_org[0][0])
220
+ # res = motion_a[0] - motion_a_org[0]
221
+ # res = res.cpu().detach().numpy()
222
+ # # Code for plotting the heatmap
223
+ # plt.figure(figsize=(15, 10))
224
+ # plt.imshow(res, cmap='hot', interpolation='nearest')
225
+ # plt.colorbar()
226
+ # plt.title('Heatmap of the Tensor')
227
+
228
+ # # Save the heatmap to a local file
229
+ # plt.savefig('/home/fazhong/studio/transmomo.pytorch/tensor_heatmap2.png')
230
+ # plt.close()
231
+
232
+ initial_motion_a = self.encode_motion(x_a) # 计算初始的motion_a
233
+
234
+ # 定义一个函数来计算motion的变化量
235
+ def motion_change(motion_a, initial_motion_a):
236
+ return (motion_a - initial_motion_a).norm()
237
+
238
+ # 设置初始的最大变化量为0
239
+ max_change = 0
240
+
241
+ # 扰动次数,可以根据需要更改
242
+ num_perturbations = 10000
243
+ init_a = x_a.clone()
244
+ for _ in range(num_perturbations):
245
+ # 复制x_a以避免在原始数据上修改
246
+ x_a_perturbed = x_a.clone().cpu()
247
+
248
+ # 选择要扰动的随机点
249
+ batch_idx, seq_idx, feature_idx = (torch.randint(0, x_a.size(0), (1,)),
250
+ torch.randint(0, x_a.size(1), (1,)),
251
+ torch.randint(0, x_a.size(2), (1,)))
252
+
253
+ # 在选定点上加上扰动
254
+ x_a_perturbed[batch_idx, seq_idx, feature_idx] += 10 * torch.randn(1)
255
+
256
+ # 计算扰动后的motion_a
257
+ perturbed_motion_a = self.encode_motion(x_a_perturbed.to('cuda:0'))
258
+
259
+ # 计算变化量
260
+ change = motion_change(perturbed_motion_a, initial_motion_a)
261
+
262
+ # 如果变化量大于之前保存的最大变化量,则更新x_a和最大变化量
263
+ if change > max_change:
264
+ x_a = x_a_perturbed
265
+ max_change = change
266
+
267
+ # 最后,x_a将是导致最大motion_a变化的扰动版本
268
+ # max_change是这个变化量
269
+ # print(max_change)
270
+ # print(max_change.shape)
271
+
272
+ print(x_a_perturbed - init_a.cpu())
273
+ motion_a = self.encode_motion(x_a_perturbed.to('cuda:0'))
274
+ # motion_a = self.encode_motion(x_a.to('cuda:0'))
275
+ body_b, _ = self.encode_body(x_b)
276
+ view_c, _ = self.encode_view(x_c)
277
+
278
+ out = self.decode(motion_a, body_b, view_c)
279
+ batch_size, channels, seq_len = out.size()
280
+ n_joints = channels // 3
281
+ out = out.view(batch_size, n_joints, 3, seq_len)
282
+ out = out[:, :, [0, 2], :]
283
+ out = out.view(batch_size, n_joints * 2, seq_len)
284
+ return out
285
+
286
+ def cross2d_one(self, x_a):
287
+ motion_a = self.encode_motion(x_a)
288
+ body_b, _ = self.encode_body(x_a)
289
+ view_c, _ = self.encode_view(x_a)
290
+
291
+ out = self.decode(motion_a, body_b, view_c)
292
+ batch_size, channels, seq_len = out.size()
293
+ n_joints = channels // 3
294
+ out = out.view(batch_size, n_joints, 3, seq_len)
295
+ out = out[:, :, [0, 2], :]
296
+ out = out.view(batch_size, n_joints * 2, seq_len)
297
+ return out
298
+
299
+ def adv_cross(self,x_a):
300
+ motion_a = self.encode_motion(x_a)
301
+ body_b, _ = self.encode_body(x_a)
302
+ view_c, _ = self.encode_view(x_a)
303
+ return motion_a
304
+
305
+ def reconstruct3d(self, x):
306
+ motion_code = self.encode_motion(x)
307
+ body_code, _ = self.encode_body(x)
308
+ view_code, _ = self.encode_view(x)
309
+ out = self.decode(motion_code, body_code, view_code)
310
+ return out
311
+
312
+ def reconstruct2d(self, x):
313
+ motion_code = self.encode_motion(x)
314
+ body_code, _ = self.encode_body(x)
315
+ view_code, _ = self.encode_view(x)
316
+ out = self.decode(motion_code, body_code, view_code)
317
+ batch_size, channels, seq_len = out.size()
318
+ n_joints = channels // 3
319
+ out = out.view(batch_size, n_joints, 3, seq_len)
320
+ out = out[:, :, [0, 2], :]
321
+ out = out.view(batch_size, n_joints * 2, seq_len)
322
+ return out
323
+
324
+ def interpolate(self, x_a, x_b, N):
325
+
326
+ step_size = 1. / (N-1)
327
+ batch_size, _, seq_len = x_a.size()
328
+
329
+ motion_a = self.encode_motion(x_a)
330
+ body_a, body_a_seq = self.encode_body(x_a)
331
+ view_a, view_a_seq = self.encode_view(x_a)
332
+
333
+ motion_b = self.encode_motion(x_b)
334
+ body_b, body_b_seq = self.encode_body(x_b)
335
+ view_b, view_b_seq = self.encode_view(x_b)
336
+
337
+ batch_out = torch.zeros([batch_size, N, N, 2 * self.n_joints, seq_len])
338
+
339
+ for i in range(N):
340
+ motion_weight = i * step_size
341
+ for j in range(N):
342
+ body_weight = j * step_size
343
+ motion = (1. - motion_weight) * motion_a + motion_weight * motion_b
344
+ body = (1. - body_weight) * body_a + body_weight * body_b
345
+ view = (1. - body_weight) * view_a + body_weight * view_b
346
+ out = self.decode(motion, body, view)
347
+ batch_size, channels, seq_len = out.size()
348
+ n_joints = channels // 3
349
+ out = out.view(batch_size, n_joints, 3, seq_len)
350
+ out = out[:, :, [0, 2], :]
351
+ out = out.view(batch_size, n_joints * 2, seq_len)
352
+ batch_out[:, i, j, :, :] = out
353
+
354
+ return batch_out
355
+
356
+
lib/operation.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import imageio
5
+ from math import pi
6
+ from tqdm import tqdm
7
+ from lib.data import get_dataloader, get_meanpose
8
+ from lib.util.general import get_config
9
+ from lib.util.visualization import motion2video_np, hex2rgb
10
+ import os
11
+
12
+ eps = 1e-16
13
+
14
+
15
+ def localize_motion_torch(motion):
16
+ """
17
+ :param motion: B x J x D x T
18
+ :return:
19
+ """
20
+ B, J, D, T = motion.size()
21
+
22
+ # subtract centers to local coordinates
23
+ centers = motion[:, 8:9, :, :] # B x 1 x D x (T-1)
24
+ motion = motion - centers
25
+
26
+ # adding velocity
27
+ translation = centers[:, :, :, 1:] - centers[:, :, :, :-1] # B x 1 x D x (T-1)
28
+ velocity = F.pad(translation, [1, 0], "constant", 0.) # B x 1 x D x T
29
+ motion = torch.cat([motion[:, :8], motion[:, 9:], velocity], dim=1)
30
+
31
+ return motion
32
+
33
+
34
+ def normalize_motion_torch(motion, meanpose, stdpose):
35
+ """
36
+ :param motion: (B, J, D, T)
37
+ :param meanpose: (J, D)
38
+ :param stdpose: (J, D)
39
+ :return:
40
+ """
41
+ B, J, D, T = motion.size()
42
+ if D == 2 and meanpose.size(1) == 3:
43
+ meanpose = meanpose[:, [0, 2]]
44
+ if D == 2 and stdpose.size(1) == 3:
45
+ stdpose = stdpose[:, [0, 2]]
46
+ return (motion - meanpose.view(1, J, D, 1)) / stdpose.view(1, J, D, 1)
47
+
48
+
49
+ def normalize_motion_inv_torch(motion, meanpose, stdpose):
50
+ """
51
+ :param motion: (B, J, D, T)
52
+ :param meanpose: (J, D)
53
+ :param stdpose: (J, D)
54
+ :return:
55
+ """
56
+ B, J, D, T = motion.size()
57
+ if D == 2 and meanpose.size(1) == 3:
58
+ meanpose = meanpose[:, [0, 2]]
59
+ if D == 2 and stdpose.size(1) == 3:
60
+ stdpose = stdpose[:, [0, 2]]
61
+ return motion * stdpose.view(1, J, D, 1) + meanpose.view(1, J, D, 1)
62
+
63
+
64
+ def globalize_motion_torch(motion):
65
+ """
66
+ :param motion: B x J x D x T
67
+ :return:
68
+ """
69
+ B, J, D, T = motion.size()
70
+
71
+ motion_inv = torch.zeros_like(motion)
72
+ motion_inv[:, :8] = motion[:, :8]
73
+ motion_inv[:, 9:] = motion[:, 8:-1]
74
+
75
+ velocity = motion[:, -1:, :, :]
76
+ centers = torch.zeros_like(velocity)
77
+ displacement = torch.zeros_like(velocity[:, :, :, 0])
78
+
79
+ for t in range(T):
80
+ displacement += velocity[:, :, :, t]
81
+ centers[:, :, :, t] = displacement
82
+
83
+ motion_inv = motion_inv + centers
84
+
85
+ return motion_inv
86
+
87
+
88
+ def restore_world_space(motion, meanpose, stdpose, n_joints=15):
89
+ B, C, T = motion.size()
90
+ motion = motion.view(B, n_joints, C // n_joints, T)
91
+ motion = normalize_motion_inv_torch(motion, meanpose, stdpose)
92
+ motion = globalize_motion_torch(motion)
93
+ return motion
94
+
95
+
96
+ def convert_to_learning_space(motion, meanpose, stdpose):
97
+ B, J, D, T = motion.size()
98
+ motion = localize_motion_torch(motion)
99
+ motion = normalize_motion_torch(motion, meanpose, stdpose)
100
+ motion = motion.view(B, J*D, T)
101
+ return motion
102
+
103
+
104
+ # tensor operations for rotating and projecting 3d skeleton sequence
105
+
106
+ def get_body_basis(motion_3d):
107
+ """
108
+ Get the unit vectors for vector rectangular coordinates for given 3D motion
109
+ :param motion_3d: 3D motion from 3D joints positions, shape (B, n_joints, 3, seq_len).
110
+ :param angles: (K, 3), Rotation angles around each axis.
111
+ :return: unit vectors for vector rectangular coordinates's , shape (B, 3, 3).
112
+ """
113
+ B = motion_3d.size(0)
114
+
115
+ # 2 RightArm 5 LeftArm 9 RightUpLeg 12 LeftUpLeg
116
+ horizontal = (motion_3d[:, 2] - motion_3d[:, 5] + motion_3d[:, 9] - motion_3d[:, 12]) / 2 # [B, 3, seq_len]
117
+ horizontal = horizontal.mean(dim=-1) # [B, 3]
118
+ horizontal = horizontal / horizontal.norm(dim=-1).unsqueeze(-1) # [B, 3]
119
+
120
+ vector_z = torch.tensor([0., 0., 1.], device=motion_3d.device, dtype=motion_3d.dtype).unsqueeze(0).repeat(B, 1) # [B, 3]
121
+ vector_y = torch.cross(horizontal, vector_z) # [B, 3]
122
+ vector_y = vector_y / vector_y.norm(dim=-1).unsqueeze(-1)
123
+ vector_x = torch.cross(vector_y, vector_z)
124
+ vectors = torch.stack([vector_x, vector_y, vector_z], dim=2) # [B, 3, 3]
125
+
126
+ vectors = vectors.detach()
127
+
128
+ return vectors
129
+
130
+
131
+ def rotate_basis_euler(basis_vectors, angles):
132
+ """
133
+ Rotate vector rectangular coordinates from given angles.
134
+
135
+ :param basis_vectors: [B, 3, 3]
136
+ :param angles: [B, K, T, 3] Rotation angles around each axis.
137
+ :return: [B, K, T, 3, 3]
138
+ """
139
+ B, K, T, _ = angles.size()
140
+
141
+ cos, sin = torch.cos(angles), torch.sin(angles)
142
+ cx, cy, cz = cos[:, :, :, 0], cos[:, :, :, 1], cos[:, :, :, 2] # [B, K, T]
143
+ sx, sy, sz = sin[:, :, :, 0], sin[:, :, :, 1], sin[:, :, :, 2] # [B, K, T]
144
+
145
+ x = basis_vectors[:, 0, :] # [B, 3]
146
+ o = torch.zeros_like(x[:, 0]) # [B]
147
+
148
+ x_cpm_0 = torch.stack([o, -x[:, 2], x[:, 1]], dim=1) # [B, 3]
149
+ x_cpm_1 = torch.stack([x[:, 2], o, -x[:, 0]], dim=1) # [B, 3]
150
+ x_cpm_2 = torch.stack([-x[:, 1], x[:, 0], o], dim=1) # [B, 3]
151
+ x_cpm = torch.stack([x_cpm_0, x_cpm_1, x_cpm_2], dim=1) # [B, 3, 3]
152
+ x_cpm = x_cpm.unsqueeze(1).unsqueeze(2) # [B, 1, 1, 3, 3]
153
+
154
+ x = x.unsqueeze(-1) # [B, 3, 1]
155
+ xx = torch.matmul(x, x.transpose(-1, -2)).unsqueeze(1).unsqueeze(2) # [B, 1, 1, 3, 3]
156
+ eye = torch.eye(n=3, dtype=basis_vectors.dtype, device=basis_vectors.device)
157
+ eye = eye.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, 3, 3]
158
+ mat33_x = cx.unsqueeze(-1).unsqueeze(-1) * eye \
159
+ + sx.unsqueeze(-1).unsqueeze(-1) * x_cpm \
160
+ + (1. - cx).unsqueeze(-1).unsqueeze(-1) * xx # [B, K, T, 3, 3]
161
+
162
+ o = torch.zeros_like(cz)
163
+ i = torch.ones_like(cz)
164
+ mat33_z_0 = torch.stack([cz, sz, o], dim=3) # [B, K, T, 3]
165
+ mat33_z_1 = torch.stack([-sz, cz, o], dim=3) # [B, K, T, 3]
166
+ mat33_z_2 = torch.stack([o, o, i], dim=3) # [B, K, T, 3]
167
+ mat33_z = torch.stack([mat33_z_0, mat33_z_1, mat33_z_2], dim=3) # [B, K, T, 3, 3]
168
+
169
+ basis_vectors = basis_vectors.unsqueeze(1).unsqueeze(2)
170
+ basis_vectors = basis_vectors @ mat33_x.transpose(-1, -2) @ mat33_z
171
+
172
+
173
+ return basis_vectors
174
+
175
+
176
+ def change_of_basis(motion_3d, basis_vectors=None, project_2d=False):
177
+ # motion_3d: (B, n_joints, 3, seq_len)
178
+ # basis_vectors: (B, K, T, 3, 3)
179
+
180
+ if basis_vectors is None:
181
+ motion_proj = motion_3d[:, :, [0, 2], :] # [B, n_joints, 2, seq_len]
182
+ else:
183
+ if project_2d: basis_vectors = basis_vectors[:, :, :, [0, 2], :]
184
+ _, K, seq_len, _, _ = basis_vectors.size()
185
+ motion_3d = motion_3d.unsqueeze(1).repeat(1, K, 1, 1, 1)
186
+ motion_3d = motion_3d.permute([0, 1, 4, 3, 2]) # [B, K, J, 3, T] -> [B, K, T, 3, J]
187
+ motion_proj = basis_vectors @ motion_3d # [B, K, T, 2, 3] @ [B, K, T, 3, J] -> [B, K, T, 2, J]
188
+ motion_proj = motion_proj.permute([0, 1, 4, 3, 2]) # [B, K, T, 3, J] -> [B, K, J, 3, T]
189
+
190
+ return motion_proj
191
+
192
+
193
+ def rotate_and_maybe_project_world(X, angles=None, body_reference=True, project_2d=False):
194
+
195
+ out_dim = 2 if project_2d else 3
196
+ batch_size, n_joints, _, seq_len = X.size()
197
+
198
+ if angles is not None:
199
+ K = angles.size(1)
200
+ basis_vectors = get_body_basis(X) if body_reference else \
201
+ torch.eye(3, device=X.device).unsqueeze(0).repeat(batch_size, 1, 1)
202
+ basis_vectors = rotate_basis_euler(basis_vectors, angles)
203
+ X_trans = change_of_basis(X, basis_vectors, project_2d=project_2d)
204
+ X_trans = X_trans.reshape(batch_size * K, n_joints, out_dim, seq_len)
205
+ else:
206
+ X_trans = change_of_basis(X, project_2d=project_2d)
207
+ X_trans = X_trans.reshape(batch_size, n_joints, out_dim, seq_len)
208
+
209
+ return X_trans
210
+
211
+
212
+
213
+ def rotate_and_maybe_project_learning(X, meanpose, stdpose, angles=None, body_reference=True, project_2d=False):
214
+ batch_size, channels, seq_len = X.size()
215
+ n_joints = channels // 3
216
+ X = restore_world_space(X, meanpose, stdpose, n_joints)
217
+ X = rotate_and_maybe_project_world(X, angles, body_reference, project_2d)
218
+ X = convert_to_learning_space(X, meanpose, stdpose)
219
+ return X
lib/trainer.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import random
6
+ import lib.network
7
+ from lib.loss import *
8
+ from lib.util.general import weights_init, get_model_list, get_scheduler
9
+ from lib.network import Discriminator
10
+ from lib.operation import rotate_and_maybe_project_learning
11
+
12
+ class BaseTrainer(nn.Module):
13
+
14
+ def __init__(self, config):
15
+ super(BaseTrainer, self).__init__()
16
+
17
+ lr = config.lr
18
+ autoencoder_cls = getattr(lib.network, config.autoencoder.cls)
19
+ self.autoencoder = autoencoder_cls(config.autoencoder)
20
+ self.discriminator = Discriminator(config.discriminator)
21
+
22
+ # Setup the optimizers
23
+ beta1 = config.beta1
24
+ beta2 = config.beta2
25
+ dis_params = list(self.discriminator.parameters())
26
+ ae_params = list(self.autoencoder.parameters())
27
+ self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
28
+ lr=lr, betas=(beta1, beta2), weight_decay=config.weight_decay)
29
+ self.ae_opt = torch.optim.Adam([p for p in ae_params if p.requires_grad],
30
+ lr=lr, betas=(beta1, beta2), weight_decay=config.weight_decay)
31
+ self.dis_scheduler = get_scheduler(self.dis_opt, config)
32
+ self.ae_scheduler = get_scheduler(self.ae_opt, config)
33
+
34
+ # Network weight initialization
35
+ self.apply(weights_init(config.init))
36
+ self.discriminator.apply(weights_init('gaussian'))
37
+
38
+ def forward(self, data):
39
+ x_a, x_b = data["x_a"], data["x_b"]
40
+ batch_size = x_a.size(0)
41
+ self.eval()
42
+ body_a, body_b = self.sample_body_code(batch_size)
43
+ motion_a = self.autoencoder.encode_motion(x_a)
44
+ body_a_enc, _ = self.autoencoder.encode_body(x_a)
45
+ motion_b = self.autoencoder.encode_motion(x_b)
46
+ body_b_enc, _ = self.autoencoder.encode_body(x_b)
47
+ x_ab = self.autoencoder.decode(motion_a, body_b)
48
+ x_ba = self.autoencoder.decode(motion_b, body_a)
49
+ self.train()
50
+ return x_ab, x_ba
51
+
52
+ def dis_update(self, data, config):
53
+ raise NotImplemented
54
+
55
+ def ae_update(self, data, config):
56
+ raise NotImplemented
57
+
58
+ def recon_criterion(self, input, target):
59
+ raise NotImplemented
60
+
61
+ def update_learning_rate(self):
62
+ if self.dis_scheduler is not None:
63
+ self.dis_scheduler.step()
64
+ if self.ae_scheduler is not None:
65
+ self.ae_scheduler.step()
66
+
67
+ def resume(self, checkpoint_dir, config):
68
+ # Load generators
69
+ last_model_name = get_model_list(checkpoint_dir, "autoencoder")
70
+ state_dict = torch.load(last_model_name)
71
+ self.autoencoder.load_state_dict(state_dict)
72
+ iterations = int(last_model_name[-11:-3])
73
+ # Load discriminators
74
+ last_model_name = get_model_list(checkpoint_dir, "discriminator")
75
+ state_dict = torch.load(last_model_name)
76
+ self.discriminator.load_state_dict(state_dict)
77
+ # Load optimizers
78
+ state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
79
+ self.dis_opt.load_state_dict(state_dict['discriminator'])
80
+ self.ae_opt.load_state_dict(state_dict['autoencoder'])
81
+ # Reinitilize schedulers
82
+ self.dis_scheduler = get_scheduler(self.dis_opt, config, iterations)
83
+ self.ae_scheduler = get_scheduler(self.ae_opt, config, iterations)
84
+ print('Resume from iteration %d' % iterations)
85
+ return iterations
86
+
87
+ def save(self, snapshot_dir, iterations):
88
+ # Save generators, discriminators, and optimizers
89
+ ae_name = os.path.join(snapshot_dir, 'autoencoder_%08d.pt' % (iterations + 1))
90
+ dis_name = os.path.join(snapshot_dir, 'discriminator_%08d.pt' % (iterations + 1))
91
+ opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
92
+ torch.save(self.autoencoder.state_dict(), ae_name)
93
+ torch.save(self.discriminator.state_dict(), dis_name)
94
+ torch.save({'autoencoder': self.ae_opt.state_dict(), 'discriminator': self.dis_opt.state_dict()}, opt_name)
95
+
96
+ def validate(self, data, config):
97
+ re_dict = self.evaluate(self.autoencoder, data, config)
98
+ for key, val in re_dict.items():
99
+ setattr(self, key, val)
100
+
101
+ @staticmethod
102
+ def recon_criterion(input, target):
103
+ return torch.mean(torch.abs(input - target))
104
+
105
+ @classmethod
106
+ def evaluate(cls, autoencoder, data, config):
107
+ autoencoder.eval()
108
+ x_a, x_b = data["x_a"], data["x_b"]
109
+ x_aba, x_bab = data["x_aba"], data["x_bab"]
110
+ batch_size, _, seq_len = x_a.size()
111
+
112
+ re_dict = {}
113
+
114
+ with torch.no_grad(): # 2D eval
115
+
116
+ x_a_recon = autoencoder.reconstruct2d(x_a)
117
+ x_b_recon = autoencoder.reconstruct2d(x_b)
118
+ x_aba_recon = autoencoder.cross2d(x_a, x_b, x_a)
119
+ x_bab_recon = autoencoder.cross2d(x_b, x_a, x_b)
120
+
121
+ re_dict['loss_val_recon_x'] = cls.recon_criterion(x_a_recon, x_a) + cls.recon_criterion(x_b_recon, x_b)
122
+ re_dict['loss_val_cross_body'] = cls.recon_criterion(x_aba_recon, x_aba) + cls.recon_criterion(
123
+ x_bab_recon, x_bab)
124
+ re_dict['loss_val_total'] = 0.5 * re_dict['loss_val_recon_x'] + 0.5 * re_dict['loss_val_cross_body']
125
+
126
+ autoencoder.train()
127
+ return re_dict
128
+
129
+
130
+ class TransmomoTrainer(BaseTrainer):
131
+
132
+ def __init__(self, config):
133
+ super(TransmomoTrainer, self).__init__(config)
134
+
135
+ self.angle_unit = np.pi / (config.K + 1)
136
+ view_angles = np.array([i * self.angle_unit for i in range(1, config.K + 1)])
137
+ x_angles = view_angles if config.rotation_axes[0] else np.array([0])
138
+ z_angles = view_angles if config.rotation_axes[1] else np.array([0])
139
+ y_angles = view_angles if config.rotation_axes[2] else np.array([0])
140
+ x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
141
+ angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
142
+ self.angles = torch.tensor(angles).float().cuda()
143
+ self.rotation_axes = torch.tensor(config.rotation_axes).float().cuda()
144
+ self.rotation_axes_mask = [(_ > 0) for _ in config.rotation_axes]
145
+
146
+ def dis_update(self, data, config):
147
+
148
+ x_a = data["x"]
149
+ x_s = data["x_s"] # the limb-scaled version of x_a
150
+ meanpose = data["meanpose"][0]
151
+ stdpose = data["stdpose"][0]
152
+
153
+ self.dis_opt.zero_grad()
154
+
155
+ # encode
156
+ motion_a = self.autoencoder.encode_motion(x_a)
157
+ body_a, body_a_seq = self.autoencoder.encode_body(x_a)
158
+ view_a, view_a_seq = self.autoencoder.encode_view(x_a)
159
+
160
+ motion_s = self.autoencoder.encode_motion(x_s)
161
+ body_s, body_s_seq = self.autoencoder.encode_body(x_s)
162
+ view_s, view_s_seq = self.autoencoder.encode_view(x_s)
163
+
164
+ # decode (reconstruct, transform)
165
+ inds = random.sample(list(range(self.angles.size(0))), config.K)
166
+ angles = self.angles[inds].clone().detach() # [K, 3]
167
+ angles += self.angle_unit * self.rotation_axes * torch.randn([3], device=x_a.device)
168
+ angles = angles.unsqueeze(0).unsqueeze(2) # [B=1, K, T=1, 3]
169
+
170
+ X_a_recon = self.autoencoder.decode(motion_a, body_a, view_a)
171
+ x_a_trans = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=angles,
172
+ body_reference=config.autoencoder.body_reference, project_2d=True)
173
+
174
+ x_a_exp = x_a.repeat_interleave(config.K, dim=0)
175
+
176
+ self.loss_dis_trans = self.discriminator.calc_dis_loss(x_a_trans.detach(), x_a_exp)
177
+
178
+ if config.trans_gan_ls_w > 0:
179
+ X_s_recon = self.autoencoder.decode(motion_s, body_s, view_s)
180
+ x_s_trans = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=angles,
181
+ body_reference=config.autoencoder.body_reference, project_2d=True)
182
+ x_s_exp = x_s.repeat_interleave(config.K, dim=0)
183
+ self.loss_dis_trans_ls = self.discriminator.calc_dis_loss(x_s_trans.detach(), x_s_exp)
184
+ else:
185
+ self.loss_dis_trans_ls = 0
186
+
187
+ self.loss_dis_total = config.trans_gan_w * self.loss_dis_trans + \
188
+ config.trans_gan_ls_w * self.loss_dis_trans_ls
189
+
190
+ self.loss_dis_total.backward()
191
+ self.dis_opt.step()
192
+
193
+ def ae_update(self, data, config):
194
+
195
+ x_a = data["x"]
196
+ x_s = data["x_s"]
197
+ meanpose = data["meanpose"][0]
198
+ stdpose = data["stdpose"][0]
199
+ self.ae_opt.zero_grad()
200
+
201
+ # encode
202
+ motion_a = self.autoencoder.encode_motion(x_a)
203
+ body_a, body_a_seq = self.autoencoder.encode_body(x_a)
204
+ view_a, view_a_seq = self.autoencoder.encode_view(x_a)
205
+
206
+ motion_s = self.autoencoder.encode_motion(x_s)
207
+ body_s, body_s_seq = self.autoencoder.encode_body(x_s)
208
+ view_s, view_s_seq = self.autoencoder.encode_view(x_s)
209
+
210
+ # invariance loss
211
+ self.loss_inv_v_ls = self.recon_criterion(view_a, view_s) if config.inv_v_ls_w > 0 else 0
212
+ self.loss_inv_m_ls = self.recon_criterion(motion_a, motion_s) if config.inv_m_ls_w > 0 else 0
213
+
214
+ # body triplet loss
215
+ if config.triplet_b_w > 0:
216
+ self.loss_triplet_b = triplet_margin_loss(
217
+ body_a_seq, body_s_seq,
218
+ neg_range=config.triplet_neg_range,
219
+ margin=config.triplet_margin)
220
+ else:
221
+ self.loss_triplet_b = 0
222
+
223
+ # reconstruction
224
+ X_a_recon = self.autoencoder.decode(motion_a, body_a, view_a)
225
+ x_a_recon = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=None,
226
+ body_reference=config.autoencoder.body_reference, project_2d=True)
227
+
228
+ X_s_recon = self.autoencoder.decode(motion_s, body_s, view_s)
229
+ x_s_recon = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=None,
230
+ body_reference=config.autoencoder.body_reference, project_2d=True)
231
+
232
+ self.loss_recon_x = 0.5 * self.recon_criterion(x_a_recon, x_a) +\
233
+ 0.5 * self.recon_criterion(x_s_recon, x_s)
234
+
235
+ # cross reconstruction
236
+ X_as_recon = self.autoencoder.decode(motion_a, body_s, view_s)
237
+ x_as_recon = rotate_and_maybe_project_learning(X_as_recon, meanpose, stdpose, angles=None,
238
+ body_reference=config.autoencoder.body_reference, project_2d=True)
239
+
240
+ X_sa_recon = self.autoencoder.decode(motion_s, body_a, view_a)
241
+ x_sa_recon = rotate_and_maybe_project_learning(X_sa_recon, meanpose, stdpose, angles=None,
242
+ body_reference=config.autoencoder.body_reference, project_2d=True)
243
+
244
+ self.loss_cross_x = 0.5 * self.recon_criterion(x_as_recon, x_s) + 0.5 * self.recon_criterion(x_sa_recon, x_a)
245
+
246
+ # apply transformation
247
+ inds = random.sample(list(range(self.angles.size(0))), config.K)
248
+ angles = self.angles[inds].clone().detach()
249
+ angles += self.angle_unit * self.rotation_axes * torch.randn([3], device=x_a.device)
250
+ angles = angles.unsqueeze(0).unsqueeze(2)
251
+
252
+ x_a_trans = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=angles,
253
+ body_reference=config.autoencoder.body_reference, project_2d=True)
254
+ x_s_trans = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=angles,
255
+ body_reference=config.autoencoder.body_reference, project_2d=True)
256
+
257
+ # GAN loss
258
+ self.loss_gan_trans = self.discriminator.calc_gen_loss(x_a_trans)
259
+ self.loss_gan_trans_ls = self.discriminator.calc_gen_loss(x_s_trans) if config.trans_gan_ls_w > 0 else 0
260
+
261
+ # encode again
262
+ motion_a_trans = self.autoencoder.encode_motion(x_a_trans)
263
+ body_a_trans, _ = self.autoencoder.encode_body(x_a_trans)
264
+ view_a_trans, view_a_trans_seq = self.autoencoder.encode_view(x_a_trans)
265
+
266
+ motion_s_trans = self.autoencoder.encode_motion(x_s_trans)
267
+ body_s_trans, _ = self.autoencoder.encode_body(x_s_trans)
268
+
269
+ self.loss_inv_m_trans = 0.5 * self.recon_criterion(motion_a_trans, motion_a.repeat_interleave(config.K, dim=0)) + \
270
+ 0.5 * self.recon_criterion(motion_s_trans, motion_s.repeat_interleave(config.K, dim=0))
271
+ self.loss_inv_b_trans = 0.5 * self.recon_criterion(body_a_trans, body_a.repeat_interleave(config.K, dim=0)) + \
272
+ 0.5 * self.recon_criterion(body_s_trans, body_s.repeat_interleave(config.K, dim=0))
273
+
274
+ # view triplet loss
275
+ if config.triplet_v_w > 0:
276
+ view_a_seq_exp = view_a_seq.repeat_interleave(config.K, dim=0)
277
+ self.loss_triplet_v = triplet_margin_loss(
278
+ view_a_seq_exp, view_a_trans_seq,
279
+ neg_range=config.triplet_neg_range, margin=config.triplet_margin)
280
+ else:
281
+ self.loss_triplet_v = 0
282
+
283
+ # add all losses
284
+ self.loss_total = torch.tensor(0.).float().cuda()
285
+ self.loss_total += config.recon_x_w * self.loss_recon_x
286
+ self.loss_total += config.cross_x_w * self.loss_cross_x
287
+ self.loss_total += config.inv_v_ls_w * self.loss_inv_v_ls
288
+ self.loss_total += config.inv_m_ls_w * self.loss_inv_m_ls
289
+ self.loss_total += config.inv_b_trans_w * self.loss_inv_b_trans
290
+ self.loss_total += config.inv_m_trans_w * self.loss_inv_m_trans
291
+ self.loss_total += config.trans_gan_w * self.loss_gan_trans
292
+ self.loss_total += config.trans_gan_ls_w * self.loss_gan_trans_ls
293
+ self.loss_total += config.triplet_b_w * self.loss_triplet_b
294
+ self.loss_total += config.triplet_v_w * self.loss_triplet_v
295
+
296
+ self.loss_total.backward()
297
+ self.ae_opt.step()
298
+
lib/util/__init__.py ADDED
File without changes
lib/util/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (146 Bytes). View file
 
lib/util/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (150 Bytes). View file
 
lib/util/__pycache__/general.cpython-37.pyc ADDED
Binary file (13.3 kB). View file
 
lib/util/__pycache__/general.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
lib/util/__pycache__/motion.cpython-37.pyc ADDED
Binary file (8.05 kB). View file
 
lib/util/__pycache__/motion.cpython-38.pyc ADDED
Binary file (8.07 kB). View file
 
lib/util/__pycache__/visualization.cpython-37.pyc ADDED
Binary file (12.6 kB). View file
 
lib/util/__pycache__/visualization.cpython-38.pyc ADDED
Binary file (12.7 kB). View file
 
lib/util/general.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ import json
4
+ import logging
5
+ import shutil
6
+ import csv
7
+ # from lib.network.munit import Vgg16
8
+ from torch.autograd import Variable
9
+ from torch.optim import lr_scheduler
10
+ from easydict import EasyDict as edict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import os
15
+ import math
16
+ import torchvision.utils as vutils
17
+ import yaml
18
+ import numpy as np
19
+ import torch.nn.init as init
20
+ import time
21
+
22
+
23
+ def get_config(config_path):
24
+ with open(config_path, 'r') as stream:
25
+ config = yaml.load(stream, Loader=yaml.SafeLoader)
26
+ config = edict(config)
27
+ _, config_filename = os.path.split(config_path)
28
+ config_name, _ = os.path.splitext(config_filename)
29
+ config.name = config_name
30
+ return config
31
+
32
+ class TextLogger:
33
+
34
+ def __init__(self, log_path):
35
+ self.log_path = log_path
36
+ with open(self.log_path, "w") as f:
37
+ f.write("")
38
+ def log(self, log):
39
+ with open(self.log_path, "a+") as f:
40
+ f.write(log + "\n")
41
+
42
+ def eformat(f, prec):
43
+ s = "%.*e"%(prec, f)
44
+ mantissa, exp = s.split('e')
45
+ # add 1 to digits as 1 is taken by sign +/-
46
+ return "%se%d"%(mantissa, int(exp))
47
+
48
+
49
+ def __write_images(image_outputs, display_image_num, file_name):
50
+ image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels
51
+ image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0)
52
+ image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True)
53
+ vutils.save_image(image_grid, file_name, nrow=1)
54
+
55
+
56
+ def write_2images(image_outputs, display_image_num, image_directory, postfix):
57
+ n = len(image_outputs)
58
+ __write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix))
59
+ __write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix))
60
+
61
+
62
+ def write_one_row_html(html_file, iterations, img_filename, all_size):
63
+ html_file.write("<h3>iteration [%d] (%s)</h3>" % (iterations,img_filename.split('/')[-1]))
64
+ html_file.write("""
65
+ <p><a href="%s">
66
+ <img src="%s" style="width:%dpx">
67
+ </a><br>
68
+ <p>
69
+ """ % (img_filename, img_filename, all_size))
70
+ return
71
+
72
+
73
+ def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536):
74
+ html_file = open(filename, "w")
75
+ html_file.write('''
76
+ <!DOCTYPE html>
77
+ <html>
78
+ <head>
79
+ <title>Experiment name = %s</title>
80
+ <meta http-equiv="refresh" content="30">
81
+ </head>
82
+ <body>
83
+ ''' % os.path.basename(filename))
84
+ html_file.write("<h3>current</h3>")
85
+ write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size)
86
+ write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size)
87
+ for j in range(iterations, image_save_iterations-1, -1):
88
+ if j % image_save_iterations == 0:
89
+ write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size)
90
+ write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size)
91
+ write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size)
92
+ write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size)
93
+ html_file.write("</body></html>")
94
+ html_file.close()
95
+
96
+
97
+ def write_loss(iterations, trainer, train_writer):
98
+ members = [attr for attr in dir(trainer) \
99
+ if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'grad' in attr or 'nwd' in attr)]
100
+ for m in members:
101
+ train_writer.add_scalar(m, getattr(trainer, m), iterations + 1)
102
+
103
+
104
+ def slerp(val, low, high):
105
+ """
106
+ original: Animating Rotation with Quaternion Curves, Ken Shoemake
107
+ https://arxiv.org/abs/1609.04468
108
+ Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White
109
+ """
110
+ omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high)))
111
+ so = np.sin(omega)
112
+ return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high
113
+
114
+
115
+ def get_slerp_interp(nb_latents, nb_interp, z_dim):
116
+ """
117
+ modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot
118
+ https://github.com/ptrblck/prog_gans_pytorch_inference
119
+ """
120
+
121
+ latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32)
122
+ for _ in range(nb_latents):
123
+ low = np.random.randn(z_dim)
124
+ high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7
125
+ interp_vals = np.linspace(0, 1, num=nb_interp)
126
+ latent_interp = np.array([slerp(v, low, high) for v in interp_vals],
127
+ dtype=np.float32)
128
+ latent_interps = np.vstack((latent_interps, latent_interp))
129
+
130
+ return latent_interps[:, :, np.newaxis, np.newaxis]
131
+
132
+
133
+ # Get model list for resume
134
+ def get_model_list(dirname, key):
135
+ if os.path.exists(dirname) is False:
136
+ return None
137
+ gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
138
+ os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
139
+ if gen_models is None:
140
+ return None
141
+ gen_models.sort()
142
+ last_model_name = gen_models[-1]
143
+ return last_model_name
144
+
145
+
146
+ def get_scheduler(optimizer, hyperparameters, iterations=-1):
147
+ if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant':
148
+ scheduler = None # constant scheduler
149
+ elif hyperparameters['lr_policy'] == 'step':
150
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'],
151
+ gamma=hyperparameters['gamma'], last_epoch=iterations)
152
+ else:
153
+ return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy'])
154
+ return scheduler
155
+
156
+
157
+ def weights_init(init_type='gaussian'):
158
+ def init_fun(m):
159
+ classname = m.__class__.__name__
160
+ if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
161
+ # print m.__class__.__name__
162
+ if init_type == 'gaussian':
163
+ init.normal_(m.weight.data, 0.0, 0.02)
164
+ elif init_type == 'xavier':
165
+ init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
166
+ elif init_type == 'kaiming':
167
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
168
+ elif init_type == 'orthogonal':
169
+ init.orthogonal_(m.weight.data, gain=math.sqrt(2))
170
+ elif init_type == 'default':
171
+ pass
172
+ else:
173
+ assert 0, "Unsupported initialization: {}".format(init_type)
174
+ if hasattr(m, 'bias') and m.bias is not None:
175
+ init.constant_(m.bias.data, 0.0)
176
+
177
+ return init_fun
178
+
179
+
180
+ class Timer:
181
+ def __init__(self, msg):
182
+ self.msg = msg
183
+ self.start_time = None
184
+
185
+ def __enter__(self):
186
+ self.start_time = time.time()
187
+
188
+ def __exit__(self, exc_type, exc_value, exc_tb):
189
+ print(self.msg % (time.time() - self.start_time))
190
+
191
+
192
+ class TrainClock(object):
193
+ def __init__(self):
194
+ self.epoch = 1
195
+ self.minibatch = 0
196
+ self.step = 0
197
+
198
+ def tick(self):
199
+ self.minibatch += 1
200
+ self.step += 1
201
+
202
+ def tock(self):
203
+ self.epoch += 1
204
+ self.minibatch = 0
205
+
206
+ def make_checkpoint(self):
207
+ return {
208
+ 'epoch': self.epoch,
209
+ 'minibatch': self.minibatch,
210
+ 'step': self.step
211
+ }
212
+
213
+ def restore_checkpoint(self, clock_dict):
214
+ self.epoch = clock_dict['epoch']
215
+ self.minibatch = clock_dict['minibatch']
216
+ self.step = clock_dict['step']
217
+
218
+
219
+ class Table(object):
220
+ def __init__(self, filename):
221
+ '''
222
+ create a table to record experiment results that can be opened by excel
223
+ :param filename: using '.csv' as postfix
224
+ '''
225
+ assert '.csv' in filename
226
+ self.filename = filename
227
+
228
+ @staticmethod
229
+ def merge_headers(header1, header2):
230
+ #return list(set(header1 + header2))
231
+ if len(header1) > len(header2):
232
+ return header1
233
+ else:
234
+ return header2
235
+
236
+ def write(self, ordered_dict):
237
+ '''
238
+ write an entry
239
+ :param ordered_dict: something like {'name':'exp1', 'acc':90.5, 'epoch':50}
240
+ :return:
241
+ '''
242
+ if os.path.exists(self.filename) == False:
243
+ headers = list(ordered_dict.keys())
244
+ prev_rec = None
245
+ else:
246
+ with open(self.filename) as f:
247
+ reader = csv.DictReader(f)
248
+ headers = reader.fieldnames
249
+ prev_rec = [row for row in reader]
250
+ headers = self.merge_headers(headers, list(ordered_dict.keys()))
251
+
252
+ with open(self.filename, 'w', newline='') as f:
253
+ writer = csv.DictWriter(f, headers)
254
+ writer.writeheader()
255
+ if not prev_rec == None:
256
+ writer.writerows(prev_rec)
257
+ writer.writerow(ordered_dict)
258
+
259
+
260
+ class WorklogLogger:
261
+ def __init__(self, log_file):
262
+ logging.basicConfig(filename=log_file,
263
+ level=logging.DEBUG,
264
+ format='%(asctime)s - %(threadName)s - %(levelname)s - %(message)s')
265
+
266
+ self.logger = logging.getLogger()
267
+
268
+ def put_line(self, line):
269
+ self.logger.info(line)
270
+
271
+
272
+ class AverageMeter(object):
273
+ """Computes and stores the average and current value"""
274
+
275
+ def __init__(self, name):
276
+ self.name = name
277
+ self.reset()
278
+
279
+ def reset(self):
280
+ self.val = 0
281
+ self.avg = 0
282
+ self.sum = 0
283
+ self.count = 0
284
+
285
+ def update(self, val, n=1):
286
+ self.val = val
287
+ self.sum += val * n
288
+ self.count += n
289
+ self.avg = self.sum / self.count
290
+
291
+
292
+ def save_args(args, save_dir):
293
+ param_path = os.path.join(save_dir, 'params.json')
294
+
295
+ with open(param_path, 'w') as fp:
296
+ json.dump(args.__dict__, fp, indent=4, sort_keys=True)
297
+
298
+
299
+ def ensure_dir(path):
300
+ """
301
+ create path by first checking its existence,
302
+ :param paths: path
303
+ :return:
304
+ """
305
+ if not os.path.exists(path):
306
+ os.makedirs(path)
307
+
308
+
309
+ def ensure_dirs(paths):
310
+ """
311
+ create paths by first checking their existence
312
+ :param paths: list of path
313
+ :return:
314
+ """
315
+ if isinstance(paths, list) and not isinstance(paths, str):
316
+ for path in paths:
317
+ ensure_dir(path)
318
+ else:
319
+ ensure_dir(paths)
320
+
321
+
322
+ def remkdir(path):
323
+ """
324
+ if dir exists, remove it and create a new one
325
+ :param path:
326
+ :return:
327
+ """
328
+ if os.path.exists(path):
329
+ shutil.rmtree(path)
330
+ os.makedirs(path)
331
+
332
+
333
+ def cycle(iterable):
334
+ while True:
335
+ for x in iterable:
336
+ yield x
337
+
338
+
339
+ def save_image(image_numpy, image_path):
340
+ image_pil = Image.fromarray(image_numpy)
341
+ image_pil.save(image_path)
342
+
343
+
344
+ def pad_to_16x(x):
345
+ if x % 16 > 0:
346
+ return x - x % 16 + 16
347
+ return x
348
+
349
+
350
+ def pad_to_height(tar_height, img_height, img_width):
351
+ scale = tar_height / img_height
352
+ h = pad_to_16x(tar_height)
353
+ w = pad_to_16x(int(img_width * scale))
354
+ return h, w, scale
355
+
356
+
357
+ def to_gpu(data):
358
+ for key, item in data.items():
359
+ if torch.is_tensor(item):
360
+ data[key] = item.cuda()
361
+ return data
lib/util/global_norm.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def get_box(pose):
4
+ #input: pose([15,2])
5
+ return(np.min(pose[:,0]), np.max(pose[:,0]), np.min(pose[:,1]), np.max(pose[:,1]))
6
+
7
+ def get_height(pose):
8
+ #input: pose([15,2])
9
+ mean_ankle = (pose[14]+pose[11])/2
10
+ nose = pose[0]
11
+ return np.linalg.norm(mean_ankle-nose)
12
+
13
+ def get_base_mean(pose):
14
+ #input: pose([15,2])
15
+ x1, x2, y1, y2 = get_box(pose)
16
+ return np.array([(x1+x2)/2, y2])
17
+
18
+ def global_norm(driving_npy, target_npy):
19
+ #input: pose([15,2,frame1]), pose([15,2,frame2])
20
+ target_mean = np.mean(target_npy, axis=2)
21
+ driving_mean = np.mean(driving_npy, axis=2)
22
+ k2 = get_height(target_mean)/get_height(driving_mean)
23
+ target_mean_base = get_base_mean(target_mean)
24
+ driving_mean_base = get_base_mean(driving_mean)
25
+ driving_npy_permuted = np.transpose(driving_npy, axes=[2, 0, 1])
26
+ k = [1, k2]
27
+ normalized_permuted = (driving_npy_permuted-driving_mean_base)*k+target_mean_base
28
+ normalized = np.transpose(normalized_permuted, axes=[1,2,0])
29
+ return normalized # pose([15,2,frame1])
lib/util/motion.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage import gaussian_filter1d
2
+ import numpy as np
3
+ import json
4
+ import os
5
+ import torch
6
+
7
+
8
+ def preprocess_test(motion, meanpose, stdpose, unit=128):
9
+
10
+ motion = motion * unit
11
+
12
+ motion[1, :, :] = (motion[2, :, :] + motion[5, :, :]) / 2
13
+ motion[8, :, :] = (motion[9, :, :] + motion[12, :, :]) / 2
14
+
15
+ start = motion[8, :, 0]
16
+
17
+ motion = localize_motion(motion)
18
+ motion = normalize_motion(motion, meanpose, stdpose)
19
+
20
+ return motion, start
21
+
22
+
23
+ def postprocess(motion, meanpose, stdpose, unit=128, start=None):
24
+
25
+ motion = motion.detach().cpu().numpy()[0].reshape(-1, 2, motion.shape[-1])
26
+ motion = normalize_motion_inv(motion, meanpose, stdpose)
27
+ motion = globalize_motion(motion, start=start)
28
+ # motion = motion / unit
29
+
30
+ return motion
31
+
32
+
33
+ def preprocess_mixamo(motion, unit=128):
34
+
35
+ _, D, _ = motion.shape
36
+ horizontal_dim = 0
37
+ vertical_dim = D - 1
38
+
39
+ motion[1, :, :] = (motion[2, :, :] + motion[5, :, :]) / 2
40
+ motion[8, :, :] = (motion[9, :, :] + motion[12, :, :]) / 2
41
+
42
+ # rotate 180
43
+ motion[:, horizontal_dim, :] = - motion[:, horizontal_dim, :]
44
+ motion[:, vertical_dim, :] = - motion[:, vertical_dim, :]
45
+
46
+ motion = motion * unit
47
+
48
+ return motion
49
+
50
+
51
+ def rotate_motion_3d(motion3d, change_of_basis):
52
+
53
+ if change_of_basis is not None: motion3d = change_of_basis @ motion3d
54
+
55
+ return motion3d
56
+
57
+
58
+ def limb_scale_motion_2d(motion2d, global_range, local_range):
59
+
60
+ global_scale = global_range[0] + np.random.random() * (global_range[1] - global_range[0])
61
+ local_scales = local_range[0] + np.random.random([8]) * (local_range[1] - local_range[0])
62
+ motion_scale = scale_limbs(motion2d, global_scale, local_scales)
63
+
64
+ return motion_scale
65
+
66
+
67
+ def localize_motion(motion):
68
+ """
69
+ Motion fed into our network is the local motion, i.e. coordinates relative to the hip joint.
70
+ This function removes global motion of the hip joint, and instead represents global motion with velocity
71
+ """
72
+
73
+ D = motion.shape[1]
74
+
75
+ # subtract centers to local coordinates
76
+ centers = motion[8, :, :] # N_dim x T
77
+ motion = motion - centers
78
+
79
+ # adding velocity
80
+ translation = centers[:, 1:] - centers[:, :-1]
81
+ velocity = np.c_[np.zeros((D, 1)), translation]
82
+ velocity = velocity.reshape(1, D, -1)
83
+ motion = np.r_[motion[:8], motion[9:], velocity]
84
+ # motion_proj = np.r_[motion_proj[:8], motion_proj[9:]]
85
+
86
+ return motion
87
+
88
+
89
+ def globalize_motion(motion, start=None, velocity=None):
90
+ """
91
+ inverse process of localize_motion
92
+ """
93
+
94
+ if velocity is None: velocity = motion[-1].copy()
95
+ motion_inv = np.r_[motion[:8], np.zeros((1, 2, motion.shape[-1])), motion[8:-1]]
96
+
97
+ # restore centre position
98
+ centers = np.zeros_like(velocity)
99
+ sum = 0
100
+ for i in range(motion.shape[-1]):
101
+ sum += velocity[:, i]
102
+ centers[:, i] = sum
103
+ centers += start.reshape([2, 1])
104
+
105
+ return motion_inv + centers.reshape((1, 2, -1))
106
+
107
+
108
+ def normalize_motion(motion, meanpose, stdpose):
109
+ """
110
+ :param motion: (J, 2, T)
111
+ :param meanpose: (J, 2)
112
+ :param stdpose: (J, 2)
113
+ :return:
114
+ """
115
+ if motion.shape[1] == 2 and meanpose.shape[1] == 3:
116
+ meanpose = meanpose[:, [0, 2]]
117
+ if motion.shape[1] == 2 and stdpose.shape[1] == 3:
118
+ stdpose = stdpose[:, [0, 2]]
119
+ return (motion - meanpose[:, :, np.newaxis]) / stdpose[:, :, np.newaxis]
120
+
121
+
122
+ def normalize_motion_inv(motion, meanpose, stdpose):
123
+ if motion.shape[1] == 2 and meanpose.shape[1] == 3:
124
+ meanpose = meanpose[:, [0, 2]]
125
+ if motion.shape[1] == 2 and stdpose.shape[1] == 3:
126
+ stdpose = stdpose[:, [0, 2]]
127
+ return motion * stdpose[:, :, np.newaxis] + meanpose[:, :, np.newaxis]
128
+
129
+
130
+ def get_change_of_basis(motion3d, angles=None):
131
+ """
132
+ Get the unit vectors for local rectangular coordinates for given 3D motion
133
+ :param motion3d: numpy array. 3D motion from 3D joints positions, shape (nr_joints, 3, nr_frames).
134
+ :param angles: tuple of length 3. Rotation angles around each axis.
135
+ :return: numpy array. unit vectors for local rectangular coordinates's , shape (3, 3).
136
+ """
137
+ # 2 RightArm 5 LeftArm 9 RightUpLeg 12 LeftUpLeg
138
+ horizontal = (motion3d[2] - motion3d[5] + motion3d[9] - motion3d[12]) / 2
139
+ horizontal = np.mean(horizontal, axis=1)
140
+ horizontal = horizontal / np.linalg.norm(horizontal)
141
+ local_z = np.array([0, 0, 1])
142
+ local_y = np.cross(horizontal, local_z) # bugs!!!, horizontal and local_Z may not be perpendicular
143
+ local_y = local_y / np.linalg.norm(local_y)
144
+ local_x = np.cross(local_y, local_z)
145
+ local = np.stack([local_x, local_y, local_z], axis=0)
146
+
147
+ if angles is not None:
148
+ local = rotate_basis(local, angles)
149
+
150
+ return local
151
+
152
+
153
+ def rotate_basis(local3d, angles):
154
+ """
155
+ Rotate local rectangular coordinates from given view_angles.
156
+
157
+ :param local3d: numpy array. Unit vectors for local rectangular coordinates's , shape (3, 3).
158
+ :param angles: tuple of length 3. Rotation angles around each axis.
159
+ :return:
160
+ """
161
+ cx, cy, cz = np.cos(angles)
162
+ sx, sy, sz = np.sin(angles)
163
+
164
+ x = local3d[0]
165
+ x_cpm = np.array([
166
+ [0, -x[2], x[1]],
167
+ [x[2], 0, -x[0]],
168
+ [-x[1], x[0], 0]
169
+ ], dtype='float')
170
+ x = x.reshape(-1, 1)
171
+ mat33_x = cx * np.eye(3) + sx * x_cpm + (1.0 - cx) * np.matmul(x, x.T)
172
+
173
+ mat33_z = np.array([
174
+ [cz, sz, 0],
175
+ [-sz, cz, 0],
176
+ [0, 0, 1]
177
+ ], dtype='float')
178
+
179
+ local3d = local3d @ mat33_x.T @ mat33_z
180
+ return local3d
181
+
182
+
183
+ def get_foot_vel(batch_motion, foot_idx):
184
+ return batch_motion[:, foot_idx, 1:] - batch_motion[:, foot_idx, :-1] + batch_motion[:, -2:, 1:].repeat(1, 2, 1)
185
+
186
+
187
+ def get_limbs(motion):
188
+ J, D, T = motion.shape
189
+ limbs = np.zeros([14, D, T])
190
+ limbs[0] = motion[0] - motion[1] # neck
191
+ limbs[1] = motion[2] - motion[1] # r_shoulder
192
+ limbs[2] = motion[3] - motion[2] # r_arm
193
+ limbs[3] = motion[4] - motion[3] # r_forearm
194
+ limbs[4] = motion[5] - motion[1] # l_shoulder
195
+ limbs[5] = motion[6] - motion[5] # l_arm
196
+ limbs[6] = motion[7] - motion[6] # l_forearm
197
+ limbs[7] = motion[1] - motion[8] # spine
198
+ limbs[8] = motion[9] - motion[8] # r_pelvis
199
+ limbs[9] = motion[10] - motion[9] # r_thigh
200
+ limbs[10] = motion[11] - motion[10] # r_shin
201
+ limbs[11] = motion[12] - motion[8] # l_pelvis
202
+ limbs[12] = motion[13] - motion[12] # l_thigh
203
+ limbs[13] = motion[14] - motion[13] # l_shin
204
+ return limbs
205
+
206
+
207
+ def scale_limbs(motion, global_scale, local_scales):
208
+ """
209
+ :param motion: joint sequence [J, 2, T]
210
+ :param local_scales: 8 numbers of scales
211
+ :return: scaled joint sequence
212
+ """
213
+
214
+ limb_dependents = [
215
+ [0],
216
+ [2, 3, 4],
217
+ [3, 4],
218
+ [4],
219
+ [5, 6, 7],
220
+ [6, 7],
221
+ [7],
222
+ [0, 1, 2, 3, 4, 5, 6, 7],
223
+ [9, 10, 11],
224
+ [10, 11],
225
+ [11],
226
+ [12, 13, 14],
227
+ [13, 14],
228
+ [14]
229
+ ]
230
+
231
+ limbs = get_limbs(motion)
232
+ scaled_limbs = limbs.copy() * global_scale
233
+ scaled_limbs[0] *= local_scales[0]
234
+ scaled_limbs[1] *= local_scales[1]
235
+ scaled_limbs[2] *= local_scales[2]
236
+ scaled_limbs[3] *= local_scales[3]
237
+ scaled_limbs[4] *= local_scales[1]
238
+ scaled_limbs[5] *= local_scales[2]
239
+ scaled_limbs[6] *= local_scales[3]
240
+ scaled_limbs[7] *= local_scales[4]
241
+ scaled_limbs[8] *= local_scales[5]
242
+ scaled_limbs[9] *= local_scales[6]
243
+ scaled_limbs[10] *= local_scales[7]
244
+ scaled_limbs[11] *= local_scales[5]
245
+ scaled_limbs[12] *= local_scales[6]
246
+ scaled_limbs[13] *= local_scales[7]
247
+
248
+ delta = scaled_limbs - limbs
249
+
250
+ scaled_motion = motion.copy()
251
+ scaled_motion[limb_dependents[7]] += delta[7] # spine
252
+ scaled_motion[limb_dependents[1]] += delta[1] # r_shoulder
253
+ scaled_motion[limb_dependents[4]] += delta[4] # l_shoulder
254
+ scaled_motion[limb_dependents[2]] += delta[2] # r_arm
255
+ scaled_motion[limb_dependents[5]] += delta[5] # l_arm
256
+ scaled_motion[limb_dependents[3]] += delta[3] # r_forearm
257
+ scaled_motion[limb_dependents[6]] += delta[6] # l_forearm
258
+ scaled_motion[limb_dependents[0]] += delta[0] # neck
259
+ scaled_motion[limb_dependents[8]] += delta[8] # r_pelvis
260
+ scaled_motion[limb_dependents[11]] += delta[11] # l_pelvis
261
+ scaled_motion[limb_dependents[9]] += delta[9] # r_thigh
262
+ scaled_motion[limb_dependents[12]] += delta[12] # l_thigh
263
+ scaled_motion[limb_dependents[10]] += delta[10] # r_shin
264
+ scaled_motion[limb_dependents[13]] += delta[13] # l_shin
265
+
266
+
267
+ return scaled_motion
268
+
269
+
270
+ def get_limb_lengths(x):
271
+ _, dims, _ = x.shape
272
+ if dims == 2:
273
+ limbs = np.max(np.linalg.norm(get_limbs(x), axis=1), axis=-1)
274
+ limb_lengths = np.array([
275
+ limbs[0], # neck
276
+ max(limbs[1], limbs[4]), # shoulders
277
+ max(limbs[2], limbs[5]), # arms
278
+ max(limbs[3], limbs[6]), # forearms
279
+ limbs[7], # spine
280
+ max(limbs[8], limbs[11]), # pelvis
281
+ max(limbs[9], limbs[12]), # thighs
282
+ max(limbs[10], limbs[13]) # shins
283
+ ])
284
+ else:
285
+ limbs = np.mean(np.linalg.norm(get_limbs(x), axis=1), axis=-1)
286
+ limb_lengths = np.array([
287
+ limbs[0], # neck
288
+ (limbs[1] + limbs[4]) / 2., # shoulders
289
+ (limbs[2] + limbs[5]) / 2., # arms
290
+ (limbs[3] + limbs[6]) / 2., # forearms
291
+ limbs[7], # spine
292
+ (limbs[8] + limbs[11]) / 2., # pelvis
293
+ (limbs[9] + limbs[12]) / 2., # thighs
294
+ (limbs[10] + limbs[13]) / 2. # shins
295
+ ])
296
+ return limb_lengths
297
+
298
+
299
+ def limb_norm(x_a, x_b):
300
+
301
+ limb_lengths_a = get_limb_lengths(x_a)
302
+ limb_lengths_b = get_limb_lengths(x_b)
303
+
304
+ limb_lengths_a[limb_lengths_a < 1e-3] = 1e-3
305
+ local_scales = limb_lengths_b / limb_lengths_a
306
+
307
+ x_ab = scale_limbs(x_a, global_scale=1.0, local_scales=local_scales)
308
+
309
+ return x_ab
lib/util/visualization.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import cv2
4
+ import math
5
+ import imageio
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+ from lib.util.motion import normalize_motion_inv, globalize_motion
9
+ from lib.util.general import ensure_dir
10
+ from threading import Thread, Lock
11
+
12
+
13
+ def interpolate_color(color1, color2, alpha):
14
+ color_i = alpha * np.array(color1) + (1 - alpha) * np.array(color2)
15
+ return color_i.tolist()
16
+
17
+
18
+ def two_pts_to_rectangle(point1, point2):
19
+ X = [point1[1], point2[1]]
20
+ Y = [point1[0], point2[0]]
21
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
22
+ length = 5
23
+ alpha = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
24
+ beta = alpha - 90
25
+ if beta <= -180:
26
+ beta += 360
27
+ p1 = ( int(point1[0] - length*math.cos(math.radians(beta))) , int(point1[1] - length*math.sin(math.radians(beta))) )
28
+ p2 = ( int(point1[0] + length*math.cos(math.radians(beta))) , int(point1[1] + length*math.sin(math.radians(beta))) )
29
+ p3 = ( int(point2[0] + length*math.cos(math.radians(beta))) , int(point2[1] + length*math.sin(math.radians(beta))) )
30
+ p4 = ( int(point2[0] - length*math.cos(math.radians(beta))) , int(point2[1] - length*math.sin(math.radians(beta))) )
31
+ return [p1,p2,p3,p4]
32
+
33
+
34
+ def rgb2rgba(color):
35
+ return (color[0], color[1], color[2], 255)
36
+
37
+
38
+ def hex2rgb(hex, number_of_colors=3):
39
+ h = hex
40
+ rgb = []
41
+ for i in range(number_of_colors):
42
+ h = h.lstrip('#')
43
+ hex_color = h[0:6]
44
+ rgb_color = [int(hex_color[i:i+2], 16) for i in (0, 2 ,4)]
45
+ rgb.append(rgb_color)
46
+ h = h[6:]
47
+
48
+ return rgb
49
+
50
+ def normalize_joints(joints_position, H=512, W=512):
51
+ # 找出关节坐标的最大值和最小值
52
+ min_x, min_y = np.min(joints_position, axis=0)
53
+ max_x, max_y = np.max(joints_position, axis=0)
54
+
55
+ # 计算关节坐标的范围
56
+ range_x, range_y = max_x - min_x, max_y - min_y
57
+
58
+ # 设定一个缩放的边界保护值,防止关节坐标在缩放后超出画布
59
+ buffer = 0.05 # 例如 5% 的边界保护
60
+ scale_x, scale_y = (1 - buffer) * W / range_x, (1 - buffer) * H / range_y
61
+
62
+ # 使用较小的缩放比例来保证所有关节都能适合画布
63
+ scale = min(scale_x, scale_y)
64
+
65
+ # 缩放关节坐标
66
+ joints_position_scaled = (joints_position - np.array([min_x, min_y])) * scale
67
+
68
+ # 计算缩放后关节坐标的新边界
69
+ new_min_x, new_min_y = np.min(joints_position_scaled, axis=0)
70
+ new_max_x, new_max_y = np.max(joints_position_scaled, axis=0)
71
+
72
+ # 计算平移量,将关节移到画布中心
73
+ translate_x = (W - (new_max_x - new_min_x)) / 2 - new_min_x
74
+ translate_y = (H - (new_max_y - new_min_y)) / 2 - new_min_y
75
+
76
+ # 平移关节坐标
77
+ joints_position_normalized = joints_position_scaled + np.array([translate_x, translate_y])
78
+
79
+ return joints_position_normalized
80
+
81
+ def joints2image(joints_position, colors, transparency=False, H=512, W=512, nr_joints=15, imtype=np.uint8, grayscale=False, bg_color=(255, 255, 255)):
82
+ nr_joints = joints_position.shape[0]
83
+ joints_position=normalize_joints(joints_position)
84
+ if nr_joints == 49: # full joints(49): basic(15) + eyes(2) + toes(2) + hands(30)
85
+ limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7], \
86
+ [8, 9], [8, 13], [9, 10], [10, 11], [11, 12], [13, 14], [14, 15], [15, 16],
87
+ ]#[0, 17], [0, 18]] #ignore eyes
88
+
89
+ L = rgb2rgba(colors[0]) if transparency else colors[0]
90
+ M = rgb2rgba(colors[1]) if transparency else colors[1]
91
+ R = rgb2rgba(colors[2]) if transparency else colors[2]
92
+
93
+ colors_joints = [M, M, L, L, L, R, R,
94
+ R, M, L, L, L, L, R, R, R,
95
+ R, R, L] + [L] * 15 + [R] * 15
96
+
97
+ colors_limbs = [M, L, R, M, L, L, R,
98
+ R, L, R, L, L, L, R, R, R,
99
+ R, R]
100
+ elif nr_joints == 15 or nr_joints == 17: # basic joints(15) + (eyes(2))
101
+ limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7],
102
+ [8, 9], [8, 12], [9, 10], [10, 11], [12, 13], [13, 14]]
103
+ # [0, 15], [0, 16] two eyes are not drawn
104
+
105
+ L = rgb2rgba(colors[0]) if transparency else colors[0]
106
+ M = rgb2rgba(colors[1]) if transparency else colors[1]
107
+ R = rgb2rgba(colors[2]) if transparency else colors[2]
108
+
109
+ colors_joints = [M, M, L, L, L, R, R,
110
+ R, M, L, L, L, R, R, R]
111
+
112
+ colors_limbs = [M, L, R, M, L, L, R,
113
+ R, L, R, L, L, R, R]
114
+ else:
115
+ raise ValueError("Only support number of joints be 49 or 17 or 15")
116
+
117
+ if transparency:
118
+ canvas = np.zeros(shape=(H, W, 4))
119
+ else:
120
+ canvas = np.ones(shape=(H, W, 3)) * np.array(bg_color).reshape([1, 1, 3])
121
+ hips = joints_position[8]
122
+ neck = joints_position[1]
123
+ torso_length = ((hips[1] - neck[1]) ** 2 + (hips[0] - neck[0]) ** 2) ** 0.5
124
+
125
+ head_radius = int(torso_length/4.5)
126
+ end_effectors_radius = int(torso_length/15)
127
+ end_effectors_radius = 7
128
+ joints_radius = 7
129
+ # joints_position[0][0]*=200
130
+ # joints_position[0][1]*=200
131
+ cv2.circle(canvas, (int(joints_position[0][0]),int(joints_position[0][1])), head_radius, colors_joints[0], thickness=-1)
132
+
133
+ for i in range(1, len(colors_joints)):
134
+ # print(joints_position[i][0])
135
+ # joints_position[i][0]*=200
136
+ # joints_position[i][1]*=200
137
+ # print(joints_position[i][1])
138
+ if i in (17, 18):
139
+ continue
140
+ elif i > 18:
141
+ radius = 2
142
+ else:
143
+ radius = joints_radius
144
+ cv2.circle(canvas, (int(joints_position[i][0]),int(joints_position[i][1])), radius, colors_joints[i], thickness=-1)
145
+
146
+ stickwidth = 2
147
+
148
+ for i in range(len(limbSeq)):
149
+ limb = limbSeq[i]
150
+ cur_canvas = canvas.copy()
151
+ point1_index = limb[0]
152
+ point2_index = limb[1]
153
+
154
+ #if len(all_peaks[point1_index]) > 0 and len(all_peaks[point2_index]) > 0:
155
+ point1 = joints_position[point1_index]
156
+ point2 = joints_position[point2_index]
157
+ X = [point1[1], point2[1]]
158
+ Y = [point1[0], point2[0]]
159
+ mX = np.mean(X)
160
+ mY = np.mean(Y)
161
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
162
+ alpha = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
163
+
164
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(alpha), 0, 360, 1)
165
+ cv2.fillConvexPoly(cur_canvas, polygon, colors_limbs[i])
166
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
167
+ bb = bounding_box(canvas)
168
+ canvas_cropped = canvas[:,bb[2]:bb[3], :]
169
+
170
+ canvas = canvas.astype(imtype)
171
+ canvas_cropped = canvas_cropped.astype(imtype)
172
+
173
+ if grayscale:
174
+ if transparency:
175
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_RGBA2GRAY)
176
+ canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGBA2GRAY)
177
+ else:
178
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2GRAY)
179
+ canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGB2GRAY)
180
+
181
+ return [canvas, canvas_cropped]
182
+
183
+
184
+ def joints2image_highlight(joints_position, colors, highlights, transparency=False, H=512, W=512, nr_joints=15, imtype=np.uint8, grayscale=False):
185
+ nr_joints = joints_position.shape[0]
186
+
187
+ limbSeq = [[0, 1], [1, 2], [1, 5], [1, 8], [2, 3], [3, 4], [5, 6], [6, 7],
188
+ [8, 9], [8, 12], [9, 10], [10, 11], [12, 13], [13, 14]]
189
+ # [0, 15], [0, 16] two eyes are not drawn
190
+
191
+ L = rgb2rgba(colors[0]) if transparency else colors[0]
192
+ M = rgb2rgba(colors[1]) if transparency else colors[1]
193
+ R = rgb2rgba(colors[2]) if transparency else colors[2]
194
+ Hi = rgb2rgba(colors[3]) if transparency else colors[3]
195
+
196
+ colors_joints = [M, M, L, L, L, R, R,
197
+ R, M, L, L, L, R, R, R]
198
+
199
+ colors_limbs = [M, L, R, M, L, L, R,
200
+ R, L, R, L, L, R, R]
201
+
202
+ for hi in highlights: colors_limbs[hi] = Hi
203
+
204
+ if transparency:
205
+ canvas = np.zeros(shape=(H, W, 4))
206
+ else:
207
+ canvas = np.ones(shape=(H, W, 3)) * 255
208
+ hips = joints_position[8]
209
+ neck = joints_position[1]
210
+ torso_length = ((hips[1] - neck[1]) ** 2 + (hips[0] - neck[0]) ** 2) ** 0.5
211
+
212
+ head_radius = int(torso_length/4.5)
213
+ end_effectors_radius = int(torso_length/15)
214
+ end_effectors_radius = 7
215
+ joints_radius = 7
216
+
217
+ cv2.circle(canvas, (int(joints_position[0][0]*500),int(joints_position[0][1]*500)), head_radius, colors_joints[0], thickness=-1)
218
+
219
+ for i in range(1, len(colors_joints)):
220
+ if i in (17, 18):
221
+ continue
222
+ elif i > 18:
223
+ radius = 2
224
+ else:
225
+ radius = joints_radius
226
+ cv2.circle(canvas, (int(joints_position[i][0]*500),int(joints_position[i][1]*500)), radius, colors_joints[i], thickness=-1)
227
+
228
+ stickwidth = 2
229
+
230
+ for i in range(len(limbSeq)):
231
+ limb = limbSeq[i]
232
+ cur_canvas = canvas.copy()
233
+ point1_index = limb[0]
234
+ point2_index = limb[1]
235
+
236
+ #if len(all_peaks[point1_index]) > 0 and len(all_peaks[point2_index]) > 0:
237
+ point1 = joints_position[point1_index]
238
+ point2 = joints_position[point2_index]
239
+ X = [point1[1], point2[1]]
240
+ Y = [point1[0], point2[0]]
241
+ mX = np.mean(X)
242
+ mY = np.mean(Y)
243
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
244
+ alpha = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
245
+
246
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(alpha), 0, 360, 1)
247
+ cv2.fillConvexPoly(cur_canvas, polygon, colors_limbs[i])
248
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
249
+ bb = bounding_box(canvas)
250
+ canvas_cropped = canvas[:,bb[2]:bb[3], :]
251
+
252
+ canvas = canvas.astype(imtype)
253
+ canvas_cropped = canvas_cropped.astype(imtype)
254
+
255
+ if grayscale:
256
+ if transparency:
257
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_RGBA2GRAY)
258
+ canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGBA2GRAY)
259
+ else:
260
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2GRAY)
261
+ canvas_cropped = cv2.cvtColor(canvas_cropped, cv2.COLOR_RGB2GRAY)
262
+
263
+ return [canvas, canvas_cropped]
264
+
265
+
266
+ def motion2video(motion, h, w, save_path, colors, bg_color=(255, 255, 255), transparency=False, motion_tgt=None, fps=25, save_frame=True, grayscale=False, show_progress=True):
267
+ nr_joints = motion.shape[0]
268
+ as_array = save_path.endswith(".npy")
269
+ vlen = motion.shape[-1]
270
+
271
+ out_array = np.zeros([h, w, vlen]) if as_array else None
272
+ videowriter = None if as_array else imageio.get_writer(save_path, fps=fps, codec='libx264')
273
+
274
+ if save_frame:
275
+ frames_dir = save_path[:-4] + '-frames'
276
+ ensure_dir(frames_dir)
277
+
278
+ iterator = range(vlen)
279
+ if show_progress: iterator = tqdm(iterator)
280
+ for i in iterator:
281
+ [img, img_cropped] = joints2image(motion[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=grayscale)
282
+ if motion_tgt is not None:
283
+ [img_tgt, img_tgt_cropped] = joints2image(motion_tgt[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=grayscale)
284
+ img_ori = img.copy()
285
+ img = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
286
+ img_cropped = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
287
+ bb = bounding_box(img_cropped)
288
+ img_cropped = img_cropped[:, bb[2]:bb[3], :]
289
+ if save_frame:
290
+ save_image(img_cropped, os.path.join(frames_dir, "%04d.png" % i))
291
+ if as_array: out_array[:, :, i] = img
292
+ #else: videowriter.append_data(img)
293
+
294
+ if as_array: np.save(save_path, out_array)
295
+ else: videowriter.close()
296
+
297
+ return out_array
298
+
299
+
300
+ def motion2video_np(motion, h, w, colors, bg_color=(255, 255, 255), transparency=False, motion_tgt=None, show_progress=True, workers=6):
301
+
302
+ nr_joints = motion.shape[0]
303
+ vlen = motion.shape[-1]
304
+ out_array = np.zeros([vlen, h, w , 3])
305
+
306
+ queue = [i for i in range(vlen)]
307
+ lock = Lock()
308
+ pbar = tqdm(total=vlen) if show_progress else None
309
+
310
+ class Worker(Thread):
311
+
312
+ def __init__(self):
313
+ super(Worker, self).__init__()
314
+
315
+ def run(self):
316
+ while True:
317
+ lock.acquire()
318
+ if len(queue) == 0:
319
+ lock.release()
320
+ break
321
+ else:
322
+ i = queue.pop(0)
323
+ lock.release()
324
+ [img, img_cropped] = joints2image(motion[:, :, i], colors, transparency=transparency, bg_color=bg_color, H=h, W=w, nr_joints=nr_joints, grayscale=False)
325
+ if motion_tgt is not None:
326
+ [img_tgt, img_tgt_cropped] = joints2image(motion_tgt[:, :, i], colors, transparency=transparency, H=h, W=w, nr_joints=nr_joints, grayscale=False)
327
+ img_ori = img.copy()
328
+ img = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
329
+ # img_cropped = cv2.addWeighted(img_tgt, 0.3, img_ori, 0.7, 0)
330
+ # bb = bounding_box(img_cropped)
331
+ # img_cropped = img_cropped[:, bb[2]:bb[3], :]
332
+ out_array[i, :, :] = img
333
+ if show_progress: pbar.update(1)
334
+
335
+ pool = [Worker() for _ in range(workers)]
336
+ for worker in pool: worker.start()
337
+ for worker in pool: worker.join()
338
+ for worker in pool: del worker
339
+
340
+ return out_array
341
+
342
+
343
+
344
+ def save_image(image_numpy, image_path):
345
+ image_pil = Image.fromarray(image_numpy)
346
+ image_pil.save(image_path)
347
+
348
+
349
+ def bounding_box(img):
350
+ a = np.where(img != 0)
351
+ bbox = np.min(a[0]), np.max(a[0]), np.min(a[1]), np.max(a[1])
352
+ return bbox
353
+
354
+
355
+ def pose2im_all(all_peaks, H=512, W=512):
356
+ limbSeq = [[1, 2], [2, 3], [3, 4], # right arm
357
+ [1, 5], [5, 6], [6, 7], # left arm
358
+ [8, 9], [9, 10], [10, 11], # right leg
359
+ [8, 12], [12, 13], [13, 14], # left leg
360
+ [1, 0], # head/neck
361
+ [1, 8], # body,
362
+ ]
363
+
364
+ limb_colors = [[0, 60, 255], [0, 120, 255], [0, 180, 255],
365
+ [180, 255, 0], [120, 255, 0], [60, 255, 0],
366
+ [170, 255, 0], [85, 255, 0], [0, 255, 0],
367
+ [255, 170, 0], [255, 85, 0], [255, 0, 0],
368
+ [0, 85, 255],
369
+ [0, 0, 255],
370
+ ]
371
+
372
+ joint_colors = [[85, 0, 255], [0, 0, 255], [0, 60, 255], [0, 120, 255], [0, 180, 255],
373
+ [180, 255, 0], [120, 255, 0], [60, 255, 0], [0, 0, 255],
374
+ [170, 255, 0], [85, 255, 0], [0, 255, 0],
375
+ [255, 170, 0], [255, 85, 0], [255, 0, 0],
376
+ ]
377
+
378
+ image = pose2im(all_peaks, limbSeq, limb_colors, joint_colors, H, W)
379
+ return image
380
+
381
+
382
+ def pose2im(all_peaks, limbSeq, limb_colors, joint_colors, H, W, _circle=True, _limb=True, imtype=np.uint8):
383
+ canvas = np.zeros(shape=(H, W, 3))
384
+ canvas.fill(255)
385
+
386
+ if _circle:
387
+ for i in range(len(joint_colors)):
388
+ cv2.circle(canvas, (int(all_peaks[i][0]), int(all_peaks[i][1])), 2, joint_colors[i], thickness=2)
389
+
390
+ if _limb:
391
+ stickwidth = 2
392
+
393
+ for i in range(len(limbSeq)):
394
+ limb = limbSeq[i]
395
+ cur_canvas = canvas.copy()
396
+ point1_index = limb[0]
397
+ point2_index = limb[1]
398
+
399
+ if len(all_peaks[point1_index]) > 0 and len(all_peaks[point2_index]) > 0:
400
+ point1 = all_peaks[point1_index][0:2]
401
+ point2 = all_peaks[point2_index][0:2]
402
+ X = [point1[1], point2[1]]
403
+ Y = [point1[0], point2[0]]
404
+ mX = np.mean(X)
405
+ mY = np.mean(Y)
406
+ # cv2.line()
407
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
408
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
409
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
410
+ cv2.fillConvexPoly(cur_canvas, polygon, limb_colors[i])
411
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
412
+
413
+ return canvas.astype(imtype)
414
+
415
+
416
+ def visualize_motion_in_training(outputs, mean_pose, std_pose, nr_visual=4, H=512, W=512):
417
+ ret = {}
418
+ for k, out in outputs.items():
419
+ motion = out[0].detach().cpu().numpy()
420
+ inds = np.linspace(0, motion.shape[1] - 1, nr_visual, dtype=int)
421
+ motion = motion[:, inds]
422
+ motion = motion.reshape(-1, 2, motion.shape[-1])
423
+ motion = normalize_motion_inv(motion, mean_pose, std_pose)
424
+ peaks = globalize_motion(motion)
425
+
426
+ heatmaps = []
427
+ for i in range(peaks.shape[2]):
428
+ skeleton = pose2im_all(peaks[:, :, i], H, W)
429
+ heatmaps.append(skeleton)
430
+ heatmaps = np.stack(heatmaps).transpose((0, 3, 1, 2)) / 255.0
431
+ ret[k] = heatmaps
432
+
433
+ return ret
434
+
435
+ if __name__ == '__main__':
436
+ # 加载.npy文件
437
+ motion_data = np.load('/home/fazhong/studio/transmomo.pytorch/out/retarget_1_121.npy')
438
+
439
+ # 设置视频参数
440
+ height = 512 # 视频的高度
441
+ width = 512 # 视频的宽度
442
+ save_path = 'Angry.mp4' # 保存视频的路径
443
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] # 关节颜色
444
+ bg_color = (255, 255, 255) # 背景颜色
445
+ fps = 25 # 视频的帧率
446
+
447
+ # 调用函数生成视频
448
+ motion2video(motion_data, height, width, save_path, colors, bg_color=bg_color, transparency=False, fps=fps)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ fastapi
3
+ aiohttp
4
+ easydict
5
+ imageio-ffmpeg
6
+ matplotlib
7
+ numpy
8
+ Pillow
9
+ protobuf
10
+ PyYAML
11
+ scikit-image
12
+ scikit-learn
13
+ scipy
14
+ tensorboardX
15
+ torch>=1.2.0
16
+ torchvision
17
+ tqdm