diff --git a/experiments/evals/long_tail_ppl.py b/experiments/evals/long_tail_ppl.py new file mode 100644 index 0000000000..b603bcb4e6 --- /dev/null +++ b/experiments/evals/long_tail_ppl.py @@ -0,0 +1,497 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Registry for the first-pass long-tail diagnostic PPL slices. + +This module is intentionally metadata-first. It records the initial family/source +coverage plan from epic #5005 and child issues #5056-#5062 without downloading or +mirroring any of the large corpora referenced there. +""" + +from __future__ import annotations + +import posixpath +from dataclasses import dataclass +from enum import StrEnum + +from marin.evaluation.perplexity_gap import RawTextEvaluationDataset, raw_text_dataset + +EPIC_5005 = 5005 +WEB_RAW_ISSUE = 5056 +BINARY_RAW_ISSUE = 5057 +BIO_CHEM_ISSUE = 5058 +TIME_SERIES_ISSUE = 5059 +FORMAL_HARDWARE_ISSUE = 5060 +PACKAGE_METADATA_ISSUE = 5061 +GAME_MUSIC_ISSUE = 5062 + + +class LongTailPplFamily(StrEnum): + """High-level coverage families for the long-tail PPL registry.""" + + WEB_MARKUP_IMAGE_TEXT = "web_markup_image_text" + BINARY_NETWORK_SECURITY = "binary_network_security" + BIO_CHEM = "bio_chem" + TIME_SERIES_TABLE_GEO = "time_series_table_geo" + FORMAL_HARDWARE = "formal_hardware" + PACKAGE_METADATA = "package_metadata" + GAME_MUSIC = "game_music" + + +@dataclass(frozen=True) +class LongTailPplSlice: + """A single diagnostic slice in the long-tail PPL registry.""" + + name: str + family: LongTailPplFamily + issue_number: int + source_url: str + surface_form: str + raw_relative_path: str + notes: str = "" + + @property + def registry_key(self) -> str: + return posixpath.join("long_tail_ppl", self.family.value, self.name) + + @property + def tags(self) -> tuple[str, ...]: + return ("long_tail_ppl", f"epic:{EPIC_5005}", f"issue:{self.issue_number}", self.family.value) + + def to_raw_text_dataset(self, raw_root: str) -> RawTextEvaluationDataset: + """Render the slice as a raw-text eval dataset rooted at ``raw_root``.""" + + return raw_text_dataset(posixpath.join(raw_root, self.raw_relative_path), tags=self.tags) + + +def _slice( + *, + name: str, + family: LongTailPplFamily, + issue_number: int, + source_url: str, + surface_form: str, + raw_relative_path: str, + notes: str = "", +) -> LongTailPplSlice: + return LongTailPplSlice( + name=name, + family=family, + issue_number=issue_number, + source_url=source_url, + surface_form=surface_form, + raw_relative_path=raw_relative_path, + notes=notes, + ) + + +LONG_TAIL_PPL_SLICES: tuple[LongTailPplSlice, ...] = ( + # Web / markup / image-text + _slice( + name="common_crawl_warc", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + issue_number=WEB_RAW_ISSUE, + source_url="https://commoncrawl.org/get-started", + surface_form="warc", + raw_relative_path="web/common_crawl/warc.jsonl.gz", + notes="Keep WARC headers, HTTP metadata, URLs, and raw response bodies.", + ), + _slice( + name="common_crawl_wat", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + issue_number=WEB_RAW_ISSUE, + source_url="https://commoncrawl.org/get-started", + surface_form="wat_json", + raw_relative_path="web/common_crawl/wat.jsonl.gz", + notes="Keep WAT JSON and extracted text without cleaning away structure.", + ), + _slice( + name="web_data_commons_web_tables", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + issue_number=WEB_RAW_ISSUE, + source_url="https://webdatacommons.org/webtables/englishTables.html", + surface_form="html_tables", + raw_relative_path="web/web_data_commons/web_tables.jsonl.gz", + notes="Preserve HTML tables, delimiters, and table metadata.", + ), + _slice( + name="svg_stack", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + issue_number=WEB_RAW_ISSUE, + source_url="https://huggingface.co/datasets/starvector/svg-stack", + surface_form="svg_xml", + raw_relative_path="web/svg_stack/svg_stack.jsonl.gz", + notes="Keep SVG XML, path data, and caption text intact.", + ), + _slice( + name="textocr", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + issue_number=WEB_RAW_ISSUE, + source_url="https://textvqa.org/textocr/", + surface_form="ocr_text", + raw_relative_path="web/textocr/textocr.jsonl.gz", + notes="Preserve OCR strings, annotations, and layout hints.", + ), + _slice( + name="ocr_vqa", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + issue_number=WEB_RAW_ISSUE, + source_url="https://ocr-vqa.github.io/", + surface_form="ocr_question_context", + raw_relative_path="web/ocr_vqa/ocr_vqa.jsonl.gz", + notes="Keep book-cover OCR text and question context surface forms.", + ), + _slice( + name="laion_metadata", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + issue_number=WEB_RAW_ISSUE, + source_url="https://laion.ai/blog/laion-400-open-dataset/", + surface_form="url_alt_text_metadata", + raw_relative_path="web/laion_metadata/laion_metadata.jsonl.gz", + notes="Treat as metadata only; later pipeline work should sample conservatively.", + ), + # Binary / network / security + _slice( + name="microsoft_malware_bytes", + family=LongTailPplFamily.BINARY_NETWORK_SECURITY, + issue_number=BINARY_RAW_ISSUE, + source_url="https://www.kaggle.com/c/malware-classification", + surface_form="hex_dump", + raw_relative_path="binary/microsoft_malware/bytes.jsonl.gz", + notes="Render binary as hex text only; preserve line breaks and offsets.", + ), + _slice( + name="microsoft_malware_asm", + family=LongTailPplFamily.BINARY_NETWORK_SECURITY, + issue_number=BINARY_RAW_ISSUE, + source_url="https://www.kaggle.com/c/malware-classification", + surface_form="disassembly_text", + raw_relative_path="binary/microsoft_malware/asm.jsonl.gz", + notes="Keep assembler syntax, labels, comments, and identifier casing.", + ), + _slice( + name="wireshark_rendered_text", + family=LongTailPplFamily.BINARY_NETWORK_SECURITY, + issue_number=BINARY_RAW_ISSUE, + source_url="https://wiki.wireshark.org/SampleCaptures", + surface_form="protocol_tree_text", + raw_relative_path="binary/wireshark/rendered_text.jsonl.gz", + notes="Use rendered protocol trees / hex views instead of raw PCAPs.", + ), + _slice( + name="mawi_zeek", + family=LongTailPplFamily.BINARY_NETWORK_SECURITY, + issue_number=BINARY_RAW_ISSUE, + source_url="https://mawi.wide.ad.jp/mawi/", + surface_form="network_flow_records", + raw_relative_path="binary/mawi/zeek.jsonl.gz", + notes="Keep flow records and timestamp / IP / port fields literal.", + ), + _slice( + name="cicids_flow", + family=LongTailPplFamily.BINARY_NETWORK_SECURITY, + issue_number=BINARY_RAW_ISSUE, + source_url="https://www.unb.ca/cic/datasets/ids-2017.html", + surface_form="flow_csv", + raw_relative_path="binary/cicids/flow.csv.jsonl.gz", + notes="Preserve CSV delimiters, labels, and flow statistics.", + ), + _slice( + name="uwf_zeek", + family=LongTailPplFamily.BINARY_NETWORK_SECURITY, + issue_number=BINARY_RAW_ISSUE, + source_url="https://datasets.uwf.edu/", + surface_form="zeek_logs", + raw_relative_path="binary/uwf/zeek.jsonl.gz", + notes="Preserve Zeek field names, hashes, IPs, and delimiter structure.", + ), + # Bio / chemistry + _slice( + name="refseq_fasta", + family=LongTailPplFamily.BIO_CHEM, + issue_number=BIO_CHEM_ISSUE, + source_url="https://www.ncbi.nlm.nih.gov/refseq/", + surface_form="fasta", + raw_relative_path="bio/refseq/fasta.jsonl.gz", + notes="Keep sequence IDs, wrapping, and nucleotide / amino-acid characters.", + ), + _slice( + name="refseq_gff", + family=LongTailPplFamily.BIO_CHEM, + issue_number=BIO_CHEM_ISSUE, + source_url="https://www.ncbi.nlm.nih.gov/refseq/", + surface_form="gff", + raw_relative_path="bio/refseq/gff.jsonl.gz", + notes="Preserve coordinate columns, attributes, and record boundaries.", + ), + _slice( + name="uniprot_fasta", + family=LongTailPplFamily.BIO_CHEM, + issue_number=BIO_CHEM_ISSUE, + source_url="https://www.ebi.ac.uk/uniprot/download-center", + surface_form="protein_fasta", + raw_relative_path="bio/uniprot/fasta.jsonl.gz", + notes="Keep UniProt headers and wrapped sequence bodies unchanged.", + ), + _slice( + name="pubchem_smiles", + family=LongTailPplFamily.BIO_CHEM, + issue_number=BIO_CHEM_ISSUE, + source_url="https://pubchem.ncbi.nlm.nih.gov/docs/downloads", + surface_form="smiles", + raw_relative_path="bio/pubchem/smiles.jsonl.gz", + notes="Preserve atom / bond notation, stereochemistry markers, and IDs.", + ), + _slice( + name="pubchem_sdf", + family=LongTailPplFamily.BIO_CHEM, + issue_number=BIO_CHEM_ISSUE, + source_url="https://pubchem.ncbi.nlm.nih.gov/docs/downloads", + surface_form="sdf", + raw_relative_path="bio/pubchem/sdf.jsonl.gz", + notes="Keep block separators, metadata fields, and record delimiters.", + ), + _slice( + name="rcsb_mmcif", + family=LongTailPplFamily.BIO_CHEM, + issue_number=BIO_CHEM_ISSUE, + source_url="https://www.rcsb.org/docs/programmatic-access/file-download-services", + surface_form="mmcif", + raw_relative_path="bio/rcsb/mmcif.jsonl.gz", + notes="Preserve crystallographic tags, atom tables, and field punctuation.", + ), + # Time-series / tables / geo + _slice( + name="monash_tsf", + family=LongTailPplFamily.TIME_SERIES_TABLE_GEO, + issue_number=TIME_SERIES_ISSUE, + source_url="https://forecastingdata.org/", + surface_form="tsf", + raw_relative_path="time_series/monash/tsf.jsonl.gz", + notes="Preserve horizon metadata, units, missing markers, and line layout.", + ), + _slice( + name="gittables_csv", + family=LongTailPplFamily.TIME_SERIES_TABLE_GEO, + issue_number=TIME_SERIES_ISSUE, + source_url="https://gittables.github.io/", + surface_form="csv_table", + raw_relative_path="time_series/gittables/csv.jsonl.gz", + notes="Keep CSV structure, headers, quoted cells, and cell delimiters.", + ), + _slice( + name="web_data_commons_tables", + family=LongTailPplFamily.TIME_SERIES_TABLE_GEO, + issue_number=TIME_SERIES_ISSUE, + source_url="https://webdatacommons.org/webtables/englishTables.html", + surface_form="html_csv_json_tables", + raw_relative_path="time_series/web_data_commons/tables.jsonl.gz", + notes="Preserve extracted table text, HTML table context, and JSON metadata.", + ), + _slice( + name="whos_on_first_geojson", + family=LongTailPplFamily.TIME_SERIES_TABLE_GEO, + issue_number=TIME_SERIES_ISSUE, + source_url="https://whosonfirst.org/download/", + surface_form="geojson", + raw_relative_path="time_series/whos_on_first/geojson.jsonl.gz", + notes="Keep feature IDs, coordinates, and nested metadata fields literal.", + ), + _slice( + name="openstreetmap_extract", + family=LongTailPplFamily.TIME_SERIES_TABLE_GEO, + issue_number=TIME_SERIES_ISSUE, + source_url="https://planet.openstreetmap.org/", + surface_form="osm_text", + raw_relative_path="time_series/openstreetmap/extract.jsonl.gz", + notes="Use a small textual extract; preserve tags, nodes, and relation structure.", + ), + # Formal methods / hardware + _slice( + name="smtlib", + family=LongTailPplFamily.FORMAL_HARDWARE, + issue_number=FORMAL_HARDWARE_ISSUE, + source_url="https://smt-lib.org/benchmarks.shtml", + surface_form="smt2", + raw_relative_path="formal/smtlib/smt2.jsonl.gz", + notes="Preserve solver syntax, symbols, comments, and status markers.", + ), + _slice( + name="tptp", + family=LongTailPplFamily.FORMAL_HARDWARE, + issue_number=FORMAL_HARDWARE_ISSUE, + source_url="https://www.tptp.org/", + surface_form="tptp", + raw_relative_path="formal/tptp/tptp.jsonl.gz", + notes="Keep theorem-proving problem syntax and generated identifiers intact.", + ), + _slice( + name="coqgym", + family=LongTailPplFamily.FORMAL_HARDWARE, + issue_number=FORMAL_HARDWARE_ISSUE, + source_url="https://github.com/princeton-vl/CoqGym", + surface_form="coq_proof_script", + raw_relative_path="formal/coqgym/coq.jsonl.gz", + notes="Keep proof scripts and proof-state text together for later pipeline work.", + ), + _slice( + name="dimacs_cnf", + family=LongTailPplFamily.FORMAL_HARDWARE, + issue_number=FORMAL_HARDWARE_ISSUE, + source_url="https://satcompetition.github.io/2022/benchmarks.html", + surface_form="cnf", + raw_relative_path="formal/dimacs/cnf.jsonl.gz", + notes="Preserve DIMACS headers, clauses, and comment lines.", + ), + _slice( + name="verilogeval", + family=LongTailPplFamily.FORMAL_HARDWARE, + issue_number=FORMAL_HARDWARE_ISSUE, + source_url="https://github.com/NVlabs/verilog-eval", + surface_form="verilog", + raw_relative_path="formal/verilogeval/verilog.jsonl.gz", + notes="Keep module boundaries, long symbols, and hardware-description syntax.", + ), + _slice( + name="rtl_repo", + family=LongTailPplFamily.FORMAL_HARDWARE, + issue_number=FORMAL_HARDWARE_ISSUE, + source_url="https://github.com/AUCOHL/RTL-Repo", + surface_form="verilog_repo_context", + raw_relative_path="formal/rtl_repo/verilog.jsonl.gz", + notes="Preserve repository-context completions and module-local identifiers.", + ), + _slice( + name="rtl_coder", + family=LongTailPplFamily.FORMAL_HARDWARE, + issue_number=FORMAL_HARDWARE_ISSUE, + source_url="https://github.com/hkust-zhiyao/RTL-Coder", + surface_form="verilog_instruction_text", + raw_relative_path="formal/rtl_coder/verilog.jsonl.gz", + notes="Keep instruction text, generated code, and hardware tokens together.", + ), + _slice( + name="hwmcc_aiger_btor", + family=LongTailPplFamily.FORMAL_HARDWARE, + issue_number=FORMAL_HARDWARE_ISSUE, + source_url="https://fmv.jku.at/hwmcc11/benchmarks.html", + surface_form="aiger_btor_text", + raw_relative_path="formal/hwmcc/aiger_btor.jsonl.gz", + notes="Use textual renderings only; preserve solver and model-checking syntax.", + ), + # Package metadata + _slice( + name="deps_dev", + family=LongTailPplFamily.PACKAGE_METADATA, + issue_number=PACKAGE_METADATA_ISSUE, + source_url="https://docs.deps.dev/bigquery/v1/", + surface_form="dependency_rows", + raw_relative_path="packages/deps_dev/rows.jsonl.gz", + notes="Preserve package names, semver constraints, hashes, and dependency edges.", + ), + _slice( + name="ecosystem_ms_libraries_io", + family=LongTailPplFamily.PACKAGE_METADATA, + issue_number=PACKAGE_METADATA_ISSUE, + source_url="https://repos.ecosyste.ms/open-data", + surface_form="ecosystem_metadata", + raw_relative_path="packages/ecosystems_ms/metadata.jsonl.gz", + notes="Keep repository/package metadata, licenses, and release records literal.", + ), + _slice( + name="npm_registry_metadata", + family=LongTailPplFamily.PACKAGE_METADATA, + issue_number=PACKAGE_METADATA_ISSUE, + source_url="https://docs.npmjs.com/policies/crawlers/", + surface_form="registry_json", + raw_relative_path="packages/npm/registry.jsonl.gz", + notes="Preserve CouchDB-style package JSON and nested version fields.", + ), + _slice( + name="package_lock_corpora", + family=LongTailPplFamily.PACKAGE_METADATA, + issue_number=PACKAGE_METADATA_ISSUE, + source_url="https://github.com/marin-community/marin/issues/4961", + surface_form="lockfile", + raw_relative_path="packages/package_lock/lockfiles.jsonl.gz", + notes="Later pipeline work should keep lockfile structure, URLs, and checksums intact.", + ), + # Game / music + _slice( + name="lichess_pgn", + family=LongTailPplFamily.GAME_MUSIC, + issue_number=GAME_MUSIC_ISSUE, + source_url="https://database.lichess.org/", + surface_form="pgn", + raw_relative_path="games/lichess/pgn.jsonl.gz", + notes="Keep move text, headers, comments, and result markers.", + ), + _slice( + name="kernscores_humdrum", + family=LongTailPplFamily.GAME_MUSIC, + issue_number=GAME_MUSIC_ISSUE, + source_url="https://kern.ccarh.org/", + surface_form="humdrum_kern", + raw_relative_path="music/kernscores/humdrum.jsonl.gz", + notes="Preserve **kern syntax, comments, and note boundaries.", + ), + _slice( + name="abc_notation", + family=LongTailPplFamily.GAME_MUSIC, + issue_number=GAME_MUSIC_ISSUE, + source_url="https://abcnotation.com/", + surface_form="abc_notation", + raw_relative_path="music/abc/notation.jsonl.gz", + notes="Keep ABC headers, barlines, and note-length annotations.", + ), +) + +LONG_TAIL_PPL_REGISTRY: dict[str, LongTailPplSlice] = {slice_.registry_key: slice_ for slice_ in LONG_TAIL_PPL_SLICES} + + +def long_tail_ppl_slices(*, family: LongTailPplFamily | None = None) -> tuple[LongTailPplSlice, ...]: + """Return all registered long-tail slices, optionally filtered by family.""" + + if family is None: + return LONG_TAIL_PPL_SLICES + return tuple(slice_ for slice_ in LONG_TAIL_PPL_SLICES if slice_.family == family) + + +def long_tail_raw_validation_sets( + raw_root: str = "raw/long_tail_ppl", + *, + family: LongTailPplFamily | None = None, +) -> dict[str, RawTextEvaluationDataset]: + """Materialize the registry into raw-text evaluation datasets. + + The returned datasets point at deterministic paths under ``raw_root``. The + registry itself does not download or mirror any corpus. + """ + + datasets: dict[str, RawTextEvaluationDataset] = {} + for slice_ in long_tail_ppl_slices(family=family): + datasets[slice_.registry_key] = slice_.to_raw_text_dataset(raw_root) + return datasets + + +def render_long_tail_ppl_registry_markdown(*, family: LongTailPplFamily | None = None) -> str: + """Render the registry as a compact markdown summary.""" + + lines = ["# Long-tail PPL registry", ""] + for current_family in LongTailPplFamily: + if family is not None and current_family != family: + continue + + family_slices = long_tail_ppl_slices(family=current_family) + if not family_slices: + continue + + lines.append(f"## {current_family.value}") + for slice_ in family_slices: + lines.append( + f"- `{slice_.registry_key}`: #{slice_.issue_number} | {slice_.surface_form} | {slice_.source_url}" + ) + if slice_.notes: + lines.append(f" - {slice_.notes}") + lines.append("") + return "\n".join(lines).rstrip() + "\n" diff --git a/experiments/evals/long_tail_ppl_runnable.py b/experiments/evals/long_tail_ppl_runnable.py new file mode 100644 index 0000000000..8f2b0bcf97 --- /dev/null +++ b/experiments/evals/long_tail_ppl_runnable.py @@ -0,0 +1,113 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Runnable first-pass long-tail PPL slices backed by public Hugging Face datasets.""" + +from __future__ import annotations + +import posixpath +from dataclasses import dataclass + +from marin.evaluation.perplexity_gap import RawTextEvaluationDataset, raw_text_dataset +from marin.processing.tokenize import HfDatasetSpec + +from experiments.evals.long_tail_ppl import LongTailPplFamily + +RUNNABLE_LONG_TAIL_SOURCE_NOTE = ( + "These slices are directly executable from public Hugging Face datasets and do not require a bulk mirror." +) + + +@dataclass(frozen=True) +class RunnableLongTailPplSlice: + """A runnable long-tail slice backed by a small public Hugging Face dataset.""" + + name: str + family: LongTailPplFamily + source_url: str + hf_dataset: HfDatasetSpec + text_key: str + split: str + notes: str = "" + + @property + def registry_key(self) -> str: + return posixpath.join("long_tail_ppl_runnable", self.family.value, self.name) + + @property + def tags(self) -> tuple[str, ...]: + return ("long_tail_ppl", "long_tail_ppl_runnable", self.family.value, f"split:{self.split}") + + def to_raw_text_dataset(self) -> RawTextEvaluationDataset: + return raw_text_dataset(self.hf_dataset, text_key=self.text_key, split=self.split, tags=self.tags) + + +RUNNABLE_LONG_TAIL_PPL_SLICES: tuple[RunnableLongTailPplSlice, ...] = ( + RunnableLongTailPplSlice( + name="svg_stack_val", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + source_url="https://huggingface.co/datasets/starvector/svg-stack", + hf_dataset=HfDatasetSpec(id="starvector/svg-stack"), + text_key="Svg", + split="val", + notes="Preserve SVG XML and caption-adjacent markup in the validation split.", + ), + RunnableLongTailPplSlice( + name="svg_stack_test", + family=LongTailPplFamily.WEB_MARKUP_IMAGE_TEXT, + source_url="https://huggingface.co/datasets/starvector/svg-stack", + hf_dataset=HfDatasetSpec(id="starvector/svg-stack"), + text_key="Svg", + split="test", + notes="Preserve SVG XML in the held-out test split.", + ), + RunnableLongTailPplSlice( + name="verilogeval_prompt", + family=LongTailPplFamily.FORMAL_HARDWARE, + source_url="https://huggingface.co/datasets/dakies/nvlabs-verilogeval", + hf_dataset=HfDatasetSpec(id="dakies/nvlabs-verilogeval"), + text_key="prompt", + split="test", + notes="Keep VerilogEval problem statements and interface text intact.", + ), + RunnableLongTailPplSlice( + name="verilogeval_canonical_solution", + family=LongTailPplFamily.FORMAL_HARDWARE, + source_url="https://huggingface.co/datasets/dakies/nvlabs-verilogeval", + hf_dataset=HfDatasetSpec(id="dakies/nvlabs-verilogeval"), + text_key="canonical_solution", + split="test", + notes="Keep VerilogEval reference implementations and formatting intact.", + ), +) + +RUNNABLE_LONG_TAIL_PPL_REGISTRY: dict[str, RunnableLongTailPplSlice] = { + slice_.registry_key: slice_ for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES +} + + +def runnable_long_tail_ppl_slices(*, family: LongTailPplFamily | None = None) -> tuple[RunnableLongTailPplSlice, ...]: + if family is None: + return RUNNABLE_LONG_TAIL_PPL_SLICES + return tuple(slice_ for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES if slice_.family == family) + + +def runnable_long_tail_raw_validation_sets() -> dict[str, RawTextEvaluationDataset]: + """Materialize the runnable HF-backed slices into raw-text datasets.""" + + return {slice_.registry_key: slice_.to_raw_text_dataset() for slice_ in RUNNABLE_LONG_TAIL_PPL_SLICES} + + +def render_runnable_long_tail_registry_markdown() -> str: + lines = ["# Runnable long-tail PPL registry", "", RUNNABLE_LONG_TAIL_SOURCE_NOTE, ""] + for current_family in LongTailPplFamily: + family_slices = runnable_long_tail_ppl_slices(family=current_family) + if not family_slices: + continue + lines.append(f"## {current_family.value}") + for slice_ in family_slices: + lines.append(f"- `{slice_.registry_key}`: split={slice_.split} | {slice_.text_key} | {slice_.source_url}") + if slice_.notes: + lines.append(f" - {slice_.notes}") + lines.append("") + return "\n".join(lines).rstrip() + "\n" diff --git a/experiments/exp_model_perplexity_gap_long_tail_runnable.py b/experiments/exp_model_perplexity_gap_long_tail_runnable.py new file mode 100644 index 0000000000..6b27f13705 --- /dev/null +++ b/experiments/exp_model_perplexity_gap_long_tail_runnable.py @@ -0,0 +1,84 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Run reusable long-tail perplexity-gap reports for epic #5005. + +See https://github.com/marin-community/marin/issues/5005. +""" + +from fray.v2.types import ResourceConfig + +from experiments.evals.long_tail_ppl_runnable import runnable_long_tail_raw_validation_sets +from marin.evaluation.perplexity_gap import GapFinderModelConfig, default_model_perplexity_gap +from marin.execution.executor import executor_main + +RESOURCE_CONFIG = ResourceConfig.with_tpu("v5p-8", regions=["us-central1"]) +MAX_DOCS_PER_DATASET = 256 +MAX_DOC_BYTES = 32_768 + +DATASETS = runnable_long_tail_raw_validation_sets() + +MARIN_MODEL = GapFinderModelConfig( + checkpoint_path="marin-community/marin-8b-base", + checkpoint_is_hf=True, + tokenizer="meta-llama/Llama-3.1-8B", +) + +MARIN_VS_LLAMA = default_model_perplexity_gap( + name="long-tail-runnable-marin-8b-base-vs-llama-3.1-8b-base-doccap256", + model_a=MARIN_MODEL, + model_b=GapFinderModelConfig( + checkpoint_path="meta-llama/Llama-3.1-8B", + checkpoint_is_hf=True, + tokenizer="meta-llama/Llama-3.1-8B", + ), + datasets=DATASETS, + resource_config=RESOURCE_CONFIG, + per_device_batch_size=4, + max_eval_length=4096, + max_docs_per_dataset=MAX_DOCS_PER_DATASET, + max_doc_bytes=MAX_DOC_BYTES, + wandb_tags=[ + "eval=perplexity-gap", + "rerun=long-tail-runnable-first-pass", + "model_a=marin-community/marin-8b-base", + "model_b=meta-llama/Llama-3.1-8B", + "dataset_bundle=runnable_long_tail_hf_backed", + "source_split=hf_dataset", + "region=us-central1", + f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}", + ], +) + +MARIN_VS_QWEN3 = default_model_perplexity_gap( + name="long-tail-runnable-marin-8b-base-vs-qwen3-8b-base-doccap256", + model_a=MARIN_MODEL, + model_b=GapFinderModelConfig( + checkpoint_path="Qwen/Qwen3-8B-Base", + checkpoint_is_hf=True, + tokenizer="Qwen/Qwen3-8B", + ), + datasets=DATASETS, + resource_config=RESOURCE_CONFIG, + per_device_batch_size=4, + max_eval_length=4096, + max_docs_per_dataset=MAX_DOCS_PER_DATASET, + max_doc_bytes=MAX_DOC_BYTES, + wandb_tags=[ + "eval=perplexity-gap", + "rerun=long-tail-runnable-first-pass", + "model_a=marin-community/marin-8b-base", + "model_b=Qwen/Qwen3-8B-Base", + "dataset_bundle=runnable_long_tail_hf_backed", + "source_split=hf_dataset", + "region=us-central1", + f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}", + ], +) + + +if __name__ == "__main__": + executor_main( + [MARIN_VS_LLAMA, MARIN_VS_QWEN3], + description="Run Marin perplexity-gap reports on runnable first-pass long-tail PPL slices.", + ) diff --git a/lib/levanter/src/levanter/analysis/perplexity_gap.py b/lib/levanter/src/levanter/analysis/perplexity_gap.py index 246350e2fd..29bf9a52ae 100644 --- a/lib/levanter/src/levanter/analysis/perplexity_gap.py +++ b/lib/levanter/src/levanter/analysis/perplexity_gap.py @@ -431,7 +431,7 @@ def iter_raw_text_documents( "Gap finding currently supports TextLmDatasetFormat only." ) - source = component.source.get_shard_source("validation") + source = component.source.get_shard_source(component.split) if source is None: continue diff --git a/lib/levanter/src/levanter/data/text/datasets.py b/lib/levanter/src/levanter/data/text/datasets.py index 9e7ddabbf8..175547c769 100644 --- a/lib/levanter/src/levanter/data/text/datasets.py +++ b/lib/levanter/src/levanter/data/text/datasets.py @@ -335,6 +335,7 @@ class DatasetComponent(DatasetComponentBase): format: LmDatasetFormatBase = field(default_factory=TextLmDatasetFormat) pack: bool | int | Literal["pad"] | None = None tags: list[str] | None = None + split: str = "validation" @DatasetComponentBase.register_subclass("direct") diff --git a/lib/marin/src/marin/evaluation/perplexity_gap.py b/lib/marin/src/marin/evaluation/perplexity_gap.py index da59009b89..7d1f43dba0 100644 --- a/lib/marin/src/marin/evaluation/perplexity_gap.py +++ b/lib/marin/src/marin/evaluation/perplexity_gap.py @@ -40,6 +40,7 @@ class RawTextEvaluationDataset: hf_dataset_id: str | None = None hf_dataset_name: str | None = None text_key: str = "text" + split: str = "validation" tags: tuple[str, ...] = () @@ -63,6 +64,7 @@ def raw_text_dataset( source: str | InputName | ExecutorStep | HfDatasetSpec, *, text_key: str = "text", + split: str = "validation", tags: tuple[str, ...] = (), ) -> RawTextEvaluationDataset: if isinstance(source, HfDatasetSpec): @@ -70,9 +72,12 @@ def raw_text_dataset( hf_dataset_id=source.id, hf_dataset_name=source.name, text_key=text_key, + split=split, tags=tags, ) - return RawTextEvaluationDataset(input_path=source, text_key=text_key, tags=tags) + if split != "validation": + raise ValueError("split is only supported for Hugging Face dataset sources; file paths use validation.") + return RawTextEvaluationDataset(input_path=source, text_key=text_key, split=split, tags=tags) def default_model_perplexity_gap( @@ -184,10 +189,13 @@ def _to_dataset_component(config: RawTextEvaluationDataset) -> DatasetComponent: id=config.hf_dataset_id, name=config.hf_dataset_name, format=dataset_format, + splits=[config.split], ) else: if config.input_path is None: raise ValueError("RawTextEvaluationDataset requires either input_path or hf_dataset_id.") + if config.split != "validation": + raise ValueError("RawTextEvaluationDataset split is only supported for Hugging Face dataset sources.") input_path = config.input_path if isinstance(input_path, ExecutorStep): input_path = input_path.as_input_name() @@ -196,7 +204,7 @@ def _to_dataset_component(config: RawTextEvaluationDataset) -> DatasetComponent: validation_urls=[input_path], # type: ignore[list-item] format=dataset_format, ) - return DatasetComponent(source=source, format=dataset_format, tags=list(config.tags)) + return DatasetComponent(source=source, format=dataset_format, tags=list(config.tags), split=config.split) def _default_step_name(model_a: GapFinderModelConfig, model_b: GapFinderModelConfig) -> str: @@ -234,5 +242,6 @@ def _cache_key_for_dataset(dataset: RawTextEvaluationDataset) -> dict[str, Any]: "hf_dataset_id": dataset.hf_dataset_id, "hf_dataset_name": dataset.hf_dataset_name, "text_key": dataset.text_key, + "split": dataset.split, "tags": dataset.tags, } diff --git a/tests/evals/test_long_tail_ppl.py b/tests/evals/test_long_tail_ppl.py new file mode 100644 index 0000000000..f00e2c8b63 --- /dev/null +++ b/tests/evals/test_long_tail_ppl.py @@ -0,0 +1,52 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from experiments.evals.long_tail_ppl import ( + GAME_MUSIC_ISSUE, + LongTailPplFamily, + long_tail_raw_validation_sets, + render_long_tail_ppl_registry_markdown, +) +from levanter.data.text import HfDatasetSourceConfig +from marin.evaluation.perplexity_gap import _to_dataset_component, raw_text_dataset +from marin.processing.tokenize import HfDatasetSpec + + +def test_long_tail_raw_validation_sets_render_deterministic_paths_and_tags(): + datasets = long_tail_raw_validation_sets(raw_root="gs://example-bucket/raw/long_tail") + + web_key = "long_tail_ppl/web_markup_image_text/common_crawl_warc" + game_key = "long_tail_ppl/game_music/lichess_pgn" + + assert datasets[web_key].input_path == "gs://example-bucket/raw/long_tail/web/common_crawl/warc.jsonl.gz" + assert datasets[web_key].tags == ("long_tail_ppl", "epic:5005", "issue:5056", "web_markup_image_text") + assert datasets[web_key].text_key == "text" + + assert datasets[game_key].input_path == "gs://example-bucket/raw/long_tail/games/lichess/pgn.jsonl.gz" + assert datasets[game_key].tags == ("long_tail_ppl", "epic:5005", f"issue:{GAME_MUSIC_ISSUE}", "game_music") + + +def test_long_tail_registry_rendering_mentions_issue_links(): + markdown = render_long_tail_ppl_registry_markdown(family=LongTailPplFamily.GAME_MUSIC) + + assert "long_tail_ppl/game_music/lichess_pgn" in markdown + assert "database.lichess.org" in markdown + assert f"#{GAME_MUSIC_ISSUE}" in markdown + + +def test_hf_backed_raw_dataset_preserves_requested_split(): + dataset = raw_text_dataset(HfDatasetSpec(id="example/dataset"), text_key="body", split="test") + + component = _to_dataset_component(dataset) + + assert component.split == "test" + assert component.format.text_key == "body" + assert isinstance(component.source, HfDatasetSourceConfig) + assert component.source.splits == ["test"] + + +def test_file_backed_raw_dataset_rejects_non_validation_split(): + with pytest.raises(ValueError, match="Hugging Face dataset sources"): + raw_text_dataset("gs://example-bucket/eval.jsonl", split="test")