Skip to content

[levanter] Add labeled eval spans#5723

Merged
dlwh merged 2 commits into
mainfrom
codex/labeled-eval-spans
May 14, 2026
Merged

[levanter] Add labeled eval spans#5723
dlwh merged 2 commits into
mainfrom
codex/labeled-eval-spans

Conversation

@dlwh
Copy link
Copy Markdown
Member

@dlwh dlwh commented May 14, 2026

Add an exclusive integer label contract for per-token LM evaluation, plus a LabeledEvaluator that aggregates losses and BPB over named label groups. This separates training loss weights from evaluation annotations and gives future trace-style evals a generic span-label target.

@dlwh dlwh added the agent-generated Created by automation/agent label May 14, 2026
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 14, 2026

🤖 Specification

Problem
levanter.eval only had dataset-level tags and per-position loss weights. That made training masks and evaluation annotations share the same channel, and there was no generic way to score named token/span categories such as assistant text, tool calls, or other trace regions. The relevant surfaces are lib/levanter/src/levanter/eval.py and lib/levanter/src/levanter/data/text/examples.py.

Approach
Add LabeledLmExample with exclusive integer loss_labels aligned to next-token loss positions, plus LossLabelSpan and loss_labels_from_spans for non-overlapping span construction. Add LossLabelSpec to map label ids to named aggregates, and LabeledEvaluator to aggregate loss, token counts, and BPB over those aggregate label sets. This keeps loss_weight as the training contract and makes labels the evaluation annotation contract.

Key code
LabeledEvaluator precomputes padded aggregate label ids, then in the jitted accumulator compares batch labels against each aggregate, reduces per-aggregate weights, and feeds RunningMean with token-count weighting. The default LM path converts LabeledLmExample to LmExample by scoring nonzero labels while returning labels separately for evaluation aggregation.

Tests
lib/levanter/tests/test_eval.py covers aggregate math over exclusive labels, default labeled LM model evaluation, LossLabelSpec validation, and overlapping span rejection. The branch also passed ./infra/pre-commit.py --all-files --fix and the focused eval test module.

Comment thread lib/levanter/src/levanter/eval.py Outdated
Pos = model_pos.resize(batch.tokens.shape[0])
return named_lm_example_from_labeled(batch, Pos=Pos)
if batch.tokens.ndim == 2:
Pos = model_pos.resize(batch.tokens.shape[1])
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be easiest to just pass axis names instead?

Comment thread lib/levanter/src/levanter/eval.py Outdated
eval_current: bool = True,
eval_ema: bool = True,
) -> Callable[[StepInfo], None]:
"""Build a callback that logs labeled eval metrics for current and/or eval model."""
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ema model is different from eval model


for i in range(vocab_size):
bytes[i] = byte_length_of_token(tokenizer, i)
class LabeledEvaluator(Generic[Ex, M]):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc comment

Comment thread lib/levanter/src/levanter/eval.py Outdated
self.accum_for_batch = self._make_accum_for_batch()

@classmethod
def from_labeled_lm(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function name weird. there is no labeled lm right?

@dlwh dlwh marked this pull request as ready for review May 14, 2026 23:49
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 14, 2026

Claude finished @dlwh's task in 8m 52s —— View job


Code review

No high-signal issues found. Checked the diff for bugs and AGENTS.md/CLAUDE.md compliance.

What I checked

  • Aggregate / BPB math in LabeledEvaluator._make_accum_for_batch (
    )
    self.device_mesh = device_mesh
    self.tokenizer = tokenizer
    self.axis_mapping = axis_mapping
    self.aggregate_names = label_spec.aggregate_names
    self.aggregate_label_ids = self._padded_aggregate_label_ids(label_spec)
    self.per_pos_out_sharding = None
    if device_mesh is not None and axis_mapping is not None:
    batch_axis_resource = axis_mapping.get(EvalBatch.name, axis_mapping.get("batch"))
    if batch_axis_resource is not None and axis_resource_is_explicit(device_mesh, batch_axis_resource):
    self.per_pos_out_sharding = NamedSharding(device_mesh, P(batch_axis_resource, None))
    self.bytes_per_token = _calculate_bytes_per_token_type(tokenizer)
    self.accum_for_batch = self._make_accum_for_batch()
    @classmethod
    def for_labeled_examples(
    cls,
    EvalBatch: hax.Axis | int,
    eval_set: AsyncDataset[LabeledLmExample],
    label_spec: LossLabelSpec,
    *,
    tokenizer: Optional[MarinTokenizer] = None,
    device_mesh=None,
    axis_mapping=None,
    mp: jmp.Policy = None,
    ) -> "LabeledEvaluator[LabeledLmExample, LmHeadModel]":
    if isinstance(EvalBatch, int):
    EvalBatch = hax.Axis("batch", EvalBatch)
    resolved_eval_batch = EvalBatch
    def loss_fn(model: LmHeadModel, batch: LabeledLmExample) -> LabeledLossFnOutput:
    return _default_labeled_lm_eval_loss_fn(model, batch, EvalBatch=resolved_eval_batch, mp=mp)
    return cls(
    EvalBatch=resolved_eval_batch,
    eval_set=eval_set,
    label_spec=label_spec,
    loss_fn=loss_fn,
    tokenizer=tokenizer,
    device_mesh=device_mesh,
    axis_mapping=axis_mapping,
    )
    def _make_accum_for_batch(self) -> Callable[[M, "_LabeledEvalRunningMeans", Ex], "_LabeledEvalRunningMeans"]:
    bytes_per_token = self.bytes_per_token
    aggregate_label_ids = self.aggregate_label_ids
    valid_label_ids = aggregate_label_ids >= 0
    log2e = jnp.log2(jnp.e)
    per_pos_out_sharding = self.per_pos_out_sharding
    @hax.named_jit(axis_resources=self.axis_mapping)
    def accum_for_batch(model: M, state: _LabeledEvalRunningMeans, batch: Ex):
    losses, labels, token_ids = self.loss_fn(model, batch)
    ): broadcast-and-mask over padded aggregate_label_ids with valid_label_ids = aggregate_label_ids >= 0 correctly excludes the -1 padding, safe_mean + token-count weighting is sound, and the BPB formula sum_loss / max(bytes,1) * log2(e) weighted by bytes_per_label aggregates to the correct overall bits-per-byte across batches.
  • _LabeledEvalRunningMeans.zeros_like (
    @staticmethod
    def zeros_like(per_label: Float[Array, "label"]) -> "_LabeledEvalRunningMeans":
    per_label_mean = RunningMean.zeros_like(per_label)
    return _LabeledEvalRunningMeans(per_label_mean, per_label_mean)
    ) returns the same RunningMean instance for both fields. Fine in practice — RunningMean is a frozen eqx.Module and the fields diverge through dataclasses.replace.
  • loss_labels_from_spans rejects negative starts, out-of-range ends, empty/inverted spans, and overlaps before writing labels — covered by test_loss_labels_from_spans_rejects_overlapping_spans.
  • LossLabelSpec.__post_init__ rejects aggregates that name unknown ids or include dont_score_label; covered by test_loss_label_spec_validates_aggregates.
  • Default labeled LM path: named_lm_example_from_labeled builds loss_weight from loss_labels != LOSS_IGNORE_LABEL (or from scored_labels when set), so positions with ignored labels do not contribute loss; aggregate validation guarantees no aggregate references the ignore label, so that masking is consistent with eval aggregation.
  • Tests use np.testing.assert_raises_regex as a context manager — supported (it delegates to unittest.TestCase.assertRaisesRegex).

Reviewer comments already on the PR
Your four self-review comments cover the items I'd otherwise raise (EvalBatch axis vs name; _ensure_named_labeled_lm_example naming / there is no "labeled lm"; eval_model ≠ ema model in cb_labeled_evaluate; doc-comment at the LabeledEvaluator site). No need to duplicate.
· Branch codex/labeled-eval-spans

Copy link
Copy Markdown
Member

@Helw150 Helw150 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This is very clean

@dlwh dlwh merged commit b6fbb13 into main May 14, 2026
32 of 33 checks passed
@dlwh dlwh deleted the codex/labeled-eval-spans branch May 14, 2026 23:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants