Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 53 additions & 42 deletions examples/generative/real_nvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@
from sklearn.datasets import make_moons
import numpy as np
import matplotlib.pyplot as plt

# Compatibility patch for TFP with Keras 3 / TF 2.19+
try:
if not hasattr(tf._api.v2.compat.v2.__internal__, "register_load_context_function"):
tf._api.v2.compat.v2.__internal__.register_load_context_function = (
tf._api.v2.compat.v2.__internal__.register_call_context_function
)
except AttributeError:
pass

import tensorflow_probability as tfp

"""
Expand Down Expand Up @@ -179,48 +189,49 @@ def test_step(self, data):
## Model training
"""

model = RealNVP(num_coupling_layers=6)

model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001))

history = model.fit(
normalized_data, batch_size=256, epochs=300, verbose=2, validation_split=0.2
)

"""
## Performance evaluation
"""

plt.figure(figsize=(15, 10))
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("model loss")
plt.legend(["train", "validation"], loc="upper right")
plt.ylabel("loss")
plt.xlabel("epoch")

# From data to latent space.
z, _ = model(normalized_data)

# From latent space to data.
samples = model.distribution.sample(3000)
x, _ = model.predict(samples)

f, axes = plt.subplots(2, 2)
f.set_size_inches(20, 15)

axes[0, 0].scatter(normalized_data[:, 0], normalized_data[:, 1], color="r")
axes[0, 0].set(title="Inference data space X", xlabel="x", ylabel="y")
axes[0, 1].scatter(z[:, 0], z[:, 1], color="r")
axes[0, 1].set(title="Inference latent space Z", xlabel="x", ylabel="y")
axes[0, 1].set_xlim([-3.5, 4])
axes[0, 1].set_ylim([-4, 4])
axes[1, 0].scatter(samples[:, 0], samples[:, 1], color="g")
axes[1, 0].set(title="Generated latent space Z", xlabel="x", ylabel="y")
axes[1, 1].scatter(x[:, 0], x[:, 1], color="g")
axes[1, 1].set(title="Generated data space X", label="x", ylabel="y")
axes[1, 1].set_xlim([-2, 2])
axes[1, 1].set_ylim([-2, 2])
if __name__ == "__main__":
model = RealNVP(num_coupling_layers=6)

model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001))

history = model.fit(
normalized_data, batch_size=256, epochs=300, verbose=2, validation_split=0.2
)

"""
## Performance evaluation
"""

plt.figure(figsize=(15, 10))
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("model loss")
plt.legend(["train", "validation"], loc="upper right")
plt.ylabel("loss")
plt.xlabel("epoch")

# From data to latent space.
z, _ = model(normalized_data)

# From latent space to data.
samples = model.distribution.sample(3000)
x, _ = model.predict(samples)

f, axes = plt.subplots(2, 2)
f.set_size_inches(20, 15)

axes[0, 0].scatter(normalized_data[:, 0], normalized_data[:, 1], color="r")
axes[0, 0].set(title="Inference data space X", xlabel="x", ylabel="y")
axes[0, 1].scatter(z[:, 0], z[:, 1], color="r")
axes[0, 1].set(title="Inference latent space Z", xlabel="x", ylabel="y")
axes[0, 1].set_xlim([-3.5, 4])
axes[0, 1].set_ylim([-4, 4])
axes[1, 0].scatter(samples[:, 0], samples[:, 1], color="g")
axes[1, 0].set(title="Generated latent space Z", xlabel="x", ylabel="y")
axes[1, 1].scatter(x[:, 0], x[:, 1], color="g")
axes[1, 1].set(title="Generated data space X", label="x", ylabel="y")
axes[1, 1].set_xlim([-2, 2])
axes[1, 1].set_ylim([-2, 2])

"""
## Relevant Chapters from Deep Learning with Python
Expand Down
59 changes: 34 additions & 25 deletions examples/generative/vq_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,19 @@

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_probability as tfp
import tensorflow as tf

# Compatibility patch for TFP with Keras 3 / TF 2.19+
try:
if not hasattr(tf._api.v2.compat.v2.__internal__, "register_load_context_function"):
tf._api.v2.compat.v2.__internal__.register_load_context_function = (
tf._api.v2.compat.v2.__internal__.register_call_context_function
)
except AttributeError:
pass

import tensorflow_probability as tfp

"""
## `VectorQuantizer` layer

Expand Down Expand Up @@ -275,36 +285,35 @@ def train_step(self, x):
## Train the VQ-VAE model
"""

vqvae_trainer = VQVAETrainer(data_variance, latent_dim=16, num_embeddings=128)
vqvae_trainer.compile(optimizer=keras.optimizers.Adam())
vqvae_trainer.fit(x_train_scaled, epochs=30, batch_size=128)
if __name__ == "__main__":
vqvae_trainer = VQVAETrainer(data_variance, latent_dim=16, num_embeddings=128)
vqvae_trainer.compile(optimizer=keras.optimizers.Adam())
vqvae_trainer.fit(x_train_scaled, epochs=30, batch_size=128)

"""
## Reconstruction results on the test set
"""
"""
## Reconstruction results on the test set
"""

def show_subplot(original, reconstructed):
plt.subplot(1, 2, 1)
plt.imshow(original.squeeze() + 0.5)
plt.title("Original")
plt.axis("off")

def show_subplot(original, reconstructed):
plt.subplot(1, 2, 1)
plt.imshow(original.squeeze() + 0.5)
plt.title("Original")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(reconstructed.squeeze() + 0.5)
plt.title("Reconstructed")
plt.axis("off")

plt.show()
plt.subplot(1, 2, 2)
plt.imshow(reconstructed.squeeze() + 0.5)
plt.title("Reconstructed")
plt.axis("off")

plt.show()

trained_vqvae_model = vqvae_trainer.vqvae
idx = np.random.choice(len(x_test_scaled), 10)
test_images = x_test_scaled[idx]
reconstructions_test = trained_vqvae_model.predict(test_images)
trained_vqvae_model = vqvae_trainer.vqvae
idx = np.random.choice(len(x_test_scaled), 10)
test_images = x_test_scaled[idx]
reconstructions_test = trained_vqvae_model.predict(test_images)

for test_image, reconstructed_image in zip(test_images, reconstructions_test):
show_subplot(test_image, reconstructed_image)
for test_image, reconstructed_image in zip(test_images, reconstructions_test):
show_subplot(test_image, reconstructed_image)

"""
These results look decent. You are encouraged to play with different hyperparameters
Expand Down
Loading