Skip to content

Commit 9ec4ccd

Browse files
authored
[levanter] Defer token boundary rendering (#5023)
Move token-boundary rendering for perplexity-gap literal examples out of the document scan hot path. Long-document reports now keep compact token span metadata while scoring and render boundaries only for literals that survive into the final report. Adds regression coverage for the lazy rendering contract.
1 parent dab18a5 commit 9ec4ccd

2 files changed

Lines changed: 193 additions & 29 deletions

File tree

lib/levanter/src/levanter/analysis/perplexity_gap.py

Lines changed: 129 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,32 @@ class TokenizedChunk:
6363
byte_ends: np.ndarray
6464

6565

66+
@dataclass(frozen=True)
67+
class _TokenBoundarySpans:
68+
byte_starts: tuple[int, ...]
69+
byte_ends: tuple[int, ...]
70+
71+
6672
@dataclass(frozen=True)
6773
class LiteralExample:
6874
abs_delta_bits: float
75+
dataset_name: str
76+
doc_preview: str
77+
segment_text: str
78+
segment_byte_start: int
79+
segment_byte_end: int
80+
model_a_token_spans: _TokenBoundarySpans
81+
model_b_token_spans: _TokenBoundarySpans
82+
83+
84+
@dataclass(frozen=True)
85+
class _TokenBoundaryIndex:
86+
byte_starts: np.ndarray
87+
byte_ends: np.ndarray
88+
89+
90+
@dataclass(frozen=True)
91+
class _RenderedLiteralExample:
6992
dataset_name: str
7093
doc_preview: str
7194
model_a_token_boundaries: str
@@ -175,6 +198,8 @@ def add_document(
175198
byte_offsets = char_to_byte_offsets(document.text)
176199
worst_positive_segment: WorstSegment | None = None
177200
worst_negative_segment: WorstSegment | None = None
201+
token_boundary_index_a = _token_boundary_index(tokenized_a) if tokenized_a is not None else None
202+
token_boundary_index_b = _token_boundary_index(tokenized_b) if tokenized_b is not None else None
178203

179204
for match in _SEGMENT_RE.finditer(document.text):
180205
segment = match.group(0)
@@ -217,7 +242,7 @@ def add_document(
217242
loss_b=segment_loss_b,
218243
num_bytes=segment_bytes,
219244
)
220-
if tokenized_a is not None and tokenized_b is not None:
245+
if token_boundary_index_a is not None and token_boundary_index_b is not None:
221246
self._maybe_record_literal_example(
222247
literal_key=literal_key,
223248
document=document,
@@ -227,8 +252,8 @@ def add_document(
227252
segment_char_start=match.start(),
228253
segment_char_end=match.end(),
229254
segment_delta_bits=segment_delta_bits,
230-
tokenized_a=tokenized_a,
231-
tokenized_b=tokenized_b,
255+
token_boundary_index_a=token_boundary_index_a,
256+
token_boundary_index_b=token_boundary_index_b,
232257
)
233258

234259
if segment_delta_bits == 0.0:
@@ -363,29 +388,32 @@ def _maybe_record_literal_example(
363388
segment_char_start: int,
364389
segment_char_end: int,
365390
segment_delta_bits: float,
366-
tokenized_a: TokenizedDocument,
367-
tokenized_b: TokenizedDocument,
391+
token_boundary_index_a: _TokenBoundaryIndex,
392+
token_boundary_index_b: _TokenBoundaryIndex,
368393
) -> None:
369-
candidate = LiteralExample(
370-
abs_delta_bits=abs(segment_delta_bits),
394+
abs_delta_bits = abs(segment_delta_bits)
395+
current = self.literal_examples.get(literal_key)
396+
if current is not None and abs_delta_bits <= current.abs_delta_bits:
397+
return
398+
399+
self.literal_examples[literal_key] = LiteralExample(
400+
abs_delta_bits=abs_delta_bits,
371401
dataset_name=document.dataset_name,
372402
doc_preview=preview_text_window(document.text, segment_char_start, segment_char_end),
373-
model_a_token_boundaries=render_token_boundaries(
374-
segment_text=segment_text,
403+
segment_text=segment_text,
404+
segment_byte_start=segment_byte_start,
405+
segment_byte_end=segment_byte_end,
406+
model_a_token_spans=_overlapping_token_spans(
407+
token_boundary_index_a,
375408
segment_byte_start=segment_byte_start,
376409
segment_byte_end=segment_byte_end,
377-
tokenized=tokenized_a,
378410
),
379-
model_b_token_boundaries=render_token_boundaries(
380-
segment_text=segment_text,
411+
model_b_token_spans=_overlapping_token_spans(
412+
token_boundary_index_b,
381413
segment_byte_start=segment_byte_start,
382414
segment_byte_end=segment_byte_end,
383-
tokenized=tokenized_b,
384415
),
385416
)
386-
current = self.literal_examples.get(literal_key)
387-
if current is None or candidate.abs_delta_bits > current.abs_delta_bits:
388-
self.literal_examples[literal_key] = candidate
389417

390418

391419
def iter_raw_text_documents(
@@ -655,6 +683,68 @@ def render_token_boundaries(
655683
return "|" + "|".join(pieces) + "|"
656684

657685

686+
def _token_boundary_index(tokenized: TokenizedDocument) -> _TokenBoundaryIndex:
687+
valid = (tokenized.byte_starts >= 0) & (tokenized.byte_ends > tokenized.byte_starts)
688+
return _TokenBoundaryIndex(
689+
byte_starts=tokenized.byte_starts[valid].astype(np.int64, copy=False),
690+
byte_ends=tokenized.byte_ends[valid].astype(np.int64, copy=False),
691+
)
692+
693+
694+
def _overlapping_token_spans(
695+
token_boundary_index: _TokenBoundaryIndex,
696+
*,
697+
segment_byte_start: int,
698+
segment_byte_end: int,
699+
) -> _TokenBoundarySpans:
700+
first_index = int(np.searchsorted(token_boundary_index.byte_ends, segment_byte_start, side="right"))
701+
byte_starts: list[int] = []
702+
byte_ends: list[int] = []
703+
for token_start, token_end in zip(
704+
token_boundary_index.byte_starts[first_index:],
705+
token_boundary_index.byte_ends[first_index:],
706+
strict=True,
707+
):
708+
if token_start >= segment_byte_end:
709+
break
710+
if token_end <= segment_byte_start:
711+
continue
712+
byte_starts.append(int(token_start))
713+
byte_ends.append(int(token_end))
714+
715+
return _TokenBoundarySpans(byte_starts=tuple(byte_starts), byte_ends=tuple(byte_ends))
716+
717+
718+
def _tokenized_document_from_boundary_spans(token_spans: _TokenBoundarySpans) -> TokenizedDocument:
719+
byte_starts = np.asarray(token_spans.byte_starts, dtype=np.int32)
720+
byte_ends = np.asarray(token_spans.byte_ends, dtype=np.int32)
721+
return TokenizedDocument(
722+
token_ids=np.zeros(len(byte_starts), dtype=np.int32),
723+
byte_starts=byte_starts,
724+
byte_ends=byte_ends,
725+
num_bytes=0,
726+
)
727+
728+
729+
def _render_literal_example(example: LiteralExample) -> _RenderedLiteralExample:
730+
return _RenderedLiteralExample(
731+
dataset_name=example.dataset_name,
732+
doc_preview=example.doc_preview,
733+
model_a_token_boundaries=render_token_boundaries(
734+
segment_text=example.segment_text,
735+
segment_byte_start=example.segment_byte_start,
736+
segment_byte_end=example.segment_byte_end,
737+
tokenized=_tokenized_document_from_boundary_spans(example.model_a_token_spans),
738+
),
739+
model_b_token_boundaries=render_token_boundaries(
740+
segment_text=example.segment_text,
741+
segment_byte_start=example.segment_byte_start,
742+
segment_byte_end=example.segment_byte_end,
743+
tokenized=_tokenized_document_from_boundary_spans(example.model_b_token_spans),
744+
),
745+
)
746+
747+
658748
def bucket_for_segment(segment: str) -> str:
659749
if segment.isspace():
660750
if "\t" in segment or "\r" in segment:
@@ -810,31 +900,41 @@ def _top_literal_rows(
810900
direction: str,
811901
limit: int,
812902
) -> list[dict[str, Any]]:
813-
rows: list[dict[str, Any]] = []
814-
for (bucket, literal), stats in literal_stats.items():
815-
example = literal_examples.get((bucket, literal))
903+
rows: list[tuple[tuple[str, str], dict[str, Any]]] = []
904+
for literal_key, stats in literal_stats.items():
905+
bucket, literal = literal_key
816906
stats_row = stats.as_dict(literal)
817907
row = {
818908
"name": stats_row["name"],
819909
"bucket": bucket,
820-
"example_dataset": example.dataset_name if example is not None else None,
821-
"example_doc_preview": example.doc_preview if example is not None else None,
822-
"model_a_token_boundaries": example.model_a_token_boundaries if example is not None else None,
823-
"model_b_token_boundaries": example.model_b_token_boundaries if example is not None else None,
824910
"documents": stats_row["documents"],
825911
"bytes": stats_row["bytes"],
826912
"model_a_bpb": stats_row["model_a_bpb"],
827913
"model_b_bpb": stats_row["model_b_bpb"],
828914
"gap_bpb": stats_row["gap_bpb"],
829915
"delta_bits": stats_row["delta_bits"],
830916
}
831-
rows.append(row)
917+
rows.append((literal_key, row))
832918

833919
if direction == "positive":
834-
rows = [row for row in rows if row["delta_bits"] > 0]
835-
rows.sort(key=lambda row: row["delta_bits"], reverse=True)
920+
rows = [(literal_key, row) for literal_key, row in rows if row["delta_bits"] > 0]
921+
rows.sort(key=lambda item: item[1]["delta_bits"], reverse=True)
836922
else:
837-
rows = [row for row in rows if row["delta_bits"] < 0]
838-
rows.sort(key=lambda row: row["delta_bits"])
923+
rows = [(literal_key, row) for literal_key, row in rows if row["delta_bits"] < 0]
924+
rows.sort(key=lambda item: item[1]["delta_bits"])
925+
926+
selected_rows: list[dict[str, Any]] = []
927+
for literal_key, row in rows[:limit]:
928+
example = literal_examples.get(literal_key)
929+
rendered_example = _render_literal_example(example) if example is not None else None
930+
row["example_dataset"] = rendered_example.dataset_name if rendered_example is not None else None
931+
row["example_doc_preview"] = rendered_example.doc_preview if rendered_example is not None else None
932+
row["model_a_token_boundaries"] = (
933+
rendered_example.model_a_token_boundaries if rendered_example is not None else None
934+
)
935+
row["model_b_token_boundaries"] = (
936+
rendered_example.model_b_token_boundaries if rendered_example is not None else None
937+
)
938+
selected_rows.append(row)
839939

840-
return rows[:limit]
940+
return selected_rows

lib/levanter/tests/test_perplexity_gap.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import haliax
1515

16+
import levanter.analysis.perplexity_gap as gap_analysis
1617
from levanter.analysis.perplexity_gap import (
1718
GapReportBuilder,
1819
RawTextDocument,
@@ -127,6 +128,69 @@ def test_gap_report_builder_records_per_model_literal_boundaries():
127128
assert literal_row["example_dataset"] == "paloma/example"
128129

129130

131+
def test_gap_report_builder_renders_literal_boundaries_only_for_reported_literals(monkeypatch):
132+
report = GapReportBuilder(model_a_name="a", model_b_name="b", output_path="/tmp/report", top_k_literals=1)
133+
calls: list[str] = []
134+
135+
def fake_render_token_boundaries(**kwargs):
136+
calls.append(kwargs["segment_text"])
137+
return f"|{kwargs['segment_text']}|"
138+
139+
monkeypatch.setattr(gap_analysis, "render_token_boundaries", fake_render_token_boundaries)
140+
141+
weaker_document = RawTextDocument(
142+
dataset_name="paloma/example",
143+
tags=("paloma/example",),
144+
shard_name="docs",
145+
row_index=0,
146+
text="aaa",
147+
)
148+
stronger_document = RawTextDocument(
149+
dataset_name="paloma/example",
150+
tags=("paloma/example",),
151+
shard_name="docs",
152+
row_index=1,
153+
text="bbb",
154+
)
155+
tokenized_weaker = TokenizedDocument(
156+
token_ids=np.asarray([1], dtype=np.int32),
157+
byte_starts=np.asarray([0], dtype=np.int32),
158+
byte_ends=np.asarray([3], dtype=np.int32),
159+
num_bytes=3,
160+
)
161+
tokenized_stronger = TokenizedDocument(
162+
token_ids=np.asarray([2], dtype=np.int32),
163+
byte_starts=np.asarray([0], dtype=np.int32),
164+
byte_ends=np.asarray([3], dtype=np.int32),
165+
num_bytes=3,
166+
)
167+
168+
report.add_document(
169+
document=weaker_document,
170+
per_byte_loss_a=np.asarray([1.0, 1.0, 1.0], dtype=np.float64),
171+
per_byte_loss_b=np.asarray([0.0, 0.0, 0.0], dtype=np.float64),
172+
tokenized_a=tokenized_weaker,
173+
tokenized_b=tokenized_weaker,
174+
)
175+
report.add_document(
176+
document=stronger_document,
177+
per_byte_loss_a=np.asarray([2.0, 2.0, 2.0], dtype=np.float64),
178+
per_byte_loss_b=np.asarray([0.0, 0.0, 0.0], dtype=np.float64),
179+
tokenized_a=tokenized_stronger,
180+
tokenized_b=tokenized_stronger,
181+
)
182+
183+
assert calls == []
184+
185+
summary = report.build_summary()
186+
literal_rows = summary["top_literals"]["model_a_worse"]
187+
188+
assert [row["name"] for row in literal_rows] == ["bbb"]
189+
assert literal_rows[0]["model_a_token_boundaries"] == "|bbb|"
190+
assert literal_rows[0]["model_b_token_boundaries"] == "|bbb|"
191+
assert calls == ["bbb", "bbb"]
192+
193+
130194
def test_gap_report_builder_previews_worst_region():
131195
report = GapReportBuilder(model_a_name="a", model_b_name="b", output_path="/tmp/report")
132196
prefix = "safe " * 40

0 commit comments

Comments
 (0)