Skip to content

Commit dad6f61

Browse files
committed
Add long-tail PPL gap rerun registry
1 parent 4f0728d commit dad6f61

8 files changed

Lines changed: 872 additions & 3 deletions

File tree

experiments/evals/long_tail_ppl.py

Lines changed: 497 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Runnable first-pass long-tail PPL slices backed by public Hugging Face datasets."""
5+
6+
from __future__ import annotations
7+
8+
import posixpath
9+
from dataclasses import dataclass
10+
11+
from marin.evaluation.perplexity_gap import RawTextEvaluationDataset, raw_text_dataset
12+
from marin.processing.tokenize import HfDatasetSpec
13+
14+
from experiments.evals.long_tail_ppl import LongTailPplFamily
15+
16+
RUNNABLE_LONG_TAIL_SOURCE_NOTE = (
17+
"These slices are directly executable from public Hugging Face datasets and do not require a bulk mirror."
18+
)
19+
20+
21+
@dataclass(frozen=True)
22+
class RunnableLongTailPplSlice:
23+
"""A runnable long-tail slice backed by a small public Hugging Face dataset."""
24+
25+
name: str
26+
family: LongTailPplFamily
27+
source_url: str
28+
hf_dataset: HfDatasetSpec
29+
text_key: str
30+
split: str
31+
notes: str = ""
32+
33+
@property
34+
def registry_key(self) -> str:
35+
return posixpath.join("long_tail_ppl_runnable", self.family.value, self.name)
36+
37+
@property
38+
def tags(self) -> tuple[str, ...]:
39+
return ("long_tail_ppl", "long_tail_ppl_runnable", self.family.value, f"split:{self.split}")
40+
41+
def to_raw_text_dataset(self) -> RawTextEvaluationDataset:
42+
return raw_text_dataset(self.hf_dataset, text_key=self.text_key, split=self.split, tags=self.tags)
43+
44+
45+
RUNNABLE_LONG_TAIL_PPL_SLICES: tuple[RunnableLongTailPplSlice, ...] = (
46+
RunnableLongTailPplSlice(
47+
name="svg_stack_val",
48+
family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT,
49+
source_url="https://huggingface.co/datasets/starvector/svg-stack",
50+
hf_dataset=HfDatasetSpec(id="starvector/svg-stack"),
51+
text_key="Svg",
52+
split="val",
53+
notes="Preserve SVG XML and caption-adjacent markup in the validation split.",
54+
),
55+
RunnableLongTailPplSlice(
56+
name="svg_stack_test",
57+
family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT,
58+
source_url="https://huggingface.co/datasets/starvector/svg-stack",
59+
hf_dataset=HfDatasetSpec(id="starvector/svg-stack"),
60+
text_key="Svg",
61+
split="test",
62+
notes="Preserve SVG XML in the held-out test split.",
63+
),
64+
RunnableLongTailPplSlice(
65+
name="verilogeval_prompt",
66+
family=LongTailPplFamily.FORMAL_HARDWARE,
67+
source_url="https://huggingface.co/datasets/dakies/nvlabs-verilogeval",
68+
hf_dataset=HfDatasetSpec(id="dakies/nvlabs-verilogeval"),
69+
text_key="prompt",
70+
split="test",
71+
notes="Keep VerilogEval problem statements and interface text intact.",
72+
),
73+
RunnableLongTailPplSlice(
74+
name="verilogeval_canonical_solution",
75+
family=LongTailPplFamily.FORMAL_HARDWARE,
76+
source_url="https://huggingface.co/datasets/dakies/nvlabs-verilogeval",
77+
hf_dataset=HfDatasetSpec(id="dakies/nvlabs-verilogeval"),
78+
text_key="canonical_solution",
79+
split="test",
80+
notes="Keep VerilogEval reference implementations and formatting intact.",
81+
),
82+
)
83+
84+
RUNNABLE_LONG_TAIL_PPL_REGISTRY: dict[str, RunnableLongTailPplSlice] = {
85+
slice_.registry_key: slice_ for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES
86+
}
87+
88+
89+
def runnable_long_tail_ppl_slices(*, family: LongTailPplFamily | None = None) -> tuple[RunnableLongTailPplSlice, ...]:
90+
if family is None:
91+
return RUNNABLE_LONG_TAIL_PPL_SLICES
92+
return tuple(slice_ for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES if slice_.family == family)
93+
94+
95+
def runnable_long_tail_raw_validation_sets() -> dict[str, RawTextEvaluationDataset]:
96+
"""Materialize the runnable HF-backed slices into raw-text datasets."""
97+
98+
return {slice_.registry_key: slice_.to_raw_text_dataset() for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES}
99+
100+
101+
def render_runnable_long_tail_registry_markdown() -> str:
102+
lines = ["# Runnable long-tail PPL registry", "", RUNNABLE_LONG_TAIL_SOURCE_NOTE, ""]
103+
for current_family in LongTailPplFamily:
104+
family_slices = runnable_long_tail_ppl_slices(family=current_family)
105+
if not family_slices:
106+
continue
107+
lines.append(f"## {current_family.value}")
108+
for slice_ in family_slices:
109+
lines.append(f"- `{slice_.registry_key}`: split={slice_.split} | {slice_.text_key} | {slice_.source_url}")
110+
if slice_.notes:
111+
lines.append(f" - {slice_.notes}")
112+
lines.append("")
113+
return "\n".join(lines).rstrip() + "\n"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from fray.v2.types import ResourceConfig
5+
6+
from experiments.evals.long_tail_ppl_runnable import runnable_long_tail_raw_validation_sets
7+
from marin.evaluation.perplexity_gap import GapFinderModelConfig, default_model_perplexity_gap
8+
from marin.execution.executor import executor_main
9+
10+
RESOURCE_CONFIG = ResourceConfig.with_tpu("v5p-8", regions=["us-central1"])
11+
MAX_DOCS_PER_DATASET = 256
12+
MAX_DOC_BYTES = 32_768
13+
14+
DATASETS = runnable_long_tail_raw_validation_sets()
15+
16+
MARIN_MODEL = GapFinderModelConfig(
17+
checkpoint_path="marin-community/marin-8b-base",
18+
checkpoint_is_hf=True,
19+
tokenizer="meta-llama/Llama-3.1-8B",
20+
)
21+
22+
MARIN_VS_LLAMA = default_model_perplexity_gap(
23+
name="long-tail-runnable-marin-8b-base-vs-llama-3.1-8b-base-doccap256",
24+
model_a=MARIN_MODEL,
25+
model_b=GapFinderModelConfig(
26+
checkpoint_path="meta-llama/Llama-3.1-8B",
27+
checkpoint_is_hf=True,
28+
tokenizer="meta-llama/Llama-3.1-8B",
29+
),
30+
datasets=DATASETS,
31+
resource_config=RESOURCE_CONFIG,
32+
per_device_batch_size=4,
33+
max_eval_length=4096,
34+
max_docs_per_dataset=MAX_DOCS_PER_DATASET,
35+
max_doc_bytes=MAX_DOC_BYTES,
36+
wandb_tags=[
37+
"eval=perplexity-gap",
38+
"rerun=long-tail-runnable-first-pass",
39+
"model_a=marin-community/marin-8b-base",
40+
"model_b=meta-llama/Llama-3.1-8B",
41+
"dataset_bundle=runnable_long_tail_hf_backed",
42+
"source_split=hf_dataset",
43+
"region=us-central1",
44+
f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}",
45+
],
46+
)
47+
48+
MARIN_VS_QWEN3 = default_model_perplexity_gap(
49+
name="long-tail-runnable-marin-8b-base-vs-qwen3-8b-base-doccap256",
50+
model_a=MARIN_MODEL,
51+
model_b=GapFinderModelConfig(
52+
checkpoint_path="Qwen/Qwen3-8B-Base",
53+
checkpoint_is_hf=True,
54+
tokenizer="Qwen/Qwen3-8B",
55+
),
56+
datasets=DATASETS,
57+
resource_config=RESOURCE_CONFIG,
58+
per_device_batch_size=4,
59+
max_eval_length=4096,
60+
max_docs_per_dataset=MAX_DOCS_PER_DATASET,
61+
max_doc_bytes=MAX_DOC_BYTES,
62+
wandb_tags=[
63+
"eval=perplexity-gap",
64+
"rerun=long-tail-runnable-first-pass",
65+
"model_a=marin-community/marin-8b-base",
66+
"model_b=Qwen/Qwen3-8B-Base",
67+
"dataset_bundle=runnable_long_tail_hf_backed",
68+
"source_split=hf_dataset",
69+
"region=us-central1",
70+
f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}",
71+
],
72+
)
73+
74+
75+
if __name__ == "__main__":
76+
executor_main(
77+
[MARIN_VS_LLAMA, MARIN_VS_QWEN3],
78+
description="Run Marin perplexity-gap reports on runnable first-pass long-tail PPL slices.",
79+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from fray.v2.types import ResourceConfig
5+
6+
from experiments.evals.long_tail_ppl_runnable import runnable_long_tail_raw_validation_sets
7+
from marin.evaluation.perplexity_gap import GapFinderModelConfig, default_model_perplexity_gap
8+
from marin.execution.executor import executor_main
9+
10+
RESOURCE_CONFIG = ResourceConfig.with_tpu("v5p-8", regions=["us-central1"])
11+
MAX_DOCS_PER_DATASET = 32
12+
MAX_DOC_BYTES = 32_768
13+
14+
DATASETS = runnable_long_tail_raw_validation_sets()
15+
16+
MARIN_MODEL = GapFinderModelConfig(
17+
checkpoint_path="marin-community/marin-8b-base",
18+
checkpoint_is_hf=True,
19+
tokenizer="meta-llama/Llama-3.1-8B",
20+
)
21+
22+
STEP = default_model_perplexity_gap(
23+
name="long-tail-smoke-marin-8b-base-vs-llama-3.1-8b-base-doccap32",
24+
model_a=MARIN_MODEL,
25+
model_b=GapFinderModelConfig(
26+
checkpoint_path="meta-llama/Llama-3.1-8B",
27+
checkpoint_is_hf=True,
28+
tokenizer="meta-llama/Llama-3.1-8B",
29+
),
30+
datasets=DATASETS,
31+
resource_config=RESOURCE_CONFIG,
32+
per_device_batch_size=4,
33+
max_eval_length=4096,
34+
max_docs_per_dataset=MAX_DOCS_PER_DATASET,
35+
max_doc_bytes=MAX_DOC_BYTES,
36+
wandb_tags=[
37+
"eval=perplexity-gap",
38+
"smoke=long-tail-ppl",
39+
"source_split=hf_dataset",
40+
"region=us-central1",
41+
"dataset_bundle=runnable_long_tail_hf_backed",
42+
f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}",
43+
],
44+
)
45+
46+
47+
if __name__ == "__main__":
48+
executor_main([STEP], description="Smoke-run runnable long-tail PPL slices from public Hugging Face datasets.")

lib/levanter/src/levanter/analysis/perplexity_gap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def iter_raw_text_documents(
431431
"Gap finding currently supports TextLmDatasetFormat only."
432432
)
433433

434-
source = component.source.get_shard_source("validation")
434+
source = component.source.get_shard_source(component.split)
435435
if source is None:
436436
continue
437437

lib/levanter/src/levanter/data/text/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ class DatasetComponent(DatasetComponentBase):
335335
format: LmDatasetFormatBase = field(default_factory=TextLmDatasetFormat)
336336
pack: bool | int | Literal["pad"] | None = None
337337
tags: list[str] | None = None
338+
split: str = "validation"
338339

339340

340341
@DatasetComponentBase.register_subclass("direct")

lib/marin/src/marin/evaluation/perplexity_gap.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class RawTextEvaluationDataset:
4040
hf_dataset_id: str | None = None
4141
hf_dataset_name: str | None = None
4242
text_key: str = "text"
43+
split: str = "validation"
4344
tags: tuple[str, ...] = ()
4445

4546

@@ -63,16 +64,20 @@ def raw_text_dataset(
6364
source: str | InputName | ExecutorStep | HfDatasetSpec,
6465
*,
6566
text_key: str = "text",
67+
split: str = "validation",
6668
tags: tuple[str, ...] = (),
6769
) -> RawTextEvaluationDataset:
6870
if isinstance(source, HfDatasetSpec):
6971
return RawTextEvaluationDataset(
7072
hf_dataset_id=source.id,
7173
hf_dataset_name=source.name,
7274
text_key=text_key,
75+
split=split,
7376
tags=tags,
7477
)
75-
return RawTextEvaluationDataset(input_path=source, text_key=text_key, tags=tags)
78+
if split != "validation":
79+
raise ValueError("split is only supported for Hugging Face dataset sources; file paths use validation.")
80+
return RawTextEvaluationDataset(input_path=source, text_key=text_key, split=split, tags=tags)
7681

7782

7883
def default_model_perplexity_gap(
@@ -184,10 +189,13 @@ def _to_dataset_component(config: RawTextEvaluationDataset) -> DatasetComponent:
184189
id=config.hf_dataset_id,
185190
name=config.hf_dataset_name,
186191
format=dataset_format,
192+
splits=[config.split],
187193
)
188194
else:
189195
if config.input_path is None:
190196
raise ValueError("RawTextEvaluationDataset requires either input_path or hf_dataset_id.")
197+
if config.split != "validation":
198+
raise ValueError("RawTextEvaluationDataset split is only supported for Hugging Face dataset sources.")
191199
input_path = config.input_path
192200
if isinstance(input_path, ExecutorStep):
193201
input_path = input_path.as_input_name()
@@ -196,7 +204,7 @@ def _to_dataset_component(config: RawTextEvaluationDataset) -> DatasetComponent:
196204
validation_urls=[input_path], # type: ignore[list-item]
197205
format=dataset_format,
198206
)
199-
return DatasetComponent(source=source, format=dataset_format, tags=list(config.tags))
207+
return DatasetComponent(source=source, format=dataset_format, tags=list(config.tags), split=config.split)
200208

201209

202210
def _default_step_name(model_a: GapFinderModelConfig, model_b: GapFinderModelConfig) -> str:
@@ -234,5 +242,6 @@ def _cache_key_for_dataset(dataset: RawTextEvaluationDataset) -> dict[str, Any]:
234242
"hf_dataset_id": dataset.hf_dataset_id,
235243
"hf_dataset_name": dataset.hf_dataset_name,
236244
"text_key": dataset.text_key,
245+
"split": dataset.split,
237246
"tags": dataset.tags,
238247
}

0 commit comments

Comments
 (0)