Skip to content

Commit edbbf29

Browse files
Migrate Decoder (Gemma3/Deepseek/Llama4) and utils to NNX
1 parent efce27d commit edbbf29

File tree

6 files changed

+212
-33
lines changed

6 files changed

+212
-33
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ subslice_shape: ""
10351035

10361036
# NNX
10371037
enable_nnx: false
1038-
pure_nnx_decoder: false
1038+
pure_nnx_decoder: True
10391039

10401040
################################## Qwen3-Next Specific Configs ##################################
10411041
# Kernel size for the 1D convolution in the Gated Delta Net

src/MaxText/layers/gemma3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def __init__(
9191

9292
batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode)
9393
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)
94-
9594
self.pre_self_attention_norm = RMSNorm(
9695
num_features=config.emb_dim,
9796
dtype=config.dtype,
@@ -198,7 +197,6 @@ def __call__(
198197
inputs = inputs[0]
199198
inputs = nn.with_logical_constraint(inputs, self.activation_axis_names)
200199
inputs = checkpoint_name(inputs, "decoder_layer_input")
201-
202200
lnx = self.pre_self_attention_norm(inputs)
203201
lnx = nn.with_logical_constraint(lnx, self.activation_axis_names)
204202

src/MaxText/layers/nnx_decoders.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333
from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT
3434
from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
3535
from MaxText.sharding import create_sharding
36+
from maxtext.inference import page_manager
3637
from MaxText.layers import linears
3738
from MaxText.layers import initializers
3839
from MaxText.layers import quantizations
39-
from MaxText import multimodal_utils
4040
from MaxText import sharding
4141
from MaxText.layers.attentions import Attention
4242
from MaxText.layers.normalizations import RMSNorm
@@ -61,6 +61,7 @@
6161
from maxtext.inference import page_manager
6262
from maxtext.utils import max_logging
6363
from maxtext.utils import maxtext_utils
64+
from maxtext.multimodal import utils as mm_utils
6465

6566
# ------------------------------------------------------------------------------
6667
# The network: Decoder Definitions
@@ -284,19 +285,28 @@ def __init__(
284285
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
285286
scan_length = config.num_decoder_layers // attention_pattern_length
286287
num_remaining_layers = config.num_decoder_layers % attention_pattern_length
288+
layer_kwargs = {"num_of_layers": attention_pattern_length}
289+
287290
rem_layer_kwargs = {"num_of_layers": num_remaining_layers}
288291

289292
RemattedGemma3Block = gemma3.Gemma3ScannableBlock
290293

291294
if scan_length > 0:
292-
self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs)
295+
self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs)
293296
self.layers_remainder = RemattedGemma3Block(
294297
config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs
295298
) # pytype: disable=wrong-keyword-args
296299
else:
297300
layer_cls = decoder_block_classes[0]
298-
num_layers = config.num_decoder_layers
299-
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs)
301+
num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval)
302+
layer_kwargs = {}
303+
if config.decoder_block == DecoderBlockType.LLAMA4:
304+
layer_kwargs = {
305+
"nope_layer_interval": self.config.nope_layer_interval,
306+
"interleave_moe_layer_step": self.config.interleave_moe_layer_step,
307+
}
308+
309+
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs)
300310
else:
301311
self.layers = nnx.List([])
302312
if self.is_deepseek:
@@ -366,7 +376,6 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
366376

367377
layer_cls = layers.__class__ # Access the underlying class
368378
sig = inspect.signature(layer_cls.__call__)
369-
370379
# Filter kwargs to only include keys that exist in the layer's signature
371380
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
372381

