Migrating Vector Quantized Variational auto encoders to Keras3#2309
Migrating Vector Quantized Variational auto encoders to Keras3#2309LakshmiKalaKadali wants to merge 1 commit intokeras-team:masterfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on migrating the Vector Quantized Variational Autoencoder (VQ-VAE) example to Keras 3. The primary goal was to achieve full backend compatibility, particularly for the PixelCNN component, by replacing TensorFlow Probability-specific layers with Keras 3's backend-agnostic Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request does a great job of migrating the Vector Quantized Variational Autoencoder example to be compatible with Keras 3. The changes, such as replacing TensorFlow-specific calls with keras.ops and refactoring the training loop to use compute_loss, are well-implemented and align with multi-backend best practices. My main feedback is on improving maintainability by removing hardcoded values for num_embeddings and latent_dim that were introduced during the migration. I've left a few suggestions to replace these magic numbers with variables derived from the model configuration, which will make the code more robust and easier to modify in the future.
| ohe = tf.one_hot(pixelcnn_inputs, vqvae_trainer.num_embeddings) | ||
| # Build PixelCNN | ||
| pixelcnn_inputs = keras.Input(shape=(7, 7), dtype="int32") | ||
| ohe = ops.one_hot(pixelcnn_inputs, 128) |
There was a problem hiding this comment.
The number of embeddings 128 is hardcoded here. This makes the code less maintainable. It would be better to use the num_embeddings from the vqvae_trainer instance to ensure consistency if the model's hyperparameters are changed.
| ohe = ops.one_hot(pixelcnn_inputs, 128) | |
| ohe = ops.one_hot(pixelcnn_inputs, vqvae_trainer.num_embeddings) |
| filters=vqvae_trainer.num_embeddings, kernel_size=1, strides=1, padding="valid" | ||
| )(x) | ||
|
|
||
| out = layers.Conv2D(filters=128, kernel_size=1, padding="valid")(x) |
There was a problem hiding this comment.
The number of filters 128 is hardcoded, which should correspond to the number of embeddings for the output layer. To improve maintainability and prevent potential bugs if hyperparameters change, please use the num_embeddings attribute from the vqvae_trainer instance.
| out = layers.Conv2D(filters=128, kernel_size=1, padding="valid")(x) | |
| out = layers.Conv2D(filters=vqvae_trainer.num_embeddings, kernel_size=1, padding="valid")(x) |
| outputs = categorical_layer(outputs) | ||
| sampler = keras.Model(inputs, outputs) | ||
| def sample_from_logits(logits): | ||
| logits_flat = ops.reshape(logits, (-1, 128)) |
There was a problem hiding this comment.
The dimension 128 is hardcoded when reshaping the logits. This value corresponds to num_embeddings. Using a variable, like vqvae_trainer.num_embeddings, would make this function more robust to changes in model hyperparameters.
| logits_flat = ops.reshape(logits, (-1, 128)) | |
| logits_flat = ops.reshape(logits, (-1, vqvae_trainer.num_embeddings)) |
| priors_ohe = ops.one_hot(priors, 128) | ||
| quantized = ops.matmul(priors_ohe, ops.transpose(pretrained_embeddings)) | ||
| quantized = ops.reshape(quantized, (-1, 7, 7, 16)) |
There was a problem hiding this comment.
The values for num_embeddings (128) and latent_dim (16) are hardcoded in this block. This can lead to errors if these hyperparameters are changed elsewhere. Please use the attributes from the vqvae_trainer instance, e.g., vqvae_trainer.num_embeddings and vqvae_trainer.latent_dim, to make the code more maintainable and robust.
| priors_ohe = ops.one_hot(priors, 128) | |
| quantized = ops.matmul(priors_ohe, ops.transpose(pretrained_embeddings)) | |
| quantized = ops.reshape(quantized, (-1, 7, 7, 16)) | |
| priors_ohe = ops.one_hot(priors, vqvae_trainer.num_embeddings) | |
| quantized = ops.matmul(priors_ohe, ops.transpose(pretrained_embeddings)) | |
| quantized = ops.reshape(quantized, (-1, 7, 7, vqvae_trainer.latent_dim)) |
Major change: TensorFlow Probability is not backend-agnostic. To support JAX and PyTorch, I replaced tfp.layers.DistributionLambda with a custom sampling logic utilizing keras.random.categorical. This ensures the PixelCNN autoregressive generation works natively on all backends. Here is the reference
colab notebook.