Skip to content

Commit 4a55948

Browse files
committed
Add optional QuACK MoE local backend
1 parent 13a0ef6 commit 4a55948

4 files changed

Lines changed: 931 additions & 34 deletions

File tree

experiments/grug/moe/model.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
from jax.experimental.shard_map import shard_map
2727
from jaxtyping import Array, Float, Int, PRNGKeyArray
2828
from levanter.grug.attention import AttentionMask, RotaryConfig, align_kv_heads, apply_rotary_embedding, attention
29-
from levanter.grug.grug_moe import MoeActivation, MoeImplementation, moe_mlp
29+
from levanter.grug.grug_moe import (
30+
MoeActivation,
31+
MoEExpertMlp,
32+
MoeImplementation,
33+
resolve_moe_implementation,
34+
)
3035
from levanter.grug.loss import fused_linear_softmax_cross_entropy_loss
3136
from levanter.grug.sharding import Pembed_vocab, Plm_head, unshard
3237
from levanter.tracker.histogram import Histogram
@@ -89,6 +94,7 @@ def __post_init__(self) -> None:
8994
raise ValueError("num_experts_per_token must be <= num_experts")
9095
if self.shared_expert_intermediate_dim < 0:
9196
raise ValueError("shared_expert_intermediate_dim must be non-negative")
97+
resolve_moe_implementation(self.moe_implementation)
9298

9399
@property
94100
def inferred_head_dim(self) -> int:
@@ -312,32 +318,33 @@ class MoEMLP(eqx.Module):
312318

313319
router: jax.Array
314320
router_bias: jax.Array
315-
w_gate_up: jax.Array
316-
w_down: jax.Array
321+
expert_mlp: MoEExpertMlp
317322
cfg: GrugModelConfig = eqx.field(static=True)
318323

319324
@staticmethod
320325
def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MoEMLP":
321-
k_router, k_gate, k_up, k_down = random.split(key, 4)
326+
k_router, k_expert_mlp = random.split(key, 2)
322327
mesh = get_abstract_mesh()
323328

324329
expert_axis_size = _mesh_axis_size(mesh, "expert")
325330
if cfg.num_experts % expert_axis_size != 0:
326331
raise ValueError(f"num_experts={cfg.num_experts} must be divisible by expert axis size={expert_axis_size}")
327332

328333
d, e, i = cfg.hidden_dim, cfg.num_experts, cfg.intermediate_dim
329-
w_gate = _init_weight(k_gate, (e, d, i), cfg.initializer_std)
330-
w_up = _init_weight(k_up, (e, d, i), cfg.initializer_std)
331-
# TODO: Explore whether concatenating gate/up at init (instead of keeping separate params)
332-
# is (1) a meaningful MFU speedup and (2) a meaningful perf hit due to AdamH treating the
333-
# concatenated tensor as a single parameter for its scale-invariant norm computation.
334-
w_gate_up = jnp.concatenate([w_gate, w_up], axis=-1)
335334

336335
return MoEMLP(
337336
router=reshard(_init_weight(k_router, (d, e), cfg.initializer_std), P(None, None)),
338337
router_bias=jnp.zeros((e,)),
339-
w_gate_up=reshard(w_gate_up, P("expert", "data", "model")),
340-
w_down=reshard(_init_weight(k_down, (e, i, d), cfg.initializer_std), P("expert", "model", "data")),
338+
expert_mlp=MoEExpertMlp.init(
339+
num_experts=e,
340+
hidden_dim=d,
341+
intermediate_dim=i,
342+
initializer_std=cfg.initializer_std,
343+
key=k_expert_mlp,
344+
implementation=cfg.moe_implementation,
345+
activation=ActivationFunctionEnum.silu,
346+
capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR,
347+
),
341348
cfg=cfg,
342349
)
343350

@@ -389,16 +396,11 @@ def _local_qb_beta(s_ma):
389396
out_specs=P(),
390397
)(s_minus_alpha)
391398

392-
routed_flat = moe_mlp(
399+
routed_flat = self.expert_mlp(
393400
x_flat,
394401
selected_experts.astype(jnp.int32),
395402
combine_weights,
396-
self.w_gate_up,
397-
self.w_down,
398-
activation=ActivationFunctionEnum.silu,
399-
implementation=self.cfg.moe_implementation,
400403
mesh=get_abstract_mesh(),
401-
capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR,
402404
)
403405

404406
routed = rearrange(routed_flat, "(b s) d -> b s d", b=b, s=s)
@@ -592,5 +594,4 @@ def debug_mesh_and_token_pspec(num_devices: int) -> tuple[jax.sharding.AbstractM
592594
"RMSNorm",
593595
"Transformer",
594596
"debug_mesh_and_token_pspec",
595-
"moe_mlp",
596597
]

0 commit comments

Comments
 (0)