poem_analysis / vae_model.py
esocoder's picture
first commit
996aa19
raw
history blame
No virus
1.78 kB
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras import backend as K
def sampling(args):
z_mean, z_log_var = args
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = K.random_normal(shape=(batch, dim))
return z_mean + K.exp(0.5 * z_log_var) * epsilon
def vae_loss(inputs, x_decoded_mean, z_log_var, z_mean):
xent_loss = binary_crossentropy(inputs, x_decoded_mean)
kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
return K.mean(xent_loss + kl_loss)
class VAELossLayer(Layer):
def call(self, inputs):
x = inputs[0]
x_decoded_mean = inputs[1]
z_log_var = inputs[2]
z_mean = inputs[3]
loss = vae_loss(x, x_decoded_mean, z_log_var, z_mean)
self.add_loss(loss)
return x
def compute_output_shape(self, input_shape):
return input_shape[0]
def create_vae(input_dim, latent_dim):
# Encoder
inputs = Input(shape=(input_dim,))
h = Dense(512, activation='relu')(inputs)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
# Use reparameterization trick
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
# Decoder
decoder_h = Dense(512, activation='relu')
decoder_mean = Dense(input_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
# Define VAE model
outputs = VAELossLayer()([inputs, x_decoded_mean, z_log_var, z_mean])
vae = Model(inputs, outputs)
vae.compile(optimizer='rmsprop')
vae.summary()
return vae, Model(inputs, z_mean)