import torch import torch.nn as nn import torch.distributions as td import numpy as np class Swish(nn.Module): def __init__(self): super(Swish, self).__init__() def forward(self, x): return x * torch.sigmoid(x) def cycle_interval(starting_value, num_frames, min_val, max_val): """Cycles through the state space in a single cycle.""" starting_in_01 = ((starting_value - min_val) / (max_val - min_val)).cpu() grid = torch.linspace(starting_in_01.item(), starting_in_01.item() + 2., steps=num_frames + 1)[:-1] grid -= np.maximum(0, 2 * grid - 2) grid += np.maximum(0, -2 * grid) return grid * (max_val - min_val) + min_val class BetaVAE_Linear(nn.Module): def __init__(self, in_dim=1024, n_hidden=64, latent=8): super(BetaVAE_Linear, self).__init__() self.n_hidden = n_hidden self.latent = latent # Encoder self.encoder = nn.Sequential( nn.Linear(in_dim, n_hidden), Swish(), ) # Latent self.mu = nn.Linear(n_hidden, latent) self.lv = nn.Linear(n_hidden, latent) # Decoder self.decoder = nn.Sequential( nn.Linear(latent, n_hidden), Swish(), nn.Linear(n_hidden, in_dim), Swish() ) def BottomUp(self, x): out = self.encoder(x) mu, lv = self.mu(out), self.lv(out) return mu, lv def reparameterize(self, mu, lv): std = torch.exp(0.5 * lv) eps = torch.randn_like(std) return mu + std * eps def TopDown(self, z): out = self.decoder(z) return out def forward(self, x): # x = x.view(x.shape[0], -1) mu, lv = self.BottomUp(x) z = self.reparameterize(mu, lv) out = self.TopDown(z) return out, mu, lv def calc_loss(self, x, beta): mu, lv = self.BottomUp(x) z = self.reparameterize(mu, lv) out = torch.sigmoid(self.TopDown(z)) nll = -nn.functional.binary_cross_entropy(out, x, reduction='sum') / x.shape[0] kl = (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) + 1e-5) / x.shape[0] # print(kl, nll) return -nll + kl * beta, kl, nll def LT_fitted_gauss_2std(self, x,num_var=6, num_traversal=5): # Cycle linearly through +-2 std dev of a fitted Gaussian. x = x.view(x.shape[0], -1) mu, lv = self.BottomUp(x) images = [] for i, batch_mu in enumerate(mu[:num_var]): images.append(torch.sigmoid(self.TopDown(batch_mu)).unsqueeze(0)) for latent_var in range(batch_mu.shape[0]): new_mu = batch_mu.unsqueeze(0).repeat([num_traversal, 1]) loc = mu[:, latent_var].mean() total_var = lv[:, latent_var].exp().mean() + mu[:, latent_var].var() scale = total_var.sqrt() new_mu[:, latent_var] = cycle_interval(batch_mu[latent_var], num_traversal, loc - 2 * scale, loc + 2 * scale) images.append(torch.sigmoid(self.TopDown(new_mu))) return images if __name__ == "__main__": model = BetaVAE_Linear() x = torch.rand(10, 784) out = model(x) print(out.shape) loss, kl, nll = model.calc_loss(x, 0.05) print(loss, kl, nll) images = model.LT_fitted_gauss_2std(x) print(len(images), images[0].shape) print(images[0].shape)