|
26 | 26 | from jax.experimental.shard_map import shard_map |
27 | 27 | from jaxtyping import Array, Float, Int, PRNGKeyArray |
28 | 28 | from levanter.grug.attention import AttentionMask, RotaryConfig, align_kv_heads, apply_rotary_embedding, attention |
29 | | -from levanter.grug.grug_moe import MoeActivation, MoeImplementation, moe_mlp |
| 29 | +from levanter.grug.grug_moe import ( |
| 30 | + MoeActivation, |
| 31 | + MoEExpertMlp, |
| 32 | + MoeImplementation, |
| 33 | + resolve_moe_implementation, |
| 34 | +) |
30 | 35 | from levanter.grug.loss import fused_linear_softmax_cross_entropy_loss |
31 | 36 | from levanter.grug.sharding import Pembed_vocab, Plm_head, unshard |
32 | 37 | from levanter.tracker.histogram import Histogram |
@@ -89,6 +94,7 @@ def __post_init__(self) -> None: |
89 | 94 | raise ValueError("num_experts_per_token must be <= num_experts") |
90 | 95 | if self.shared_expert_intermediate_dim < 0: |
91 | 96 | raise ValueError("shared_expert_intermediate_dim must be non-negative") |
| 97 | + resolve_moe_implementation(self.moe_implementation) |
92 | 98 |
|
93 | 99 | @property |
94 | 100 | def inferred_head_dim(self) -> int: |
@@ -312,32 +318,33 @@ class MoEMLP(eqx.Module): |
312 | 318 |
|
313 | 319 | router: jax.Array |
314 | 320 | router_bias: jax.Array |
315 | | - w_gate_up: jax.Array |
316 | | - w_down: jax.Array |
| 321 | + expert_mlp: MoEExpertMlp |
317 | 322 | cfg: GrugModelConfig = eqx.field(static=True) |
318 | 323 |
|
319 | 324 | @staticmethod |
320 | 325 | def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MoEMLP": |
321 | | - k_router, k_gate, k_up, k_down = random.split(key, 4) |
| 326 | + k_router, k_expert_mlp = random.split(key, 2) |
322 | 327 | mesh = get_abstract_mesh() |
323 | 328 |
|
324 | 329 | expert_axis_size = _mesh_axis_size(mesh, "expert") |
325 | 330 | if cfg.num_experts % expert_axis_size != 0: |
326 | 331 | raise ValueError(f"num_experts={cfg.num_experts} must be divisible by expert axis size={expert_axis_size}") |
327 | 332 |
|
328 | 333 | d, e, i = cfg.hidden_dim, cfg.num_experts, cfg.intermediate_dim |
329 | | - w_gate = _init_weight(k_gate, (e, d, i), cfg.initializer_std) |
330 | | - w_up = _init_weight(k_up, (e, d, i), cfg.initializer_std) |
331 | | - # TODO: Explore whether concatenating gate/up at init (instead of keeping separate params) |
332 | | - # is (1) a meaningful MFU speedup and (2) a meaningful perf hit due to AdamH treating the |
333 | | - # concatenated tensor as a single parameter for its scale-invariant norm computation. |
334 | | - w_gate_up = jnp.concatenate([w_gate, w_up], axis=-1) |
335 | 334 |
|
336 | 335 | return MoEMLP( |
337 | 336 | router=reshard(_init_weight(k_router, (d, e), cfg.initializer_std), P(None, None)), |
338 | 337 | router_bias=jnp.zeros((e,)), |
339 | | - w_gate_up=reshard(w_gate_up, P("expert", "data", "model")), |
340 | | - w_down=reshard(_init_weight(k_down, (e, i, d), cfg.initializer_std), P("expert", "model", "data")), |
| 338 | + expert_mlp=MoEExpertMlp.init( |
| 339 | + num_experts=e, |
| 340 | + hidden_dim=d, |
| 341 | + intermediate_dim=i, |
| 342 | + initializer_std=cfg.initializer_std, |
| 343 | + key=k_expert_mlp, |
| 344 | + implementation=cfg.moe_implementation, |
| 345 | + activation=ActivationFunctionEnum.silu, |
| 346 | + capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR, |
| 347 | + ), |
341 | 348 | cfg=cfg, |
342 | 349 | ) |
343 | 350 |
|
@@ -389,16 +396,11 @@ def _local_qb_beta(s_ma): |
389 | 396 | out_specs=P(), |
390 | 397 | )(s_minus_alpha) |
391 | 398 |
|
392 | | - routed_flat = moe_mlp( |
| 399 | + routed_flat = self.expert_mlp( |
393 | 400 | x_flat, |
394 | 401 | selected_experts.astype(jnp.int32), |
395 | 402 | combine_weights, |
396 | | - self.w_gate_up, |
397 | | - self.w_down, |
398 | | - activation=ActivationFunctionEnum.silu, |
399 | | - implementation=self.cfg.moe_implementation, |
400 | 403 | mesh=get_abstract_mesh(), |
401 | | - capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR, |
402 | 404 | ) |
403 | 405 |
|
404 | 406 | routed = rearrange(routed_flat, "(b s) d -> b s d", b=b, s=s) |
@@ -592,5 +594,4 @@ def debug_mesh_and_token_pspec(num_devices: int) -> tuple[jax.sharding.AbstractM |
592 | 594 | "RMSNorm", |
593 | 595 | "Transformer", |
594 | 596 | "debug_mesh_and_token_pspec", |
595 | | - "moe_mlp", |
596 | 597 | ] |
0 commit comments