Skip to content

Add optimized GRU for JAX backend#22401

Open
MarcosAsh wants to merge 3 commits intokeras-team:masterfrom
MarcosAsh:jax-gru
Open

Add optimized GRU for JAX backend#22401
MarcosAsh wants to merge 3 commits intokeras-team:masterfrom
MarcosAsh:jax-gru

Conversation

@MarcosAsh
Copy link
Contributor

Summary

The JAX backend's gru() was just a NotImplementedError stub, so every GRU call went through the generic rnn() loop. This
replaces it with a lax.scan implementation that precomputes all the input projections in one big matmul before entering the
loop, cutting the per-step work down to a single recurrent matmul.

Same constraints as the torch backend: only works with reset_after=True, and masked sequences still go through the generic
path since the layer handles zero_output_for_mask separately.

There's no cuDNN GRU primitive in JAX (jax-ml/jax#18867), so this is a manual scan. The precomputed projections plus XLA fusion should still help though, especially on longer sequences.

If needed I could make tests for this.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized Gated Recurrent Unit (GRU) implementation for the JAX backend in Keras. By replacing the previous NotImplementedError with a lax.scan-based approach that precomputes input transformations, it aims to drastically improve the performance of GRU layers, especially for longer sequences, while adhering to specific operational constraints.

Highlights

  • JAX GRU Implementation: Replaced the NotImplementedError stub for the JAX backend's gru() function with an optimized implementation using lax.scan.
  • Performance Optimization: The new gru() implementation precomputes all input projections in a single matrix multiplication before entering the recurrent loop, significantly reducing per-step computation.
  • Constraints and Limitations: The optimized gru() currently only supports reset_after=True. Masked sequences and unroll=True still fall back to a NotImplementedError or the generic RNN path.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/backend/jax/rnn.py
    • Implemented the gru function with an optimized lax.scan approach.
Activity
  • The pull request was opened by MarcosAsh.
  • No review comments or other human activity have been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized GRU implementation for the JAX backend, which was previously a NotImplementedError stub. The optimization pre-computes input projections to improve performance, especially on longer sequences. The implementation seems correct for the supported configuration (reset_after=True, no masking, no unrolling). My review includes a suggestion to improve the error message for unsupported configurations and a strong recommendation to add unit tests to ensure the correctness and prevent future regressions.


def gru(*args, **kwargs):
raise NotImplementedError
def gru(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new optimized GRU implementation for the JAX backend lacks unit tests. As you mentioned in the PR description, tests are needed. Please add a new test file (e.g., keras/src/backend/jax/rnn_test.py) to verify the correctness of this implementation. The tests should compare the output with the generic keras.layers.RNN(keras.layers.GRUCell(...)) implementation for various configurations (e.g., go_backwards=True/False, return_sequences=True/False) to ensure correctness and prevent future regressions.

reset_after=True,
):
if not reset_after or unroll or mask is not None:
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The NotImplementedError could be more descriptive to aid in debugging, even if it's caught by the GRU layer. According to the Keras API design guidelines (line 139), error messages should be informative. Consider providing details on which specific arguments are not supported in this optimized path.

Suggested change
raise NotImplementedError
raise NotImplementedError("Optimized JAX GRU implementation only supports `reset_after=True`, `unroll=False`, and no masking.")
References
  1. Error messages should be contextual, informative, and actionable, explaining what happened, what was expected, and how to fix it. (link)

@codecov-commenter
Copy link

codecov-commenter commented Mar 11, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 83.02%. Comparing base (e65cfb8) to head (ccc6907).
⚠️ Report is 6 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22401      +/-   ##
==========================================
+ Coverage   83.00%   83.02%   +0.02%     
==========================================
  Files         596      596              
  Lines       66588    66710     +122     
  Branches    10370    10384      +14     
==========================================
+ Hits        55272    55388     +116     
- Misses       8678     8680       +2     
- Partials     2638     2642       +4     
Flag Coverage Δ
keras 82.85% <100.00%> (+0.02%) ⬆️
keras-jax 60.57% <100.00%> (-0.06%) ⬇️
keras-numpy 54.76% <3.12%> (-0.09%) ⬇️
keras-openvino 49.90% <3.12%> (+0.48%) ⬆️
keras-tensorflow 61.75% <3.12%> (-0.10%) ⬇️
keras-torch 60.60% <3.12%> (-0.10%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@MarcosAsh
Copy link
Contributor Author

I added tests for it. Hope that helps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants