Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions src/MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,29 @@
"""Transformer model definition."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

from typing import Optional
from flax import linen as nn
from flax import nnx

from jax.ad_checkpoint import checkpoint_name
import jax.numpy as jnp
from jax.sharding import Mesh
import jax.numpy as jnp

from flax import linen as nn
from flax import nnx

from MaxText import max_utils
from MaxText.common_types import Config, MODEL_MODE_PREFILL
from MaxText.common_types import Config
from MaxText.common_types import MODEL_MODE_PREFILL
from MaxText.inference import page_manager
from MaxText.layers import attention_mla, initializers, linears, moe, nnx_wrappers, quantizations
from MaxText.layers import attention_mla
from MaxText.layers import initializers
from MaxText.layers import linears
from MaxText.layers import moe
from MaxText.layers import nnx_wrappers
from MaxText.layers import quantizations
from MaxText.layers.linears import Dropout
from MaxText.layers.normalizations import RMSNorm

# -----------------------------------------
# The Decoder Layer for DeepSeek v3
# -----------------------------------------
Expand Down Expand Up @@ -54,9 +65,7 @@ def __init__(
self.quant = quant
self.rngs = rngs

batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(
self.config, self.model_mode
)
batch_size, sequence_length = max_utils.get_batch_seq_len_for_mode(self.config, self.model_mode)
self.dummy_inputs_shape = (batch_size, sequence_length, self.config.emb_dim)

self.pre_self_attention_layer_norm = RMSNorm(
Expand Down Expand Up @@ -108,9 +117,7 @@ def __init__(
rngs=rngs,
)

self.dropout = Dropout(
rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs
)
self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs)

def __call__(
self,
Expand Down Expand Up @@ -151,9 +158,7 @@ def with_logical_constraint(self, x):
return nn.with_logical_constraint(x, self.logical_axis_names)

def dropout_op(self, x, deterministic):
return self.with_logical_constraint(
self.dropout(x, deterministic=deterministic)
)
return self.with_logical_constraint(self.dropout(x, deterministic=deterministic))

def pre_attention_norm_op(self, x):
return self.with_logical_constraint(self.pre_self_attention_layer_norm(x))
Expand Down Expand Up @@ -300,9 +305,7 @@ def __init__(
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
config=self.config,
mesh=mesh,
kernel_init=initializers.nd_dense_init(
1.0, "fan_in", "truncated_normal"
),
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
Expand Down
11 changes: 7 additions & 4 deletions src/MaxText/layerwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from absl import app

from flax.linen import partitioning as nn_partitioning

from flax import nnx

from MaxText import checkpointing
from MaxText import max_utils
Expand Down Expand Up @@ -93,9 +93,12 @@ def load_and_quantize(self, rng: None | PRNGKeyType = None) -> None:

self.quant.quant_mode = quantizations.get_quant_mode("convert")

model_mode = common_types.MODEL_MODE_PREFILL
rngs = nnx.Rngs(0)

layers = [
deepseek.DeepSeekDenseLayer(config, mesh=self._mesh, quant=self.quant),
deepseek.DeepSeekMoELayer(config, mesh=self._mesh, quant=self.quant),
deepseek.DeepSeekDenseLayerToLinen(config, mesh=self._mesh, quant=self.quant, model_mode=model_mode, rngs=rngs),
deepseek.DeepSeekMoELayerToLinen(config, mesh=self._mesh, quant=self.quant, model_mode=model_mode, rngs=rngs),
Comment on lines -97 to +101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be changed here?

Copy link
Contributor Author

@mesakhcienet mesakhcienet Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use nnx module of DeepSeekDenseLayer or DeepSeekMoELayer instead of linen converted DeepSeekDenseLayerToLinen or DeepSeekMoELayerToLinen, these parts will fail. Apologize I didn't screenshot the error, I remember that the error comes from the unit test.

Also, when calling this function, I believe the model should be on linen instead of nnx since .apply function is not a function of nnx module.

]
layer_prefixes = ["dense_layers", "moe_layers"]
num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers
Expand All @@ -108,7 +111,7 @@ def model_apply(_p, _rng, layer):
None,
jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32),
True,
model_mode=common_types.MODEL_MODE_PREFILL,
model_mode=model_mode,
rngs={"params": _rng},
mutable=True,
)
Expand Down
10 changes: 3 additions & 7 deletions tests/pipeline_parallelism_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,9 @@ def assert_pipeline_same_output_and_grad(self, config, single_pipeline_stage_cla
else:
if issubclass(single_pipeline_stage_class, nnx_wrappers.ToLinen):
rngs = nnx.Rngs(params=0)
single_pipeline_stage = single_pipeline_stage_class(
config=config, mesh=mesh, model_mode=model_mode, rngs=rngs
)
single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode, rngs=rngs)
else:
single_pipeline_stage = single_pipeline_stage_class(
config=config, mesh=mesh, model_mode=model_mode
)
single_pipeline_stage = single_pipeline_stage_class(config=config, mesh=mesh, model_mode=model_mode)

def get_inputs(batch_size, sequence, features):
"""Get random inputs, and random dummy targets
Expand Down Expand Up @@ -238,7 +234,7 @@ def test_circular_deepseek_megablox_same_output_and_grad(self):
capacity_factor=1,
decoder_block="deepseek",
)
self.assert_pipeline_same_output_and_grad(config, single_pipeline_stage_class=deepseek.DeepSeekMoELayer)
self.assert_pipeline_same_output_and_grad(config, single_pipeline_stage_class=deepseek.DeepSeekMoELayerToLinen)

@pytest.mark.tpu_only
def test_circular_ag_once(self):
Expand Down