|
| 1 | +# Copyright The Marin Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Runnable first-pass long-tail PPL slices backed by public Hugging Face datasets.""" |
| 5 | + |
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import posixpath |
| 9 | +from dataclasses import dataclass |
| 10 | + |
| 11 | +from marin.evaluation.perplexity_gap import RawTextEvaluationDataset, raw_text_dataset |
| 12 | +from marin.processing.tokenize import HfDatasetSpec |
| 13 | + |
| 14 | +from experiments.evals.long_tail_ppl import LongTailPplFamily |
| 15 | + |
| 16 | +RUNNABLE_LONG_TAIL_SOURCE_NOTE = ( |
| 17 | + "These slices are directly executable from public Hugging Face datasets and do not require a bulk mirror." |
| 18 | +) |
| 19 | + |
| 20 | + |
| 21 | +@dataclass(frozen=True) |
| 22 | +class RunnableLongTailPplSlice: |
| 23 | + """A runnable long-tail slice backed by a small public Hugging Face dataset.""" |
| 24 | + |
| 25 | + name: str |
| 26 | + family: LongTailPplFamily |
| 27 | + source_url: str |
| 28 | + hf_dataset: HfDatasetSpec |
| 29 | + text_key: str |
| 30 | + split: str |
| 31 | + notes: str = "" |
| 32 | + |
| 33 | + @property |
| 34 | + def registry_key(self) -> str: |
| 35 | + return posixpath.join("long_tail_ppl_runnable", self.family.value, self.name) |
| 36 | + |
| 37 | + @property |
| 38 | + def tags(self) -> tuple[str, ...]: |
| 39 | + return ("long_tail_ppl", "long_tail_ppl_runnable", self.family.value, f"split:{self.split}") |
| 40 | + |
| 41 | + def to_raw_text_dataset(self) -> RawTextEvaluationDataset: |
| 42 | + return raw_text_dataset(self.hf_dataset, text_key=self.text_key, split=self.split, tags=self.tags) |
| 43 | + |
| 44 | + |
| 45 | +RUNNABLE_LONG_TAIL_PPL_SLICES: tuple[RunnableLongTailPplSlice, ...] = ( |
| 46 | + RunnableLongTailPplSlice( |
| 47 | + name="svg_stack_val", |
| 48 | + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, |
| 49 | + source_url="https://huggingface.co/datasets/starvector/svg-stack", |
| 50 | + hf_dataset=HfDatasetSpec(id="starvector/svg-stack"), |
| 51 | + text_key="Svg", |
| 52 | + split="val", |
| 53 | + notes="Preserve SVG XML and caption-adjacent markup in the validation split.", |
| 54 | + ), |
| 55 | + RunnableLongTailPplSlice( |
| 56 | + name="svg_stack_test", |
| 57 | + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, |
| 58 | + source_url="https://huggingface.co/datasets/starvector/svg-stack", |
| 59 | + hf_dataset=HfDatasetSpec(id="starvector/svg-stack"), |
| 60 | + text_key="Svg", |
| 61 | + split="test", |
| 62 | + notes="Preserve SVG XML in the held-out test split.", |
| 63 | + ), |
| 64 | + RunnableLongTailPplSlice( |
| 65 | + name="verilogeval_prompt", |
| 66 | + family=LongTailPplFamily.FORMAL_HARDWARE, |
| 67 | + source_url="https://huggingface.co/datasets/dakies/nvlabs-verilogeval", |
| 68 | + hf_dataset=HfDatasetSpec(id="dakies/nvlabs-verilogeval"), |
| 69 | + text_key="prompt", |
| 70 | + split="test", |
| 71 | + notes="Keep VerilogEval problem statements and interface text intact.", |
| 72 | + ), |
| 73 | + RunnableLongTailPplSlice( |
| 74 | + name="verilogeval_canonical_solution", |
| 75 | + family=LongTailPplFamily.FORMAL_HARDWARE, |
| 76 | + source_url="https://huggingface.co/datasets/dakies/nvlabs-verilogeval", |
| 77 | + hf_dataset=HfDatasetSpec(id="dakies/nvlabs-verilogeval"), |
| 78 | + text_key="canonical_solution", |
| 79 | + split="test", |
| 80 | + notes="Keep VerilogEval reference implementations and formatting intact.", |
| 81 | + ), |
| 82 | +) |
| 83 | + |
| 84 | +RUNNABLE_LONG_TAIL_PPL_REGISTRY: dict[str, RunnableLongTailPplSlice] = { |
| 85 | + slice_.registry_key: slice_ for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES |
| 86 | +} |
| 87 | + |
| 88 | + |
| 89 | +def runnable_long_tail_ppl_slices(*, family: LongTailPplFamily | None = None) -> tuple[RunnableLongTailPplSlice, ...]: |
| 90 | + if family is None: |
| 91 | + return RUNNABLE_LONG_TAIL_PPL_SLICES |
| 92 | + return tuple(slice_ for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES if slice_.family == family) |
| 93 | + |
| 94 | + |
| 95 | +def runnable_long_tail_raw_validation_sets() -> dict[str, RawTextEvaluationDataset]: |
| 96 | + """Materialize the runnable HF-backed slices into raw-text datasets.""" |
| 97 | + |
| 98 | + return {slice_.registry_key: slice_.to_raw_text_dataset() for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES} |
| 99 | + |
| 100 | + |
| 101 | +def render_runnable_long_tail_registry_markdown() -> str: |
| 102 | + lines = ["# Runnable long-tail PPL registry", "", RUNNABLE_LONG_TAIL_SOURCE_NOTE, ""] |
| 103 | + for current_family in LongTailPplFamily: |
| 104 | + family_slices = runnable_long_tail_ppl_slices(family=current_family) |
| 105 | + if not family_slices: |
| 106 | + continue |
| 107 | + lines.append(f"## {current_family.value}") |
| 108 | + for slice_ in family_slices: |
| 109 | + lines.append(f"- `{slice_.registry_key}`: split={slice_.split} | {slice_.text_key} | {slice_.source_url}") |
| 110 | + if slice_.notes: |
| 111 | + lines.append(f" - {slice_.notes}") |
| 112 | + lines.append("") |
| 113 | + return "\n".join(lines).rstrip() + "\n" |
0 commit comments