File size: 3,525 Bytes
a4d0945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import torch.distributions as td
import numpy as np

class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()
    def forward(self, x):
        return x * torch.sigmoid(x)

def cycle_interval(starting_value, num_frames, min_val, max_val):
    """Cycles through the state space in a single cycle."""
    starting_in_01 = ((starting_value - min_val) / (max_val - min_val)).cpu()
    grid = torch.linspace(starting_in_01.item(), starting_in_01.item() + 2., steps=num_frames + 1)[:-1]
    grid -= np.maximum(0, 2 * grid - 2)
    grid += np.maximum(0, -2 * grid)
    return grid * (max_val - min_val) + min_val
class BetaVAE_Linear(nn.Module):
    def __init__(self, in_dim=1024, n_hidden=64, latent=8):
        super(BetaVAE_Linear, self).__init__()

        self.n_hidden = n_hidden
        self.latent = latent

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, n_hidden), Swish(),
        )

        # Latent
        self.mu = nn.Linear(n_hidden, latent)
        self.lv = nn.Linear(n_hidden, latent)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent, n_hidden), Swish(),
            nn.Linear(n_hidden, in_dim), Swish()
        )

    def BottomUp(self, x):
        out = self.encoder(x)
        mu, lv = self.mu(out), self.lv(out)
        return mu, lv

    def reparameterize(self, mu, lv):
        std = torch.exp(0.5 * lv)
        eps = torch.randn_like(std)
        return mu + std * eps

    def TopDown(self, z):
        out = self.decoder(z)
        return out

    def forward(self, x):
        # x = x.view(x.shape[0], -1)
        mu, lv = self.BottomUp(x)
        z = self.reparameterize(mu, lv)
        out = self.TopDown(z)
        return out, mu, lv

    def calc_loss(self, x, beta):
        mu, lv = self.BottomUp(x)
        z = self.reparameterize(mu, lv)
        out = torch.sigmoid(self.TopDown(z))

        nll = -nn.functional.binary_cross_entropy(out, x, reduction='sum') / x.shape[0]
        kl = (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) + 1e-5) / x.shape[0]
        # print(kl, nll)

        return -nll + kl * beta, kl, nll

    def LT_fitted_gauss_2std(self, x,num_var=6, num_traversal=5):
        # Cycle linearly through +-2 std dev of a fitted Gaussian.
        x = x.view(x.shape[0], -1)
        mu, lv = self.BottomUp(x)

        images = []
        for i, batch_mu in enumerate(mu[:num_var]):
            images.append(torch.sigmoid(self.TopDown(batch_mu)).unsqueeze(0))
            for latent_var in range(batch_mu.shape[0]):
                new_mu = batch_mu.unsqueeze(0).repeat([num_traversal, 1])
                loc = mu[:, latent_var].mean()
                total_var = lv[:, latent_var].exp().mean() + mu[:, latent_var].var()
                scale = total_var.sqrt()
                new_mu[:, latent_var] = cycle_interval(batch_mu[latent_var], num_traversal,
                                                       loc - 2 * scale, loc + 2 * scale)
                images.append(torch.sigmoid(self.TopDown(new_mu)))
        return images


if __name__ == "__main__":
    model = BetaVAE_Linear()
    x = torch.rand(10, 784)
    out = model(x)
    print(out.shape)
    loss, kl, nll = model.calc_loss(x, 0.05)
    print(loss, kl, nll)
    images = model.LT_fitted_gauss_2std(x)
    print(len(images), images[0].shape)
    print(images[0].shape)