TMM / lib /network.py
Fazhong Liu
fin
7ca9b42
import sys
thismodule = sys.modules[__name__]
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(123)
def get_autoencoder(config):
ae_cls = getattr(thismodule, config.autoencoder.cls)
return ae_cls(config.autoencoder)
class ConvEncoder(nn.Module):
@classmethod
def build_from_config(cls, config):
conv_pool = None if config.conv_pool is None else getattr(nn, config.conv_pool)
encoder = cls(config.channels, config.padding, config.kernel_size, config.conv_stride, conv_pool)
return encoder
def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None):
super(ConvEncoder, self).__init__()
self.in_channels = channels[0]
model = []
acti = nn.LeakyReLU(0.2)
nr_layer = len(channels) - 1
for i in range(nr_layer):
if conv_pool is None:
model.append(nn.ReflectionPad1d(padding))
model.append(nn.Conv1d(channels[i], channels[i+1], kernel_size=kernel_size, stride=conv_stride))
model.append(acti)
else:
model.append(nn.ReflectionPad1d(padding))
model.append(nn.Conv1d(channels[i], channels[i+1], kernel_size=kernel_size, stride=conv_stride))
model.append(acti)
model.append(conv_pool(kernel_size=2, stride=2))
self.model = nn.Sequential(*model)
def forward(self, x):
x = x[:, :self.in_channels, :]
x = self.model(x)
return x
class ConvDecoder(nn.Module):
@classmethod
def build_from_config(cls, config):
decoder = cls(config.channels, config.kernel_size)
return decoder
def __init__(self, channels, kernel_size=7):
super(ConvDecoder, self).__init__()
model = []
pad = (kernel_size - 1) // 2
acti = nn.LeakyReLU(0.2)
for i in range(len(channels) - 1):
model.append(nn.Upsample(scale_factor=2, mode='nearest'))
model.append(nn.ReflectionPad1d(pad))
model.append(nn.Conv1d(channels[i], channels[i + 1],
kernel_size=kernel_size, stride=1))
if i == 0 or i == 1:
model.append(nn.Dropout(p=0.2))
if not i == len(channels) - 2:
model.append(acti) # whether to add tanh a last?
#model.append(nn.Dropout(p=0.2))
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class Discriminator(nn.Module):
def __init__(self, config):
super(Discriminator, self).__init__()
self.gan_type = config.gan_type
encoder_cls = getattr(thismodule, config.encoder_cls)
self.encoder = encoder_cls.build_from_config(config)
self.linear = nn.Linear(config.channels[-1], 1)
def forward(self, seqs):
code_seq = self.encoder(seqs)
logits = self.linear(code_seq.permute(0, 2, 1))
return logits
def calc_dis_loss(self, x_gen, x_real):
fake_logits = self.forward(x_gen)
real_logits = self.forward(x_real)
if self.gan_type == 'lsgan':
loss = torch.mean((fake_logits - 0) ** 2) + torch.mean((real_logits - 1) ** 2)
elif self.gan_type == 'nsgan':
all0 = torch.zeros_like(fake_logits, requires_grad=False)
all1 = torch.ones_like(real_logits, requires_grad=False)
loss = torch.mean(F.binary_cross_entropy(F.sigmoid(fake_logits), all0) +
F.binary_cross_entropy(F.sigmoid(real_logits), all1))
else:
raise NotImplementedError
return loss
def calc_gen_loss(self, x_gen):
logits = self.forward(x_gen)
if self.gan_type == 'lsgan':
loss = torch.mean((logits - 1) ** 2)
elif self.gan_type == 'nsgan':
all1 = torch.ones_like(logits, requires_grad=False)
loss = torch.mean(F.binary_cross_entropy(F.sigmoid(logits), all1))
else:
raise NotImplementedError
return loss
class Autoencoder3f(nn.Module):
def __init__(self, config):
super(Autoencoder3f, self).__init__()
assert config.motion_encoder.channels[-1] + config.body_encoder.channels[-1] + \
config.view_encoder.channels[-1] == config.decoder.channels[0]
self.n_joints = config.decoder.channels[-1] // 3
self.body_reference = config.body_reference
motion_cls = getattr(thismodule, config.motion_encoder.cls)
body_cls = getattr(thismodule, config.body_encoder.cls)
view_cls = getattr(thismodule, config.view_encoder.cls)
self.motion_encoder = motion_cls.build_from_config(config.motion_encoder)
self.body_encoder = body_cls.build_from_config(config.body_encoder)
self.view_encoder = view_cls.build_from_config(config.view_encoder)
self.decoder = ConvDecoder.build_from_config(config.decoder)
self.body_pool = getattr(F, config.body_encoder.global_pool) if config.body_encoder.global_pool is not None else None
self.view_pool = getattr(F, config.view_encoder.global_pool) if config.view_encoder.global_pool is not None else None
def forward(self, seqs):
return self.reconstruct(seqs)
def encode_motion(self, seqs):
motion_code_seq = self.motion_encoder(seqs)
return motion_code_seq
def encode_body(self, seqs):
body_code_seq = self.body_encoder(seqs)
kernel_size = body_code_seq.size(-1)
body_code = self.body_pool(body_code_seq, kernel_size) if self.body_pool is not None else body_code_seq
return body_code, body_code_seq
def encode_view(self, seqs):
view_code_seq = self.view_encoder(seqs)
kernel_size = view_code_seq.size(-1)
view_code = self.view_pool(view_code_seq, kernel_size) if self.view_pool is not None else view_code_seq
return view_code, view_code_seq
def decode(self, motion_code, body_code, view_code):
if body_code.size(-1) == 1:
body_code = body_code.repeat(1, 1, motion_code.shape[-1])
if view_code.size(-1) == 1:
view_code = view_code.repeat(1, 1, motion_code.shape[-1])
complete_code = torch.cat([motion_code, body_code, view_code], dim=1)
out = self.decoder(complete_code)
return out
def cross3d(self, x_a, x_b, x_c):
motion_a = self.encode_motion(x_a)
body_b, _ = self.encode_body(x_b)
view_c, _ = self.encode_view(x_c)
out = self.decode(motion_a, body_b, view_c)
return out
def cross2d(self, x_a, x_b, x_c):
motion_a = self.encode_motion(x_a)
body_b, _ = self.encode_body(x_b)
view_c, _ = self.encode_view(x_c)
out = self.decode(motion_a, body_b, view_c)
batch_size, channels, seq_len = out.size()
n_joints = channels // 3
out = out.view(batch_size, n_joints, 3, seq_len)
out = out[:, :, [0, 2], :]
out = out.view(batch_size, n_joints * 2, seq_len)
return out
def cross2d_adv(self, x_a, x_b, x_c):
x_a.cpu()
x_a_shape = x_a.shape
print(x_a.shape)
#motion_a_org = self.encode_motion(x_a)
print(x_a)
# The heatmap image is saved as 'tensor_heatmap.png' in the current directory
# for i in range(0,119):
# x_a[0][11][i]+=1
#x_a[0][7][60]+=0.01
#motion_a = self.encode_motion(x_a)
# print(motion_a.shape)
# print(motion_a[0][0]-motion_a_org[0][0])
# res = motion_a[0] - motion_a_org[0]
# res = res.cpu().detach().numpy()
# # Code for plotting the heatmap
# plt.figure(figsize=(15, 10))
# plt.imshow(res, cmap='hot', interpolation='nearest')
# plt.colorbar()
# plt.title('Heatmap of the Tensor')
# # Save the heatmap to a local file
# plt.savefig('/home/fazhong/studio/transmomo.pytorch/tensor_heatmap2.png')
# plt.close()
initial_motion_a = self.encode_motion(x_a) # 计算初始的motion_a
# 定义一个函数来计算motion的变化量
def motion_change(motion_a, initial_motion_a):
return (motion_a - initial_motion_a).norm()
# 设置初始的最大变化量为0
max_change = 0
# 扰动次数,可以根据需要更改
num_perturbations = 10000
init_a = x_a.clone()
for _ in range(num_perturbations):
# 复制x_a以避免在原始数据上修改
x_a_perturbed = x_a.clone().cpu()
# 选择要扰动的随机点
batch_idx, seq_idx, feature_idx = (torch.randint(0, x_a.size(0), (1,)),
torch.randint(0, x_a.size(1), (1,)),
torch.randint(0, x_a.size(2), (1,)))
# 在选定点上加上扰动
x_a_perturbed[batch_idx, seq_idx, feature_idx] += 10 * torch.randn(1)
# 计算扰动后的motion_a
perturbed_motion_a = self.encode_motion(x_a_perturbed.to('cuda:0'))
# 计算变化量
change = motion_change(perturbed_motion_a, initial_motion_a)
# 如果变化量大于之前保存的最大变化量,则更新x_a和最大变化量
if change > max_change:
x_a = x_a_perturbed
max_change = change
# 最后,x_a将是导致最大motion_a变化的扰动版本
# max_change是这个变化量
# print(max_change)
# print(max_change.shape)
print(x_a_perturbed - init_a.cpu())
motion_a = self.encode_motion(x_a_perturbed.to('cuda:0'))
# motion_a = self.encode_motion(x_a.to('cuda:0'))
body_b, _ = self.encode_body(x_b)
view_c, _ = self.encode_view(x_c)
out = self.decode(motion_a, body_b, view_c)
batch_size, channels, seq_len = out.size()
n_joints = channels // 3
out = out.view(batch_size, n_joints, 3, seq_len)
out = out[:, :, [0, 2], :]
out = out.view(batch_size, n_joints * 2, seq_len)
return out
def cross2d_one(self, x_a):
motion_a = self.encode_motion(x_a)
body_b, _ = self.encode_body(x_a)
view_c, _ = self.encode_view(x_a)
out = self.decode(motion_a, body_b, view_c)
batch_size, channels, seq_len = out.size()
n_joints = channels // 3
out = out.view(batch_size, n_joints, 3, seq_len)
out = out[:, :, [0, 2], :]
out = out.view(batch_size, n_joints * 2, seq_len)
return out
def adv_cross(self,x_a):
motion_a = self.encode_motion(x_a)
body_b, _ = self.encode_body(x_a)
view_c, _ = self.encode_view(x_a)
return motion_a
def reconstruct3d(self, x):
motion_code = self.encode_motion(x)
body_code, _ = self.encode_body(x)
view_code, _ = self.encode_view(x)
out = self.decode(motion_code, body_code, view_code)
return out
def reconstruct2d(self, x):
motion_code = self.encode_motion(x)
body_code, _ = self.encode_body(x)
view_code, _ = self.encode_view(x)
out = self.decode(motion_code, body_code, view_code)
batch_size, channels, seq_len = out.size()
n_joints = channels // 3
out = out.view(batch_size, n_joints, 3, seq_len)
out = out[:, :, [0, 2], :]
out = out.view(batch_size, n_joints * 2, seq_len)
return out
def interpolate(self, x_a, x_b, N):
step_size = 1. / (N-1)
batch_size, _, seq_len = x_a.size()
motion_a = self.encode_motion(x_a)
body_a, body_a_seq = self.encode_body(x_a)
view_a, view_a_seq = self.encode_view(x_a)
motion_b = self.encode_motion(x_b)
body_b, body_b_seq = self.encode_body(x_b)
view_b, view_b_seq = self.encode_view(x_b)
batch_out = torch.zeros([batch_size, N, N, 2 * self.n_joints, seq_len])
for i in range(N):
motion_weight = i * step_size
for j in range(N):
body_weight = j * step_size
motion = (1. - motion_weight) * motion_a + motion_weight * motion_b
body = (1. - body_weight) * body_a + body_weight * body_b
view = (1. - body_weight) * view_a + body_weight * view_b
out = self.decode(motion, body, view)
batch_size, channels, seq_len = out.size()
n_joints = channels // 3
out = out.view(batch_size, n_joints, 3, seq_len)
out = out[:, :, [0, 2], :]
out = out.view(batch_size, n_joints * 2, seq_len)
batch_out[:, i, j, :, :] = out
return batch_out