File size: 5,035 Bytes
9016314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from easydict import EasyDict as edict

from .base import BaseModel
from .pc_encoder import *
from .flow.transforms import Transform
from .cVAE import CondVAE

class ACSetVAE(BaseModel):
    def __init__(self, hps):
        self.prior_net = LatentEncoder(hps, name='prior')
        self.posterior_net = LatentEncoder(hps, name='posterior')
        self.cvae = CondVAE(hps.vae_params, name="cvae")
        if hps.use_peq_embed:
            self.peq_embed = SetXformer(hps)
        super(ACSetVAE, self).__init__(hps)
    
    def build_net(self):
        with tf.compat.v1.variable_scope('acset_vae', reuse=tf.compat.v1.AUTO_REUSE):
            self.x = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension]) #[200, 1024]
            self.b = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension])
            self.m = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension])

            # build transform
            self.transform = Transform(edict(self.hps.trans_params))

            # prior
            prior_inputs = tf.concat([self.x*self.b, self.b], axis=-1) # [200, 2048]
            prior = self.prior_net(prior_inputs) # [256]
            prior_sample = prior.sample() # [256]
            prior_sample, _ = self.transform.inverse(prior_sample) # [256]
            # peq embedding
            cm = peq_embed = None
            if self.hps.use_peq_embed:
                peq_embed = self.peq_embed(prior_inputs) # [200, 256]
                C = peq_embed.get_shape().as_list()[-1] # 256
                cm = tf.reshape(peq_embed, [-1,C]) # [256]
            
            # posterior
            posterior_inputs = tf.concat([self.x*self.m, self.m], axis=-1) # [200, 2048]
            posterior = self.posterior_net(posterior_inputs) # [256]
            posterior_sample = posterior.sample() # [256]

            # kl term
            z_sample, logdet = self.transform.forward(posterior_sample) # [256], [1]
            logp = tf.reduce_sum(input_tensor=prior.log_prob(z_sample), axis=1) + logdet # [256] -> [1]
            kl = tf.reduce_sum(input_tensor=posterior.entropy(), axis=1) + logp # [256] -> [1]

            # generator 
            x = tf.reshape(self.x, [-1,self.hps.dimension]) # [1024]
            b = tf.reshape(self.b, [-1,self.hps.dimension])
            m = tf.reshape(self.m, [-1,self.hps.dimension])
            cv = tf.reshape(tf.tile(tf.expand_dims(posterior_sample, axis=1), [1,self.hps.set_size,1]), [-1,self.hps.latent_dim]) # [256] -> [1, 256] -> [1, 200, 256] -> [256]
            if not cm is None:
                c = tf.concat([cv, cm], axis=-1) # [512]
            else:
                c = cv
            # vector-wise posterior
            vec_kl, vec_post_sample = self.cvae.enc(tf.concat([x, c], axis=-1), c) # [1], [256]
            vec_kl = tf.reshape(vec_kl, [-1, self.hps.set_size]) # [200]
            recon_dist = self.cvae.dec(tf.concat([vec_post_sample, c], axis=-1), c)
            log_likel = recon_dist.log_prob(x) # [1]
            log_likel = tf.reshape(log_likel, [-1,self.hps.set_size]) # [200]
            self.set_metric = self.set_elbo = log_likel + tf.expand_dims(kl, axis=1) / self.hps.set_size # [200]
            log_likel = tf.reduce_mean(input_tensor=log_likel, axis=1)
            tf.compat.v1.summary.scalar('log_likel', tf.reduce_mean(input_tensor=log_likel))

            # elbo
            self.elbo = log_likel + kl / self.hps.set_size + tf.reduce_mean(input_tensor=vec_kl, axis=1)
            self.metric = self.elbo
            self.loss = tf.reduce_mean(input_tensor=-self.elbo)
            tf.compat.v1.summary.scalar('loss', self.loss)

            # sample
            x = tf.reshape(self.x, [-1,self.hps.dimension])
            b = tf.reshape(self.b, [-1,self.hps.dimension])
            m = tf.reshape(self.m, [-1,self.hps.dimension])
            cv = tf.reshape(tf.tile(tf.expand_dims(prior_sample, axis=1), [1,self.hps.set_size,1]), [-1,self.hps.latent_dim])
            if not cm is None:
                c = tf.concat([cv, cm], axis=-1)
            else:
                c = cv
            vec_prior_sample = tf.random.normal(shape=tf.shape(input=vec_post_sample))
            sample_dist = self.cvae.dec(tf.concat([vec_prior_sample, c], axis=-1), c)
            log_likel = sample_dist.log_prob(x)
            sample = sample_dist.sample()
            self.sample = tf.reshape(sample, [-1, self.hps.set_size, self.hps.dimension])

            # compress

            self.log_likel = tf.reshape(log_likel, [-1,self.hps.set_size]) + vec_kl
        
        # total_params = 0
        # for variable in tf.trainable_variables():
        #     shape = variable.get_shape()
        #     params = 1
        #     for dim in shape:
        #         params *= dim
        #     total_params += params

        # print("Total number of parameters: ", total_params)