FAcodecV2 / modules /beta_vae.py
Plachta's picture
Upload 69 files
a4d0945 verified
raw
history blame
3.53 kB
import torch
import torch.nn as nn
import torch.distributions as td
import numpy as np
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x * torch.sigmoid(x)
def cycle_interval(starting_value, num_frames, min_val, max_val):
"""Cycles through the state space in a single cycle."""
starting_in_01 = ((starting_value - min_val) / (max_val - min_val)).cpu()
grid = torch.linspace(starting_in_01.item(), starting_in_01.item() + 2., steps=num_frames + 1)[:-1]
grid -= np.maximum(0, 2 * grid - 2)
grid += np.maximum(0, -2 * grid)
return grid * (max_val - min_val) + min_val
class BetaVAE_Linear(nn.Module):
def __init__(self, in_dim=1024, n_hidden=64, latent=8):
super(BetaVAE_Linear, self).__init__()
self.n_hidden = n_hidden
self.latent = latent
# Encoder
self.encoder = nn.Sequential(
nn.Linear(in_dim, n_hidden), Swish(),
)
# Latent
self.mu = nn.Linear(n_hidden, latent)
self.lv = nn.Linear(n_hidden, latent)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent, n_hidden), Swish(),
nn.Linear(n_hidden, in_dim), Swish()
)
def BottomUp(self, x):
out = self.encoder(x)
mu, lv = self.mu(out), self.lv(out)
return mu, lv
def reparameterize(self, mu, lv):
std = torch.exp(0.5 * lv)
eps = torch.randn_like(std)
return mu + std * eps
def TopDown(self, z):
out = self.decoder(z)
return out
def forward(self, x):
# x = x.view(x.shape[0], -1)
mu, lv = self.BottomUp(x)
z = self.reparameterize(mu, lv)
out = self.TopDown(z)
return out, mu, lv
def calc_loss(self, x, beta):
mu, lv = self.BottomUp(x)
z = self.reparameterize(mu, lv)
out = torch.sigmoid(self.TopDown(z))
nll = -nn.functional.binary_cross_entropy(out, x, reduction='sum') / x.shape[0]
kl = (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) + 1e-5) / x.shape[0]
# print(kl, nll)
return -nll + kl * beta, kl, nll
def LT_fitted_gauss_2std(self, x,num_var=6, num_traversal=5):
# Cycle linearly through +-2 std dev of a fitted Gaussian.
x = x.view(x.shape[0], -1)
mu, lv = self.BottomUp(x)
images = []
for i, batch_mu in enumerate(mu[:num_var]):
images.append(torch.sigmoid(self.TopDown(batch_mu)).unsqueeze(0))
for latent_var in range(batch_mu.shape[0]):
new_mu = batch_mu.unsqueeze(0).repeat([num_traversal, 1])
loc = mu[:, latent_var].mean()
total_var = lv[:, latent_var].exp().mean() + mu[:, latent_var].var()
scale = total_var.sqrt()
new_mu[:, latent_var] = cycle_interval(batch_mu[latent_var], num_traversal,
loc - 2 * scale, loc + 2 * scale)
images.append(torch.sigmoid(self.TopDown(new_mu)))
return images
if __name__ == "__main__":
model = BetaVAE_Linear()
x = torch.rand(10, 784)
out = model(x)
print(out.shape)
loss, kl, nll = model.calc_loss(x, 0.05)
print(loss, kl, nll)
images = model.LT_fitted_gauss_2std(x)
print(len(images), images[0].shape)
print(images[0].shape)