Skip to content

Commit 99cca7c

Browse files
dlwhclaude[bot]
andauthored
grug/moe: restore aux-loss metrics and remove smoke launcher (#3229)
## Summary - restore MoE router auxiliary metrics/loss logging in `experiments/grug/moe/model.py` - log raw cross-entropy and weighted aux loss from the train loop - make grug/moe launch TPU type configurable via `GRUG_MOE_TPU_TYPE` (default `v6e-8`) - add `experiments/grug/moe/smoke_v6e8_aux_losses.py` for small aux-loss smoke launches - merge latest `origin/main` into this branch ## Validation - `./infra/pre-commit.py --all-files` Fixes #3196 --------- Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
1 parent f5d3a8e commit 99cca7c

3 files changed

Lines changed: 70 additions & 9 deletions

File tree

experiments/grug/moe/launch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from levanter.tracker import TrackerConfig
2222
from levanter.tracker.wandb import WandbConfig
2323
from levanter.trainer import TrainerConfig
24+
from levanter.utils.mesh import MeshConfig
2425
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
2526
from marin.processing.tokenize import add_validation_sets_to_mixture
2627

@@ -98,6 +99,7 @@ def run_grug_moe_trial(config: GrugMoeLaunchConfig) -> None:
9899
mp=jmp.get_policy(config.mp),
99100
tracker=_resolve_tracker(config.tracker, config.run_id),
100101
use_explicit_mesh_axes=True,
102+
mesh=MeshConfig(axes={"expert": 1}),
101103
require_accelerator=True,
102104
allow_nondivisible_batch_size=False,
103105
checkpointer=CheckpointerConfig(

experiments/grug/moe/model.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import equinox as eqx
1616
import jax
1717
import jax.numpy as jnp
18+
import jax.scipy as jsp
1819
from einops import rearrange
1920
from haliax.jax_utils import named_call
2021
from jax import random
@@ -59,6 +60,8 @@ class GrugModelConfig:
5960
max_seq_len: int = 4096
6061
layer_norm_eps: float = 1e-5
6162
initializer_std: float = 0.02
63+
load_balancing_loss_coef: float | None = 0.01
64+
router_z_loss_coef: float | None = 0.001
6265
rope: RotaryConfig = dataclasses.field(default_factory=RotaryConfig)
6366

6467
def __post_init__(self) -> None:
@@ -77,6 +80,10 @@ def __post_init__(self) -> None:
7780
raise ValueError("num_experts_per_token must be <= num_experts")
7881
if self.shared_expert_intermediate_dim < 0:
7982
raise ValueError("shared_expert_intermediate_dim must be non-negative")
83+
if self.load_balancing_loss_coef is not None and self.load_balancing_loss_coef < 0:
84+
raise ValueError("load_balancing_loss_coef must be non-negative when set")
85+
if self.router_z_loss_coef is not None and self.router_z_loss_coef < 0:
86+
raise ValueError("router_z_loss_coef must be non-negative when set")
8087

8188
@property
8289
def inferred_head_dim(self) -> int:
@@ -175,31 +182,58 @@ def __call__(
175182
return rearrange(out_flat, "(b s) d -> b s d", b=b, s=s)
176183

177184

178-
def _routing_stats_from_selected_experts(
185+
def _routing_stats(
179186
selected_experts: Int[Array, "T K"],
187+
router_probs: Float[Array, "T E"],
188+
router_logits: Float[Array, "T E"],
180189
*,
181190
num_experts: int,
191+
num_experts_per_token: int,
182192
) -> dict[str, jax.Array]:
193+
router_probs_f = router_probs.astype(jnp.float32)
194+
router_logits_f = router_logits.astype(jnp.float32)
183195
expert_counts = jnp.sum(jax.nn.one_hot(selected_experts, num_experts, dtype=jnp.float32), axis=(0, 1))
184196
total_assignments = jnp.maximum(jnp.sum(expert_counts), 1.0)
185-
expert_loads = expert_counts / total_assignments
186-
routing_entropy = -jnp.sum(expert_loads * jnp.log(expert_loads + 1e-6))
197+
assignment_fraction = expert_counts / total_assignments
198+
routing_entropy = -jnp.sum(assignment_fraction * jnp.log(assignment_fraction + 1e-6))
199+
# Match the Switch/OLMoE-style scaling: E * sum_i(f_i * p_i), where
200+
# f_i is token fraction for expert i (counts per token, not per assignment).
201+
# assignment_fraction sums to 1 over assignments, so convert with top-k.
202+
token_fraction = assignment_fraction * num_experts_per_token
203+
p = jnp.mean(router_probs_f, axis=0)
204+
load_balancing_loss = num_experts * jnp.sum(token_fraction * p)
205+
z = jsp.special.logsumexp(router_logits_f, axis=-1)
206+
router_z_loss = jnp.mean(z**2)
207+
187208
return {
188209
"routing_counts": expert_counts,
189210
"routing_entropy": routing_entropy,
211+
"load_balancing_loss": load_balancing_loss,
212+
"router_z_loss": router_z_loss,
190213
}
191214

192215

193216
def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str, jax.Array | Histogram]:
194217
routing_entropy = router_metrics["routing_entropy_per_layer"]
195218
routing_counts = router_metrics["routing_counts_per_layer"]
219+
load_balancing_loss = router_metrics["load_balancing_loss_per_layer"]
220+
router_z_loss = router_metrics["router_z_loss_per_layer"]
196221
num_layers = int(routing_entropy.shape[0])
222+
aux_loss_per_layer = load_balancing_loss + router_z_loss
197223

198224
out: dict[str, jax.Array | Histogram] = {
199225
"train/router/routing_entropy_mean": jnp.mean(routing_entropy),
226+
# Match MaxText + Megatron/Nemotron practice: log layer-mean raw
227+
# router terms for comparability across depth.
228+
"train/router/load_balancing_loss": jnp.mean(load_balancing_loss),
229+
"train/router/router_z_loss": jnp.mean(router_z_loss),
230+
# Keep aux loss as a per-step aggregate while exposing mean terms above.
231+
"train/router/aux_loss": jnp.sum(aux_loss_per_layer),
200232
}
201233
for i in range(num_layers):
202234
out[f"train/router/layer_{i}/routing_entropy"] = routing_entropy[i]
235+
out[f"train/router/layer_{i}/load_balancing_loss"] = load_balancing_loss[i]
236+
out[f"train/router/layer_{i}/router_z_loss"] = router_z_loss[i]
203237
out[f"train/router/layer_{i}/routing_hist"] = _histogram_from_expert_counts(routing_counts[i])
204238
return out
205239

@@ -266,9 +300,16 @@ def __call__(
266300
b, s, _ = x.shape
267301
x_flat = rearrange(x, "b s d -> (b s) d")
268302
router_logits = jnp.einsum("td,de->te", x_flat, reshard(self.router, P(None, None)))
303+
router_probs = jax.nn.softmax(router_logits, axis=-1)
269304
topk_logits, selected_experts = jax.lax.top_k(router_logits, self.cfg.num_experts_per_token)
270305
combine_weights = jax.nn.softmax(topk_logits, axis=-1).astype(x.dtype)
271-
router_stats = _routing_stats_from_selected_experts(selected_experts, num_experts=self.cfg.num_experts)
306+
router_stats = _routing_stats(
307+
selected_experts,
308+
router_probs,
309+
router_logits,
310+
num_experts=self.cfg.num_experts,
311+
num_experts_per_token=self.cfg.num_experts_per_token,
312+
)
272313

273314
routed_flat = moe_mlp(
274315
x_flat,
@@ -368,6 +409,8 @@ def __call__(
368409
router_metrics = {
369410
"routing_entropy_per_layer": jnp.stack([s["routing_entropy"] for s in all_router_stats], axis=0),
370411
"routing_counts_per_layer": jnp.stack([s["routing_counts"] for s in all_router_stats], axis=0),
412+
"load_balancing_loss_per_layer": jnp.stack([s["load_balancing_loss"] for s in all_router_stats], axis=0),
413+
"router_z_loss_per_layer": jnp.stack([s["router_z_loss"] for s in all_router_stats], axis=0),
371414
}
372415
return self.final_norm(hidden), router_metrics
373416

@@ -396,7 +439,7 @@ def next_token_loss(
396439
labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32)
397440
loss_weight = loss_weight.astype(loss_dtype)
398441

399-
loss = fused_linear_softmax_cross_entropy_loss(
442+
cross_entropy_loss = fused_linear_softmax_cross_entropy_loss(
400443
hidden,
401444
self.output_proj,
402445
labels,
@@ -405,8 +448,22 @@ def next_token_loss(
405448
logsumexp_weight=logsumexp_weight,
406449
dtype=loss_dtype,
407450
)
451+
# Keep router metrics raw and apply coefficients only at the final
452+
# objective composition step (same separation as MaxText/Megatron).
453+
load_balancing_loss_coef = (
454+
0.0 if self.config.load_balancing_loss_coef is None else self.config.load_balancing_loss_coef
455+
)
456+
router_z_loss_coef = 0.0 if self.config.router_z_loss_coef is None else self.config.router_z_loss_coef
457+
aux_loss = load_balancing_loss_coef * jnp.sum(router_metrics["load_balancing_loss_per_layer"]) + (
458+
router_z_loss_coef * jnp.sum(router_metrics["router_z_loss_per_layer"])
459+
)
460+
include_aux_in_loss = reduction != "none" and (load_balancing_loss_coef != 0.0 or router_z_loss_coef != 0.0)
461+
loss = cross_entropy_loss + aux_loss if include_aux_in_loss else cross_entropy_loss
408462
if return_router_metrics:
409-
return loss, _summarize_router_metrics(router_metrics)
463+
summarized_metrics = _summarize_router_metrics(router_metrics)
464+
summarized_metrics["train/cross_entropy_loss"] = cross_entropy_loss
465+
summarized_metrics["train/router/aux_loss_weighted"] = aux_loss
466+
return loss, summarized_metrics
410467
return loss
411468

412469

experiments/grug/moe/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class GrugTrainerConfig:
5353
"""Runtime knobs for grug training."""
5454

5555
trainer: TrainerConfig = field(default_factory=lambda: TrainerConfig(use_explicit_mesh_axes=True))
56-
train_batch_pspec: P = field(default_factory=lambda: P(("data",)))
56+
train_batch_pspec: P = field(default_factory=lambda: P(("data", "expert")))
5757
data_seed: int | None = None
5858
log_every: int = 1
5959
ema_beta: float | None = None # EMA coefficient for eval/checkpoint model; None disables EMA.
@@ -65,7 +65,7 @@ class GrugEvalConfig:
6565
"""Perplexity eval settings for grug training."""
6666

6767
eval_batch_size: int = 512
68-
eval_batch_pspec: P = field(default_factory=lambda: P(("data",)))
68+
eval_batch_pspec: P = field(default_factory=lambda: P(("data", "expert")))
6969
steps_per_eval: int | None = 1000
7070
max_eval_batches: int | None = None
7171
prefix: str = "eval"
@@ -115,7 +115,7 @@ def build_train_loader(
115115
*,
116116
batch_schedule: BatchSchedule,
117117
mesh: Mesh,
118-
batch_pspec: P = P(("data",)),
118+
batch_pspec: P = P(("data", "expert")),
119119
) -> DataLoader[GrugLmExample]:
120120
# DataLoader uses this batch axis mapping to shard batches across the distributed mesh.
121121
axis_resource = batch_pspec[0]
@@ -484,6 +484,8 @@ def _init_state(model_rng):
484484
router_metrics = {key: value for key, value in metrics.items() if key.startswith("train/router/")}
485485
if router_metrics:
486486
levanter.tracker.log(router_metrics, step=step)
487+
if "train/cross_entropy_loss" in metrics:
488+
levanter.tracker.log({"train/cross_entropy_loss": metrics["train/cross_entropy_loss"]}, step=step)
487489

488490
if watch_stats is not None:
489491
levanter.tracker.log(watch_stats, step=step)

0 commit comments

Comments
 (0)