Skip to content

Commit bda192c

Browse files
committed
[evals] Break trace label import cycle
1 parent ebd035e commit bda192c

4 files changed

Lines changed: 60 additions & 58 deletions

File tree

lib/levanter/src/levanter/data/text/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
GrugLmExample,
3131
LabeledLmExample,
3232
LOSS_IGNORE_LABEL,
33+
LossLabelSpec,
3334
LossLabelSpan,
3435
grug_attention_mask_from_named,
3536
grug_lm_example_from_named,
@@ -88,6 +89,7 @@
8889
"GrugLmExample",
8990
"LabeledLmExample",
9091
"LOSS_IGNORE_LABEL",
92+
"LossLabelSpec",
9193
"LossLabelSpan",
9294
"grug_attention_mask_from_named",
9395
"grug_lm_example_from_named",

lib/levanter/src/levanter/data/text/examples.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from dataclasses import dataclass
5-
from typing import Sequence
5+
from typing import Mapping, Sequence
66

77
import jax
88
import jax.numpy as jnp
@@ -35,6 +35,61 @@ class LossLabelSpan:
3535
label: int
3636

3737

38+
@dataclass(frozen=True)
39+
class LossLabelSpec:
40+
"""Names exclusive loss labels and defines metric rollups.
41+
42+
`id_to_name` names the leaf span types stored in `LabeledLmExample.loss_labels`.
43+
`aggregates` maps metric names to one or more leaf label ids, so callers can
44+
report both specific span types and rollups such as assistant = assistant
45+
text plus assistant tool calls. If aggregates is omitted, each non-ignored
46+
label id gets its own metric.
47+
"""
48+
49+
id_to_name: Mapping[int, str]
50+
aggregates: Mapping[str, Sequence[int]] | None = None
51+
dont_score_label: int = LOSS_IGNORE_LABEL
52+
53+
def __post_init__(self):
54+
for label_id, name in self.id_to_name.items():
55+
if not isinstance(label_id, int):
56+
raise TypeError(f"label id must be an int, got {label_id!r}")
57+
if not isinstance(name, str):
58+
raise TypeError(f"label name for id {label_id} must be a str, got {name!r}")
59+
if len(set(self.id_to_name.values())) != len(self.id_to_name):
60+
raise ValueError("label names must be unique")
61+
62+
for name, label_ids in self._aggregate_mapping().items():
63+
if not isinstance(name, str):
64+
raise TypeError(f"aggregate name must be a str, got {name!r}")
65+
if not label_ids:
66+
raise ValueError(f"aggregate {name!r} must include at least one label id")
67+
if self.dont_score_label in label_ids:
68+
raise ValueError(f"aggregate {name!r} includes dont_score_label={self.dont_score_label}")
69+
for label_id in label_ids:
70+
if not isinstance(label_id, int):
71+
raise TypeError(f"aggregate {name!r} label id must be an int, got {label_id!r}")
72+
if label_id not in self.id_to_name:
73+
raise ValueError(f"aggregate {name!r} references unknown label id {label_id}")
74+
75+
def _aggregate_mapping(self) -> Mapping[str, Sequence[int]]:
76+
if self.aggregates is not None:
77+
return self.aggregates
78+
return {
79+
label_name: (label_id,)
80+
for label_id, label_name in self.id_to_name.items()
81+
if label_id != self.dont_score_label
82+
}
83+
84+
@property
85+
def aggregate_names(self) -> tuple[str, ...]:
86+
return tuple(self._aggregate_mapping().keys())
87+
88+
@property
89+
def aggregate_label_ids(self) -> tuple[tuple[int, ...], ...]:
90+
return tuple(tuple(label_ids) for label_ids in self._aggregate_mapping().values())
91+
92+
3893
def loss_labels_from_spans(
3994
seq_len: int,
4095
spans: Sequence[LossLabelSpan],

lib/levanter/src/levanter/data/text/trace_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111

1212
from levanter.data._preprocessor import BatchProcessor
13-
from levanter.eval import LossLabelSpec
13+
from levanter.data.text.examples import LossLabelSpec
1414
from levanter.tokenizers import MarinTokenizer
1515

1616

lib/levanter/src/levanter/eval.py

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from levanter.callbacks import StepInfo
2828
from levanter.data import AsyncDataset, DataLoader
2929
from levanter.data.text.examples import (
30-
LOSS_IGNORE_LABEL,
3130
GrugLmExample,
3231
LabeledLmExample,
32+
LossLabelSpec,
3333
named_lm_example_from_grug,
3434
named_lm_example_from_labeled,
3535
)
@@ -55,61 +55,6 @@
5555
BatchedTagArray = Int[Array, "... tag"]
5656

5757

58-
@dataclasses.dataclass(frozen=True)
59-
class LossLabelSpec:
60-
"""Names exclusive loss labels and defines metric rollups.
61-
62-
`id_to_name` names the leaf span types stored in `LabeledLmExample.loss_labels`.
63-
`aggregates` maps metric names to one or more leaf label ids, so callers can
64-
report both specific span types and rollups such as assistant = assistant
65-
text plus assistant tool calls. If aggregates is omitted, each non-ignored
66-
label id gets its own metric.
67-
"""
68-
69-
id_to_name: Mapping[int, str]
70-
aggregates: Mapping[str, Sequence[int]] | None = None
71-
dont_score_label: int = LOSS_IGNORE_LABEL
72-
73-
def __post_init__(self):
74-
for label_id, name in self.id_to_name.items():
75-
if not isinstance(label_id, int):
76-
raise TypeError(f"label id must be an int, got {label_id!r}")
77-
if not isinstance(name, str):
78-
raise TypeError(f"label name for id {label_id} must be a str, got {name!r}")
79-
if len(set(self.id_to_name.values())) != len(self.id_to_name):
80-
raise ValueError("label names must be unique")
81-
82-
for name, label_ids in self._aggregate_mapping().items():
83-
if not isinstance(name, str):
84-
raise TypeError(f"aggregate name must be a str, got {name!r}")
85-
if not label_ids:
86-
raise ValueError(f"aggregate {name!r} must include at least one label id")
87-
if self.dont_score_label in label_ids:
88-
raise ValueError(f"aggregate {name!r} includes dont_score_label={self.dont_score_label}")
89-
for label_id in label_ids:
90-
if not isinstance(label_id, int):
91-
raise TypeError(f"aggregate {name!r} label id must be an int, got {label_id!r}")
92-
if label_id not in self.id_to_name:
93-
raise ValueError(f"aggregate {name!r} references unknown label id {label_id}")
94-
95-
def _aggregate_mapping(self) -> Mapping[str, Sequence[int]]:
96-
if self.aggregates is not None:
97-
return self.aggregates
98-
return {
99-
label_name: (label_id,)
100-
for label_id, label_name in self.id_to_name.items()
101-
if label_id != self.dont_score_label
102-
}
103-
104-
@property
105-
def aggregate_names(self) -> tuple[str, ...]:
106-
return tuple(self._aggregate_mapping().keys())
107-
108-
@property
109-
def aggregate_label_ids(self) -> tuple[tuple[int, ...], ...]:
110-
return tuple(tuple(label_ids) for label_ids in self._aggregate_mapping().values())
111-
112-
11358
@dataclasses.dataclass
11459
class EvalResult:
11560
micro_avg_loss: float # per token across all datasets

0 commit comments

Comments
 (0)