diff --git a/aae/aae.py b/aae/aae.py index c606d2ba38..99c9e9e982 100644 --- a/aae/aae.py +++ b/aae/aae.py @@ -1,21 +1,17 @@ from __future__ import print_function, division from keras.datasets import mnist -from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise -from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D -from keras.layers import MaxPooling2D, merge +from keras.layers import Input, Dense, Reshape, Flatten, Lambda from keras.layers.advanced_activations import LeakyReLU -from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam -from keras import losses -from keras.utils import to_categorical import keras.backend as K import matplotlib.pyplot as plt import numpy as np + class AdversarialAutoencoder(): def __init__(self): self.img_rows = 28 @@ -26,11 +22,8 @@ def __init__(self): optimizer = Adam(0.0002, 0.5) - # Build and compile the discriminator + # Build the discriminator self.discriminator = self.build_discriminator() - self.discriminator.compile(loss='binary_crossentropy', - optimizer=optimizer, - metrics=['accuracy']) # Build the encoder / decoder self.encoder = self.build_encoder() @@ -44,17 +37,21 @@ def __init__(self): # For the adversarial_autoencoder model we will only train the generator self.discriminator.trainable = False + self.discriminator.compile( + loss='binary_crossentropy', + optimizer=optimizer, + metrics=['accuracy']) # The discriminator determines validity of the encoding validity = self.discriminator(encoded_repr) # The adversarial_autoencoder model (stacked generator and discriminator) self.adversarial_autoencoder = Model(img, [reconstructed_img, validity]) - self.adversarial_autoencoder.compile(loss=['mse', 'binary_crossentropy'], + self.adversarial_autoencoder.compile( + loss=['mse', 'binary_crossentropy'], loss_weights=[0.999, 0.001], optimizer=optimizer) - def build_encoder(self): # Encoder @@ -67,12 +64,15 @@ def build_encoder(self): h = LeakyReLU(alpha=0.2)(h) mu = Dense(self.latent_dim)(h) log_var = Dense(self.latent_dim)(h) - latent_repr = merge([mu, log_var], - mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), - output_shape=lambda p: p[0]) + latent_repr = Lambda(self.latent, output_shape=(self.latent_dim, ))([mu, log_var]) return Model(img, latent_repr) + def latent(self, p): + """Sample based on `mu` and `log_var`""" + mu, log_var = p + return mu + K.random_normal(K.shape(mu)) * K.exp(log_var / 2) + def build_decoder(self): model = Sequential() @@ -146,7 +146,7 @@ def train(self, epochs, batch_size=128, sample_interval=50): g_loss = self.adversarial_autoencoder.train_on_batch(imgs, [imgs, valid]) # Plot the progress - print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1])) + print("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss[0], g_loss[1])) # If at save interval => save generated image samples if epoch % sample_interval == 0: @@ -155,7 +155,7 @@ def train(self, epochs, batch_size=128, sample_interval=50): def sample_images(self, epoch): r, c = 5, 5 - z = np.random.normal(size=(r*c, self.latent_dim)) + z = np.random.normal(size=(r * c, self.latent_dim)) gen_imgs = self.decoder.predict(z) gen_imgs = 0.5 * gen_imgs + 0.5 @@ -164,8 +164,8 @@ def sample_images(self, epoch): cnt = 0 for i in range(r): for j in range(c): - axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') - axs[i,j].axis('off') + axs[i, j].imshow(gen_imgs[cnt, : , :, 0], cmap='gray') + axs[i, j].axis('off') cnt += 1 fig.savefig("images/mnist_%d.png" % epoch) plt.close() @@ -175,8 +175,9 @@ def save_model(self): def save(model, model_name): model_path = "saved_model/%s.json" % model_name weights_path = "saved_model/%s_weights.hdf5" % model_name - options = {"file_arch": model_path, - "file_weight": weights_path} + options = { + "file_arch": model_path, + "file_weight": weights_path} json_string = model.to_json() open(options['file_arch'], 'w').write(json_string) model.save_weights(options['file_weight'])