Skip to content

Add cuDNN LSTM for JAX backend#22399

Open
MarcosAsh wants to merge 2 commits intokeras-team:masterfrom
MarcosAsh:jax-lstm
Open

Add cuDNN LSTM for JAX backend#22399
MarcosAsh wants to merge 2 commits intokeras-team:masterfrom
MarcosAsh:jax-lstm

Conversation

@MarcosAsh
Copy link
Contributor

Summary

The JAX backend's lstm() was a NotImplementedError stub, so it always fell back to the generic lax.scan loop. This wires it up to jax.experimental.rnn.lstm, which goes through cuDNN directly on GPU.

The weight conversion is straightforward, Keras and cuDNN use the same gate order [i, f, g, o], so the kernels just need to be transposed and flattened. Masking works by computing seq_lengths from the boolean mask, with the same right-padded constraint the torch backend has.

Falls back to lax.scan on CPU or when cuDNN isn't available, so nothing changes for non-GPU users.

Benchmarks from the JAX repo put the cuDNN path at around 5x faster than lax.scan for single-layer LSTMs.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces cuDNN LSTM support for the JAX backend, enhancing performance on GPUs. It includes weight conversion and masking support while maintaining a fallback mechanism for CPU execution. The changes aim to improve the speed of LSTM computations using the cuDNN library when available, with benchmarks showing significant performance gains.

Highlights

  • cuDNN LSTM Integration: The JAX backend now supports cuDNN LSTM, leveraging jax.experimental.rnn.lstm for GPU acceleration.
  • Weight Conversion: Weight conversion is implemented to align Keras and cuDNN gate orders, involving transposing and flattening kernels.
  • Masking Support: Masking is supported by computing seq_lengths from boolean masks, with a constraint for right-padded sequences.
  • Fallback Mechanism: The implementation falls back to lax.scan on CPU or when cuDNN is unavailable, ensuring functionality on non-GPU setups.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/backend/jax/rnn.py
    • Implemented cuDNN LSTM support for JAX backend.
Activity
  • The PR introduces cuDNN LSTM support for the JAX backend.
  • It includes weight conversion and masking support.
  • The implementation falls back to lax.scan on CPU or when cuDNN is unavailable.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist
Copy link
Contributor

Warning

Gemini is experiencing higher than usual traffic and was unable to create the review. Please try again in a few hours by commenting /gemini review.

@codecov-commenter
Copy link

codecov-commenter commented Mar 11, 2026

Codecov Report

❌ Patch coverage is 29.50820% with 43 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.97%. Comparing base (69f9311) to head (22a8319).
⚠️ Report is 5 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/rnn.py 29.50% 43 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22399      +/-   ##
==========================================
- Coverage   83.00%   82.97%   -0.04%     
==========================================
  Files         596      596              
  Lines       66586    66737     +151     
  Branches    10368    10391      +23     
==========================================
+ Hits        55272    55372     +100     
- Misses       8677     8724      +47     
- Partials     2637     2641       +4     
Flag Coverage Δ
keras 82.79% <29.50%> (-0.04%) ⬇️
keras-jax 60.53% <29.50%> (-0.10%) ⬇️
keras-numpy 54.74% <6.55%> (-0.11%) ⬇️
keras-openvino 49.88% <6.55%> (+0.46%) ⬆️
keras-tensorflow 61.73% <6.55%> (-0.13%) ⬇️
keras-torch 60.58% <6.55%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@MarcosAsh
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds cuDNN-accelerated LSTM support for the JAX backend, which is a great performance improvement. The implementation correctly handles weight conversion, masking for right-padded sequences, and falling back to the generic implementation when cuDNN is not available or applicable.

I have a couple of suggestions to improve robustness and efficiency:

  • In cudnn_ok, I suggest adding a check for jax.nn.sigmoid for consistency with the tanh check.
  • I've also identified a small redundant computation in the mask handling logic and suggest a refactoring to improve efficiency.

Comment on lines +226 to +235
from keras.src import activations
from keras.src import ops

return (
activation in (activations.tanh, jnp.tanh, ops.tanh)
and recurrent_activation in (activations.sigmoid, ops.sigmoid) # noqa: E501
and not unroll
and use_bias
and _is_gpu_available()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the tanh activation check and to make the cudnn_ok check more robust, you should also check for jax.nn.sigmoid. The tanh check includes jnp.tanh, which is the base JAX function. The equivalent for sigmoid is jax.nn.sigmoid. This will ensure that if a user passes the base JAX function directly, cudnn_ok will correctly identify it.

Suggested change
from keras.src import activations
from keras.src import ops
return (
activation in (activations.tanh, jnp.tanh, ops.tanh)
and recurrent_activation in (activations.sigmoid, ops.sigmoid) # noqa: E501
and not unroll
and use_bias
and _is_gpu_available()
)
from keras.src import activations
from keras.src import ops
from jax import nn
return (
activation in (activations.tanh, jnp.tanh, ops.tanh)
and recurrent_activation
in (activations.sigmoid, ops.sigmoid, nn.sigmoid)
and not unroll
and use_bias
and _is_gpu_available()
)

Comment on lines +325 to +326
seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a redundant computation here. jnp.sum(mask.astype(jnp.int32), axis=1) is calculated here to get seq_lengths, but it's also calculated inside _assert_valid_mask as count_of_true.

To improve efficiency, you can modify _assert_valid_mask to return count_of_true and use that value here.

  1. In _assert_valid_mask, add return count_of_true at the end.
  2. Then, you can replace these two lines with just seq_lengths = _assert_valid_mask(mask).
Suggested change
seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32)
seq_lengths = _assert_valid_mask(mask)

@MarcosAsh
Copy link
Contributor Author

I added test for this! hope it helps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants