Skip to content

Commit 134d7b4

Browse files
committed
Add production-safe local MoE MLP backends
1 parent 6b640d9 commit 134d7b4

4 files changed

Lines changed: 735 additions & 34 deletions

File tree

experiments/grug/moe/model.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
from jax.experimental.shard_map import shard_map
2727
from jaxtyping import Array, Float, Int, PRNGKeyArray
2828
from levanter.grug.attention import AttentionMask, RotaryConfig, align_kv_heads, apply_rotary_embedding, attention
29-
from levanter.grug.grug_moe import MoeActivation, MoeImplementation, moe_mlp
29+
from levanter.grug.grug_moe import (
30+
MoeActivation,
31+
MoEExpertMlp,
32+
MoeImplementation,
33+
resolve_moe_implementation,
34+
)
3035
from levanter.grug.loss import fused_linear_softmax_cross_entropy_loss
3136
from levanter.grug.sharding import Pembed_vocab, Plm_head, unshard
3237
from levanter.tracker.histogram import Histogram
@@ -89,6 +94,7 @@ def __post_init__(self) -> None:
8994
raise ValueError("num_experts_per_token must be <= num_experts")
9095
if self.shared_expert_intermediate_dim < 0:
9196
raise ValueError("shared_expert_intermediate_dim must be non-negative")
97+
resolve_moe_implementation(self.moe_implementation)
9298

9399
@property
94100
def inferred_head_dim(self) -> int:
@@ -312,32 +318,33 @@ class MoEMLP(eqx.Module):
312318

313319
router: jax.Array
314320
router_bias: jax.Array
315-
w_gate_up: jax.Array
316-
w_down: jax.Array
321+
expert_mlp: MoEExpertMlp
317322
cfg: GrugModelConfig = eqx.field(static=True)
318323

319324
@staticmethod
320325
def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MoEMLP":
321-
k_router, k_gate, k_up, k_down = random.split(key, 4)
326+
k_router, k_expert_mlp = random.split(key, 2)
322327
mesh = get_abstract_mesh()
323328

324329
expert_axis_size = _mesh_axis_size(mesh, "expert")
325330
if cfg.num_experts % expert_axis_size != 0:
326331
raise ValueError(f"num_experts={cfg.num_experts} must be divisible by expert axis size={expert_axis_size}")
327332

328333
d, e, i = cfg.hidden_dim, cfg.num_experts, cfg.intermediate_dim
329-
w_gate = _init_weight(k_gate, (e, d, i), cfg.initializer_std)
330-
w_up = _init_weight(k_up, (e, d, i), cfg.initializer_std)
331-
# TODO: Explore whether concatenating gate/up at init (instead of keeping separate params)
332-
# is (1) a meaningful MFU speedup and (2) a meaningful perf hit due to AdamH treating the
333-
# concatenated tensor as a single parameter for its scale-invariant norm computation.
334-
w_gate_up = jnp.concatenate([w_gate, w_up], axis=-1)
335334

336335
return MoEMLP(
337336
router=reshard(_init_weight(k_router, (d, e), cfg.initializer_std), P(None, None)),
338337
router_bias=jnp.zeros((e,)),
339-
w_gate_up=reshard(w_gate_up, P("expert", "data", "model")),
340-
w_down=reshard(_init_weight(k_down, (e, i, d), cfg.initializer_std), P("expert", "model", "data")),
338+
expert_mlp=MoEExpertMlp.init(
339+
num_experts=e,
340+
hidden_dim=d,
341+
intermediate_dim=i,
342+
initializer_std=cfg.initializer_std,
343+
key=k_expert_mlp,
344+
implementation=cfg.moe_implementation,
345+
activation=ActivationFunctionEnum.silu,
346+
capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR,
347+
),
341348
cfg=cfg,
342349
)
343350

@@ -389,16 +396,11 @@ def _local_qb_beta(s_ma):
389396
out_specs=P(),
390397
)(s_minus_alpha)
391398

392-
routed_flat = moe_mlp(
399+
routed_flat = self.expert_mlp(
393400
x_flat,
394401
selected_experts.astype(jnp.int32),
395402
combine_weights,
396-
self.w_gate_up,
397-
self.w_down,
398-
activation=ActivationFunctionEnum.silu,
399-
implementation=self.cfg.moe_implementation,
400403
mesh=get_abstract_mesh(),
401-
capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR,
402404
)
403405

404406
routed = rearrange(routed_flat, "(b s) d -> b s d", b=b, s=s)
@@ -592,5 +594,4 @@ def debug_mesh_and_token_pspec(num_devices: int) -> tuple[jax.sharding.AbstractM
592594
"RMSNorm",
593595
"Transformer",
594596
"debug_mesh_and_token_pspec",
595-
"moe_mlp",
596597
]
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright The Levanter Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Custom-VJP down/gather implementation for local Grug MoE."""
5+
6+
from __future__ import annotations
7+
8+
import jax
9+
import jax.numpy as jnp
10+
from haliax.nn.ragged_dot import ragged_dot
11+
from levanter.grug.grug_moe import _gather_sum_reference
12+
13+
14+
def _custom_vjp_down_bwd(
15+
dout: jax.Array,
16+
h_interleaved: jax.Array,
17+
w_down: jax.Array,
18+
combine_weights: jax.Array,
19+
token_ids_sort: jax.Array,
20+
sorted_assignment_ids: jax.Array,
21+
expert_frequency_offset: jax.Array,
22+
dispatch_output: jax.Array,
23+
) -> tuple[jax.Array, jax.Array, jax.Array]:
24+
group_sizes = jnp.diff(expert_frequency_offset)
25+
assignments = h_interleaved.shape[0]
26+
sorted_scores = combine_weights.reshape(assignments)[sorted_assignment_ids].astype(jnp.float32)
27+
dout_sorted = dout[token_ids_sort]
28+
29+
def activation_forward(h: jax.Array) -> jax.Array:
30+
gate = h[:, 0::2]
31+
up = h[:, 1::2]
32+
return jax.nn.silu(gate) * up
33+
34+
hidden, activation_pullback = jax.vjp(activation_forward, h_interleaved)
35+
weighted_dout = (dout_sorted.astype(jnp.float32) * sorted_scores[:, None]).astype(dispatch_output.dtype)
36+
_, down_pullback = jax.vjp(lambda h, w: ragged_dot(h, w, group_sizes), hidden, w_down)
37+
d_hidden, d_w_down = down_pullback(weighted_dout)
38+
(d_h_interleaved,) = activation_pullback(d_hidden)
39+
d_scores_sorted = jnp.sum(dout_sorted.astype(jnp.float32) * dispatch_output.astype(jnp.float32), axis=-1)
40+
d_scores = jnp.zeros_like(sorted_scores).at[sorted_assignment_ids].set(d_scores_sorted)
41+
return d_h_interleaved, d_scores, d_w_down
42+
43+
44+
@jax.custom_vjp
45+
def custom_vjp_interleaved_down_gather_sum(
46+
w13_out_interleaved: jax.Array,
47+
combine_weights: jax.Array,
48+
w_down: jax.Array,
49+
token_ids_sort: jax.Array,
50+
sorted_assignment_ids: jax.Array,
51+
dispatch_positions: jax.Array,
52+
group_sizes: jax.Array,
53+
) -> jax.Array:
54+
out, _ = _custom_vjp_interleaved_down_gather_sum_forward(
55+
w13_out_interleaved,
56+
combine_weights,
57+
w_down,
58+
token_ids_sort,
59+
sorted_assignment_ids,
60+
dispatch_positions,
61+
group_sizes,
62+
)
63+
return out
64+
65+
66+
def _custom_vjp_interleaved_down_gather_sum_forward(
67+
w13_out_interleaved: jax.Array,
68+
combine_weights: jax.Array,
69+
w_down: jax.Array,
70+
token_ids_sort: jax.Array,
71+
sorted_assignment_ids: jax.Array,
72+
dispatch_positions: jax.Array,
73+
group_sizes: jax.Array,
74+
) -> tuple[jax.Array, tuple[jax.Array, ...]]:
75+
del sorted_assignment_ids
76+
hidden = jax.nn.silu(w13_out_interleaved[:, 0::2]) * w13_out_interleaved[:, 1::2]
77+
dispatch_output = ragged_dot(hidden, w_down, group_sizes)
78+
out = _gather_sum_reference(dispatch_output, dispatch_positions, combine_weights)
79+
expert_frequency_offset = jnp.concatenate(
80+
[jnp.zeros((1,), dtype=jnp.int32), jnp.cumsum(group_sizes, dtype=jnp.int32)]
81+
)
82+
return out, (
83+
w13_out_interleaved,
84+
combine_weights,
85+
w_down,
86+
token_ids_sort,
87+
expert_frequency_offset,
88+
dispatch_output,
89+
)
90+
91+
92+
def _custom_vjp_interleaved_down_gather_sum_fwd(
93+
w13_out_interleaved: jax.Array,
94+
combine_weights: jax.Array,
95+
w_down: jax.Array,
96+
token_ids_sort: jax.Array,
97+
sorted_assignment_ids: jax.Array,
98+
dispatch_positions: jax.Array,
99+
group_sizes: jax.Array,
100+
) -> tuple[jax.Array, tuple[jax.Array, ...]]:
101+
out, residuals = _custom_vjp_interleaved_down_gather_sum_forward(
102+
w13_out_interleaved,
103+
combine_weights,
104+
w_down,
105+
token_ids_sort,
106+
sorted_assignment_ids,
107+
dispatch_positions,
108+
group_sizes,
109+
)
110+
return out, (*residuals, sorted_assignment_ids)
111+
112+
113+
def _custom_vjp_interleaved_down_gather_sum_bwd(
114+
residuals: tuple[jax.Array, ...],
115+
dout: jax.Array,
116+
) -> tuple[jax.Array, jax.Array, jax.Array, None, None, None, None]:
117+
(
118+
w13_out_interleaved,
119+
combine_weights,
120+
w_down,
121+
token_ids_sort,
122+
expert_frequency_offset,
123+
dispatch_output,
124+
sorted_assignment_ids,
125+
) = residuals
126+
d_h_interleaved, d_scores_flat, d_w_down = _custom_vjp_down_bwd(
127+
dout,
128+
w13_out_interleaved,
129+
w_down,
130+
combine_weights,
131+
token_ids_sort,
132+
sorted_assignment_ids,
133+
expert_frequency_offset,
134+
dispatch_output,
135+
)
136+
d_combine_weights = d_scores_flat.reshape(combine_weights.shape).astype(combine_weights.dtype)
137+
return d_h_interleaved, d_combine_weights, d_w_down.astype(w_down.dtype), None, None, None, None
138+
139+
140+
custom_vjp_interleaved_down_gather_sum.defvjp(
141+
_custom_vjp_interleaved_down_gather_sum_fwd,
142+
_custom_vjp_interleaved_down_gather_sum_bwd,
143+
)

0 commit comments

Comments
 (0)