Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from tensorflow.keras.layers import Input, Dense, Lambda, Layer | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.losses import binary_crossentropy | |
from tensorflow.keras import backend as K | |
def sampling(args): | |
z_mean, z_log_var = args | |
batch = tf.shape(z_mean)[0] | |
dim = tf.shape(z_mean)[1] | |
epsilon = K.random_normal(shape=(batch, dim)) | |
return z_mean + K.exp(0.5 * z_log_var) * epsilon | |
def vae_loss(inputs, x_decoded_mean, z_log_var, z_mean): | |
xent_loss = binary_crossentropy(inputs, x_decoded_mean) | |
kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) | |
return K.mean(xent_loss + kl_loss) | |
class VAELossLayer(Layer): | |
def call(self, inputs): | |
x = inputs[0] | |
x_decoded_mean = inputs[1] | |
z_log_var = inputs[2] | |
z_mean = inputs[3] | |
loss = vae_loss(x, x_decoded_mean, z_log_var, z_mean) | |
self.add_loss(loss) | |
return x | |
def compute_output_shape(self, input_shape): | |
return input_shape[0] | |
def create_vae(input_dim, latent_dim): | |
# Encoder | |
inputs = Input(shape=(input_dim,)) | |
h = Dense(512, activation='relu')(inputs) | |
z_mean = Dense(latent_dim)(h) | |
z_log_var = Dense(latent_dim)(h) | |
# Use reparameterization trick | |
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var]) | |
# Decoder | |
decoder_h = Dense(512, activation='relu') | |
decoder_mean = Dense(input_dim, activation='sigmoid') | |
h_decoded = decoder_h(z) | |
x_decoded_mean = decoder_mean(h_decoded) | |
# Define VAE model | |
outputs = VAELossLayer()([inputs, x_decoded_mean, z_log_var, z_mean]) | |
vae = Model(inputs, outputs) | |
vae.compile(optimizer='rmsprop') | |
vae.summary() | |
return vae, Model(inputs, z_mean) | |