Skip to content

Migrating Vector Quantized Variational auto encoders to Keras3#2309

Open
LakshmiKalaKadali wants to merge 1 commit intokeras-team:masterfrom
LakshmiKalaKadali:vq_vae_br
Open

Migrating Vector Quantized Variational auto encoders to Keras3#2309
LakshmiKalaKadali wants to merge 1 commit intokeras-team:masterfrom
LakshmiKalaKadali:vq_vae_br

Conversation

@LakshmiKalaKadali
Copy link
Contributor

Major change: TensorFlow Probability is not backend-agnostic. To support JAX and PyTorch, I replaced tfp.layers.DistributionLambda with a custom sampling logic utilizing keras.random.categorical. This ensures the PixelCNN autoregressive generation works natively on all backends. Here is the reference
colab notebook.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on migrating the Vector Quantized Variational Autoencoder (VQ-VAE) example to Keras 3. The primary goal was to achieve full backend compatibility, particularly for the PixelCNN component, by replacing TensorFlow Probability-specific layers with Keras 3's backend-agnostic ops and custom implementations. This involved significant refactoring of the VectorQuantizer and VQVAETrainer classes, as well as the PixelCNN's convolutional masking mechanism, to ensure seamless operation across different deep learning frameworks.

Highlights

  • Keras 3 Migration: The entire example has been updated to be compatible with Keras 3, ensuring it can run on TensorFlow, JAX, and PyTorch backends.
  • Backend-Agnostic Sampling: Replaced tfp.layers.DistributionLambda with custom sampling logic utilizing keras.random.categorical to enable backend-agnostic PixelCNN autoregressive generation.
  • VectorQuantizer Layer Output: The VectorQuantizer layer's call method now returns both the quantized tensor and the codebook indices, providing more direct access to the discrete codes.
  • VQVAETrainer Refactor: The VQVAETrainer class has been refactored to align with Keras 3's Model subclassing patterns, moving training logic from train_step to call and compute_loss.
  • Functional PixelCNN Masking: The PixelConvLayer now implements functional masking using keras.ops.conv and self.add_weight for the mask, improving flexibility and Keras 3 compatibility.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/generative/vq_vae.py
    • Updated file metadata, including the last modified date and the Keras 3 conversion author.
    • Replaced TensorFlow and TensorFlow Probability imports with Keras 3 specific imports (keras, keras.layers, keras.ops, keras.random).
    • Added os.environ["KERAS_BACKEND"] configuration for backend selection.
    • Introduced a show_figure helper function to manage plot display.
    • Modified VectorQuantizer layer to use self.add_weight for embeddings and keras.ops for tensor operations, and updated its call method to return both quantized latents and codebook indices.
    • Adjusted the input shape for the get_decoder function.
    • Updated get_vqvae to handle the new dual output from VectorQuantizer and return a model with two outputs.
    • Changed VQVAETrainer to inherit directly from keras.Model and refactored its training loop to use call and compute_loss.
    • Updated MNIST dataset preprocessing to explicitly cast to float32.
    • Modified image display calls to use the new show_figure utility.
    • Updated the visualization of discrete codes to correctly extract and convert codebook indices from the VQ-VAE model's prediction.
    • Initialized encoder and quantizer earlier for use in PixelCNN setup.
    • Refactored PixelConvLayer to use self.add_weight for the mask and keras.ops.conv for functional masking, removing direct kernel assignment.
    • Simplified ResidualBlock to use layers.Conv2D directly.
    • Updated PixelCNN model input definition and ops.one_hot usage.
    • Replaced tfp.layers.DistributionLambda with a custom sample_from_logits function using keras.random.categorical for sequential sampling.
    • Modified the image generation loop to use pixel_cnn.predict and sample_from_logits, ensuring proper conversion to NumPy for indexing.
    • Updated embedding lookup and quantization steps to use keras.ops equivalents.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request does a great job of migrating the Vector Quantized Variational Autoencoder example to be compatible with Keras 3. The changes, such as replacing TensorFlow-specific calls with keras.ops and refactoring the training loop to use compute_loss, are well-implemented and align with multi-backend best practices. My main feedback is on improving maintainability by removing hardcoded values for num_embeddings and latent_dim that were introduced during the migration. I've left a few suggestions to replace these magic numbers with variables derived from the model configuration, which will make the code more robust and easier to modify in the future.

ohe = tf.one_hot(pixelcnn_inputs, vqvae_trainer.num_embeddings)
# Build PixelCNN
pixelcnn_inputs = keras.Input(shape=(7, 7), dtype="int32")
ohe = ops.one_hot(pixelcnn_inputs, 128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The number of embeddings 128 is hardcoded here. This makes the code less maintainable. It would be better to use the num_embeddings from the vqvae_trainer instance to ensure consistency if the model's hyperparameters are changed.

Suggested change
ohe = ops.one_hot(pixelcnn_inputs, 128)
ohe = ops.one_hot(pixelcnn_inputs, vqvae_trainer.num_embeddings)

filters=vqvae_trainer.num_embeddings, kernel_size=1, strides=1, padding="valid"
)(x)

out = layers.Conv2D(filters=128, kernel_size=1, padding="valid")(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The number of filters 128 is hardcoded, which should correspond to the number of embeddings for the output layer. To improve maintainability and prevent potential bugs if hyperparameters change, please use the num_embeddings attribute from the vqvae_trainer instance.

Suggested change
out = layers.Conv2D(filters=128, kernel_size=1, padding="valid")(x)
out = layers.Conv2D(filters=vqvae_trainer.num_embeddings, kernel_size=1, padding="valid")(x)

outputs = categorical_layer(outputs)
sampler = keras.Model(inputs, outputs)
def sample_from_logits(logits):
logits_flat = ops.reshape(logits, (-1, 128))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dimension 128 is hardcoded when reshaping the logits. This value corresponds to num_embeddings. Using a variable, like vqvae_trainer.num_embeddings, would make this function more robust to changes in model hyperparameters.

Suggested change
logits_flat = ops.reshape(logits, (-1, 128))
logits_flat = ops.reshape(logits, (-1, vqvae_trainer.num_embeddings))

Comment on lines +552 to +554
priors_ohe = ops.one_hot(priors, 128)
quantized = ops.matmul(priors_ohe, ops.transpose(pretrained_embeddings))
quantized = ops.reshape(quantized, (-1, 7, 7, 16))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values for num_embeddings (128) and latent_dim (16) are hardcoded in this block. This can lead to errors if these hyperparameters are changed elsewhere. Please use the attributes from the vqvae_trainer instance, e.g., vqvae_trainer.num_embeddings and vqvae_trainer.latent_dim, to make the code more maintainable and robust.

Suggested change
priors_ohe = ops.one_hot(priors, 128)
quantized = ops.matmul(priors_ohe, ops.transpose(pretrained_embeddings))
quantized = ops.reshape(quantized, (-1, 7, 7, 16))
priors_ohe = ops.one_hot(priors, vqvae_trainer.num_embeddings)
quantized = ops.matmul(priors_ohe, ops.transpose(pretrained_embeddings))
quantized = ops.reshape(quantized, (-1, 7, 7, vqvae_trainer.latent_dim))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants