Skip to content

Commit d8b3334

Browse files
committed
Store fused MoE wi weight as (G,K,2N) when fused_mlp=True
When fused_mlp is enabled, initialize self.wi as a single (G,K,2N) parameter instead of two separate wi_0/wi_1 (G,K,N) tensors. This loads expert weights from HBM once per forward pass; the concat in sparse_matmul becomes a view of adjacent slices that XLA elides.
1 parent 84b4290 commit d8b3334

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def __init__(
410410
self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
411411
self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
412412
self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim))
413-
elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa":
413+
elif (self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa") or self.config.fused_mlp:
414414
self.wi = nnx.Param(
415415
self.kernel_init(
416416
self.rngs.params(),
@@ -1319,7 +1319,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13191319
)
13201320

13211321
if self.config.fused_mlp:
1322-
# Fuse wi_0 and wi_1: [G,K,N] + [G,K,N] -> [G,K,2N], one GEMM, split result.
1322+
# Weights are stored as (G,K,2N); w0/w1 are adjacent slices so XLA elides this concat.
13231323
w_fused = jnp.concatenate([w0, w1], axis=-1)
13241324
out = gmm_fn(x, w_fused, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
13251325
n = w0.shape[-1]
@@ -2159,6 +2159,11 @@ def __call__(
21592159
w1_kernel = None
21602160
if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa":
21612161
fused_kernel = jnp.asarray(self.wi[...], self.dtype)
2162+
elif cfg.fused_mlp:
2163+
wi = jnp.asarray(self.wi[...], self.dtype)
2164+
n = wi.shape[-1] // 2
2165+
w0_kernel = wi[..., :n]
2166+
w1_kernel = wi[..., n:]
21622167
else:
21632168
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
21642169
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)

0 commit comments

Comments
 (0)