Skip to content

Commit ee4b38a

Browse files
Added jax impl for gated full attention
Ported some of the pytorch ref functions Added all test code and verified testcase passes Removed caching logic and debug statements Fixed testcase and jax gating logic Resolved scaling factor adjustment Remove debug statements move partial rope logic to embeddings.py Moved partial rope logic to embeddings.py remove old partial rope code Resolved comments from pr review Removed qwen3rmsnorm function from qwen3.py Removed initialization for using Attention() Qwen3NextFullAttention working with Attention() instead of attention_op() resolved some comments from pr related to Qwen3NextRMSNorm Cleaned up code and now works with Attention() integration Add pyconfig check for rotary_dim Change Qwen3NextRMSNorm to match base RMSNorm impl Fixed bug with running maxtext train command with qwen3 next Updated pytorch partial ROPE impl for unit test Fix indentation Fixed failing qwen3nextrmsnorm tests Update Qwen3NextRMSNormGated to also use scale for checkpointing Remove debug statements now all tests pass for rebase Resolved gemini-code-review bot comments Fixed nit comments based on review Undo commented out code for jax 0.7.0 compatability Run linter Fixed pyink error in embeddings.py Use nnx.data to wrap rmsnorm in qwen3nextrmsnorm Add qwen3 next flash attention test Remove skip_jax_distributed_system flag Add sharding for 4 devices Update ici fsdp param Update tpu sharding params revert test code increase batch size Try with dot_product try with relaxed atol rtol Update with dot product & flash attention tests add condition rtol & atol Create new jax pyconfig based on attention_type convert to helper function so pytest doesn't pick it up
1 parent 84f3ad6 commit ee4b38a

File tree

8 files changed

+703
-142
lines changed

8 files changed

+703
-142
lines changed

src/MaxText/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,8 @@ gdn_num_value_heads: 32
905905
gdn_chunk_size: 64
906906
# Whether to apply L2 normalization to query and key tensors inside the Gated Delta Rule kernel.
907907
use_qk_norm_in_gdn: True
908+
# The ratio of dimension to apply ROPE on
909+
partial_rotary_factor: 1.0
908910

909911
# Use tokamax library for gmm kernel implementation
910912
use_tokamax_gmm: false

src/MaxText/configs/models/qwen3-next-80b-a3b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ gdn_chunk_size: 64
4545

4646
# RoPE Settings
4747
rope_max_timescale: 10000000
48+
partial_rotary_factor: 0.25

src/MaxText/layers/attentions.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,11 @@
6565
LlamaVisionRotaryEmbedding,
6666
RotaryEmbedding,
6767
YarnRotaryEmbedding,
68+
Qwen3NextRotaryEmbedding,
6869
)
6970
from MaxText.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned, default_bias_init
7071
from MaxText.layers.linears import DenseGeneral, canonicalize_tuple, normalize_axes
71-
from MaxText.layers.normalizations import RMSNorm
72+
from MaxText.layers.normalizations import RMSNorm, Qwen3NextRMSNorm
7273
from MaxText.layers.quantizations import AqtQuantization as Quant
7374

