You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Ported some of the pytorch ref functions
Added all test code and verified testcase passes
Removed caching logic and debug statements
Fixed testcase and jax gating logic
Resolved scaling factor adjustment
Remove debug statements
move partial rope logic to embeddings.py
Moved partial rope logic to embeddings.py
remove old partial rope code
Resolved comments from pr review
Removed qwen3rmsnorm function from qwen3.py
Removed initialization for using Attention()
Qwen3NextFullAttention working with Attention() instead of attention_op()
resolved some comments from pr related to Qwen3NextRMSNorm
Cleaned up code and now works with Attention() integration
Add pyconfig check for rotary_dim
Change Qwen3NextRMSNorm to match base RMSNorm impl
Fixed bug with running maxtext train command with qwen3 next
Updated pytorch partial ROPE impl for unit test
Fix indentation
Fixed failing qwen3nextrmsnorm tests
Update Qwen3NextRMSNormGated to also use scale for checkpointing
Remove debug statements now all tests pass
for rebase
Resolved gemini-code-review bot comments
Fixed nit comments based on review
Undo commented out code for jax 0.7.0 compatability
Run linter
Fixed pyink error in embeddings.py
Use nnx.data to wrap rmsnorm in qwen3nextrmsnorm
Add qwen3 next flash attention test
Remove skip_jax_distributed_system flag
Add sharding for 4 devices
Update ici fsdp param
Update tpu sharding params
revert test code
increase batch size
Try with dot_product
try with relaxed atol rtol
Update with dot product & flash attention tests
add condition rtol & atol
Create new jax pyconfig based on attention_type
convert to helper function so pytest doesn't pick it up
0 commit comments