TMM / lib /trainer.py
Fazhong Liu
fin
7ca9b42
import os
import torch
import torch.nn as nn
import numpy as np
import random
import lib.network
from lib.loss import *
from lib.util.general import weights_init, get_model_list, get_scheduler
from lib.network import Discriminator
from lib.operation import rotate_and_maybe_project_learning
class BaseTrainer(nn.Module):
def __init__(self, config):
super(BaseTrainer, self).__init__()
lr = config.lr
autoencoder_cls = getattr(lib.network, config.autoencoder.cls)
self.autoencoder = autoencoder_cls(config.autoencoder)
self.discriminator = Discriminator(config.discriminator)
# Setup the optimizers
beta1 = config.beta1
beta2 = config.beta2
dis_params = list(self.discriminator.parameters())
ae_params = list(self.autoencoder.parameters())
self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
lr=lr, betas=(beta1, beta2), weight_decay=config.weight_decay)
self.ae_opt = torch.optim.Adam([p for p in ae_params if p.requires_grad],
lr=lr, betas=(beta1, beta2), weight_decay=config.weight_decay)
self.dis_scheduler = get_scheduler(self.dis_opt, config)
self.ae_scheduler = get_scheduler(self.ae_opt, config)
# Network weight initialization
self.apply(weights_init(config.init))
self.discriminator.apply(weights_init('gaussian'))
def forward(self, data):
x_a, x_b = data["x_a"], data["x_b"]
batch_size = x_a.size(0)
self.eval()
body_a, body_b = self.sample_body_code(batch_size)
motion_a = self.autoencoder.encode_motion(x_a)
body_a_enc, _ = self.autoencoder.encode_body(x_a)
motion_b = self.autoencoder.encode_motion(x_b)
body_b_enc, _ = self.autoencoder.encode_body(x_b)
x_ab = self.autoencoder.decode(motion_a, body_b)
x_ba = self.autoencoder.decode(motion_b, body_a)
self.train()
return x_ab, x_ba
def dis_update(self, data, config):
raise NotImplemented
def ae_update(self, data, config):
raise NotImplemented
def recon_criterion(self, input, target):
raise NotImplemented
def update_learning_rate(self):
if self.dis_scheduler is not None:
self.dis_scheduler.step()
if self.ae_scheduler is not None:
self.ae_scheduler.step()
def resume(self, checkpoint_dir, config):
# Load generators
last_model_name = get_model_list(checkpoint_dir, "autoencoder")
state_dict = torch.load(last_model_name)
self.autoencoder.load_state_dict(state_dict)
iterations = int(last_model_name[-11:-3])
# Load discriminators
last_model_name = get_model_list(checkpoint_dir, "discriminator")
state_dict = torch.load(last_model_name)
self.discriminator.load_state_dict(state_dict)
# Load optimizers
state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
self.dis_opt.load_state_dict(state_dict['discriminator'])
self.ae_opt.load_state_dict(state_dict['autoencoder'])
# Reinitilize schedulers
self.dis_scheduler = get_scheduler(self.dis_opt, config, iterations)
self.ae_scheduler = get_scheduler(self.ae_opt, config, iterations)
print('Resume from iteration %d' % iterations)
return iterations
def save(self, snapshot_dir, iterations):
# Save generators, discriminators, and optimizers
ae_name = os.path.join(snapshot_dir, 'autoencoder_%08d.pt' % (iterations + 1))
dis_name = os.path.join(snapshot_dir, 'discriminator_%08d.pt' % (iterations + 1))
opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
torch.save(self.autoencoder.state_dict(), ae_name)
torch.save(self.discriminator.state_dict(), dis_name)
torch.save({'autoencoder': self.ae_opt.state_dict(), 'discriminator': self.dis_opt.state_dict()}, opt_name)
def validate(self, data, config):
re_dict = self.evaluate(self.autoencoder, data, config)
for key, val in re_dict.items():
setattr(self, key, val)
@staticmethod
def recon_criterion(input, target):
return torch.mean(torch.abs(input - target))
@classmethod
def evaluate(cls, autoencoder, data, config):
autoencoder.eval()
x_a, x_b = data["x_a"], data["x_b"]
x_aba, x_bab = data["x_aba"], data["x_bab"]
batch_size, _, seq_len = x_a.size()
re_dict = {}
with torch.no_grad(): # 2D eval
x_a_recon = autoencoder.reconstruct2d(x_a)
x_b_recon = autoencoder.reconstruct2d(x_b)
x_aba_recon = autoencoder.cross2d(x_a, x_b, x_a)
x_bab_recon = autoencoder.cross2d(x_b, x_a, x_b)
re_dict['loss_val_recon_x'] = cls.recon_criterion(x_a_recon, x_a) + cls.recon_criterion(x_b_recon, x_b)
re_dict['loss_val_cross_body'] = cls.recon_criterion(x_aba_recon, x_aba) + cls.recon_criterion(
x_bab_recon, x_bab)
re_dict['loss_val_total'] = 0.5 * re_dict['loss_val_recon_x'] + 0.5 * re_dict['loss_val_cross_body']
autoencoder.train()
return re_dict
class TransmomoTrainer(BaseTrainer):
def __init__(self, config):
super(TransmomoTrainer, self).__init__(config)
self.angle_unit = np.pi / (config.K + 1)
view_angles = np.array([i * self.angle_unit for i in range(1, config.K + 1)])
x_angles = view_angles if config.rotation_axes[0] else np.array([0])
z_angles = view_angles if config.rotation_axes[1] else np.array([0])
y_angles = view_angles if config.rotation_axes[2] else np.array([0])
x_angles, z_angles, y_angles = np.meshgrid(x_angles, z_angles, y_angles)
angles = np.stack([x_angles.flatten(), z_angles.flatten(), y_angles.flatten()], axis=1)
self.angles = torch.tensor(angles).float().cuda()
self.rotation_axes = torch.tensor(config.rotation_axes).float().cuda()
self.rotation_axes_mask = [(_ > 0) for _ in config.rotation_axes]
def dis_update(self, data, config):
x_a = data["x"]
x_s = data["x_s"] # the limb-scaled version of x_a
meanpose = data["meanpose"][0]
stdpose = data["stdpose"][0]
self.dis_opt.zero_grad()
# encode
motion_a = self.autoencoder.encode_motion(x_a)
body_a, body_a_seq = self.autoencoder.encode_body(x_a)
view_a, view_a_seq = self.autoencoder.encode_view(x_a)
motion_s = self.autoencoder.encode_motion(x_s)
body_s, body_s_seq = self.autoencoder.encode_body(x_s)
view_s, view_s_seq = self.autoencoder.encode_view(x_s)
# decode (reconstruct, transform)
inds = random.sample(list(range(self.angles.size(0))), config.K)
angles = self.angles[inds].clone().detach() # [K, 3]
angles += self.angle_unit * self.rotation_axes * torch.randn([3], device=x_a.device)
angles = angles.unsqueeze(0).unsqueeze(2) # [B=1, K, T=1, 3]
X_a_recon = self.autoencoder.decode(motion_a, body_a, view_a)
x_a_trans = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=angles,
body_reference=config.autoencoder.body_reference, project_2d=True)
x_a_exp = x_a.repeat_interleave(config.K, dim=0)
self.loss_dis_trans = self.discriminator.calc_dis_loss(x_a_trans.detach(), x_a_exp)
if config.trans_gan_ls_w > 0:
X_s_recon = self.autoencoder.decode(motion_s, body_s, view_s)
x_s_trans = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=angles,
body_reference=config.autoencoder.body_reference, project_2d=True)
x_s_exp = x_s.repeat_interleave(config.K, dim=0)
self.loss_dis_trans_ls = self.discriminator.calc_dis_loss(x_s_trans.detach(), x_s_exp)
else:
self.loss_dis_trans_ls = 0
self.loss_dis_total = config.trans_gan_w * self.loss_dis_trans + \
config.trans_gan_ls_w * self.loss_dis_trans_ls
self.loss_dis_total.backward()
self.dis_opt.step()
def ae_update(self, data, config):
x_a = data["x"]
x_s = data["x_s"]
meanpose = data["meanpose"][0]
stdpose = data["stdpose"][0]
self.ae_opt.zero_grad()
# encode
motion_a = self.autoencoder.encode_motion(x_a)
body_a, body_a_seq = self.autoencoder.encode_body(x_a)
view_a, view_a_seq = self.autoencoder.encode_view(x_a)
motion_s = self.autoencoder.encode_motion(x_s)
body_s, body_s_seq = self.autoencoder.encode_body(x_s)
view_s, view_s_seq = self.autoencoder.encode_view(x_s)
# invariance loss
self.loss_inv_v_ls = self.recon_criterion(view_a, view_s) if config.inv_v_ls_w > 0 else 0
self.loss_inv_m_ls = self.recon_criterion(motion_a, motion_s) if config.inv_m_ls_w > 0 else 0
# body triplet loss
if config.triplet_b_w > 0:
self.loss_triplet_b = triplet_margin_loss(
body_a_seq, body_s_seq,
neg_range=config.triplet_neg_range,
margin=config.triplet_margin)
else:
self.loss_triplet_b = 0
# reconstruction
X_a_recon = self.autoencoder.decode(motion_a, body_a, view_a)
x_a_recon = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=None,
body_reference=config.autoencoder.body_reference, project_2d=True)
X_s_recon = self.autoencoder.decode(motion_s, body_s, view_s)
x_s_recon = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=None,
body_reference=config.autoencoder.body_reference, project_2d=True)
self.loss_recon_x = 0.5 * self.recon_criterion(x_a_recon, x_a) +\
0.5 * self.recon_criterion(x_s_recon, x_s)
# cross reconstruction
X_as_recon = self.autoencoder.decode(motion_a, body_s, view_s)
x_as_recon = rotate_and_maybe_project_learning(X_as_recon, meanpose, stdpose, angles=None,
body_reference=config.autoencoder.body_reference, project_2d=True)
X_sa_recon = self.autoencoder.decode(motion_s, body_a, view_a)
x_sa_recon = rotate_and_maybe_project_learning(X_sa_recon, meanpose, stdpose, angles=None,
body_reference=config.autoencoder.body_reference, project_2d=True)
self.loss_cross_x = 0.5 * self.recon_criterion(x_as_recon, x_s) + 0.5 * self.recon_criterion(x_sa_recon, x_a)
# apply transformation
inds = random.sample(list(range(self.angles.size(0))), config.K)
angles = self.angles[inds].clone().detach()
angles += self.angle_unit * self.rotation_axes * torch.randn([3], device=x_a.device)
angles = angles.unsqueeze(0).unsqueeze(2)
x_a_trans = rotate_and_maybe_project_learning(X_a_recon, meanpose, stdpose, angles=angles,
body_reference=config.autoencoder.body_reference, project_2d=True)
x_s_trans = rotate_and_maybe_project_learning(X_s_recon, meanpose, stdpose, angles=angles,
body_reference=config.autoencoder.body_reference, project_2d=True)
# GAN loss
self.loss_gan_trans = self.discriminator.calc_gen_loss(x_a_trans)
self.loss_gan_trans_ls = self.discriminator.calc_gen_loss(x_s_trans) if config.trans_gan_ls_w > 0 else 0
# encode again
motion_a_trans = self.autoencoder.encode_motion(x_a_trans)
body_a_trans, _ = self.autoencoder.encode_body(x_a_trans)
view_a_trans, view_a_trans_seq = self.autoencoder.encode_view(x_a_trans)
motion_s_trans = self.autoencoder.encode_motion(x_s_trans)
body_s_trans, _ = self.autoencoder.encode_body(x_s_trans)
self.loss_inv_m_trans = 0.5 * self.recon_criterion(motion_a_trans, motion_a.repeat_interleave(config.K, dim=0)) + \
0.5 * self.recon_criterion(motion_s_trans, motion_s.repeat_interleave(config.K, dim=0))
self.loss_inv_b_trans = 0.5 * self.recon_criterion(body_a_trans, body_a.repeat_interleave(config.K, dim=0)) + \
0.5 * self.recon_criterion(body_s_trans, body_s.repeat_interleave(config.K, dim=0))
# view triplet loss
if config.triplet_v_w > 0:
view_a_seq_exp = view_a_seq.repeat_interleave(config.K, dim=0)
self.loss_triplet_v = triplet_margin_loss(
view_a_seq_exp, view_a_trans_seq,
neg_range=config.triplet_neg_range, margin=config.triplet_margin)
else:
self.loss_triplet_v = 0
# add all losses
self.loss_total = torch.tensor(0.).float().cuda()
self.loss_total += config.recon_x_w * self.loss_recon_x
self.loss_total += config.cross_x_w * self.loss_cross_x
self.loss_total += config.inv_v_ls_w * self.loss_inv_v_ls
self.loss_total += config.inv_m_ls_w * self.loss_inv_m_ls
self.loss_total += config.inv_b_trans_w * self.loss_inv_b_trans
self.loss_total += config.inv_m_trans_w * self.loss_inv_m_trans
self.loss_total += config.trans_gan_w * self.loss_gan_trans
self.loss_total += config.trans_gan_ls_w * self.loss_gan_trans_ls
self.loss_total += config.triplet_b_w * self.loss_triplet_b
self.loss_total += config.triplet_v_w * self.loss_triplet_v
self.loss_total.backward()
self.ae_opt.step()