import tensorflow as tf from tensorflow.keras.layers import Input, Dense, Lambda, Layer from tensorflow.keras.models import Model from tensorflow.keras.losses import binary_crossentropy from tensorflow.keras import backend as K def sampling(args): z_mean, z_log_var = args batch = tf.shape(z_mean)[0] dim = tf.shape(z_mean)[1] epsilon = K.random_normal(shape=(batch, dim)) return z_mean + K.exp(0.5 * z_log_var) * epsilon def vae_loss(inputs, x_decoded_mean, z_log_var, z_mean): xent_loss = binary_crossentropy(inputs, x_decoded_mean) kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) return K.mean(xent_loss + kl_loss) class VAELossLayer(Layer): def call(self, inputs): x = inputs[0] x_decoded_mean = inputs[1] z_log_var = inputs[2] z_mean = inputs[3] loss = vae_loss(x, x_decoded_mean, z_log_var, z_mean) self.add_loss(loss) return x def compute_output_shape(self, input_shape): return input_shape[0] def create_vae(input_dim, latent_dim): # Encoder inputs = Input(shape=(input_dim,)) h = Dense(512, activation='relu')(inputs) z_mean = Dense(latent_dim)(h) z_log_var = Dense(latent_dim)(h) # Use reparameterization trick z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var]) # Decoder decoder_h = Dense(512, activation='relu') decoder_mean = Dense(input_dim, activation='sigmoid') h_decoded = decoder_h(z) x_decoded_mean = decoder_mean(h_decoded) # Define VAE model outputs = VAELossLayer()([inputs, x_decoded_mean, z_log_var, z_mean]) vae = Model(inputs, outputs) vae.compile(optimizer='rmsprop') vae.summary() return vae, Model(inputs, z_mean)