Migrated the data pipeline and training loop to use keras 3 APIs in DEIT Tutorial#2314
Migrated the data pipeline and training loop to use keras 3 APIs in DEIT Tutorial#2314kharshith-k wants to merge 2 commits intokeras-team:masterfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request modernizes the DeiT tutorial by migrating its data pipeline and training components to Keras 3. The changes ensure the example is fully backend-agnostic, replacing TensorFlow-specific data handling and model loading with Keras native utilities and layers. This update improves compatibility and demonstrates best practices for Keras 3 development, particularly for complex models involving distillation. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request does a great job of migrating the data pipeline and training loop to use Keras 3 APIs, making the example backend-agnostic. The use of keras.utils.PyDataset, keras.ops, and overriding compute_loss are all excellent changes. I've identified a few areas for improvement related to efficiency in the new data pipeline and consistency between the Python script and the Jupyter notebook.
| image = keras.utils.load_img( | ||
| self.image_paths[i], | ||
| target_size=(RESOLUTION + 20, RESOLUTION + 20), | ||
| ) |
There was a problem hiding this comment.
For validation (self.shuffle=False), images are loaded at (RESOLUTION + 20, RESOLUTION + 20) and then resized to (RESOLUTION, RESOLUTION) by the augmenter. This is inefficient. It would be better to load them directly at (RESOLUTION, RESOLUTION), which would also be consistent with the previous tf.data implementation.
This can be achieved by conditioning the target_size on self.shuffle inside __getitem__:
if self.shuffle:
target_size = (RESOLUTION + 20, RESOLUTION + 20)
else:
target_size = (RESOLUTION, RESOLUTION)
image = keras.utils.load_img(
self.image_paths[i],
target_size=target_size,
)And then the validation augmenter in get_augmenter can be an empty keras.Sequential([]).
| [ | ||
| layers.Resizing(RESOLUTION + 20, RESOLUTION + 20), | ||
| layers.RandomCrop(RESOLUTION, RESOLUTION), | ||
| layers.RandomFlip("horizontal"), | ||
| ], |
There was a problem hiding this comment.
The images are already loaded with size (RESOLUTION + 20, RESOLUTION + 20) in FlowersDataset.__getitem__. This Resizing layer is redundant and can be removed for efficiency. The corresponding implementation in deit.ipynb is correct and does not include this layer, so this change would also make the two files consistent.
| [ | |
| layers.Resizing(RESOLUTION + 20, RESOLUTION + 20), | |
| layers.RandomCrop(RESOLUTION, RESOLUTION), | |
| layers.RandomFlip("horizontal"), | |
| ], | |
| [ | |
| layers.RandomCrop(RESOLUTION, RESOLUTION), | |
| layers.RandomFlip("horizontal"), | |
| ], | |
| teacher_backbone.trainable = False | ||
|
|
||
| teacher_model = keras.Sequential( | ||
| [ | ||
| teacher_backbone, | ||
| layers.Dense(NUM_CLASSES), | ||
| ], |
There was a problem hiding this comment.
The corresponding notebook deit.ipynb includes a very helpful comment explaining why an extra Rescaling layer should not be added for the teacher model. It would be great to add this comment here as well for maintainability and clarity, to keep the Python script and notebook in sync.
| teacher_backbone.trainable = False | |
| teacher_model = keras.Sequential( | |
| [ | |
| teacher_backbone, | |
| layers.Dense(NUM_CLASSES), | |
| ], | |
| # EfficientNetV2B0 includes its own preprocessing (include_preprocessing=True by | |
| # default), which maps raw [0, 255] pixel values to [-1, 1]. Do NOT add an | |
| # external Rescaling layer here — that would cause double-normalisation and | |
| # destroy all input signal (all values collapse to ~-1, rendering the backbone | |
| # unable to extract meaningful features). | |
| teacher_model = keras.Sequential( | |
| [ | |
| teacher_backbone, | |
| layers.Dense(NUM_CLASSES), | |
| ], | |
| name="teacher", | |
| ) | |
There was a problem hiding this comment.
Code Review
This pull request does a great job of migrating the data pipeline and training loop to be backend-agnostic using Keras 3 APIs. The transition from tf.data and tfds to a custom keras.utils.PyDataset is well-implemented, and replacing the custom train_step with compute_loss aligns with modern Keras practices. I've found one minor issue regarding a redundant operation in the data augmentation pipeline in the Python script, which I've detailed in a specific comment. Overall, this is a solid improvement for the example.
| layers.Resizing(RESOLUTION + 20, RESOLUTION + 20), | ||
| layers.RandomCrop(RESOLUTION, RESOLUTION), |
There was a problem hiding this comment.
This Resizing layer is redundant. The FlowersDataset.__getitem__ method already resizes images to (RESOLUTION + 20, RESOLUTION + 20) when keras.utils.load_img is called. Applying another Resizing layer with the same dimensions is an unnecessary operation that can be removed for efficiency. The corresponding notebook file (deit.ipynb) has the correct implementation without this layer.
| layers.Resizing(RESOLUTION + 20, RESOLUTION + 20), | |
| layers.RandomCrop(RESOLUTION, RESOLUTION), | |
| layers.RandomCrop(RESOLUTION, RESOLUTION), |
Please find my PR for migrating data pipeline and residual tensorflow specific APIs to keras 3 in Deit tutorial
Colab Gist