-
Notifications
You must be signed in to change notification settings - Fork 419
Qwen3-Next Gated FullAttention Implementation #2529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work! LGTM at high level.
Let's try to keep Qwen3 decoder layer simple and calling needed functions from embeddings, normalizations, and attentions, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor comments.
|
Oh, I think you will need to squash those 26 commits into 1 |
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This Pull Request introduces the Qwen3-Next Gated FullAttention implementation, including custom RMSNorm and partial Rotary Embedding. The changes are well-structured, with new components moved to appropriate layers and a comprehensive test added for verification.
🔍 General Feedback
- The refactoring of
Qwen3NextRMSNormandQwen3NextRMSNormGatedtonormalizations.pyis a good improvement for modularity. - The addition of
test_full_attention_jax_vs_pytorch_attentionis crucial for ensuring correctness and alignment with the PyTorch reference. - Some minor improvements in code clarity and consistency in variable naming and docstrings have been suggested.
2cecb2e to
f853c15
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! LGTM.
a8ef44a to
4b30745
Compare
|
🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This Pull Request introduces the Qwen3-Next Gated FullAttention implementation, including custom RMSNorm and partial Rotary Embedding. The changes are well-structured, integrating new components into existing layers and configurations. The addition of a comprehensive test case comparing JAX and PyTorch implementations is a significant improvement for verifying correctness.
🔍 General Feedback
- The refactoring of normalization classes into
normalizations.pyenhances modularity and reusability. - The detailed docstrings for new classes and functions are highly beneficial for understanding the Qwen3-Next specific implementations.
- The validation logic in
pyconfig.pyforpartial_rotary_factorensures proper configuration usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work ! LGTM
4195b97 to
53ccaa6
Compare
Ported some of the pytorch ref functions Added all test code and verified testcase passes Removed caching logic and debug statements Fixed testcase and jax gating logic Resolved scaling factor adjustment Remove debug statements move partial rope logic to embeddings.py Moved partial rope logic to embeddings.py remove old partial rope code Resolved comments from pr review Removed qwen3rmsnorm function from qwen3.py Removed initialization for using Attention() Qwen3NextFullAttention working with Attention() instead of attention_op() resolved some comments from pr related to Qwen3NextRMSNorm Cleaned up code and now works with Attention() integration Add pyconfig check for rotary_dim Change Qwen3NextRMSNorm to match base RMSNorm impl Fixed bug with running maxtext train command with qwen3 next Updated pytorch partial ROPE impl for unit test Fix indentation Fixed failing qwen3nextrmsnorm tests Update Qwen3NextRMSNormGated to also use scale for checkpointing Remove debug statements now all tests pass for rebase Resolved gemini-code-review bot comments Fixed nit comments based on review Undo commented out code for jax 0.7.0 compatability Run linter Fixed pyink error in embeddings.py Use nnx.data to wrap rmsnorm in qwen3nextrmsnorm Add qwen3 next flash attention test Remove skip_jax_distributed_system flag Add sharding for 4 devices Update ici fsdp param Update tpu sharding params revert test code increase batch size Try with dot_product try with relaxed atol rtol Update with dot product & flash attention tests add condition rtol & atol Create new jax pyconfig based on attention_type convert to helper function so pytest doesn't pick it up
f7551ea to
ee4b38a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
|
Manually adding pull ready tag since all tests pass, have 3 approvals, and all comments are resolved. Seems to be bug with a skipped check causing the tag to not be added. |
Description
This pr adds the Qwen3-Next Gated Full Attention implementation to the existing Qwen3-Next code in qwen3.py.
Current implementation in maxtext uses normal Attention, but qwen3 next requires attention with 2 slight tweaks: an output gate and partial ROPE applied to 25% of head_dim. This pr adds this functionality by building a custom attention class for Qwen3-Next on top of AttentionOp for the core attention calculation.
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/448407748
Tests
I have added a testcase that compares the pytorch ref to the jax ref and compares the output tensors after the gated fullattention layer.
This tests is passing after running this command:
pytest -vvs tests/check_qwen3_next_vs_reference.py::TestQwen3Next::test_full_attention_jax_vs_pytorch: https://paste.googleplex.com/5626786360721408Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.