-
Notifications
You must be signed in to change notification settings - Fork 111
[evals] Add capped ASR/OCR noisy-text perplexity-gap slices #5118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Copyright The Marin Authors | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Opt-in ASR/OCR noisy-text raw eval slices for perplexity-gap reports. | ||
|
|
||
| This module materializes paired noisy/clean text from ASR and OCR sources, then | ||
| registers both variants as raw-text datasets so gap reports can compute deltas. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import os | ||
| from collections.abc import Iterable, Mapping, Sequence | ||
| from dataclasses import dataclass, field | ||
| from enum import StrEnum | ||
|
|
||
| import fsspec | ||
| from datasets import load_dataset | ||
| from fray.v2 import ResourceConfig | ||
| from levanter.utils import fsspec_utils | ||
|
|
||
| from marin.evaluation.perplexity_gap import RawTextEvaluationDataset, raw_text_dataset | ||
| from marin.execution.executor import ExecutorStep, this_output_path | ||
| from marin.execution.remote import remote | ||
| from marin.processing.tokenize import HfDatasetSpec | ||
|
|
||
| ASR_OCR_NOISY_DATASET_ROOT = "asr_ocr_noisy_ppl" | ||
| NOISY_TEXT_FIELD = "noisy_text" | ||
| CLEAN_TEXT_FIELD = "clean_text" | ||
| DEFAULT_RAW_SHARD_NAME = "data-00000-of-00001.jsonl.gz" | ||
|
|
||
|
|
||
| class NoisyTextFamily(StrEnum): | ||
| ASR = "asr" | ||
| OCR = "ocr" | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class NoisyTextSlice: | ||
| registry_name: str | ||
| family: NoisyTextFamily | ||
| source_url: str | ||
| hf_dataset: HfDatasetSpec | ||
| split: str | ||
| noisy_key: str | ||
| clean_key: str | ||
| max_rows: int | ||
| notes: str = "" | ||
|
|
||
| @property | ||
| def tags(self) -> tuple[str, ...]: | ||
| return (ASR_OCR_NOISY_DATASET_ROOT, f"family:{self.family.value}", f"source:{self.registry_name}") | ||
|
|
||
|
|
||
| ASR_OCR_NOISY_SLICES: tuple[NoisyTextSlice, ...] = ( | ||
| NoisyTextSlice( | ||
| registry_name="hypr_librispeech_without_lm_test_clean", | ||
| family=NoisyTextFamily.ASR, | ||
| source_url="https://huggingface.co/datasets/ASR-HypR/LibriSpeech_withoutLM", | ||
| hf_dataset=HfDatasetSpec(id="ASR-HypR/LibriSpeech_withoutLM"), | ||
| split="test_clean", | ||
| noisy_key="hyps", | ||
| clean_key="ref", | ||
| max_rows=512, | ||
| notes=( | ||
| "HypR exposes n-best ASR hypotheses per utterance. We linearize top-1 for noisy_text and keep ref " | ||
| "as clean_text. Verify downstream use remains compatible with LibriSpeech-derived licensing terms." | ||
| ), | ||
| ), | ||
| NoisyTextSlice( | ||
| registry_name="rtm_sgt_ocr_v1_train", | ||
| family=NoisyTextFamily.OCR, | ||
| source_url="https://huggingface.co/datasets/ReadingTimeMachine/rtm-sgt-ocr-v1", | ||
| hf_dataset=HfDatasetSpec(id="ReadingTimeMachine/rtm-sgt-ocr-v1"), | ||
| split="train", | ||
| noisy_key="source", | ||
| clean_key="target", | ||
| max_rows=512, | ||
| notes=( | ||
| "ReadingTimeMachine OCR post-correction pairs may inherit source-specific archival rights. Treat as " | ||
| "eval-only until redistribution terms are reviewed per source collection." | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class NoisyAsrOcrRawConfig: | ||
| output_path: str = field(default_factory=this_output_path) # type: ignore[arg-type] | ||
| max_rows_per_slice_override: int | None = None | ||
| slices: tuple[NoisyTextSlice, ...] = ASR_OCR_NOISY_SLICES | ||
|
|
||
|
|
||
| def _coerce_text(value: object) -> str | None: | ||
| if isinstance(value, str): | ||
| return value if value.strip() else None | ||
| if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)): | ||
| for item in value: | ||
| text = _coerce_text(item) | ||
| if text is not None: | ||
| return text | ||
| return None | ||
| return None | ||
|
|
||
|
|
||
| def linearize_noisy_clean_row( | ||
| row: Mapping[str, object], | ||
| *, | ||
| noisy_key: str, | ||
| clean_key: str, | ||
| ) -> dict[str, str] | None: | ||
| """Extract paired noisy/clean text fields from one source row.""" | ||
| noisy_text = _coerce_text(row.get(noisy_key)) | ||
| clean_text = _coerce_text(row.get(clean_key)) | ||
| if noisy_text is None or clean_text is None: | ||
| return None | ||
| return {NOISY_TEXT_FIELD: noisy_text, CLEAN_TEXT_FIELD: clean_text} | ||
|
|
||
|
|
||
| def _iter_linearized_rows(slice_: NoisyTextSlice) -> Iterable[dict[str, str]]: | ||
| dataset = load_dataset( | ||
| slice_.hf_dataset.id, | ||
| name=slice_.hf_dataset.name, | ||
| split=slice_.split, | ||
| streaming=True, | ||
| ) | ||
| for row in dataset: | ||
| linearized = linearize_noisy_clean_row(row, noisy_key=slice_.noisy_key, clean_key=slice_.clean_key) | ||
| if linearized is not None: | ||
| yield linearized | ||
|
|
||
|
|
||
| def _slice_output_path(output_path: str, registry_name: str) -> str: | ||
| return os.path.join(output_path, registry_name, DEFAULT_RAW_SHARD_NAME) | ||
|
|
||
|
|
||
| def materialize_noisy_asr_ocr_raw(config: NoisyAsrOcrRawConfig) -> None: | ||
| """Materialize paired noisy/clean text rows into jsonl.gz shards.""" | ||
| fsspec_utils.mkdirs(config.output_path) | ||
| for slice_ in config.slices: | ||
| output_file = _slice_output_path(config.output_path, slice_.registry_name) | ||
| fsspec_utils.mkdirs(os.path.dirname(output_file)) | ||
| row_cap = slice_.max_rows if config.max_rows_per_slice_override is None else config.max_rows_per_slice_override | ||
| if row_cap <= 0: | ||
| raise ValueError(f"row cap must be positive, got {row_cap}.") | ||
| with fsspec.open(output_file, "wt", compression="gzip") as sink: | ||
| for index, record in enumerate(_iter_linearized_rows(slice_)): | ||
| if index >= row_cap: | ||
| break | ||
| sink.write(json.dumps(record, ensure_ascii=True)) | ||
| sink.write("\n") | ||
|
|
||
|
|
||
| noisy_asr_ocr_raw = ExecutorStep( | ||
| name=os.path.join("raw", "evals", ASR_OCR_NOISY_DATASET_ROOT), | ||
| description="Materialize paired ASR/OCR noisy-clean raw eval slices from Hugging Face.", | ||
| fn=remote( | ||
| materialize_noisy_asr_ocr_raw, | ||
| resources=ResourceConfig.with_cpu(cpu=4, ram="32g", disk="40g"), | ||
| pip_dependency_groups=["cpu"], | ||
| ), | ||
| config=NoisyAsrOcrRawConfig(), | ||
| ) | ||
|
|
||
|
|
||
| def noisy_asr_ocr_raw_validation_sets( | ||
| *, | ||
| noisy_asr_ocr_raw: ExecutorStep = noisy_asr_ocr_raw, | ||
| ) -> dict[str, RawTextEvaluationDataset]: | ||
| """Register clean and noisy variants for each ASR/OCR raw slice.""" | ||
| datasets: dict[str, RawTextEvaluationDataset] = {} | ||
| for slice_ in ASR_OCR_NOISY_SLICES: | ||
| raw_pattern = os.path.join(slice_.registry_name, "data-*.jsonl.gz") | ||
| key_root = os.path.join(ASR_OCR_NOISY_DATASET_ROOT, slice_.registry_name) | ||
|
|
||
| datasets[os.path.join(key_root, "noisy")] = raw_text_dataset( | ||
| noisy_asr_ocr_raw.cd(raw_pattern), | ||
| text_key=NOISY_TEXT_FIELD, | ||
| tags=(*slice_.tags, "variant:noisy"), | ||
| ) | ||
| datasets[os.path.join(key_root, "clean")] = raw_text_dataset( | ||
| noisy_asr_ocr_raw.cd(raw_pattern), | ||
| text_key=CLEAN_TEXT_FIELD, | ||
| tags=(*slice_.tags, "variant:clean"), | ||
| ) | ||
|
|
||
| return datasets |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # Copyright The Marin Authors | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Run opt-in ASR/OCR noisy-text perplexity-gap reports for issue #5097.""" | ||
|
|
||
| from fray.v2.types import ResourceConfig | ||
|
|
||
| from experiments.defaults import default_raw_validation_sets | ||
| from experiments.evals.asr_ocr_noisy_ppl import noisy_asr_ocr_raw_validation_sets | ||
| from marin.evaluation.perplexity_gap import GapFinderModelConfig, default_model_perplexity_gap | ||
| from marin.execution.executor import executor_main | ||
|
|
||
| RESOURCE_CONFIG = ResourceConfig.with_tpu("v5p-8", regions=["us-central1"]) | ||
| MAX_DOCS_PER_DATASET = 256 | ||
| MAX_DOC_BYTES = 32_768 | ||
|
|
||
| DATASETS = { | ||
| **default_raw_validation_sets(), | ||
| **noisy_asr_ocr_raw_validation_sets(), | ||
| } | ||
|
|
||
| MARIN_MODEL = GapFinderModelConfig( | ||
| checkpoint_path="marin-community/marin-8b-base", | ||
| checkpoint_is_hf=True, | ||
| tokenizer="meta-llama/Llama-3.1-8B", | ||
| ) | ||
|
|
||
| MARIN_VS_LLAMA = default_model_perplexity_gap( | ||
| name="asr-ocr-noisy-marin-8b-base-vs-llama-3.1-8b-base-doccap256", | ||
| model_a=MARIN_MODEL, | ||
| model_b=GapFinderModelConfig( | ||
| checkpoint_path="meta-llama/Llama-3.1-8B", | ||
| checkpoint_is_hf=True, | ||
| tokenizer="meta-llama/Llama-3.1-8B", | ||
| ), | ||
| datasets=DATASETS, | ||
| resource_config=RESOURCE_CONFIG, | ||
| per_device_batch_size=4, | ||
| max_eval_length=4096, | ||
| max_docs_per_dataset=MAX_DOCS_PER_DATASET, | ||
| max_doc_bytes=MAX_DOC_BYTES, | ||
| wandb_tags=[ | ||
| "eval=perplexity-gap", | ||
| "bundle=asr_ocr_noisy_ppl", | ||
| "model_a=marin-community/marin-8b-base", | ||
| "model_b=meta-llama/Llama-3.1-8B", | ||
| f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}", | ||
| ], | ||
| ) | ||
|
|
||
| MARIN_VS_QWEN3 = default_model_perplexity_gap( | ||
| name="asr-ocr-noisy-marin-8b-base-vs-qwen3-8b-base-doccap256", | ||
| model_a=MARIN_MODEL, | ||
| model_b=GapFinderModelConfig( | ||
| checkpoint_path="Qwen/Qwen3-8B-Base", | ||
| checkpoint_is_hf=True, | ||
| tokenizer="Qwen/Qwen3-8B", | ||
| ), | ||
| datasets=DATASETS, | ||
| resource_config=RESOURCE_CONFIG, | ||
| per_device_batch_size=4, | ||
| max_eval_length=4096, | ||
| max_docs_per_dataset=MAX_DOCS_PER_DATASET, | ||
| max_doc_bytes=MAX_DOC_BYTES, | ||
| wandb_tags=[ | ||
| "eval=perplexity-gap", | ||
| "bundle=asr_ocr_noisy_ppl", | ||
| "model_a=marin-community/marin-8b-base", | ||
| "model_b=Qwen/Qwen3-8B-Base", | ||
| f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}", | ||
| ], | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| executor_main( | ||
| [MARIN_VS_LLAMA, MARIN_VS_QWEN3], | ||
| description="Run Marin perplexity-gap reports on opt-in ASR/OCR noisy-text slices.", | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| # Copyright The Marin Authors | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import gzip | ||
| import json | ||
|
|
||
| from experiments.evals.asr_ocr_noisy_ppl import ( | ||
| ASR_OCR_NOISY_DATASET_ROOT, | ||
| ASR_OCR_NOISY_SLICES, | ||
| NoisyAsrOcrRawConfig, | ||
| NoisyTextFamily, | ||
| NoisyTextSlice, | ||
| linearize_noisy_clean_row, | ||
| materialize_noisy_asr_ocr_raw, | ||
| noisy_asr_ocr_raw_validation_sets, | ||
| ) | ||
| from marin.processing.tokenize import HfDatasetSpec | ||
|
|
||
|
|
||
| def test_linearize_noisy_clean_row_uses_first_hypothesis_and_preserves_reference(): | ||
| row = {"hyps": ["THE CAT SAT", "THE CATS AT"], "ref": "the cat sat"} | ||
|
|
||
| linearized = linearize_noisy_clean_row(row, noisy_key="hyps", clean_key="ref") | ||
|
|
||
| assert linearized == {"noisy_text": "THE CAT SAT", "clean_text": "the cat sat"} | ||
|
|
||
|
|
||
| def test_noisy_asr_ocr_raw_validation_sets_registers_clean_and_noisy_slices(): | ||
| class _SyntheticRawStep: | ||
| def cd(self, path: str) -> str: | ||
| return f"gs://synthetic/{path}" | ||
|
|
||
| datasets = noisy_asr_ocr_raw_validation_sets(noisy_asr_ocr_raw=_SyntheticRawStep()) | ||
|
|
||
| first_slice = ASR_OCR_NOISY_SLICES[0] | ||
| noisy_key = f"{ASR_OCR_NOISY_DATASET_ROOT}/{first_slice.registry_name}/noisy" | ||
| clean_key = f"{ASR_OCR_NOISY_DATASET_ROOT}/{first_slice.registry_name}/clean" | ||
|
|
||
| assert datasets[noisy_key].text_key == "noisy_text" | ||
| assert datasets[clean_key].text_key == "clean_text" | ||
| assert isinstance(datasets[noisy_key].input_path, str) | ||
| assert f"/{first_slice.registry_name}/data-*.jsonl.gz" in datasets[noisy_key].input_path | ||
| assert datasets[noisy_key].tags[-1] == "variant:noisy" | ||
| assert datasets[clean_key].tags[-1] == "variant:clean" | ||
|
|
||
|
|
||
| def test_materialize_noisy_asr_ocr_raw_respects_per_slice_cap(tmp_path, monkeypatch): | ||
| from experiments.evals import asr_ocr_noisy_ppl | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This local import violates the AGENTS.md rule: "All imports at the top of the file. No local imports except to break circular dependencies or guard optional deps." The module isn't imported at the top (the top-level import at L7-L16 pulls specific names, not the module object) and this usage — Move the import to the top alongside the existing imports: from experiments.evals import asr_ocr_noisy_pplThen drop the local import here. |
||
|
|
||
| rows = [ | ||
| {"hyps": ["NOISY ONE"], "ref": "clean one"}, | ||
| {"hyps": ["NOISY TWO"], "ref": "clean two"}, | ||
| {"hyps": ["NOISY THREE"], "ref": "clean three"}, | ||
| ] | ||
|
|
||
| def _fake_load_dataset(*args, **kwargs): | ||
| del args, kwargs | ||
| return rows | ||
|
|
||
| monkeypatch.setattr(asr_ocr_noisy_ppl, "load_dataset", _fake_load_dataset) | ||
| slice_ = NoisyTextSlice( | ||
| registry_name="synthetic", | ||
| family=NoisyTextFamily.ASR, | ||
| source_url="https://example.com", | ||
| hf_dataset=HfDatasetSpec(id="synthetic/dataset"), | ||
| split="test", | ||
| noisy_key="hyps", | ||
| clean_key="ref", | ||
| max_rows=2, | ||
| ) | ||
|
|
||
| materialize_noisy_asr_ocr_raw(NoisyAsrOcrRawConfig(output_path=str(tmp_path), slices=(slice_,))) | ||
|
|
||
| with gzip.open(tmp_path / "synthetic" / "data-00000-of-00001.jsonl.gz", "rt") as handle: | ||
| materialized = [json.loads(line) for line in handle] | ||
|
|
||
| assert materialized == [ | ||
| {"noisy_text": "NOISY ONE", "clean_text": "clean one"}, | ||
| {"noisy_text": "NOISY TWO", "clean_text": "clean two"}, | ||
| ] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The local import
from experiments.evals import asr_ocr_noisy_pplviolates the repository rule in/workspace/marin/AGENTS.mdthat imports must be at the top of the file except for circular-dependency or optional-dependency cases, and neither exception applies here because the same module is already imported at module load time. Keeping it local introduces unnecessary inconsistency and can delay import-time failures to test execution instead of collection.Useful? React with 👍 / 👎.