Skip to content

Commit df28bf6

Browse files
[moe] Wire capacity overflow reporting into Grug MoE training metrics
Enable report_capacity_overflow=True in MoEMLP.__call__ and propagate dropped_count and overflow_fraction through router_stats into the summarized training metrics logged to wandb. Fixes #4016 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a243fe5 commit df28bf6

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

experiments/grug/moe/model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str,
227227
routing_counts = router_metrics["routing_counts_per_layer"]
228228
load_balancing_loss = router_metrics["load_balancing_loss_per_layer"]
229229
router_z_loss = router_metrics["router_z_loss_per_layer"]
230+
dropped_count = router_metrics["dropped_count_per_layer"]
231+
overflow_fraction = router_metrics["overflow_fraction_per_layer"]
230232
num_layers = int(routing_entropy.shape[0])
231233
aux_loss_per_layer = load_balancing_loss + router_z_loss
232234

@@ -238,12 +240,17 @@ def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str,
238240
"train/router/router_z_loss": jnp.mean(router_z_loss),
239241
# Keep aux loss as a per-step aggregate while exposing mean terms above.
240242
"train/router/aux_loss": jnp.sum(aux_loss_per_layer),
243+
# Capacity overflow: total dropped assignments and fraction across all layers.
244+
"train/router/dropped_count_total": jnp.sum(dropped_count),
245+
"train/router/overflow_fraction_mean": jnp.mean(overflow_fraction),
241246
}
242247
for i in range(num_layers):
243248
out[f"train/router/layer_{i}/routing_entropy"] = routing_entropy[i]
244249
out[f"train/router/layer_{i}/load_balancing_loss"] = load_balancing_loss[i]
245250
out[f"train/router/layer_{i}/router_z_loss"] = router_z_loss[i]
246251
out[f"train/router/layer_{i}/routing_hist"] = _histogram_from_expert_counts(routing_counts[i])
252+
out[f"train/router/layer_{i}/dropped_count"] = dropped_count[i]
253+
out[f"train/router/layer_{i}/overflow_fraction"] = overflow_fraction[i]
247254
return out
248255

249256

@@ -325,7 +332,7 @@ def __call__(
325332
num_experts_per_token=self.cfg.num_experts_per_token,
326333
)
327334

328-
routed_flat = moe_mlp(
335+
routed_flat, dropped_count = moe_mlp(
329336
x_flat,
330337
selected_experts.astype(jnp.int32),
331338
combine_weights,
@@ -335,7 +342,11 @@ def __call__(
335342
implementation=self.cfg.moe_implementation,
336343
mesh=get_abstract_mesh(),
337344
capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR,
345+
report_capacity_overflow=True,
338346
)
347+
total_assignments = x_flat.shape[0] * self.cfg.num_experts_per_token
348+
router_stats["dropped_count"] = dropped_count
349+
router_stats["overflow_fraction"] = dropped_count.astype(jnp.float32) / max(total_assignments, 1)
339350
routed = rearrange(routed_flat, "(b s) d -> b s d", b=b, s=s)
340351
routed = reshard(routed, _batch_spec())
341352
return routed, router_stats
@@ -433,6 +444,8 @@ def __call__(
433444
"routing_counts_per_layer": jnp.stack([s["routing_counts"] for s in all_router_stats], axis=0),
434445
"load_balancing_loss_per_layer": jnp.stack([s["load_balancing_loss"] for s in all_router_stats], axis=0),
435446
"router_z_loss_per_layer": jnp.stack([s["router_z_loss"] for s in all_router_stats], axis=0),
447+
"dropped_count_per_layer": jnp.stack([s["dropped_count"] for s in all_router_stats], axis=0),
448+
"overflow_fraction_per_layer": jnp.stack([s["overflow_fraction"] for s in all_router_stats], axis=0),
436449
}
437450
return self.final_norm(hidden), router_metrics
438451

tests/test_grug_variant_contracts.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,28 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
263263
]
264264
for key in required_keys:
265265
assert key in summary
266+
267+
268+
def test_moe_summarize_router_metrics_includes_overflow():
269+
"""Capacity overflow metrics must appear in summarized router output."""
270+
from experiments.grug.moe.model import _summarize_router_metrics
271+
272+
num_layers = 2
273+
num_experts = 4
274+
router_metrics = {
275+
"routing_entropy_per_layer": jnp.ones(num_layers),
276+
"routing_counts_per_layer": jnp.ones((num_layers, num_experts)),
277+
"load_balancing_loss_per_layer": jnp.ones(num_layers),
278+
"router_z_loss_per_layer": jnp.ones(num_layers),
279+
"dropped_count_per_layer": jnp.array([10, 5], dtype=jnp.int32),
280+
"overflow_fraction_per_layer": jnp.array([0.1, 0.05], dtype=jnp.float32),
281+
}
282+
out = _summarize_router_metrics(router_metrics)
283+
284+
assert "train/router/dropped_count_total" in out
285+
assert "train/router/overflow_fraction_mean" in out
286+
assert float(out["train/router/dropped_count_total"]) == 15.0
287+
assert abs(float(out["train/router/overflow_fraction_mean"]) - 0.075) < 1e-6
288+
for i in range(num_layers):
289+
assert f"train/router/layer_{i}/dropped_count" in out
290+
assert f"train/router/layer_{i}/overflow_fraction" in out

0 commit comments

Comments
 (0)