Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.distributions as td | |
import numpy as np | |
class Swish(nn.Module): | |
def __init__(self): | |
super(Swish, self).__init__() | |
def forward(self, x): | |
return x * torch.sigmoid(x) | |
def cycle_interval(starting_value, num_frames, min_val, max_val): | |
"""Cycles through the state space in a single cycle.""" | |
starting_in_01 = ((starting_value - min_val) / (max_val - min_val)).cpu() | |
grid = torch.linspace(starting_in_01.item(), starting_in_01.item() + 2., steps=num_frames + 1)[:-1] | |
grid -= np.maximum(0, 2 * grid - 2) | |
grid += np.maximum(0, -2 * grid) | |
return grid * (max_val - min_val) + min_val | |
class BetaVAE_Linear(nn.Module): | |
def __init__(self, in_dim=1024, n_hidden=64, latent=8): | |
super(BetaVAE_Linear, self).__init__() | |
self.n_hidden = n_hidden | |
self.latent = latent | |
# Encoder | |
self.encoder = nn.Sequential( | |
nn.Linear(in_dim, n_hidden), Swish(), | |
) | |
# Latent | |
self.mu = nn.Linear(n_hidden, latent) | |
self.lv = nn.Linear(n_hidden, latent) | |
# Decoder | |
self.decoder = nn.Sequential( | |
nn.Linear(latent, n_hidden), Swish(), | |
nn.Linear(n_hidden, in_dim), Swish() | |
) | |
def BottomUp(self, x): | |
out = self.encoder(x) | |
mu, lv = self.mu(out), self.lv(out) | |
return mu, lv | |
def reparameterize(self, mu, lv): | |
std = torch.exp(0.5 * lv) | |
eps = torch.randn_like(std) | |
return mu + std * eps | |
def TopDown(self, z): | |
out = self.decoder(z) | |
return out | |
def forward(self, x): | |
# x = x.view(x.shape[0], -1) | |
mu, lv = self.BottomUp(x) | |
z = self.reparameterize(mu, lv) | |
out = self.TopDown(z) | |
return out, mu, lv | |
def calc_loss(self, x, beta): | |
mu, lv = self.BottomUp(x) | |
z = self.reparameterize(mu, lv) | |
out = torch.sigmoid(self.TopDown(z)) | |
nll = -nn.functional.binary_cross_entropy(out, x, reduction='sum') / x.shape[0] | |
kl = (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) + 1e-5) / x.shape[0] | |
# print(kl, nll) | |
return -nll + kl * beta, kl, nll | |
def LT_fitted_gauss_2std(self, x,num_var=6, num_traversal=5): | |
# Cycle linearly through +-2 std dev of a fitted Gaussian. | |
x = x.view(x.shape[0], -1) | |
mu, lv = self.BottomUp(x) | |
images = [] | |
for i, batch_mu in enumerate(mu[:num_var]): | |
images.append(torch.sigmoid(self.TopDown(batch_mu)).unsqueeze(0)) | |
for latent_var in range(batch_mu.shape[0]): | |
new_mu = batch_mu.unsqueeze(0).repeat([num_traversal, 1]) | |
loc = mu[:, latent_var].mean() | |
total_var = lv[:, latent_var].exp().mean() + mu[:, latent_var].var() | |
scale = total_var.sqrt() | |
new_mu[:, latent_var] = cycle_interval(batch_mu[latent_var], num_traversal, | |
loc - 2 * scale, loc + 2 * scale) | |
images.append(torch.sigmoid(self.TopDown(new_mu))) | |
return images | |
if __name__ == "__main__": | |
model = BetaVAE_Linear() | |
x = torch.rand(10, 784) | |
out = model(x) | |
print(out.shape) | |
loss, kl, nll = model.calc_loss(x, 0.05) | |
print(loss, kl, nll) | |
images = model.LT_fitted_gauss_2std(x) | |
print(len(images), images[0].shape) | |
print(images[0].shape) |