@@ -227,6 +227,8 @@ def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str,
227227 routing_counts = router_metrics ["routing_counts_per_layer" ]
228228 load_balancing_loss = router_metrics ["load_balancing_loss_per_layer" ]
229229 router_z_loss = router_metrics ["router_z_loss_per_layer" ]
230+ dropped_count = router_metrics ["dropped_count_per_layer" ]
231+ overflow_fraction = router_metrics ["overflow_fraction_per_layer" ]
230232 num_layers = int (routing_entropy .shape [0 ])
231233 aux_loss_per_layer = load_balancing_loss + router_z_loss
232234
@@ -238,12 +240,17 @@ def _summarize_router_metrics(router_metrics: dict[str, jax.Array]) -> dict[str,
238240 "train/router/router_z_loss" : jnp .mean (router_z_loss ),
239241 # Keep aux loss as a per-step aggregate while exposing mean terms above.
240242 "train/router/aux_loss" : jnp .sum (aux_loss_per_layer ),
243+ # Capacity overflow: total dropped assignments and fraction across all layers.
244+ "train/router/dropped_count_total" : jnp .sum (dropped_count ),
245+ "train/router/overflow_fraction_mean" : jnp .mean (overflow_fraction ),
241246 }
242247 for i in range (num_layers ):
243248 out [f"train/router/layer_{ i } /routing_entropy" ] = routing_entropy [i ]
244249 out [f"train/router/layer_{ i } /load_balancing_loss" ] = load_balancing_loss [i ]
245250 out [f"train/router/layer_{ i } /router_z_loss" ] = router_z_loss [i ]
246251 out [f"train/router/layer_{ i } /routing_hist" ] = _histogram_from_expert_counts (routing_counts [i ])
252+ out [f"train/router/layer_{ i } /dropped_count" ] = dropped_count [i ]
253+ out [f"train/router/layer_{ i } /overflow_fraction" ] = overflow_fraction [i ]
247254 return out
248255
249256
@@ -325,7 +332,7 @@ def __call__(
325332 num_experts_per_token = self .cfg .num_experts_per_token ,
326333 )
327334
328- routed_flat = moe_mlp (
335+ routed_flat , dropped_count = moe_mlp (
329336 x_flat ,
330337 selected_experts .astype (jnp .int32 ),
331338 combine_weights ,
@@ -335,7 +342,11 @@ def __call__(
335342 implementation = self .cfg .moe_implementation ,
336343 mesh = get_abstract_mesh (),
337344 capacity_factor = _DEFAULT_EP_CAPACITY_FACTOR ,
345+ report_capacity_overflow = True ,
338346 )
347+ total_assignments = x_flat .shape [0 ] * self .cfg .num_experts_per_token
348+ router_stats ["dropped_count" ] = dropped_count
349+ router_stats ["overflow_fraction" ] = dropped_count .astype (jnp .float32 ) / max (total_assignments , 1 )
339350 routed = rearrange (routed_flat , "(b s) d -> b s d" , b = b , s = s )
340351 routed = reshard (routed , _batch_spec ())
341352 return routed , router_stats
@@ -433,6 +444,8 @@ def __call__(
433444 "routing_counts_per_layer" : jnp .stack ([s ["routing_counts" ] for s in all_router_stats ], axis = 0 ),
434445 "load_balancing_loss_per_layer" : jnp .stack ([s ["load_balancing_loss" ] for s in all_router_stats ], axis = 0 ),
435446 "router_z_loss_per_layer" : jnp .stack ([s ["router_z_loss" ] for s in all_router_stats ], axis = 0 ),
447+ "dropped_count_per_layer" : jnp .stack ([s ["dropped_count" ] for s in all_router_stats ], axis = 0 ),
448+ "overflow_fraction_per_layer" : jnp .stack ([s ["overflow_fraction" ] for s in all_router_stats ], axis = 0 ),
436449 }
437450 return self .final_norm (hidden ), router_metrics
438451
0 commit comments