@@ -584,7 +593,7 @@ def _apply_embedding(
584593
"llama4-17b-128e",
585594
"qwen3-omni-30b-a3b",
586595
]:
587-
y = multimodal_utils.merge_mm_embeddings(
596+
y = mm_utils.merge_mm_embeddings(
588597
text_embeddings=y,
589598
multimodal_embeddings=image_embeddings,
590599
mask=bidirectional_mask,
@@ -596,7 +605,7 @@ def _apply_embedding(
596605

597606
if audio_embeddings is not None and cfg.use_audio:
598607
if cfg.model_name in ["qwen3-omni-30b-a3b"]:
599-
y = multimodal_utils.merge_mm_embeddings(
608+
y = mm_utils.merge_mm_embeddings(
600609
text_embeddings=y,
601610
multimodal_embeddings=audio_embeddings,
602611
mask=audio_masks,
@@ -698,7 +707,6 @@ def __call__(
698707
"previous_chunk": previous_chunk,
699708
"page_state": page_state,
700709
"slot": slot,
701-
"attention_metadata": attention_metadata,
702710
}
703711

704712
if cfg.decoder_block == DecoderBlockType.GEMMA3:
@@ -775,17 +783,12 @@ def _apply_gemma3_scanned_blocks(
775783
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
776784
scan_length = cfg.num_decoder_layers // attention_pattern_length
777785

778-
layer_call_kwargs = {"bidirectional_mask": bidirectional_mask}
786+
layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode)
787+
layer_kwargs = {"bidirectional_mask": bidirectional_mask}
779788

780789
# Apply the main scan over the full blocks
781790
if scan_length > 0:
782-
broadcast_args = (
783-
decoder_segment_ids,
784-
decoder_positions,
785-
deterministic,
786-
model_mode,
787-
)
788-
y, _ = self.layers(y, *broadcast_args, **layer_call_kwargs)
791+
y, _ = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs)
789792

790793
# Apply any remaining layers that did not fit into a full scanned block
791794
num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length
@@ -800,8 +803,9 @@ def _apply_gemma3_scanned_blocks(
800803
previous_chunk=previous_chunk,
801804
page_state=page_state,
802805
slot=slot,
803-
**layer_call_kwargs,
806+
**layer_kwargs,
804807
)
808+
805809
return y
806810

807811

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,14 +385,22 @@ def _build_single_axis_stacked_tensor(
385385
The final, assembled NumPy array for the MaxText parameter.
386386
"""
387387
tensors_to_stack = []
388+
# Heuristic to determine if we are stacking layers or experts.
389+
# If the number of items to stack equals the number of layers, it's a standard
390+
# scanned layer, and we use the configured param_scan_axis. Otherwise, it's
391+
# an unscanned MoE layer, and we stack along the expert axis (0).
392+
"""
393+
axis_to_stack = config.param_scan_axis if len(hf_source_keys) == config.base_num_decoder_layers else 0
394+
"""
388395

389-
if config.scan_layers:
390-
# If it's a standard scanned layer, we use the configured param_scan_axis.
391-
axis_to_stack = config.param_scan_axis
396+
# Workaround to load the HF model due to mismatched tensor ordering
397+
if len(hf_source_keys) == config.base_num_decoder_layers:
398+
if getattr(config, "enable_nnx", False):
399+
axis_to_stack = 0
400+
else:
401+
axis_to_stack = config.param_scan_axis
392402
else:
393-
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
394403
axis_to_stack = 0
395-
396404
# The hook function needs the shape of an individual slice, not the full stacked tensor.
397405
# We calculate it by removing the stacking dimension from the final target shape.
398406
mt_slice_shape_list = list(target_shape)

tests/checkpoint_compare.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import orbax.checkpoint as ocp
4+
import numpy as np
5+
from typing import Any, Dict, Sequence, Tuple
6+
from jax.tree_util import tree_flatten_with_path, keystr, tree_structure, tree_map_with_path
7+
from absl import app
8+
from absl import flags
9+
10+
11+
_LINEN_CKPT_PATH = flags.DEFINE_string(
12+
"linen_ckpt_path", None, "Path to the Linen model checkpoint items directory.", required=True
13+
)
14+
_NNX_CKPT_PATH = flags.DEFINE_string(
15+
"nnx_ckpt_path", None, "Path to the NNX model checkpoint items directory.", required=True
16+
)
17+
18+
19+
def load_checkpoint_params(path: str) -> Dict[str, Any]:
20+
"""Loads parameters from an Orbax checkpoint path."""
21+
print(f"Loading checkpoint from: {path}")
22+
checkpointer = ocp.PyTreeCheckpointer()
23+
restored_state = checkpointer.restore(path)
24+
if restored_state is None:
25+
raise ValueError(f"Failed to restore checkpoint from {path}")
26+
if isinstance(restored_state, dict) and "params" in restored_state:
27+
return restored_state["params"]
28+
return restored_state
29+
30+
31+
def transform_nnx_params(nnx_params: Dict[str, Any]) -> Dict[str, Any]:
32+
"""Applies specific transformations with verbose logging matching original format."""
33+
34+
def _transform(path, leaf: jax.Array) -> jax.Array:
35+
key_str = keystr(path)
36+
37+
if "layers" in key_str and hasattr(leaf, "ndim") and leaf.ndim >= 2:
38+
print(f"TRANSPOSING: {key_str} with shape {leaf.shape}")
39+
axes = (1, 0) + tuple(range(2, leaf.ndim))
40+
return jnp.transpose(leaf, axes=axes)
41+
else:
42+
if "token_embedder" in key_str:
43+
print(f"SKIPPING Transpose: {key_str} because it is token_embedder")
44+
else:
45+
shape = getattr(leaf, "shape", "N/A")
46+
print(f"SKIPPING Transpose: {key_str} with shape {shape} (ndim < 2)")
47+
return leaf
48+
49+
print("Applying transformations to NNX params...")
50+
return tree_map_with_path(_transform, nnx_params)
51+
52+
53+
def get_tree_structure_info(tree: Dict[str, Any]):
54+
"""Helper only used if structures differ."""
55+
flat_with_path, _ = tree_flatten_with_path(tree)
56+
return {keystr(p): (getattr(l, "shape", "N/A"), str(getattr(l, "dtype", type(l).__name__))) for p, l in flat_with_path}
57+
58+
59+
def print_structure_diff(params1, params2):
60+
"""Prints missing/added keys if structures differ."""
61+
info1 = get_tree_structure_info(params1)
62+
info2 = get_tree_structure_info(params2)
63+
keys1, keys2 = set(info1.keys()), set(info2.keys())
64+
65+
for k in sorted(keys2 - keys1):
66+
print(f" + Added in NNX: {k}")
67+
for k in sorted(keys1 - keys2):
68+
print(f" - Missing in NNX: {k}")
69+
70+
71+
def compare_params(params1: Dict[str, Any], params2: Dict[str, Any]) -> bool:
72+
if tree_structure(params1) != tree_structure(params2):
73+
print("[] Tree structures differ.")
74+
print_structure_diff(params1, params2)
75+
return False
76+
77+
print("[] Tree structures are the same.")
78+
79+
all_match = True
80+
81+
def _compare_leaf(path, x, y):
82+
nonlocal all_match
83+
key_str = keystr(path)
84+
85+
try:
86+
shape1 = getattr(x, "shape", "N/A")
87+
shape2 = getattr(y, "shape", "N/A")
88+
89+
if shape1 != shape2:
90+
print(f"[{key_str}] SHAPE MISMATCH: {shape1} vs {shape2}")
91+
all_match = False
92+
return
93+
94+
dtype1 = getattr(x, "dtype", type(x))
95+
dtype2 = getattr(y, "dtype", type(y))
96+
97+
if dtype1 != dtype2:
98+
print(f"[{key_str}] DTYPE MISMATCH: {dtype1} vs {dtype2}")
99+
all_match = False
100+
return
101+
102+
diff = x - y
103+
abs_diff = jnp.abs(diff)
104+
mean_diff_scalar = jnp.mean(abs_diff)
105+
max_diff_scalar = jnp.max(abs_diff)
106+
is_close_scalar = jnp.allclose(x, y)
107+
108+
mean_diff = float(mean_diff_scalar)
109+
max_diff = float(max_diff_scalar)
110+
is_close = bool(is_close_scalar)
111+
112+
print(
113+
f"[{key_str}] "
114+
f"Shape(Linen/NNX): {shape1} / {shape2} — "
115+
f"Mean abs diff: {mean_diff:.2e}, "
116+
f"Max abs diff: {max_diff:.2e}, "
117+
f"AllClose: {is_close}"
118+
)
119+
120+
if not is_close:
121+
all_match = False
122+
123+
except Exception as e:
124+
print(f"[{key_str}] Error during comparison: {e}")
125+
all_match = False
126+
127+
tree_map_with_path(_compare_leaf, params1, params2)
128+
129+
return all_match
130+
131+
132+
def main(argv: Sequence[str]):
133+
if len(argv) > 1:
134+
raise app.UsageError("Too many command-line arguments.")
135+
136+
linen_ckpt_path = _LINEN_CKPT_PATH.value
137+
nnx_ckpt_path = _NNX_CKPT_PATH.value
138+
139+
print(f"Linen Checkpoint Path: {linen_ckpt_path}")
140+
print(f"NNX Checkpoint Path: {nnx_ckpt_path}")
141+
142+
print("Loading Linen params...")
143+
linen_params = load_checkpoint_params(linen_ckpt_path)
144+
print("Loading NNX params...")
145+
nnx_params = load_checkpoint_params(nnx_ckpt_path)
146+
147+
if linen_params is not None and nnx_params is not None:
148+
nnx_params_transformed = transform_nnx_params(nnx_params)
149+
150+
print("\nComparing Linen params with Transformed NNX params...")
151+
if compare_params(linen_params, nnx_params_transformed):
152+
print("\nCheckpoints are considered the same (within np.allclose tolerance) after transformation!")
153+
else:
154+
print("\nCheckpoints DIFFER after transformation.")
155+
else:
156+
print("Failed to load params from one or both checkpoints.")
157+
158+
159+
if __name__ == "__main__":
160+
app.run(main)

tests/unit/multi_token_prediction_test.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,23 @@ def setUp(self):
5555
devices_array = maxtext_utils.create_device_mesh(self.cfg)
5656
self.mesh = Mesh(devices_array, self.cfg.mesh_axes)
5757

58-
# Instantiate the Layer
59-
self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayer(
60-
config=self.cfg,
61-
mesh=self.mesh,
62-
layer_number=TEST_LAYER_NUM,
63-
transformer_layer_module=DecoderLayer,
64-
rngs=self.rngs,
65-
)
58+
if self.cfg.pure_nnx:
59+
# Instantiate the Layer
60+
self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayer(
61+
config=self.cfg,
62+
mesh=self.mesh,
63+
layer_number=TEST_LAYER_NUM,
64+
transformer_layer_module=DecoderLayer,
65+
rngs=self.rngs,
66+
)
67+
else:
68+
# Instantiate the Layer
69+
self.mtp_layer = multi_token_prediction.MultiTokenPredictionLayerLinen(
70+
config=self.cfg,
71+
mesh=self.mesh,
72+
layer_number=TEST_LAYER_NUM,
73+
transformer_layer_module=DecoderLayer,
74+
)
6675

6776
# Dimensions directly from the config object
6877
self.batch_size = int(self.cfg.per_device_batch_size)

0 commit comments

Comments
 (0)