5757
5858@dataclasses .dataclass (frozen = True )
5959class LossLabelSpec :
60- """Names exclusive per-loss-position labels and the aggregates to score."""
60+ """Names exclusive loss labels and defines metric rollups.
61+
62+ `id_to_name` names the leaf span types stored in `LabeledLmExample.loss_labels`.
63+ `aggregates` maps metric names to one or more leaf label ids, so callers can
64+ report both specific span types and rollups such as assistant = assistant
65+ text plus assistant tool calls. If aggregates is omitted, each non-ignored
66+ label id gets its own metric.
67+ """
6168
6269 id_to_name : Mapping [int , str ]
63- aggregates : Mapping [str , Sequence [int ]] = dataclasses . field ( default_factory = dict )
70+ aggregates : Mapping [str , Sequence [int ]] | None = None
6471 dont_score_label : int = LOSS_IGNORE_LABEL
6572
6673 def __post_init__ (self ):
67- id_to_name = {int (label_id ): str (name ) for label_id , name in self .id_to_name .items ()}
68- dont_score_label = int (self .dont_score_label )
69- if dont_score_label not in id_to_name :
70- raise ValueError (f"id_to_name must include dont_score_label={ dont_score_label } " )
71- if len (set (id_to_name .values ())) != len (id_to_name ):
74+ for label_id , name in self .id_to_name .items ():
75+ if not isinstance (label_id , int ):
76+ raise TypeError (f"label id must be an int, got { label_id !r} " )
77+ if not isinstance (name , str ):
78+ raise TypeError (f"label name for id { label_id } must be a str, got { name !r} " )
79+ if len (set (self .id_to_name .values ())) != len (self .id_to_name ):
7280 raise ValueError ("label names must be unique" )
7381
74- if self .aggregates :
75- aggregates = {
76- str (name ): tuple (int (label_id ) for label_id in label_ids )
77- for name , label_ids in self .aggregates .items ()
78- }
79- else :
80- aggregates = {
81- label_name : (label_id ,) for label_id , label_name in id_to_name .items () if label_id != dont_score_label
82- }
83-
84- for name , label_ids in aggregates .items ():
82+ for name , label_ids in self ._aggregate_mapping ().items ():
83+ if not isinstance (name , str ):
84+ raise TypeError (f"aggregate name must be a str, got { name !r} " )
8585 if not label_ids :
8686 raise ValueError (f"aggregate { name !r} must include at least one label id" )
87- if dont_score_label in label_ids :
88- raise ValueError (f"aggregate { name !r} includes dont_score_label={ dont_score_label } " )
87+ if self . dont_score_label in label_ids :
88+ raise ValueError (f"aggregate { name !r} includes dont_score_label={ self . dont_score_label } " )
8989 for label_id in label_ids :
90- if label_id not in id_to_name :
90+ if not isinstance (label_id , int ):
91+ raise TypeError (f"aggregate { name !r} label id must be an int, got { label_id !r} " )
92+ if label_id not in self .id_to_name :
9193 raise ValueError (f"aggregate { name !r} references unknown label id { label_id } " )
9294
93- object .__setattr__ (self , "id_to_name" , id_to_name )
94- object .__setattr__ (self , "aggregates" , aggregates )
95- object .__setattr__ (self , "dont_score_label" , dont_score_label )
95+ def _aggregate_mapping (self ) -> Mapping [str , Sequence [int ]]:
96+ if self .aggregates is not None :
97+ return self .aggregates
98+ return {
99+ label_name : (label_id ,)
100+ for label_id , label_name in self .id_to_name .items ()
101+ if label_id != self .dont_score_label
102+ }
96103
97104 @property
98105 def aggregate_names (self ) -> tuple [str , ...]:
99- return tuple (self .aggregates .keys ())
106+ return tuple (self ._aggregate_mapping () .keys ())
100107
101108 @property
102109 def aggregate_label_ids (self ) -> tuple [tuple [int , ...], ...]:
103- return tuple (tuple (label_ids ) for label_ids in self .aggregates .values ())
110+ return tuple (tuple (label_ids ) for label_ids in self ._aggregate_mapping () .values ())
104111
105112
106113@dataclasses .dataclass
@@ -275,18 +282,18 @@ def _default_lm_eval_loss_fn(
275282def _ensure_named_labeled_lm_example (
276283 batch : LabeledLmExample ,
277284 * ,
278- EvalBatch : hax . Axis ,
279- model_pos : hax . Axis ,
285+ batch_axis_name : str ,
286+ pos_axis_name : str ,
280287) -> tuple [LmExample , hax .NamedArray ]:
281288 if not isinstance (batch , LabeledLmExample ):
282289 raise TypeError (f"Unsupported labeled eval batch type: { type (batch )} " )
283290
284291 if batch .tokens .ndim == 1 :
285- Pos = model_pos . resize ( batch .tokens .shape [0 ])
292+ Pos = hax . Axis ( pos_axis_name , batch .tokens .shape [0 ])
286293 return named_lm_example_from_labeled (batch , Pos = Pos )
287294 if batch .tokens .ndim == 2 :
288- Pos = model_pos . resize ( batch .tokens .shape [1 ])
289- return named_lm_example_from_labeled (batch , Pos = Pos , batch_axis = EvalBatch )
295+ Pos = hax . Axis ( pos_axis_name , batch .tokens .shape [1 ])
296+ return named_lm_example_from_labeled (batch , Pos = Pos , batch_axis = batch_axis_name )
290297
291298 raise ValueError (f"LabeledLmExample tokens must be rank-1 or rank-2 for eval, got rank={ batch .tokens .ndim } " )
292299
@@ -299,7 +306,11 @@ def _default_labeled_lm_eval_loss_fn(
299306 mp : jmp .Policy | None ,
300307) -> LabeledLossFnOutput :
301308 model = inference_mode (model , True )
302- named_batch , loss_labels = _ensure_named_labeled_lm_example (batch , EvalBatch = EvalBatch , model_pos = model .Pos )
309+ named_batch , loss_labels = _ensure_named_labeled_lm_example (
310+ batch ,
311+ batch_axis_name = EvalBatch .name ,
312+ pos_axis_name = model .Pos .name ,
313+ )
303314 if mp is not None :
304315 model = mp .cast_to_compute (model )
305316 per_pos_loss = model .compute_next_token_loss (named_batch , reduction = None , reduction_axis = ()).array
@@ -450,11 +461,11 @@ def cb_labeled_evaluate(
450461 * ,
451462 prefix : str = "labeled_eval" ,
452463 eval_current : bool = True ,
453- eval_ema : bool = True ,
464+ eval_model : bool = True ,
454465) -> Callable [[StepInfo ], None ]:
455- """Build a callback that logs labeled eval metrics for current and/or eval model."""
456- if not eval_current and not eval_ema :
457- raise ValueError ("At least one of eval_current or eval_ema should be True" )
466+ """Build a callback that logs labeled eval metrics for current and/or eval-mode model."""
467+ if not eval_current and not eval_model :
468+ raise ValueError ("At least one of eval_current or eval_model should be True" )
458469
459470 last_eval_step : int | None = None
460471
@@ -472,8 +483,8 @@ def eval_callback(step: StepInfo, force: bool = False):
472483 log_dict = eval_labeled_model (evaluator , step .model , prefix = prefix )
473484 levanter .tracker .log (log_dict , step = step_count )
474485
475- if eval_ema :
476- log_dict = eval_labeled_model (evaluator , step .eval_model , prefix = _join_prefix (prefix , "ema " ))
486+ if eval_model :
487+ log_dict = eval_labeled_model (evaluator , step .eval_model , prefix = _join_prefix (prefix , "eval_model " ))
477488 levanter .tracker .log (log_dict , step = step_count )
478489
479490 last_eval_step = step_count
@@ -724,6 +735,14 @@ def _construct_tag_hierarchy(self) -> dict[str, list[int]]:
724735
725736
726737class LabeledEvaluator (Generic [Ex , M ]):
738+ """Evaluator that aggregates LM loss over exclusive token-label groups.
739+
740+ The loss callback returns per-position losses, exclusive integer labels, and
741+ next-token ids. `LossLabelSpec` then rolls leaf labels up into named metrics,
742+ so one example can report loss for both fine-grained span types and broader
743+ groups without overlapping per-target masks.
744+ """
745+
727746 loss_fn : Callable [[M , Ex ], LabeledLossFnOutput ]
728747
729748 def __init__ (
@@ -738,7 +757,7 @@ def __init__(
738757 ):
739758 if isinstance (EvalBatch , int ):
740759 EvalBatch = hax .Axis ("batch" , EvalBatch )
741- if not label_spec .aggregates :
760+ if not label_spec .aggregate_names :
742761 raise ValueError ("label_spec must define at least one aggregate to score" )
743762
744763 self .loss_fn = loss_fn
@@ -766,7 +785,7 @@ def __init__(
766785 self .accum_for_batch = self ._make_accum_for_batch ()
767786
768787 @classmethod
769- def from_labeled_lm (
788+ def for_labeled_examples (
770789 cls ,
771790 EvalBatch : hax .Axis | int ,
772791 eval_set : AsyncDataset [LabeledLmExample ],
0 commit comments