-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Tests/jax splash attention regression #22016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Tests/jax splash attention regression #22016
Conversation
Summary of ChangesHello @Sikandar1310291, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial regression test to prevent a Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a regression test for an issue where splash_attention in JAX would crash when JIT-compiled with a mask. The new test correctly mocks a TPU environment and verifies that the implementation now gracefully falls back to a native implementation, preventing the crash. The test is well-structured and effectively covers the reported issue. I've only found a minor style issue with trailing whitespaces in the new test file. Additionally, this PR includes an unrelated but useful documentation update for keras.device() in the training guide, which is also well-written.
|
|
||
| # We also need to mock _can_use_flash_attention to return True | ||
| # so we enter the block where the check happens. | ||
|
|
||
| with unittest.mock.patch("keras.src.backend.jax.nn._can_use_flash_attention", return_value=True): | ||
| # We mock jax.devices() to simulate TPU platform | ||
| # The actual device object needs a 'platform' attribute | ||
| mock_device = unittest.mock.Mock() | ||
| mock_device.platform = "tpu" | ||
|
|
||
| with unittest.mock.patch("jax.devices", return_value=[mock_device]): | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #22016 +/- ##
==========================================
+ Coverage 82.73% 82.81% +0.08%
==========================================
Files 592 592
Lines 62072 62142 +70
Branches 9723 9735 +12
==========================================
+ Hits 51353 51466 +113
+ Misses 8197 8138 -59
- Partials 2522 2538 +16
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this.
| """ | ||
|
|
||
| """ | ||
| ## Controlling device placement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you're right that we're missing documentation for this feature.
Can you please add it as a docstring here: https://github.com/keras-team/keras/blob/master/keras/src/backend/__init__.py#L77
This would be a separate PR, so undo this file.
| # We ensure it falls back gracefully instead of crashing. | ||
|
|
||
| # Mock is_tpu=True to trigger the Splash Attention path | ||
| # We can't actually run on TPU in CI, but we want to test the logic path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we are running on TPU in CI: https://github.com/keras-team/keras/actions/workflows/tpu_tests.yml
We already have a unit test that cover splash attention:
https://github.com/keras-team/keras/blob/master/keras/src/ops/nn_test.py#L1328
Do we need to add something to it? If so, please add to this other file instead.
|
Thanks for the review! ✅ I've consolidated the splash attention regression test into |
| # Mock is_tpu=True to trigger the Splash Attention path | ||
| # We can't actually run on TPU in CI, but we want to test the logic path | ||
| # up to the fallback check. | ||
|
|
||
| # We also need to mock _can_use_flash_attention to return True | ||
| # so we enter the block where the check happens. | ||
|
|
||
| with unittest.mock.patch( | ||
| "keras.src.backend.jax.nn._can_use_flash_attention", | ||
| return_value=True, | ||
| ): | ||
| # We mock jax.devices() to simulate TPU platform | ||
| # The actual device object needs a 'platform' attribute | ||
| mock_device = unittest.mock.Mock() | ||
| mock_device.platform = "tpu" | ||
|
|
||
| with unittest.mock.patch("jax.devices", return_value=[mock_device]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should check this by mocking everything. It only gives me very weak confidence that it will actually behave correctly on TPU.
The good news is that we run tests on actual TPUs. Look at the test right below this one for an example of how to skip when you're not on TPU. Then, you can remove all the mocking and everything else should just work.
I can trigger the TPU tests for you once it's ready, just let me know.
Description
This PR adds a regression test for issue #21916, where
splash_attentioncaused aConcretizationTypeErrorwhen compiled withjax.jitbecause the mask became a Tracer.Changes
keras/src/backend/jax/splash_attention_test.pyTesting
The test mocks a TPU environment (where Splash Attention is active) and confirms that
dot_product_attentiongracefully falls back to the native implementation instead of crashing when a Tracer mask is encountered.Related Issues
Closes #21916