1515import equinox as eqx
1616import jax
1717import jax .numpy as jnp
18+ import jax .scipy as jsp
1819from einops import rearrange
1920from haliax .jax_utils import named_call
2021from jax import random
@@ -59,6 +60,8 @@ class GrugModelConfig:
5960 max_seq_len : int = 4096
6061 layer_norm_eps : float = 1e-5
6162 initializer_std : float = 0.02
63+ load_balancing_loss_coef : float | None = 0.01
64+ router_z_loss_coef : float | None = 0.001
6265 rope : RotaryConfig = dataclasses .field (default_factory = RotaryConfig )
6366
6467 def __post_init__ (self ) -> None :
@@ -77,6 +80,10 @@ def __post_init__(self) -> None:
7780 raise ValueError ("num_experts_per_token must be <= num_experts" )
7881 if self .shared_expert_intermediate_dim < 0 :
7982 raise ValueError ("shared_expert_intermediate_dim must be non-negative" )
83+ if self .load_balancing_loss_coef is not None and self .load_balancing_loss_coef < 0 :
84+ raise ValueError ("load_balancing_loss_coef must be non-negative when set" )
85+ if self .router_z_loss_coef is not None and self .router_z_loss_coef < 0 :
86+ raise ValueError ("router_z_loss_coef must be non-negative when set" )
8087
8188 @property
8289 def inferred_head_dim (self ) -> int :
@@ -175,31 +182,58 @@ def __call__(
175182 return rearrange (out_flat , "(b s) d -> b s d" , b = b , s = s )
176183
177184
178- def _routing_stats_from_selected_experts (
185+ def _routing_stats (
179186 selected_experts : Int [Array , "T K" ],
187+ router_probs : Float [Array , "T E" ],
188+ router_logits : Float [Array , "T E" ],
180189 * ,
181190 num_experts : int ,
191+ num_experts_per_token : int ,
182192) -> dict [str , jax .Array ]:
193+ router_probs_f = router_probs .astype (jnp .float32 )
194+ router_logits_f = router_logits .astype (jnp .float32 )
183195 expert_counts = jnp .sum (jax .nn .one_hot (selected_experts , num_experts , dtype = jnp .float32 ), axis = (0 , 1 ))
184196 total_assignments = jnp .maximum (jnp .sum (expert_counts ), 1.0 )
185- expert_loads = expert_counts / total_assignments
186- routing_entropy = - jnp .sum (expert_loads * jnp .log (expert_loads + 1e-6 ))
197+ assignment_fraction = expert_counts / total_assignments
198+ routing_entropy = - jnp .sum (assignment_fraction * jnp .log (assignment_fraction + 1e-6 ))
199+ # Match the Switch/OLMoE-style scaling: E * sum_i(f_i * p_i), where
200+ # f_i is token fraction for expert i (counts per token, not per assignment).
201+ # assignment_fraction sums to 1 over assignments, so convert with top-k.
202+ token_fraction = assignment_fraction * num_experts_per_token
203+ p = jnp .mean (router_probs_f , axis = 0 )
204+ load_balancing_loss = num_experts * jnp .sum (token_fraction * p )
205+ z = jsp .special .logsumexp (router_logits_f , axis = - 1 )
206+ router_z_loss = jnp .mean (z ** 2 )
207+
187208 return {
188209 "routing_counts" : expert_counts ,
189210 "routing_entropy" : routing_entropy ,
211+ "load_balancing_loss" : load_balancing_loss ,
212+ "router_z_loss" : router_z_loss ,
190213 }
191214
192215
193216def _summarize_router_metrics (router_metrics : dict [str , jax .Array ]) -> dict [str , jax .Array | Histogram ]:
194217 routing_entropy = router_metrics ["routing_entropy_per_layer" ]
195218 routing_counts = router_metrics ["routing_counts_per_layer" ]
219+ load_balancing_loss = router_metrics ["load_balancing_loss_per_layer" ]
220+ router_z_loss = router_metrics ["router_z_loss_per_layer" ]
196221 num_layers = int (routing_entropy .shape [0 ])
222+ aux_loss_per_layer = load_balancing_loss + router_z_loss
197223
198224 out : dict [str , jax .Array | Histogram ] = {
199225 "train/router/routing_entropy_mean" : jnp .mean (routing_entropy ),
226+ # Match MaxText + Megatron/Nemotron practice: log layer-mean raw
227+ # router terms for comparability across depth.
228+ "train/router/load_balancing_loss" : jnp .mean (load_balancing_loss ),
229+ "train/router/router_z_loss" : jnp .mean (router_z_loss ),
230+ # Keep aux loss as a per-step aggregate while exposing mean terms above.
231+ "train/router/aux_loss" : jnp .sum (aux_loss_per_layer ),
200232 }
201233 for i in range (num_layers ):
202234 out [f"train/router/layer_{ i } /routing_entropy" ] = routing_entropy [i ]
235+ out [f"train/router/layer_{ i } /load_balancing_loss" ] = load_balancing_loss [i ]
236+ out [f"train/router/layer_{ i } /router_z_loss" ] = router_z_loss [i ]
203237 out [f"train/router/layer_{ i } /routing_hist" ] = _histogram_from_expert_counts (routing_counts [i ])
204238 return out
205239
@@ -266,9 +300,16 @@ def __call__(
266300 b , s , _ = x .shape
267301 x_flat = rearrange (x , "b s d -> (b s) d" )
268302 router_logits = jnp .einsum ("td,de->te" , x_flat , reshard (self .router , P (None , None )))
303+ router_probs = jax .nn .softmax (router_logits , axis = - 1 )
269304 topk_logits , selected_experts = jax .lax .top_k (router_logits , self .cfg .num_experts_per_token )
270305 combine_weights = jax .nn .softmax (topk_logits , axis = - 1 ).astype (x .dtype )
271- router_stats = _routing_stats_from_selected_experts (selected_experts , num_experts = self .cfg .num_experts )
306+ router_stats = _routing_stats (
307+ selected_experts ,
308+ router_probs ,
309+ router_logits ,
310+ num_experts = self .cfg .num_experts ,
311+ num_experts_per_token = self .cfg .num_experts_per_token ,
312+ )
272313
273314 routed_flat = moe_mlp (
274315 x_flat ,
@@ -368,6 +409,8 @@ def __call__(
368409 router_metrics = {
369410 "routing_entropy_per_layer" : jnp .stack ([s ["routing_entropy" ] for s in all_router_stats ], axis = 0 ),
370411 "routing_counts_per_layer" : jnp .stack ([s ["routing_counts" ] for s in all_router_stats ], axis = 0 ),
412+ "load_balancing_loss_per_layer" : jnp .stack ([s ["load_balancing_loss" ] for s in all_router_stats ], axis = 0 ),
413+ "router_z_loss_per_layer" : jnp .stack ([s ["router_z_loss" ] for s in all_router_stats ], axis = 0 ),
371414 }
372415 return self .final_norm (hidden ), router_metrics
373416
@@ -396,7 +439,7 @@ def next_token_loss(
396439 labels = jnp .concatenate ([token_ids [:, 1 :], token_ids [:, :1 ] * 0 ], axis = 1 ).astype (jnp .int32 )
397440 loss_weight = loss_weight .astype (loss_dtype )
398441
399- loss = fused_linear_softmax_cross_entropy_loss (
442+ cross_entropy_loss = fused_linear_softmax_cross_entropy_loss (
400443 hidden ,
401444 self .output_proj ,
402445 labels ,
@@ -405,8 +448,22 @@ def next_token_loss(
405448 logsumexp_weight = logsumexp_weight ,
406449 dtype = loss_dtype ,
407450 )
451+ # Keep router metrics raw and apply coefficients only at the final
452+ # objective composition step (same separation as MaxText/Megatron).
453+ load_balancing_loss_coef = (
454+ 0.0 if self .config .load_balancing_loss_coef is None else self .config .load_balancing_loss_coef
455+ )
456+ router_z_loss_coef = 0.0 if self .config .router_z_loss_coef is None else self .config .router_z_loss_coef
457+ aux_loss = load_balancing_loss_coef * jnp .sum (router_metrics ["load_balancing_loss_per_layer" ]) + (
458+ router_z_loss_coef * jnp .sum (router_metrics ["router_z_loss_per_layer" ])
459+ )
460+ include_aux_in_loss = reduction != "none" and (load_balancing_loss_coef != 0.0 or router_z_loss_coef != 0.0 )
461+ loss = cross_entropy_loss + aux_loss if include_aux_in_loss else cross_entropy_loss
408462 if return_router_metrics :
409- return loss , _summarize_router_metrics (router_metrics )
463+ summarized_metrics = _summarize_router_metrics (router_metrics )
464+ summarized_metrics ["train/cross_entropy_loss" ] = cross_entropy_loss
465+ summarized_metrics ["train/router/aux_loss_weighted" ] = aux_loss
466+ return loss , summarized_metrics
410467 return loss
411468
412469
0 commit comments