Skip to content

Migrated the data pipeline and training loop to use keras 3 APIs in DEIT Tutorial#2314

Open
kharshith-k wants to merge 2 commits intokeras-team:masterfrom
kharshith-k:deit_training_step_migration
Open

Migrated the data pipeline and training loop to use keras 3 APIs in DEIT Tutorial#2314
kharshith-k wants to merge 2 commits intokeras-team:masterfrom
kharshith-k:deit_training_step_migration

Conversation

@kharshith-k
Copy link
Member

Please find my PR for migrating data pipeline and residual tensorflow specific APIs to keras 3 in Deit tutorial
Colab Gist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request modernizes the DeiT tutorial by migrating its data pipeline and training components to Keras 3. The changes ensure the example is fully backend-agnostic, replacing TensorFlow-specific data handling and model loading with Keras native utilities and layers. This update improves compatibility and demonstrates best practices for Keras 3 development, particularly for complex models involving distillation.

Highlights

  • Keras 3 API Migration: The data pipeline and training loop have been updated to leverage Keras 3 APIs, ensuring backend-agnostic operation across different frameworks.
  • Backend-Agnostic Data Pipeline: Replaced TensorFlow-specific data loading (tf.data, tensorflow_datasets) with a custom keras.utils.PyDataset implementation and Keras native augmentation layers for improved compatibility.
  • Teacher Model Refactor: The teacher model, previously loaded via keras.layers.TFSMLayer, is now constructed directly using Keras layers (EfficientNetV2B0) for full Keras 3 compatibility and explicit definition.
  • Training Loop Update: The custom train_step in the DeiT distiller model was refactored to use the compute_loss method, aligning with Keras 3's Model.compile API for custom training logic.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/vision/deit.py
    • Updated metadata for last modification date and Keras 3 conversion author.
    • Replaced TensorFlow and TensorFlow Datasets imports with pathlib.Path and numpy.
    • Removed tf.data.AUTOTUNE constant.
    • Introduced FlowersDataset (inheriting keras.utils.PyDataset) and helper functions (get_augmenter, load_flower_file_paths) to create a backend-agnostic data loading and augmentation pipeline, replacing the previous tf.data based approach.
    • Converted tf.ones to keras.ops.ones for dummy input creation.
    • Refactored the DeiT class's custom train_step into a compute_loss method, aligning with Keras 3's training loop customization.
    • Replaced the TFSMLayer based teacher model with a keras.Sequential model using EfficientNetV2B0 as a backbone, including its fine-tuning.
    • Adjusted comments and notes to reflect the Keras 3 migration and the new teacher model implementation.
  • examples/vision/ipynb/deit.ipynb
    • Updated notebook metadata for last modification date and Keras 3 conversion author.
    • Modified import statements to remove TensorFlow and TensorFlow Datasets, adding pathlib.Path and numpy.
    • Removed tf.data.AUTOTUNE constant.
    • Updated markdown cells to describe the new Keras 3 backend-agnostic data pipeline and teacher model.
    • Replaced the tf.data based data loading and preprocessing code with the FlowersDataset class and associated helper functions.
    • Changed tf.ones to keras.ops.ones for dummy input.
    • Refactored the DeiT class's custom train_step into a compute_loss method.
    • Replaced the TFSMLayer based teacher model with a keras.Sequential model using EfficientNetV2B0, including its fine-tuning.
    • Updated notes section to highlight Keras 3 backend-agnosticism and the new teacher model.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request does a great job of migrating the data pipeline and training loop to use Keras 3 APIs, making the example backend-agnostic. The use of keras.utils.PyDataset, keras.ops, and overriding compute_loss are all excellent changes. I've identified a few areas for improvement related to efficiency in the new data pipeline and consistency between the Python script and the Jupyter notebook.

Comment on lines +154 to 157
image = keras.utils.load_img(
self.image_paths[i],
target_size=(RESOLUTION + 20, RESOLUTION + 20),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For validation (self.shuffle=False), images are loaded at (RESOLUTION + 20, RESOLUTION + 20) and then resized to (RESOLUTION, RESOLUTION) by the augmenter. This is inefficient. It would be better to load them directly at (RESOLUTION, RESOLUTION), which would also be consistent with the previous tf.data implementation.

This can be achieved by conditioning the target_size on self.shuffle inside __getitem__:

if self.shuffle:
    target_size = (RESOLUTION + 20, RESOLUTION + 20)
else:
    target_size = (RESOLUTION, RESOLUTION)
image = keras.utils.load_img(
    self.image_paths[i],
    target_size=target_size,
)

And then the validation augmenter in get_augmenter can be an empty keras.Sequential([]).

Comment on lines +171 to +175
[
layers.Resizing(RESOLUTION + 20, RESOLUTION + 20),
layers.RandomCrop(RESOLUTION, RESOLUTION),
layers.RandomFlip("horizontal"),
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The images are already loaded with size (RESOLUTION + 20, RESOLUTION + 20) in FlowersDataset.__getitem__. This Resizing layer is redundant and can be removed for efficiency. The corresponding implementation in deit.ipynb is correct and does not include this layer, so this change would also make the two files consistent.

Suggested change
[
layers.Resizing(RESOLUTION + 20, RESOLUTION + 20),
layers.RandomCrop(RESOLUTION, RESOLUTION),
layers.RandomFlip("horizontal"),
],
[
layers.RandomCrop(RESOLUTION, RESOLUTION),
layers.RandomFlip("horizontal"),
],

Comment on lines +635 to +641
teacher_backbone.trainable = False

teacher_model = keras.Sequential(
[
teacher_backbone,
layers.Dense(NUM_CLASSES),
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The corresponding notebook deit.ipynb includes a very helpful comment explaining why an extra Rescaling layer should not be added for the teacher model. It would be great to add this comment here as well for maintainability and clarity, to keep the Python script and notebook in sync.

Suggested change
teacher_backbone.trainable = False
teacher_model = keras.Sequential(
[
teacher_backbone,
layers.Dense(NUM_CLASSES),
],
# EfficientNetV2B0 includes its own preprocessing (include_preprocessing=True by
# default), which maps raw [0, 255] pixel values to [-1, 1]. Do NOT add an
# external Rescaling layer here — that would cause double-normalisation and
# destroy all input signal (all values collapse to ~-1, rendering the backbone
# unable to extract meaningful features).
teacher_model = keras.Sequential(
[
teacher_backbone,
layers.Dense(NUM_CLASSES),
],
name="teacher",
)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request does a great job of migrating the data pipeline and training loop to be backend-agnostic using Keras 3 APIs. The transition from tf.data and tfds to a custom keras.utils.PyDataset is well-implemented, and replacing the custom train_step with compute_loss aligns with modern Keras practices. I've found one minor issue regarding a redundant operation in the data augmentation pipeline in the Python script, which I've detailed in a specific comment. Overall, this is a solid improvement for the example.

Comment on lines +172 to +173
layers.Resizing(RESOLUTION + 20, RESOLUTION + 20),
layers.RandomCrop(RESOLUTION, RESOLUTION),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This Resizing layer is redundant. The FlowersDataset.__getitem__ method already resizes images to (RESOLUTION + 20, RESOLUTION + 20) when keras.utils.load_img is called. Applying another Resizing layer with the same dimensions is an unnecessary operation that can be removed for efficiency. The corresponding notebook file (deit.ipynb) has the correct implementation without this layer.

Suggested change
layers.Resizing(RESOLUTION + 20, RESOLUTION + 20),
layers.RandomCrop(RESOLUTION, RESOLUTION),
layers.RandomCrop(RESOLUTION, RESOLUTION),

@kharshith-k kharshith-k changed the title Migrated the data pipeline and training loop to use keras 3 APIs Migrated the data pipeline and training loop to use keras 3 APIs in DEIT Tutorial Mar 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants