@@ -263,9 +263,7 @@ def init(cfg: GrugModelConfig, *, key: PRNGKeyArray) -> "MoEMLP":
263263 def __call__ (
264264 self ,
265265 x : Float [Array , "B S D" ],
266- * ,
267- return_router_stats : bool = False ,
268- ) -> Float [Array , "B S D" ] | tuple [Float [Array , "B S D" ], dict [str , jax .Array ]]:
266+ ) -> tuple [Float [Array , "B S D" ], dict [str , jax .Array ]]:
269267 b , s , _ = x .shape
270268 x_flat = rearrange (x , "b s d -> (b s) d" )
271269 router_logits = jnp .einsum ("td,de->te" , x_flat , reshard (self .router , P (None , None )))
@@ -293,11 +291,7 @@ def __call__(
293291 )
294292 out = routed + shared_out
295293
296- if return_router_stats :
297- assert router_stats is not None
298- return out , router_stats
299-
300- return out
294+ return out , router_stats
301295
302296 def _has_shared (self ):
303297 if self .shared_w_up_gate is not None or self .shared_w_down is not None :
@@ -327,17 +321,11 @@ def __call__(
327321 self ,
328322 x : Float [Array , "B S D" ],
329323 mask : AttentionMask | jax .Array ,
330- * ,
331- return_router_stats : bool = False ,
332- ) -> Float [Array , "B S D" ] | tuple [Float [Array , "B S D" ], dict [str , jax .Array ]]:
324+ ) -> tuple [Float [Array , "B S D" ], dict [str , jax .Array ]]:
333325 x = x + self .attn (self .rms_attn (x ), mask )
334- if return_router_stats :
335- mlp_out , router_stats = self .mlp (self .rms_mlp (x ), return_router_stats = True )
336- x = x + mlp_out
337- return x , router_stats
338-
339- x = x + self .mlp (self .rms_mlp (x ))
340- return x
326+ mlp_out , router_stats = self .mlp (self .rms_mlp (x ))
327+ x = x + mlp_out
328+ return x , router_stats
341329
342330
343331class Transformer (eqx .Module ):
@@ -368,29 +356,22 @@ def __call__(
368356 self ,
369357 token_ids : Int [Array , "B S" ],
370358 mask : AttentionMask | jax .Array | None = None ,
371- * ,
372- return_router_stats : bool = False ,
373- ) -> Float [Array , "B S D" ] | tuple [Float [Array , "B S D" ], dict [str , jax .Array ]]:
359+ ) -> tuple [Float [Array , "B S D" ], dict [str , jax .Array ]]:
374360 if mask is None :
375361 mask = AttentionMask .causal ()
376362
377363 batch_spec = _batch_spec ()
378364 hidden = self .token_embed .at [token_ids ].get (out_sharding = batch_spec )
379- if return_router_stats :
380- all_router_stats : list [dict [str , jax .Array ]] = []
381- for block in self .blocks :
382- hidden , router_stats = eqx .filter_checkpoint (block )(hidden , mask , return_router_stats = True )
383- all_router_stats .append (router_stats )
384-
385- router_metrics = {
386- "routing_entropy_per_layer" : jnp .stack ([s ["routing_entropy" ] for s in all_router_stats ], axis = 0 ),
387- "routing_counts_per_layer" : jnp .stack ([s ["routing_counts" ] for s in all_router_stats ], axis = 0 ),
388- }
389- return self .final_norm (hidden ), router_metrics
390-
365+ all_router_stats : list [dict [str , jax .Array ]] = []
391366 for block in self .blocks :
392- hidden = eqx .filter_checkpoint (block )(hidden , mask )
393- return self .final_norm (hidden )
367+ hidden , router_stats = eqx .filter_checkpoint (block )(hidden , mask )
368+ all_router_stats .append (router_stats )
369+
370+ router_metrics = {
371+ "routing_entropy_per_layer" : jnp .stack ([s ["routing_entropy" ] for s in all_router_stats ], axis = 0 ),
372+ "routing_counts_per_layer" : jnp .stack ([s ["routing_counts" ] for s in all_router_stats ], axis = 0 ),
373+ }
374+ return self .final_norm (hidden ), router_metrics
394375
395376 @named_call
396377 def logits (
@@ -399,10 +380,10 @@ def logits(
399380 mask : AttentionMask | jax .Array | None = None ,
400381 ) -> Float [Array , "B S V" ]:
401382 batch_spec = _batch_spec ()
402- hidden = self (token_ids , mask = mask )
383+ hidden , _ = self (token_ids , mask = mask )
403384 return jnp .einsum ("bsh,hd->bsd" , hidden , self .output_proj , out_sharding = batch_spec )
404385
405- def compute_loss (
386+ def next_token_loss (
406387 self ,
407388 token_ids : Int [Array , "B S" ],
408389 loss_weight : Float [Array , "B S" ],
@@ -413,12 +394,7 @@ def compute_loss(
413394 loss_dtype : jnp .dtype = jnp .float32 ,
414395 return_router_metrics : bool = False ,
415396 ) -> jax .Array | tuple [jax .Array , dict [str , jax .Array | Histogram ]]:
416- """Compute next-token cross-entropy loss for a batch."""
417- router_metrics : dict [str , jax .Array ] | None = None
418- if return_router_metrics :
419- hidden , router_metrics = self (token_ids , mask = mask , return_router_stats = True )
420- else :
421- hidden = self (token_ids , mask = mask )
397+ hidden , router_metrics = self (token_ids , mask = mask )
422398 labels = jnp .concatenate ([token_ids [:, 1 :], token_ids [:, :1 ] * 0 ], axis = 1 ).astype (jnp .int32 )
423399 loss_weight = loss_weight .astype (loss_dtype )
424400
@@ -432,31 +408,9 @@ def compute_loss(
432408 dtype = loss_dtype ,
433409 )
434410 if return_router_metrics :
435- assert router_metrics is not None
436411 return loss , _summarize_router_metrics (router_metrics )
437412 return loss
438413
439- def next_token_loss (
440- self ,
441- token_ids : Int [Array , "B S" ],
442- loss_weight : Float [Array , "B S" ],
443- * ,
444- mask : AttentionMask | jax .Array | None = None ,
445- reduction : str = "mean" ,
446- logsumexp_weight : float | None = None ,
447- loss_dtype : jnp .dtype = jnp .float32 ,
448- return_router_metrics : bool = False ,
449- ) -> jax .Array | tuple [jax .Array , dict [str , jax .Array | Histogram ]]:
450- return self .compute_loss (
451- token_ids ,
452- loss_weight ,
453- mask = mask ,
454- reduction = reduction ,
455- logsumexp_weight = logsumexp_weight ,
456- loss_dtype = loss_dtype ,
457- return_router_metrics = return_router_metrics ,
458- )
459-
460414
461415def _init_weight (key : PRNGKeyArray , shape : tuple [int , ...], std : float ) -> Float [Array , "..." ]:
462416 return std * random .truncated_normal (key , - 3 , 3 , shape )
0 commit comments