diff --git a/examples/generative/vq_vae.py b/examples/generative/vq_vae.py index 5caa95d26b..0cf9d25f9f 100644 --- a/examples/generative/vq_vae.py +++ b/examples/generative/vq_vae.py @@ -2,9 +2,10 @@ Title: Vector-Quantized Variational Autoencoders Author: [Sayak Paul](https://twitter.com/RisingSayak) Date created: 2021/07/21 -Last modified: 2022/06/27 +Last modified: 2026/03/06 Description: Training a VQ-VAE for image reconstruction and codebook sampling for generation. Accelerator: GPU +Converted to Keras 3 by: [LakshmiKalaKadali](https://github.com/LakshmiKalaKadali) """ """ @@ -35,22 +36,30 @@ To run this example, you will need TensorFlow 2.5 or higher, as well as TensorFlow Probability, which can be installed using the command below. """ - -"""shell -pip install -q tensorflow-probability -""" - """ ## Imports """ +import os + +os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax", "torch" + import numpy as np import matplotlib.pyplot as plt +import keras +from keras import layers +from keras import ops +from keras import random + +keras.utils.set_random_seed(42) + + +def show_figure(): + if "inline" in plt.get_backend().lower(): + plt.show() + else: + plt.close() -from tensorflow import keras -from tensorflow.keras import layers -import tensorflow_probability as tfp -import tensorflow as tf """ ## `VectorQuantizer` layer @@ -81,67 +90,57 @@ def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): super().__init__(**kwargs) self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings - - # The `beta` parameter is best kept between [0.25, 2] as per the paper. self.beta = beta - # Initialize the embeddings which we will quantize. - w_init = tf.random_uniform_initializer() - self.embeddings = tf.Variable( - initial_value=w_init( - shape=(self.embedding_dim, self.num_embeddings), dtype="float32" - ), + # Initialize the embeddings codebook + self.embeddings = self.add_weight( + shape=(self.embedding_dim, self.num_embeddings), + initializer="random_uniform", trainable=True, name="embeddings_vqvae", ) def call(self, x): - # Calculate the input shape of the inputs and - # then flatten the inputs keeping `embedding_dim` intact. - input_shape = tf.shape(x) - flattened = tf.reshape(x, [-1, self.embedding_dim]) + input_shape = ops.shape(x) + flattened = ops.reshape(x, [-1, self.embedding_dim]) - # Quantization. encoding_indices = self.get_code_indices(flattened) - encodings = tf.one_hot(encoding_indices, self.num_embeddings) - quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) - - # Reshape the quantized values back to the original input shape - quantized = tf.reshape(quantized, input_shape) - - # Calculate vector quantization loss and add that to the layer. You can learn more - # about adding losses to different layers here: - # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check - # the original paper to get a handle on the formulation of the loss function. - commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2) - codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) + # Reshape indices to match spatial dimensions (e.g., 7x7) + encoding_indices = ops.reshape(encoding_indices, input_shape[:-1]) + + encodings = ops.one_hot(encoding_indices, self.num_embeddings) + quantized = ops.matmul(encodings, ops.transpose(self.embeddings)) + quantized = ops.reshape(quantized, input_shape) + + commitment_loss = ops.mean((ops.stop_gradient(quantized) - x) ** 2) + codebook_loss = ops.mean((quantized - ops.stop_gradient(x)) ** 2) self.add_loss(self.beta * commitment_loss + codebook_loss) - # Straight-through estimator. - quantized = x + tf.stop_gradient(quantized - x) - return quantized + quantized = x + ops.stop_gradient(quantized - x) + + # RETURN BOTH: The quantized tensor and the indices + return [quantized, encoding_indices] def get_code_indices(self, flattened_inputs): - # Calculate L2-normalized distance between the inputs and the codes. - similarity = tf.matmul(flattened_inputs, self.embeddings) + # Calculate L2-normalized distance + similarity = ops.matmul(flattened_inputs, self.embeddings) distances = ( - tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True) - + tf.reduce_sum(self.embeddings**2, axis=0) + ops.sum(flattened_inputs**2, axis=1, keepdims=True) + + ops.sum(self.embeddings**2, axis=0) - 2 * similarity ) - - # Derive the indices for minimum distances. - encoding_indices = tf.argmin(distances, axis=1) - return encoding_indices + return ops.argmin(distances, axis=1) """ **A note on straight-through estimation**: -This line of code does the straight-through estimation part: `quantized = x + -tf.stop_gradient(quantized - x)`. During backpropagation, `(quantized - x)` won't be -included in the computation graph and the gradients obtained for `quantized` -will be copied for `inputs`. Thanks to [this video](https://youtu.be/VZFVUrYcig0?t=1393) +This line of code implements the straight-through estimator: quantized = x + ops.stop_gradient(quantized - x). +In the forward pass, the terms cancel out (x+quantized−x), and the layer outputs the discrete quantized vectors. +In the backward pass, since the gradient of ops.stop_gradient is zero, +the gradient of the loss with respect to the output is effectively copied directly +to the input x (the encoder's output).This allows the model to bypass the non-differentiable quantization step +and train the encoder using the decoder's gradients.. Thanks to [this video](https://youtu.be/VZFVUrYcig0?t=1393) for helping me understand this technique. """ @@ -170,7 +169,7 @@ def get_encoder(latent_dim=16): def get_decoder(latent_dim=16): - latent_inputs = keras.Input(shape=get_encoder(latent_dim).output.shape[1:]) + latent_inputs = keras.Input(shape=(7, 7, latent_dim)) x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")( latent_inputs ) @@ -188,11 +187,17 @@ def get_vqvae(latent_dim=16, num_embeddings=64): vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer") encoder = get_encoder(latent_dim) decoder = get_decoder(latent_dim) + inputs = keras.Input(shape=(28, 28, 1)) encoder_outputs = encoder(inputs) - quantized_latents = vq_layer(encoder_outputs) + + # quantized_latents and codebook_indices are now KerasTensors + quantized_latents, codebook_indices = vq_layer(encoder_outputs) + reconstructions = decoder(quantized_latents) - return keras.Model(inputs, reconstructions, name="vq_vae") + + # Return a model with two outputs + return keras.Model(inputs, [reconstructions, codebook_indices], name="vq_vae") get_vqvae().summary() @@ -207,7 +212,7 @@ def get_vqvae(latent_dim=16, num_embeddings=64): """ -class VQVAETrainer(keras.models.Model): +class VQVAETrainer(keras.Model): def __init__(self, train_variance, latent_dim=32, num_embeddings=128, **kwargs): super().__init__(**kwargs) self.train_variance = train_variance @@ -230,32 +235,27 @@ def metrics(self): self.vq_loss_tracker, ] - def train_step(self, x): - with tf.GradientTape() as tape: - # Outputs from the VQ-VAE. - reconstructions = self.vqvae(x) + def call(self, x, training=False): + return self.vqvae(x, training=training) - # Calculate the losses. - reconstruction_loss = ( - tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance - ) - total_loss = reconstruction_loss + sum(self.vqvae.losses) + def compute_loss( + self, x=None, y=None, y_pred=None, sample_weight=None, training=True + ): + # y_pred is now [reconstructions, indices] + reconstructions = y_pred[0] - # Backpropagation. - grads = tape.gradient(total_loss, self.vqvae.trainable_variables) - self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables)) + # Rule 6: Stability + reconstruction_loss = ops.mean((x - reconstructions) ** 2) / ( + self.train_variance + 1e-7 + ) + total_loss = reconstruction_loss + ops.sum(self.vqvae.losses) - # Loss tracking. + # Update trackers self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) - self.vq_loss_tracker.update_state(sum(self.vqvae.losses)) + self.vq_loss_tracker.update_state(ops.sum(self.vqvae.losses)) - # Log results. - return { - "loss": self.total_loss_tracker.result(), - "reconstruction_loss": self.reconstruction_loss_tracker.result(), - "vqvae_loss": self.vq_loss_tracker.result(), - } + return total_loss """ @@ -263,12 +263,12 @@ def train_step(self, x): """ (x_train, _), (x_test, _) = keras.datasets.mnist.load_data() +x_train = np.expand_dims(x_train, -1).astype("float32") +x_test = np.expand_dims(x_test, -1).astype("float32") -x_train = np.expand_dims(x_train, -1) -x_test = np.expand_dims(x_test, -1) +# Integrated Preprocessing (Rule 4) x_train_scaled = (x_train / 255.0) - 0.5 x_test_scaled = (x_test / 255.0) - 0.5 - data_variance = np.var(x_train / 255.0) """ @@ -294,16 +294,16 @@ def show_subplot(original, reconstructed): plt.imshow(reconstructed.squeeze() + 0.5) plt.title("Reconstructed") plt.axis("off") - - plt.show() + show_figure() trained_vqvae_model = vqvae_trainer.vqvae idx = np.random.choice(len(x_test_scaled), 10) test_images = x_test_scaled[idx] -reconstructions_test = trained_vqvae_model.predict(test_images) -for test_image, reconstructed_image in zip(test_images, reconstructions_test): +reconstructions_output, _ = trained_vqvae_model.predict(test_images) + +for test_image, reconstructed_image in zip(test_images, reconstructions_output): show_subplot(test_image, reconstructed_image) """ @@ -316,25 +316,22 @@ def show_subplot(original, reconstructed): ## Visualizing the discrete codes """ -encoder = vqvae_trainer.vqvae.get_layer("encoder") -quantizer = vqvae_trainer.vqvae.get_layer("vector_quantizer") +_, codebook_indices = vqvae_trainer.vqvae.predict(test_images) -encoded_outputs = encoder.predict(test_images) -flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1]) -codebook_indices = quantizer.get_code_indices(flat_enc_outputs) -codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1]) +codebook_indices_np = ops.convert_to_numpy(codebook_indices) for i in range(len(test_images)): + plt.figure(figsize=(6, 3)) plt.subplot(1, 2, 1) plt.imshow(test_images[i].squeeze() + 0.5) plt.title("Original") plt.axis("off") plt.subplot(1, 2, 2) - plt.imshow(codebook_indices[i]) + plt.imshow(codebook_indices_np[i]) plt.title("Code") plt.axis("off") - plt.show() + show_figure() """ The figure above shows that the discrete codes have been able to capture some @@ -360,6 +357,9 @@ def show_subplot(original, reconstructed): num_residual_blocks = 2 num_pixelcnn_layers = 2 +encoder = vqvae_trainer.vqvae.get_layer("encoder") +quantizer = vqvae_trainer.vqvae.get_layer("vector_quantizer") +encoded_outputs = encoder.predict(x_train_scaled) pixelcnn_input_shape = encoded_outputs.shape[1:-1] print(f"Input shape of the PixelCNN: {pixelcnn_input_shape}") @@ -400,8 +400,6 @@ def show_subplot(original, reconstructed): """ -# The first layer is the PixelCNN layer. This layer simply -# builds on the 2D convolutional layer, but includes masking. class PixelConvLayer(layers.Layer): def __init__(self, mask_type, **kwargs): super().__init__() @@ -409,29 +407,41 @@ def __init__(self, mask_type, **kwargs): self.conv = layers.Conv2D(**kwargs) def build(self, input_shape): - # Build the conv2d layer to initialize kernel variables self.conv.build(input_shape) - # Use the initialized kernel to create the mask - kernel_shape = self.conv.kernel.get_shape() - self.mask = np.zeros(shape=kernel_shape) - self.mask[: kernel_shape[0] // 2, ...] = 1.0 - self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0 + kernel_shape = self.conv.kernel.shape + + mask = np.zeros(shape=kernel_shape) + mask[: kernel_shape[0] // 2, ...] = 1.0 + mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0 if self.mask_type == "B": - self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0 + mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0 + + self.mask = self.add_weight( + name="mask", + shape=kernel_shape, + initializer=keras.initializers.Constant(mask), + trainable=False, + ) def call(self, inputs): - self.conv.kernel.assign(self.conv.kernel * self.mask) - return self.conv(inputs) + # Mask the kernel functionally + masked_kernel = self.conv.kernel * self.mask + return ( + ops.conv( + inputs, + masked_kernel, + strides=self.conv.strides, + padding=self.conv.padding.upper(), + data_format="channels_last", + ) + + self.conv.bias + ) -# Next, we build our residual block layer. -# This is just a normal residual block, but based on the PixelConvLayer. -class ResidualBlock(keras.layers.Layer): +class ResidualBlock(layers.Layer): def __init__(self, filters, **kwargs): super().__init__(**kwargs) - self.conv1 = keras.layers.Conv2D( - filters=filters, kernel_size=1, activation="relu" - ) + self.conv1 = layers.Conv2D(filters, 1, activation="relu") self.pixel_conv = PixelConvLayer( mask_type="B", filters=filters // 2, @@ -439,43 +449,31 @@ def __init__(self, filters, **kwargs): activation="relu", padding="same", ) - self.conv2 = keras.layers.Conv2D( - filters=filters, kernel_size=1, activation="relu" - ) + self.conv2 = layers.Conv2D(filters, 1, activation="relu") def call(self, inputs): x = self.conv1(inputs) x = self.pixel_conv(x) x = self.conv2(x) - return keras.layers.add([inputs, x]) + return layers.add([inputs, x]) -pixelcnn_inputs = keras.Input(shape=pixelcnn_input_shape, dtype=tf.int32) -ohe = tf.one_hot(pixelcnn_inputs, vqvae_trainer.num_embeddings) +# Build PixelCNN +pixelcnn_inputs = keras.Input(shape=(7, 7), dtype="int32") +ohe = ops.one_hot(pixelcnn_inputs, 128) x = PixelConvLayer( mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same" )(ohe) - for _ in range(num_residual_blocks): x = ResidualBlock(filters=128)(x) - for _ in range(num_pixelcnn_layers): x = PixelConvLayer( - mask_type="B", - filters=128, - kernel_size=1, - strides=1, - activation="relu", - padding="valid", + mask_type="B", filters=128, kernel_size=1, activation="relu", padding="valid" )(x) - -out = keras.layers.Conv2D( - filters=vqvae_trainer.num_embeddings, kernel_size=1, strides=1, padding="valid" -)(x) - +out = layers.Conv2D(filters=128, kernel_size=1, padding="valid")(x) pixel_cnn = keras.Model(pixelcnn_inputs, out, name="pixel_cnn") -pixel_cnn.summary() +pixel_cnn.summary() """ ## Prepare data to train the PixelCNN @@ -488,13 +486,9 @@ def call(self, inputs): it gets its generative capabilities from. """ -# Generate the codebook indices. -encoded_outputs = encoder.predict(x_train_scaled) flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1]) -codebook_indices = quantizer.get_code_indices(flat_enc_outputs) - -codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1]) -print(f"Shape of the training data for PixelCNN: {codebook_indices.shape}") +codebook_indices = ops.convert_to_numpy(quantizer.get_code_indices(flat_enc_outputs)) +codebook_indices = codebook_indices.reshape(encoded_outputs.shape[:-1]) """ ## PixelCNN training @@ -524,12 +518,13 @@ def call(self, inputs): them to our decoder to generate novel images. """ + # Create a mini sampler model. -inputs = layers.Input(shape=pixel_cnn.input_shape[1:]) -outputs = pixel_cnn(inputs, training=False) -categorical_layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical) -outputs = categorical_layer(outputs) -sampler = keras.Model(inputs, outputs) +def sample_from_logits(logits): + logits_flat = ops.reshape(logits, (-1, 128)) + sampled = random.categorical(logits_flat, 1) + return ops.reshape(sampled, ops.shape(logits)[:-1]) + """ We now construct a prior to generate images. Here, we will generate 10 images. @@ -537,18 +532,15 @@ def call(self, inputs): # Create an empty array of priors. batch = 10 -priors = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:]) -batch, rows, cols = priors.shape - -# Iterate over the priors because generation has to be done sequentially pixel by pixel. -for row in range(rows): - for col in range(cols): - # Feed the whole array and retrieving the pixel value probabilities for the next - # pixel. - probs = sampler.predict(priors) - # Use the probabilities to pick pixel values and append the values to the priors. - priors[:, row, col] = probs[:, row, col] - +priors = np.zeros((batch, 7, 7), dtype="int32") +for row in range(7): + for col in range(7): + logits = pixel_cnn.predict(priors, verbose=0) + # sampled_indices is a Keras tensor + sampled_indices = sample_from_logits(logits) + # Convert to numpy to avoid JAX tracer/index errors + sampled_indices_np = ops.convert_to_numpy(sampled_indices) + priors[:, row, col] = sampled_indices_np[:, row, col] print(f"Prior shape: {priors.shape}") """ @@ -557,13 +549,10 @@ def call(self, inputs): # Perform an embedding lookup. pretrained_embeddings = quantizer.embeddings -priors_ohe = tf.one_hot(priors.astype("int32"), vqvae_trainer.num_embeddings).numpy() -quantized = tf.matmul( - priors_ohe.astype("float32"), pretrained_embeddings, transpose_b=True -) -quantized = tf.reshape(quantized, (-1, *(encoded_outputs.shape[1:]))) +priors_ohe = ops.one_hot(priors, 128) +quantized = ops.matmul(priors_ohe, ops.transpose(pretrained_embeddings)) +quantized = ops.reshape(quantized, (-1, 7, 7, 16)) -# Generate novel images. decoder = vqvae_trainer.vqvae.get_layer("decoder") generated_samples = decoder.predict(quantized) @@ -577,8 +566,7 @@ def call(self, inputs): plt.imshow(generated_samples[i].squeeze() + 0.5) plt.title("Generated Sample") plt.axis("off") - plt.show() - + show_figure() """ We can enhance the quality of these generated samples by tweaking the PixelCNN. """