Spaces:
Running
Running
File size: 1,347 Bytes
9016314 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from .set_transformer import set_transformer
class LatentEncoder(object):
def __init__(self, hps, name='latent'):
self.hps = hps
self.name = name
def __call__(self, x):
'''
x: [B,N,C]
'''
B,N,C = tf.shape(input=x)[0], tf.shape(input=x)[1], *x.get_shape().as_list()[2:]
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
x = set_transformer(x, self.hps.latent_encoder_hidden, name='set_xformer')
x = tf.reduce_mean(input_tensor=x, axis=1)
x = tf.compat.v1.layers.dense(x, self.hps.latent_dim*2, name='d1')
x = tf.nn.leaky_relu(x)
x = tf.compat.v1.layers.dense(x, self.hps.latent_dim*2, name='d2')
m, s = x[...,:self.hps.latent_dim], tf.nn.softplus(x[...,self.hps.latent_dim:])
dist = tfd.Normal(loc=m, scale=s)
return dist
class SetXformer(object):
def __init__(self, hps, name='set_xformer'):
self.hps = hps
self.name = name
def __call__(self, x):
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
x = set_transformer(x, self.hps.set_xformer_hids, name='set_xformer')
return x |