|
27 | 27 | from levanter.callbacks import StepInfo |
28 | 28 | from levanter.data import AsyncDataset, DataLoader |
29 | 29 | from levanter.data.text.examples import ( |
30 | | - LOSS_IGNORE_LABEL, |
31 | 30 | GrugLmExample, |
32 | 31 | LabeledLmExample, |
| 32 | + LossLabelSpec, |
33 | 33 | named_lm_example_from_grug, |
34 | 34 | named_lm_example_from_labeled, |
35 | 35 | ) |
|
55 | 55 | BatchedTagArray = Int[Array, "... tag"] |
56 | 56 |
|
57 | 57 |
|
58 | | -@dataclasses.dataclass(frozen=True) |
59 | | -class LossLabelSpec: |
60 | | - """Names exclusive loss labels and defines metric rollups. |
61 | | -
|
62 | | - `id_to_name` names the leaf span types stored in `LabeledLmExample.loss_labels`. |
63 | | - `aggregates` maps metric names to one or more leaf label ids, so callers can |
64 | | - report both specific span types and rollups such as assistant = assistant |
65 | | - text plus assistant tool calls. If aggregates is omitted, each non-ignored |
66 | | - label id gets its own metric. |
67 | | - """ |
68 | | - |
69 | | - id_to_name: Mapping[int, str] |
70 | | - aggregates: Mapping[str, Sequence[int]] | None = None |
71 | | - dont_score_label: int = LOSS_IGNORE_LABEL |
72 | | - |
73 | | - def __post_init__(self): |
74 | | - for label_id, name in self.id_to_name.items(): |
75 | | - if not isinstance(label_id, int): |
76 | | - raise TypeError(f"label id must be an int, got {label_id!r}") |
77 | | - if not isinstance(name, str): |
78 | | - raise TypeError(f"label name for id {label_id} must be a str, got {name!r}") |
79 | | - if len(set(self.id_to_name.values())) != len(self.id_to_name): |
80 | | - raise ValueError("label names must be unique") |
81 | | - |
82 | | - for name, label_ids in self._aggregate_mapping().items(): |
83 | | - if not isinstance(name, str): |
84 | | - raise TypeError(f"aggregate name must be a str, got {name!r}") |
85 | | - if not label_ids: |
86 | | - raise ValueError(f"aggregate {name!r} must include at least one label id") |
87 | | - if self.dont_score_label in label_ids: |
88 | | - raise ValueError(f"aggregate {name!r} includes dont_score_label={self.dont_score_label}") |
89 | | - for label_id in label_ids: |
90 | | - if not isinstance(label_id, int): |
91 | | - raise TypeError(f"aggregate {name!r} label id must be an int, got {label_id!r}") |
92 | | - if label_id not in self.id_to_name: |
93 | | - raise ValueError(f"aggregate {name!r} references unknown label id {label_id}") |
94 | | - |
95 | | - def _aggregate_mapping(self) -> Mapping[str, Sequence[int]]: |
96 | | - if self.aggregates is not None: |
97 | | - return self.aggregates |
98 | | - return { |
99 | | - label_name: (label_id,) |
100 | | - for label_id, label_name in self.id_to_name.items() |
101 | | - if label_id != self.dont_score_label |
102 | | - } |
103 | | - |
104 | | - @property |
105 | | - def aggregate_names(self) -> tuple[str, ...]: |
106 | | - return tuple(self._aggregate_mapping().keys()) |
107 | | - |
108 | | - @property |
109 | | - def aggregate_label_ids(self) -> tuple[tuple[int, ...], ...]: |
110 | | - return tuple(tuple(label_ids) for label_ids in self._aggregate_mapping().values()) |
111 | | - |
112 | | - |
113 | 58 | @dataclasses.dataclass |
114 | 59 | class EvalResult: |
115 | 60 | micro_avg_loss: float # per token across all datasets |
|
0 commit comments