7475
# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
@@ -416,6 +417,8 @@ def __init__(
416417
self.model_mode = model_mode
417418
self.rngs = rngs
418419

420+
self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT
421+
419422
# Module attribute names must match names previously passed to Linen for checkpointing
420423
self.KVCache_0 = (
421424
self.init_kv_caches(inputs_kv_shape=inputs_kv_shape)
@@ -478,6 +481,9 @@ def __init__(
478481
else:
479482
self.sinks = None
480483

484+
self.query_norm = None
485+
self.key_norm = None
486+
481487
is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4
482488
if self.use_qk_norm and not is_llama4_decoder_block:
483489
self.query_norm = RMSNorm(
@@ -498,9 +504,21 @@ def __init__(
498504
kernel_axes=("norm",),
499505
rngs=self.rngs,
500506
)
501-
else:
502-
self.query_norm = None
503-
self.key_norm = None
507+
elif self.is_qwen3_next:
508+
self.query_norm = Qwen3NextRMSNorm(
509+
num_features=self.config.head_dim,
510+
eps=self.config.normalization_layer_epsilon,
511+
dtype=self.config.dtype,
512+
weight_dtype=self.config.weight_dtype,
513+
rngs=self.rngs,
514+
)
515+
self.key_norm = Qwen3NextRMSNorm(
516+
num_features=self.config.head_dim,
517+
eps=self.config.normalization_layer_epsilon,
518+
dtype=self.config.dtype,
519+
weight_dtype=self.config.weight_dtype,
520+
rngs=self.rngs,
521+
)
504522

505523
self._maybe_shard_with_logical = functools.partial(
506524
maybe_shard_with_logical,
@@ -538,9 +556,15 @@ def query_init(*args):
538556
kernel_axes = (
539557
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("embed", "q_heads", "kv")
540558
)
559+
in_features = self.convert_dense_general_inputs_shape(inputs_q_shape)
560+
out_features = (self.num_query_heads, self.head_dim)
561+
562+
if self.is_qwen3_next:
563+
out_features = (self.num_query_heads, self.head_dim * 2)
564+
541565
return DenseGeneral(
542-
in_features_shape=self.convert_dense_general_inputs_shape(inputs_q_shape),
543-
out_features_shape=(self.num_query_heads, self.head_dim),
566+
in_features_shape=in_features,
567+
out_features_shape=out_features,
544568
axis=-1,
545569
kernel_init=query_init,
546570
kernel_axes=kernel_axes,
@@ -642,13 +666,22 @@ def qkv_projection(self, inputs: Array, proj_name: str, out_sharding: NamedShard
642666

643667
def init_out_w(self, output_dim: int) -> nnx.Module:
644668
"""out projection"""
669+
in_features = (self.num_query_heads, self.head_dim)
670+
out_features = output_dim
645671
out_kernel_axis = (
646672
(None, None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("heads", "kv", "embed")
647673
)
674+
axis = (-2, -1)
675+
676+
if self.is_qwen3_next:
677+
in_features = self.num_query_heads * self.head_dim
678+
out_kernel_axis = ("mlp", "embed")
679+
axis = (-1,)
680+
648681
return DenseGeneral(
649-
in_features_shape=(self.num_query_heads, self.head_dim),
650-
out_features_shape=output_dim,
651-
axis=(-2, -1),
682+
in_features_shape=in_features,
683+
out_features_shape=out_features,
684+
axis=axis,
652685
kernel_init=self.kernel_init,
653686
kernel_axes=out_kernel_axis, # trade speed with memory
654687
dtype=self.dtype,
@@ -720,6 +753,16 @@ def init_rotary_embedding(self):
720753
attention_scaling=self.config.rope_attention_scaling,
721754
rngs=self.rngs,
722755
)
756+
elif self.is_qwen3_next:
757+
rotary_embedding = Qwen3NextRotaryEmbedding(
758+
min_timescale=self.config.rope_min_timescale,
759+
max_timescale=self.config.rope_max_timescale,
760+
embedding_dims=self.config.head_dim,
761+
partial_rotary_factor=self.config.partial_rotary_factor,
762+
cast_as_fprop_dtype=True,
763+
fprop_dtype=self.config.dtype,
764+
rngs=self.rngs,
765+
)
723766
else:
724767
max_timescale = self.config.rope_max_timescale
725768
# For local attention use local_rope_max_timescale if it's is positive
@@ -890,9 +933,17 @@ def __call__(
890933
value_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes(self.value_axis_names))
891934
value = self.kv_projection(inputs_kv, proj_name="value", out_sharding=value_sharding)
892935

936+
gate = None
937+
if self.is_qwen3_next:
938+
# Split query into query & gate.
939+
query, gate = jnp.split(query, 2, axis=-1)
940+
batch_size, seq_len, _, _ = gate.shape
941+
gate = gate.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
942+
893943
is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4
894944
# NOTE: llama 4 does L2 normalization after RoPE
895-
if self.use_qk_norm and not is_llama4_decoder_block:
945+
# Apply Qwen3Next specific RMS Norm
946+
if (self.use_qk_norm and not is_llama4_decoder_block) or self.is_qwen3_next:
896947
query = self.query_norm(query)
897948
key = self.key_norm(key)
898949

@@ -964,7 +1015,9 @@ def __call__(
9641015
bidirectional_mask,
9651016
self.sinks,
9661017
)
967-
1018+
if self.is_qwen3_next:
1019+
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
1020+
out = out * jax.nn.sigmoid(gate)
9681021
if model_mode == MODEL_MODE_PREFILL:
9691022
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
9701023
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:

src/MaxText/layers/embeddings.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,97 @@ def llama_rotary_embedding_as_linen(
380380
)
381381

382382

383+
def qwen3_next_rotary_embedding_as_linen(
384+
*,
385+
min_timescale: int,
386+
max_timescale: int,
387+
embedding_dims: int = 0,
388+
partial_rotary_factor: float = 0.25,
389+
cast_as_fprop_dtype: bool = True,
390+
fprop_dtype: DType = jnp.bfloat16,
391+
name: str | None = None,
392+
):
393+
"""Initializes the Qwen3NextRotaryEmbedding module and returns it as a Linen module.
394+
395+
Args:
396+
min_timescale: Start of the geometric index. Determines the periodicity of
397+
the added signal.
398+
max_timescale: End of the geometric index. Determines the frequency of the
399+
added signal.
400+
embedding_dims: Dimension of the embedding to be generated.
401+
partial_rotary_factor: Ratio of dimensions to apply ROPE to.
402+
cast_as_fprop_dtype: Whether to cast the output to the fprop dtype.
403+
fprop_dtype: The dtype of the output.
404+
name: Name of the Linen module.
405+
"""
406+
return nnx_wrappers.to_linen(
407+
Qwen3NextRotaryEmbedding,
408+
min_timescale=min_timescale,
409+
max_timescale=max_timescale,
410+
embedding_dims=embedding_dims,
411+
partial_rotary_factor=partial_rotary_factor,
412+
cast_as_fprop_dtype=cast_as_fprop_dtype,
413+
fprop_dtype=fprop_dtype,
414+
metadata_fn=variable_to_logically_partitioned,
415+
name=name,
416+
)
417+
418+
419+
class Qwen3NextRotaryEmbedding(RotaryEmbedding):
420+
"""Qwen3 Next variant of ROPE (partial ROPE)"""
421+
422+
def __init__(
423+
self,
424+
min_timescale: int,
425+
max_timescale: int,
426+
embedding_dims: int = 0,
427+
cast_as_fprop_dtype: bool = True,
428+
fprop_dtype: DType = jnp.bfloat16,
429+
partial_rotary_factor: float = 0.25,
430+
rngs: nnx.Rngs = None,
431+
):
432+
"""Initializes the Qwen3NextRotaryEmbedding module.
433+
434+
Args:
435+
min_timescale: Start of the geometric index. Determines the periodicity of
436+
the added signal.
437+
max_timescale: End of the geometric index. Determines the frequency of the
438+
added signal.
439+
embedding_dims: Dimension of the embedding to be generated.
440+
partial_rotary_factor: Ratio of dimensions to apply ROPE to
441+
rngs: rng keys passed in by nnx.bridge.to_linen.
442+
"""
443+
self.head_dim = embedding_dims
444+
self.partial_rotary_factor = partial_rotary_factor
445+
self.rotary_dim = int(self.head_dim * self.partial_rotary_factor)
446+
447+
super().__init__(
448+
min_timescale=min_timescale,
449+
max_timescale=max_timescale,
450+
embedding_dims=self.rotary_dim,
451+
cast_as_fprop_dtype=cast_as_fprop_dtype,
452+
fprop_dtype=fprop_dtype,
453+
rngs=rngs,
454+
)
455+
456+
def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.Array:
457+
"""Applies LLaMA variant of rotary position embedding.
458+
459+
Args:
460+
inputs: The input sequence on which to apply the Rotary position
461+
embedding. It is assumed of shape [B, S, H, D].
462+
position: Optional position array [B, S]. Only needed when the sequence
463+
is packed.
464+
465+
Returns:
466+
A jax.Array of shape [B, S, H, D - rotary_dim] with rotary position embeddings applied.
467+
"""
468+
inputs_rot, inputs_pass = jnp.split(inputs, [self.rotary_dim], axis=-1)
469+
inputs_rot = super().__call__(inputs_rot, position)
470+
inputs = jnp.concatenate([inputs_rot, inputs_pass], axis=-1)
471+
return inputs
472+
473+
383474
class LLaMARotaryEmbedding(RotaryEmbedding):
384475
"""LLaMA variant of ROPE."""
385476

src/MaxText/layers/normalizations.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from flax import linen as nn
2020
from flax import nnx
21+
from flax.linen import initializers as linen_initializers
2122
from jax import lax
2223
import jax
2324
import jax.numpy as jnp
@@ -26,7 +27,7 @@
2627
from MaxText import max_utils
2728
from MaxText.layers import nnx_wrappers
2829
from MaxText.layers.initializers import Initializer, variable_to_logically_partitioned
29-
from MaxText.common_types import Array, ShardMode
30+
from MaxText.common_types import Array, DType, ShardMode
3031

3132

3233
class RMSNorm(nnx.Module):
@@ -42,6 +43,7 @@ def __init__(
4243
kernel_axes: tuple[None | str, ...] = (),
4344
scale_init: Initializer = nn.initializers.ones,
4445
parameter_memory_host_offload: bool = False,
46+
scale_offset: float = 0.0,
4547
*,
4648
rngs: nnx.Rngs,
4749
):
@@ -53,6 +55,7 @@ def __init__(
5355
self.kernel_axes = kernel_axes
5456
self.scale_init = scale_init
5557
self.parameter_memory_host_offload = parameter_memory_host_offload
58+
self.scale_offset = scale_offset
5659
self.scale = nnx.Param(
5760
scale_init(rngs.params(), (num_features,), weight_dtype),
5861
sharding=kernel_axes,
@@ -73,8 +76,83 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
7376
out_sharding = None
7477

7578
scale = jnp.asarray(scale, self.dtype)
79+
effective_scale = scale + self.scale_offset # Apply offset
7680
# broadcast 2nd input then element-wise mul
77-
return jnp.einsum("i...k,...k->i...k", y, scale, out_sharding=out_sharding)
81+
return jnp.einsum("i...k,...k->i...k", y, effective_scale, out_sharding=out_sharding)
82+
83+
84+
def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
85+
"""
86+
Used for input and post attention layernorms
87+
in Qwen3NextDecoderLayer.
88+
89+
This normalization layer is specific to Qwen3-Next. Key characteristics:
90+
1. The learnable scale parameter `scale` is initialized to ZEROS.
91+
2. The scale is applied as `(1.0 + self.scale)`, making the initial scale effectively 1.0.
92+
This matches the PyTorch implementation of Qwen3NextRMSNorm.
93+
"""
94+
return nnx.data(
95+
RMSNorm(
96+
num_features=num_features,
97+
epsilon=eps,
98+
dtype=dtype,
99+
weight_dtype=weight_dtype,
100+
scale_init=linen_initializers.zeros,
101+
scale_offset=1.0,
102+
rngs=rngs,
103+
)
104+
)
105+
106+
107+
class Qwen3NextRMSNormGated(nnx.Module):
108+
"""
109+
This applies RMS Normalization and then a gated activation function (SiLU).
110+
This is used within the Qwen3NextGatedDeltaNet.
111+
112+
The normalization is performed by an internal `RMSNorm` instance (`self.rms_norm`),
113+
which has its own learnable `scale` parameter, initialized to ONES.
114+
115+
Attributes:
116+
num_features: The number of features in the input.
117+
eps: A small epsilon value to prevent division by zero in RMSNorm.
118+
dtype: The datatype of the computation.
119+
weight_dtype: The datatype of the internal RMSNorm scale.
120+
"""
121+
122+
def __init__(self, num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
123+
self.num_features = num_features
124+
self.eps = eps
125+
self.dtype = dtype
126+
self.weight_dtype = weight_dtype
127+
self.rms_norm = nnx.data(
128+
RMSNorm(
129+
num_features=num_features,
130+
epsilon=eps,
131+
dtype=dtype,
132+
weight_dtype=weight_dtype,
133+
scale_init=nnx.initializers.ones,
134+
rngs=rngs,
135+
)
136+
)
137+
138+
def __call__(self, hidden_states: Array, gate: Array) -> Array:
139+
"""
140+
Applies RMSNorm and then a SiLU gate.
141+
142+
Args:
143+
hidden_states: The input array to be normalized (o). Shape: (..., F)
144+
gate: The gating array for the activation (z). Shape: (..., F)
145+
where F is num_features.
146+
147+
Returns:
148+
The normalized and gated output array. Shape: (..., F)
149+
"""
150+
normalized_states = self.rms_norm(hidden_states)
151+
152+
# Gated Activation using SiLU (Sigmoid-weighted Linear Unit)
153+
gated_states = normalized_states * jax.nn.silu(gate.astype(jnp.float32))
154+
155+
return gated_states.astype(self.dtype)
78156

79157

80158
def rms_norm(

0 commit comments

Comments
 (0)