Add optimized GRU for JAX backend#22401
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a highly optimized Gated Recurrent Unit (GRU) implementation for the JAX backend in Keras. By replacing the previous Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces an optimized GRU implementation for the JAX backend, which was previously a NotImplementedError stub. The optimization pre-computes input projections to improve performance, especially on longer sequences. The implementation seems correct for the supported configuration (reset_after=True, no masking, no unrolling). My review includes a suggestion to improve the error message for unsupported configurations and a strong recommendation to add unit tests to ensure the correctness and prevent future regressions.
|
|
||
| def gru(*args, **kwargs): | ||
| raise NotImplementedError | ||
| def gru( |
There was a problem hiding this comment.
This new optimized GRU implementation for the JAX backend lacks unit tests. As you mentioned in the PR description, tests are needed. Please add a new test file (e.g., keras/src/backend/jax/rnn_test.py) to verify the correctness of this implementation. The tests should compare the output with the generic keras.layers.RNN(keras.layers.GRUCell(...)) implementation for various configurations (e.g., go_backwards=True/False, return_sequences=True/False) to ensure correctness and prevent future regressions.
| reset_after=True, | ||
| ): | ||
| if not reset_after or unroll or mask is not None: | ||
| raise NotImplementedError |
There was a problem hiding this comment.
The NotImplementedError could be more descriptive to aid in debugging, even if it's caught by the GRU layer. According to the Keras API design guidelines (line 139), error messages should be informative. Consider providing details on which specific arguments are not supported in this optimized path.
| raise NotImplementedError | |
| raise NotImplementedError("Optimized JAX GRU implementation only supports `reset_after=True`, `unroll=False`, and no masking.") |
References
- Error messages should be contextual, informative, and actionable, explaining what happened, what was expected, and how to fix it. (link)
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #22401 +/- ##
==========================================
+ Coverage 83.00% 83.02% +0.02%
==========================================
Files 596 596
Lines 66588 66710 +122
Branches 10370 10384 +14
==========================================
+ Hits 55272 55388 +116
- Misses 8678 8680 +2
- Partials 2638 2642 +4
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
I added tests for it. Hope that helps |
Summary
The JAX backend's gru() was just a NotImplementedError stub, so every GRU call went through the generic rnn() loop. This
replaces it with a lax.scan implementation that precomputes all the input projections in one big matmul before entering the
loop, cutting the per-step work down to a single recurrent matmul.
Same constraints as the torch backend: only works with reset_after=True, and masked sequences still go through the generic
path since the layer handles zero_output_for_mask separately.
There's no cuDNN GRU primitive in JAX (jax-ml/jax#18867), so this is a manual scan. The precomputed projections plus XLA fusion should still help though, especially on longer sequences.
If needed I could make tests for this.