Skip to content

Barlow twins for Contrastive SSL migrate to keras 3#2315

Open
maitry63 wants to merge 1 commit intokeras-team:masterfrom
maitry63:contrastive_ssl_keras3_migration
Open

Barlow twins for Contrastive SSL migrate to keras 3#2315
maitry63 wants to merge 1 commit intokeras-team:masterfrom
maitry63:contrastive_ssl_keras3_migration

Conversation

@maitry63
Copy link
Contributor

This PR Migrate the Barlow twins for Contrastive SSL tutorial into Keras 3.

Colab file: Notebook

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request modernizes the Barlow Twins for Contrastive SSL tutorial by migrating its codebase to Keras 3. The changes encompass a comprehensive update of API calls, data handling mechanisms, and the core loss function implementation, ensuring the tutorial remains current and functional with the latest Keras version. The refactoring aims to streamline the code, improve compatibility, and leverage Keras 3's unified backend capabilities.

Highlights

  • Keras 3 Migration: The entire tutorial has been updated to be compatible with Keras 3, replacing tf.keras imports and functionalities with their keras equivalents, and adapting to the new Keras 3 API.
  • Data Augmentation Refactor: The data augmentation pipeline, including RandomToGrayscale, RandomColorJitter, RandomFlip, RandomResizedCrop, RandomSolarize, and RandomBlur, has been refactored to use NumPy-based operations and standard Python __call__ methods instead of tf.function decorated call methods and tf.image operations. The Augmentation base class was removed.
  • Dataset Handling Update: The BTDatasetCreator class has been completely rewritten to leverage keras.utils.Sequence for data loading and augmentation, moving away from tf.data.Dataset for better Keras 3 compatibility and flexibility.
  • Barlow Loss Implementation: The BarlowLoss class has been updated to align with Keras 3's loss API, including changes to its initialization, normalization logic using keras.ops, and the calculation of the cross-correlation matrix and loss components.
  • Model Architecture and Training Loop: The ResNet34 and build_twin functions now exclusively use keras.layers and keras.models. The BarlowModel class has been removed and replaced with a build_barlow_model function that directly constructs the Barlow Twins model, simplifying the training setup.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/vision/barlow_twins.py
    • Updated metadata including last modified date and Keras 3 conversion author.
    • Replaced TensorFlow and TensorFlow Addons imports with Keras and NumPy/SciPy equivalents.
    • Refactored all augmentation classes (RandomToGrayscale, RandomColorJitter, RandomFlip, RandomResizedCrop, RandomSolarize, RandomBlur) to use NumPy-based operations and Python __call__ methods, removing tf.function decorators and Augmentation inheritance.
    • Rewrote BTDatasetCreator to utilize keras.utils.Sequence for efficient data loading and augmentation.
    • Modified plot_values function to adapt to the new dataset output format.
    • Updated BarlowLoss class to inherit from keras.losses.Loss and use keras.ops for tensor operations, simplifying cross-correlation and loss calculation.
    • Converted ResNet34 and build_twin model construction to use Keras 3 layers and models.
    • Replaced the BarlowModel class with a functional build_barlow_model approach for constructing the main training model.
    • Updated the training and linear evaluation loops to use Keras 3's model.compile and model.fit with keras.optimizers.Lamb and the new XYDataset Sequence.
Activity
  • The pull request was created by maitry63 to migrate the Barlow Twins tutorial to Keras 3.
  • A Colab notebook link was provided for testing the changes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Barlow Twins implementation to be compatible with Keras 3, primarily by migrating from TensorFlow-specific operations and tf.keras layers to Keras 3's multi-backend keras API and NumPy for augmentation logic. Key changes include updating imports, modifying augmentation classes (RandomToGrayscale, RandomColorJitter, RandomFlip, RandomResizedCrop, RandomSolarize, RandomBlur, RandomAugmentor) to use NumPy-based operations and removing tf.function decorators and keras.layers.Layer inheritance where appropriate. The BarlowLoss class was updated to use Keras ops for tensor operations, and the BTDatasetCreator was rewritten to use keras.utils.Sequence for data loading and augmentation. The ResNet34 model and the build_twin projector network were also updated to use Keras 3 layers. Review comments highlight several critical issues, including runtime errors due to incorrect inheritance in augmentation classes, a missing hue jitter component in RandomColorJitter that could impact performance, and an inefficient implementation in BTDatasetCreator. Additionally, there are suggestions to improve documentation, correct grayscale conversion logic to use a weighted average, and remove redundant variable definitions.

Comment on lines +261 to +262
def __init__(self, prob=0.8):
self.prob = prob
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The RandomColorJitter class inherits from Augmentation (which is a keras.layers.Layer), but its __init__ method does not call super().__init__(). This will cause a runtime error. To fix this and be consistent with other refactored augmentation classes like RandomFlip, you should remove the inheritance from Augmentation. This also applies to RandomResizedCrop, RandomSolarize, and RandomBlur classes.

Comment on lines 308 to +310
def __init__(self, image_size):
super().__init__()
self.image_size = image_size
self.resize_layer = layers.Resizing(image_size, image_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The RandomResizedCrop class inherits from Augmentation (a keras.layers.Layer), but super().__init__() was removed from its constructor. This will cause a runtime error. For consistency with the other refactored augmentation classes, the inheritance from Augmentation should be removed from the class definition.

Comment on lines +332 to +334
def __init__(self, prob=0.2, threshold=0.5):
self.prob = prob
self.threshold = threshold
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The RandomSolarize class inherits from Augmentation (a keras.layers.Layer), but its __init__ method does not call super().__init__(). This will cause a runtime error. To fix this, you should remove the inheritance from Augmentation in the class definition, which would be consistent with other refactored augmentation classes.

Comment on lines +362 to +363
def __init__(self, prob=0.2):
self.prob = prob
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The RandomBlur class inherits from Augmentation (a keras.layers.Layer), but its __init__ method does not call super().__init__(). This will cause a runtime error. To fix this, you should remove the inheritance from Augmentation in the class definition, which would be consistent with other refactored augmentation classes.

Comment on lines +264 to 271
def __call__(self, x):
if np.random.rand() < self.prob:
x = x + np.random.uniform(-0.2, 0.2)
x = (x - 0.5) * np.random.uniform(0.8, 1.2) + 0.5
gray = np.mean(x, axis=-1, keepdims=True)
x = gray + (x - gray) * np.random.uniform(0.8, 1.2)
x = np.clip(x, 0, 1)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new implementation of RandomColorJitter is missing the hue jitter component, which was present in the original implementation (tf.image.random_hue). The class docstring also mentions hue jitter. This change in augmentation logic could negatively impact the model's performance. Please consider adding hue jitter back into this augmentation.

Comment on lines 144 to +146
# Allocates two threads for a gpu private which allows more operations to be
# done faster
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
os.environ["KERAS_BACKEND"] = "tensorflow"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on lines 144-145 is now misleading as it describes the previous TF_GPU_THREAD_MODE setting, not the new KERAS_BACKEND setting. To avoid confusion, this comment should be removed.

Suggested change
# Allocates two threads for a gpu private which allows more operations to be
# done faster
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["KERAS_BACKEND"] = "tensorflow"

Comment on lines +240 to 245
def __call__(self, x):
if np.random.rand() < self.prob:
# average channels to get grayscale
gray = np.mean(x, axis=-1, keepdims=True)
x = np.repeat(gray, 3, axis=-1)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The grayscale conversion logic has been changed from tf.image.rgb_to_grayscale, which uses a weighted average of channels (luma), to np.mean, which is a simple average. This changes the augmentation behavior and may impact model performance. For consistency with standard image processing libraries, consider using a weighted average similar to tf.image.rgb_to_grayscale (0.299R + 0.587G + 0.114*B).

def __call__(self, x):
"""call function.

Randomly solarizes the image.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this function seems to be a copy-paste error from RandomSolarize. It should describe the blur operation.

Suggested change
Randomly solarizes the image.
Randomly blurs the image.

return ZippedSequence(seq1, seq2)


bt_augmentor = RandomAugmentor(IMAGE_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bt_augmentor variable is being redefined here. It was already defined on line 427. This redundant definition should be removed.

Comment on lines 664 to 666
Returns:
Returns a tf.tensor which represents the cross correlation
Returns a tensor which represents the cross correlation
matrix with its diagonals as zeros.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for the return value is incorrect. This method returns a scalar tensor representing the loss, not the cross-correlation matrix.

Suggested change
Returns:
Returns a tf.tensor which represents the cross correlation
Returns a tensor which represents the cross correlation
matrix with its diagonals as zeros.
Returns:
Returns a scalar tensor which represents the loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants