Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion experiments/grug/moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str,
routing_counts = router_metrics["routing_counts_per_layer"]
load_balancing_loss = router_metrics["load_balancing_loss_per_layer"]
router_z_loss = router_metrics["router_z_loss_per_layer"]
dropped_count = router_metrics["dropped_count_per_layer"]
overflow_fraction = router_metrics["overflow_fraction_per_layer"]
num_layers = int(routing_entropy.shape[0])
aux_loss_per_layer = load_balancing_loss + router_z_loss

Expand All @@ -238,12 +240,17 @@ def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str,
"train/router/router_z_loss": jnp.mean(router_z_loss),
# Keep aux loss as a per-step aggregate while exposing mean terms above.
"train/router/aux_loss": jnp.sum(aux_loss_per_layer),
# Capacity overflow: total dropped assignments and fraction across all layers.
"train/router/dropped_count_total": jnp.sum(dropped_count),
"train/router/overflow_fraction_mean": jnp.mean(overflow_fraction),
}
for i in range(num_layers):
out[f"train/router/layer_{i}/routing_entropy"] = routing_entropy[i]
out[f"train/router/layer_{i}/load_balancing_loss"] = load_balancing_loss[i]
out[f"train/router/layer_{i}/router_z_loss"] = router_z_loss[i]
out[f"train/router/layer_{i}/routing_hist"] = _histogram_from_expert_counts(routing_counts[i])
out[f"train/router/layer_{i}/dropped_count"] = dropped_count[i]
out[f"train/router/layer_{i}/overflow_fraction"] = overflow_fraction[i]
return out


Expand Down Expand Up @@ -325,7 +332,7 @@ def __call__(
num_experts_per_token=self.cfg.num_experts_per_token,
)

routed_flat = moe_mlp(
routed_flat, dropped_count = moe_mlp(
x_flat,
selected_experts.astype(jnp.int32),
combine_weights,
Expand All @@ -335,7 +342,11 @@ def __call__(
implementation=self.cfg.moe_implementation,
mesh=get_abstract_mesh(),
capacity_factor=_DEFAULT_EP_CAPACITY_FACTOR,
report_capacity_overflow=True,
)
total_assignments = x_flat.shape[0] * self.cfg.num_experts_per_token
router_stats["dropped_count"] = dropped_count
router_stats["overflow_fraction"] = dropped_count.astype(jnp.float32) / max(total_assignments, 1)
routed = rearrange(routed_flat, "(b s) d -> b s d", b=b, s=s)
routed = reshard(routed, _batch_spec())
return routed, router_stats
Expand Down Expand Up @@ -433,6 +444,8 @@ def __call__(
"routing_counts_per_layer": jnp.stack([s["routing_counts"] for s in all_router_stats], axis=0),
"load_balancing_loss_per_layer": jnp.stack([s["load_balancing_loss"] for s in all_router_stats], axis=0),
"router_z_loss_per_layer": jnp.stack([s["router_z_loss"] for s in all_router_stats], axis=0),
"dropped_count_per_layer": jnp.stack([s["dropped_count"] for s in all_router_stats], axis=0),
"overflow_fraction_per_layer": jnp.stack([s["overflow_fraction"] for s in all_router_stats], axis=0),
}
return self.final_norm(hidden), router_metrics

Expand Down
25 changes: 25 additions & 0 deletions tests/test_grug_variant_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,28 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
]
for key in required_keys:
assert key in summary


def test_moe_summarize_router_metrics_includes_overflow():
"""Capacity overflow metrics must appear in summarized router output."""
from experiments.grug.moe.model import _summarize_router_metrics

num_layers = 2
num_experts = 4
router_metrics = {
"routing_entropy_per_layer": jnp.ones(num_layers),
"routing_counts_per_layer": jnp.ones((num_layers, num_experts)),
"load_balancing_loss_per_layer": jnp.ones(num_layers),
"router_z_loss_per_layer": jnp.ones(num_layers),
"dropped_count_per_layer": jnp.array([10, 5], dtype=jnp.int32),
"overflow_fraction_per_layer": jnp.array([0.1, 0.05], dtype=jnp.float32),
}
out = _summarize_router_metrics(router_metrics)

assert "train/router/dropped_count_total" in out
assert "train/router/overflow_fraction_mean" in out
assert float(out["train/router/dropped_count_total"]) == 15.0
assert abs(float(out["train/router/overflow_fraction_mean"]) - 0.075) < 1e-6
for i in range(num_layers):
assert f"train/router/layer_{i}/dropped_count" in out
assert f"train/router/layer_{i}/overflow_fraction" in out
Loading