Skip to content

Conversation

@Rohan-Bierneni
Copy link
Collaborator

@Rohan-Bierneni Rohan-Bierneni commented Oct 21, 2025

Description

This pr adds the Qwen3-Next Gated Full Attention implementation to the existing Qwen3-Next code in qwen3.py.

Current implementation in maxtext uses normal Attention, but qwen3 next requires attention with 2 slight tweaks: an output gate and partial ROPE applied to 25% of head_dim. This pr adds this functionality by building a custom attention class for Qwen3-Next on top of AttentionOp for the core attention calculation.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/448407748

Tests

I have added a testcase that compares the pytorch ref to the jax ref and compares the output tensors after the gated fullattention layer.

This tests is passing after running this command: pytest -vvs tests/check_qwen3_next_vs_reference.py::TestQwen3Next::test_full_attention_jax_vs_pytorch: https://paste.googleplex.com/5626786360721408

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@Rohan-Bierneni Rohan-Bierneni changed the title Rbierneni qwen3 next fullattention Qwen3 next fullattention Oct 21, 2025
@Rohan-Bierneni Rohan-Bierneni changed the title Qwen3 next fullattention Qwen3-Next Gated FullAttention Implementation Oct 21, 2025
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! LGTM at high level.

Let's try to keep Qwen3 decoder layer simple and calling needed functions from embeddings, normalizations, and attentions, etc.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments.

@RissyRan
Copy link
Collaborator

Oh, I think you will need to squash those 26 commits into 1

@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request introduces the Qwen3-Next Gated FullAttention implementation, including custom RMSNorm and partial Rotary Embedding. The changes are well-structured, with new components moved to appropriate layers and a comprehensive test added for verification.

🔍 General Feedback

  • The refactoring of Qwen3NextRMSNorm and Qwen3NextRMSNormGated to normalizations.py is a good improvement for modularity.
  • The addition of test_full_attention_jax_vs_pytorch_attention is crucial for ensuring correctness and alignment with the PyTorch reference.
  • Some minor improvements in code clarity and consistency in variable naming and docstrings have been suggested.

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-fullattention branch from 2cecb2e to f853c15 Compare November 1, 2025 00:09
Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM.

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-fullattention branch 3 times, most recently from a8ef44a to 4b30745 Compare November 3, 2025 20:09
@github-actions
Copy link

github-actions bot commented Nov 3, 2025

🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request introduces the Qwen3-Next Gated FullAttention implementation, including custom RMSNorm and partial Rotary Embedding. The changes are well-structured, integrating new components into existing layers and configurations. The addition of a comprehensive test case comparing JAX and PyTorch implementations is a significant improvement for verifying correctness.

🔍 General Feedback

  • The refactoring of normalization classes into normalizations.py enhances modularity and reusability.
  • The detailed docstrings for new classes and functions are highly beneficial for understanding the Qwen3-Next specific implementations.
  • The validation logic in pyconfig.py for partial_rotary_factor ensures proper configuration usage.

@parambole parambole self-requested a review November 3, 2025 21:23
Copy link
Collaborator

@parambole parambole left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work ! LGTM

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-fullattention branch 2 times, most recently from 4195b97 to 53ccaa6 Compare November 4, 2025 02:59
Ported some of the pytorch ref functions

Added all test code and verified testcase passes

Removed caching logic and debug statements

Fixed testcase and jax gating logic

Resolved scaling factor adjustment

Remove debug statements

move partial rope logic to embeddings.py

Moved partial rope logic to embeddings.py

remove old partial rope code

Resolved comments from pr review

Removed qwen3rmsnorm function from qwen3.py

Removed initialization for using Attention()

Qwen3NextFullAttention working with Attention() instead of attention_op()

resolved some comments from pr related to Qwen3NextRMSNorm

Cleaned up code and now works with Attention() integration

Add pyconfig check for rotary_dim

Change Qwen3NextRMSNorm to match base RMSNorm impl

Fixed bug with running maxtext train command with qwen3 next

Updated pytorch partial ROPE impl for unit test

Fix indentation

Fixed failing qwen3nextrmsnorm tests

Update Qwen3NextRMSNormGated to also use scale for checkpointing

Remove debug statements now all tests pass

for rebase

Resolved gemini-code-review bot comments

Fixed nit comments based on review

Undo commented out code for jax 0.7.0 compatability

Run linter

Fixed pyink error in embeddings.py

Use nnx.data to wrap rmsnorm in qwen3nextrmsnorm

Add qwen3 next flash attention test

Remove skip_jax_distributed_system flag

Add sharding for 4 devices

Update ici fsdp param

Update tpu sharding params

revert test code

increase batch size

Try with dot_product

try with relaxed atol rtol

Update with dot product & flash attention tests

add condition rtol & atol

Create new jax pyconfig based on attention_type

convert to helper function so pytest doesn't pick it up
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-qwen3-next-fullattention branch from f7551ea to ee4b38a Compare November 4, 2025 03:11
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@Rohan-Bierneni
Copy link
Collaborator Author

Rohan-Bierneni commented Nov 4, 2025

Manually adding pull ready tag since all tests pass, have 3 approvals, and all comments are resolved. Seems to be bug with a skipped check causing the tag to not be added.

@copybara-service copybara-service bot merged commit ff5be4a into main Nov 4, 2025
38 checks passed
@copybara-service copybara-service bot deleted the rbierneni-qwen3-next-fullattention branch November 4, 2025 16:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants