Skip to content

Commit 69e090b

Browse files
committed
[evals] Merge labeled eval API updates into trace scoring
# Conflicts: # lib/marin/src/marin/evaluation/trace_labeled_eval.py
2 parents 09ea842 + 57f9013 commit 69e090b

4 files changed

Lines changed: 76 additions & 46 deletions

File tree

lib/levanter/src/levanter/data/text/examples.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222

2323
@dataclass(frozen=True)
2424
class LossLabelSpan:
25-
"""An exclusive label span over next-token loss positions."""
25+
"""An exclusive typed span used to evaluate per-span-type LM loss.
26+
27+
The span is over next-token loss positions, not raw token positions:
28+
`[start, end)` labels losses for predicting `tokens[start + 1]` through
29+
`tokens[end]`. Spans are intentionally exclusive so each loss position has
30+
one semantic type before aggregate metrics roll those types up.
31+
"""
2632

2733
start: int
2834
end: int
@@ -143,10 +149,12 @@ def causal_loss_mask(seq_len: int, prompt_length: NamedOrNumeric | None = None)
143149
@register_dataclass
144150
@dataclass(frozen=True)
145151
class LabeledLmExample:
146-
"""A grug-conformant LM example with exclusive per-loss-position labels.
152+
"""A grug-conformant LM example with exclusive labels for loss evaluation.
147153
148-
`loss_labels[i]` labels the loss for predicting `tokens[i + 1]`. The final
149-
position should normally use `LOSS_IGNORE_LABEL`, because it has no next
154+
Use this when an eval needs to report loss by token or span type, such as
155+
assistant text, tool calls, observations, or derived answer spans.
156+
`loss_labels[i]` labels the loss for predicting `tokens[i + 1]`; the final
157+
position should normally use `LOSS_IGNORE_LABEL` because it has no next
150158
token to predict.
151159
"""
152160

lib/levanter/src/levanter/eval.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -57,50 +57,57 @@
5757

5858
@dataclasses.dataclass(frozen=True)
5959
class LossLabelSpec:
60-
"""Names exclusive per-loss-position labels and the aggregates to score."""
60+
"""Names exclusive loss labels and defines metric rollups.
61+
62+
`id_to_name` names the leaf span types stored in `LabeledLmExample.loss_labels`.
63+
`aggregates` maps metric names to one or more leaf label ids, so callers can
64+
report both specific span types and rollups such as assistant = assistant
65+
text plus assistant tool calls. If aggregates is omitted, each non-ignored
66+
label id gets its own metric.
67+
"""
6168

6269
id_to_name: Mapping[int, str]
63-
aggregates: Mapping[str, Sequence[int]] = dataclasses.field(default_factory=dict)
70+
aggregates: Mapping[str, Sequence[int]] | None = None
6471
dont_score_label: int = LOSS_IGNORE_LABEL
6572

6673
def __post_init__(self):
67-
id_to_name = {int(label_id): str(name) for label_id, name in self.id_to_name.items()}
68-
dont_score_label = int(self.dont_score_label)
69-
if dont_score_label not in id_to_name:
70-
raise ValueError(f"id_to_name must include dont_score_label={dont_score_label}")
71-
if len(set(id_to_name.values())) != len(id_to_name):
74+
for label_id, name in self.id_to_name.items():
75+
if not isinstance(label_id, int):
76+
raise TypeError(f"label id must be an int, got {label_id!r}")
77+
if not isinstance(name, str):
78+
raise TypeError(f"label name for id {label_id} must be a str, got {name!r}")
79+
if len(set(self.id_to_name.values())) != len(self.id_to_name):
7280
raise ValueError("label names must be unique")
7381

74-
if self.aggregates:
75-
aggregates = {
76-
str(name): tuple(int(label_id) for label_id in label_ids)
77-
for name, label_ids in self.aggregates.items()
78-
}
79-
else:
80-
aggregates = {
81-
label_name: (label_id,) for label_id, label_name in id_to_name.items() if label_id != dont_score_label
82-
}
83-
84-
for name, label_ids in aggregates.items():
82+
for name, label_ids in self._aggregate_mapping().items():
83+
if not isinstance(name, str):
84+
raise TypeError(f"aggregate name must be a str, got {name!r}")
8585
if not label_ids:
8686
raise ValueError(f"aggregate {name!r} must include at least one label id")
87-
if dont_score_label in label_ids:
88-
raise ValueError(f"aggregate {name!r} includes dont_score_label={dont_score_label}")
87+
if self.dont_score_label in label_ids:
88+
raise ValueError(f"aggregate {name!r} includes dont_score_label={self.dont_score_label}")
8989
for label_id in label_ids:
90-
if label_id not in id_to_name:
90+
if not isinstance(label_id, int):
91+
raise TypeError(f"aggregate {name!r} label id must be an int, got {label_id!r}")
92+
if label_id not in self.id_to_name:
9193
raise ValueError(f"aggregate {name!r} references unknown label id {label_id}")
9294

93-
object.__setattr__(self, "id_to_name", id_to_name)
94-
object.__setattr__(self, "aggregates", aggregates)
95-
object.__setattr__(self, "dont_score_label", dont_score_label)
95+
def _aggregate_mapping(self) -> Mapping[str, Sequence[int]]:
96+
if self.aggregates is not None:
97+
return self.aggregates
98+
return {
99+
label_name: (label_id,)
100+
for label_id, label_name in self.id_to_name.items()
101+
if label_id != self.dont_score_label
102+
}
96103

97104
@property
98105
def aggregate_names(self) -> tuple[str, ...]:
99-
return tuple(self.aggregates.keys())
106+
return tuple(self._aggregate_mapping().keys())
100107

101108
@property
102109
def aggregate_label_ids(self) -> tuple[tuple[int, ...], ...]:
103-
return tuple(tuple(label_ids) for label_ids in self.aggregates.values())
110+
return tuple(tuple(label_ids) for label_ids in self._aggregate_mapping().values())
104111

105112

106113
@dataclasses.dataclass
@@ -275,18 +282,18 @@ def _default_lm_eval_loss_fn(
275282
def _ensure_named_labeled_lm_example(
276283
batch: LabeledLmExample,
277284
*,
278-
EvalBatch: hax.Axis,
279-
model_pos: hax.Axis,
285+
batch_axis_name: str,
286+
pos_axis_name: str,
280287
) -> tuple[LmExample, hax.NamedArray]:
281288
if not isinstance(batch, LabeledLmExample):
282289
raise TypeError(f"Unsupported labeled eval batch type: {type(batch)}")
283290

284291
if batch.tokens.ndim == 1:
285-
Pos = model_pos.resize(batch.tokens.shape[0])
292+
Pos = hax.Axis(pos_axis_name, batch.tokens.shape[0])
286293
return named_lm_example_from_labeled(batch, Pos=Pos)
287294
if batch.tokens.ndim == 2:
288-
Pos = model_pos.resize(batch.tokens.shape[1])
289-
return named_lm_example_from_labeled(batch, Pos=Pos, batch_axis=EvalBatch)
295+
Pos = hax.Axis(pos_axis_name, batch.tokens.shape[1])
296+
return named_lm_example_from_labeled(batch, Pos=Pos, batch_axis=batch_axis_name)
290297

291298
raise ValueError(f"LabeledLmExample tokens must be rank-1 or rank-2 for eval, got rank={batch.tokens.ndim}")
292299

@@ -299,7 +306,11 @@ def _default_labeled_lm_eval_loss_fn(
299306
mp: jmp.Policy | None,
300307
) -> LabeledLossFnOutput:
301308
model = inference_mode(model, True)
302-
named_batch, loss_labels = _ensure_named_labeled_lm_example(batch, EvalBatch=EvalBatch, model_pos=model.Pos)
309+
named_batch, loss_labels = _ensure_named_labeled_lm_example(
310+
batch,
311+
batch_axis_name=EvalBatch.name,
312+
pos_axis_name=model.Pos.name,
313+
)
303314
if mp is not None:
304315
model = mp.cast_to_compute(model)
305316
per_pos_loss = model.compute_next_token_loss(named_batch, reduction=None, reduction_axis=()).array
@@ -450,11 +461,11 @@ def cb_labeled_evaluate(
450461
*,
451462
prefix: str = "labeled_eval",
452463
eval_current: bool = True,
453-
eval_ema: bool = True,
464+
eval_model: bool = True,
454465
) -> Callable[[StepInfo], None]:
455-
"""Build a callback that logs labeled eval metrics for current and/or eval model."""
456-
if not eval_current and not eval_ema:
457-
raise ValueError("At least one of eval_current or eval_ema should be True")
466+
"""Build a callback that logs labeled eval metrics for current and/or eval-mode model."""
467+
if not eval_current and not eval_model:
468+
raise ValueError("At least one of eval_current or eval_model should be True")
458469

459470
last_eval_step: int | None = None
460471

@@ -472,8 +483,8 @@ def eval_callback(step: StepInfo, force: bool = False):
472483
log_dict = eval_labeled_model(evaluator, step.model, prefix=prefix)
473484
levanter.tracker.log(log_dict, step=step_count)
474485

475-
if eval_ema:
476-
log_dict = eval_labeled_model(evaluator, step.eval_model, prefix=_join_prefix(prefix, "ema"))
486+
if eval_model:
487+
log_dict = eval_labeled_model(evaluator, step.eval_model, prefix=_join_prefix(prefix, "eval_model"))
477488
levanter.tracker.log(log_dict, step=step_count)
478489

479490
last_eval_step = step_count
@@ -724,6 +735,14 @@ def _construct_tag_hierarchy(self) -> dict[str, list[int]]:
724735

725736

726737
class LabeledEvaluator(Generic[Ex, M]):
738+
"""Evaluator that aggregates LM loss over exclusive token-label groups.
739+
740+
The loss callback returns per-position losses, exclusive integer labels, and
741+
next-token ids. `LossLabelSpec` then rolls leaf labels up into named metrics,
742+
so one example can report loss for both fine-grained span types and broader
743+
groups without overlapping per-target masks.
744+
"""
745+
727746
loss_fn: Callable[[M, Ex], LabeledLossFnOutput]
728747

729748
def __init__(
@@ -738,7 +757,7 @@ def __init__(
738757
):
739758
if isinstance(EvalBatch, int):
740759
EvalBatch = hax.Axis("batch", EvalBatch)
741-
if not label_spec.aggregates:
760+
if not label_spec.aggregate_names:
742761
raise ValueError("label_spec must define at least one aggregate to score")
743762

744763
self.loss_fn = loss_fn
@@ -766,7 +785,7 @@ def __init__(
766785
self.accum_for_batch = self._make_accum_for_batch()
767786

768787
@classmethod
769-
def from_labeled_lm(
788+
def for_labeled_examples(
770789
cls,
771790
EvalBatch: hax.Axis | int,
772791
eval_set: AsyncDataset[LabeledLmExample],

lib/levanter/tests/test_eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def test_labeled_lm_evaluator_accepts_labeled_lm_examples():
256256
label_spec = LossLabelSpec(id_to_name={0: "dont_score", 1: "assistant"})
257257

258258
with use_test_mesh(tensor_parallelism=1) as mesh:
259-
evaluator = LabeledEvaluator.from_labeled_lm(
259+
evaluator = LabeledEvaluator.for_labeled_examples(
260260
EvalBatch=EvalBatch,
261261
eval_set=ListAsyncDataset(examples),
262262
label_spec=label_spec,
@@ -271,6 +271,9 @@ def test_labeled_lm_evaluator_accepts_labeled_lm_examples():
271271

272272

273273
def test_loss_label_spec_validates_aggregates():
274+
label_spec = LossLabelSpec(id_to_name={1: "assistant"})
275+
assert label_spec.aggregate_names == ("assistant",)
276+
274277
with np.testing.assert_raises_regex(ValueError, "unknown label id"):
275278
LossLabelSpec(
276279
id_to_name={0: "dont_score", 1: "assistant"},

lib/marin/src/marin/evaluation/trace_labeled_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def evaluate_dataset(
684684
block_cross_document_attention=True,
685685
)
686686

687-
evaluator = LabeledEvaluator.from_labeled_lm(
687+
evaluator = LabeledEvaluator.for_labeled_examples(
688688
EvalBatch,
689689
dataset,
690690
label_spec=current_dataset_config.trace_format.loss_label_spec(),

0 commit comments

Comments
 (0)