Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions experiments/evals/asr_ocr_noisy_ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Opt-in ASR/OCR noisy-text raw eval slices for perplexity-gap reports.

This module materializes paired noisy/clean text from ASR and OCR sources, then
registers both variants as raw-text datasets so gap reports can compute deltas.
"""

from __future__ import annotations

import json
import os
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from enum import StrEnum

import fsspec
from datasets import load_dataset
from fray.v2 import ResourceConfig
from levanter.utils import fsspec_utils

from marin.evaluation.perplexity_gap import RawTextEvaluationDataset, raw_text_dataset
from marin.execution.executor import ExecutorStep, this_output_path
from marin.execution.remote import remote
from marin.processing.tokenize import HfDatasetSpec

ASR_OCR_NOISY_DATASET_ROOT = "asr_ocr_noisy_ppl"
NOISY_TEXT_FIELD = "noisy_text"
CLEAN_TEXT_FIELD = "clean_text"
DEFAULT_RAW_SHARD_NAME = "data-00000-of-00001.jsonl.gz"


class NoisyTextFamily(StrEnum):
ASR = "asr"
OCR = "ocr"


@dataclass(frozen=True)
class NoisyTextSlice:
registry_name: str
family: NoisyTextFamily
source_url: str
hf_dataset: HfDatasetSpec
split: str
noisy_key: str
clean_key: str
max_rows: int
notes: str = ""

@property
def tags(self) -> tuple[str, ...]:
return (ASR_OCR_NOISY_DATASET_ROOT, f"family:{self.family.value}", f"source:{self.registry_name}")


ASR_OCR_NOISY_SLICES: tuple[NoisyTextSlice, ...] = (
NoisyTextSlice(
registry_name="hypr_librispeech_without_lm_test_clean",
family=NoisyTextFamily.ASR,
source_url="https://huggingface.co/datasets/ASR-HypR/LibriSpeech_withoutLM",
hf_dataset=HfDatasetSpec(id="ASR-HypR/LibriSpeech_withoutLM"),
split="test_clean",
noisy_key="hyps",
clean_key="ref",
max_rows=512,
notes=(
"HypR exposes n-best ASR hypotheses per utterance. We linearize top-1 for noisy_text and keep ref "
"as clean_text. Verify downstream use remains compatible with LibriSpeech-derived licensing terms."
),
),
NoisyTextSlice(
registry_name="rtm_sgt_ocr_v1_train",
family=NoisyTextFamily.OCR,
source_url="https://huggingface.co/datasets/ReadingTimeMachine/rtm-sgt-ocr-v1",
hf_dataset=HfDatasetSpec(id="ReadingTimeMachine/rtm-sgt-ocr-v1"),
split="train",
noisy_key="source",
clean_key="target",
max_rows=512,
notes=(
"ReadingTimeMachine OCR post-correction pairs may inherit source-specific archival rights. Treat as "
"eval-only until redistribution terms are reviewed per source collection."
),
),
)


@dataclass(frozen=True)
class NoisyAsrOcrRawConfig:
output_path: str = field(default_factory=this_output_path) # type: ignore[arg-type]
max_rows_per_slice_override: int | None = None
slices: tuple[NoisyTextSlice, ...] = ASR_OCR_NOISY_SLICES


def _coerce_text(value: object) -> str | None:
if isinstance(value, str):
return value if value.strip() else None
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
for item in value:
text = _coerce_text(item)
if text is not None:
return text
return None
return None


def linearize_noisy_clean_row(
row: Mapping[str, object],
*,
noisy_key: str,
clean_key: str,
) -> dict[str, str] | None:
"""Extract paired noisy/clean text fields from one source row."""
noisy_text = _coerce_text(row.get(noisy_key))
clean_text = _coerce_text(row.get(clean_key))
if noisy_text is None or clean_text is None:
return None
return {NOISY_TEXT_FIELD: noisy_text, CLEAN_TEXT_FIELD: clean_text}


def _iter_linearized_rows(slice_: NoisyTextSlice) -> Iterable[dict[str, str]]:
dataset = load_dataset(
slice_.hf_dataset.id,
name=slice_.hf_dataset.name,
split=slice_.split,
streaming=True,
)
for row in dataset:
linearized = linearize_noisy_clean_row(row, noisy_key=slice_.noisy_key, clean_key=slice_.clean_key)
if linearized is not None:
yield linearized


def _slice_output_path(output_path: str, registry_name: str) -> str:
return os.path.join(output_path, registry_name, DEFAULT_RAW_SHARD_NAME)


def materialize_noisy_asr_ocr_raw(config: NoisyAsrOcrRawConfig) -> None:
"""Materialize paired noisy/clean text rows into jsonl.gz shards."""
fsspec_utils.mkdirs(config.output_path)
for slice_ in config.slices:
output_file = _slice_output_path(config.output_path, slice_.registry_name)
fsspec_utils.mkdirs(os.path.dirname(output_file))
row_cap = slice_.max_rows if config.max_rows_per_slice_override is None else config.max_rows_per_slice_override
if row_cap <= 0:
raise ValueError(f"row cap must be positive, got {row_cap}.")
with fsspec.open(output_file, "wt", compression="gzip") as sink:
for index, record in enumerate(_iter_linearized_rows(slice_)):
if index >= row_cap:
break
sink.write(json.dumps(record, ensure_ascii=True))
sink.write("\n")


noisy_asr_ocr_raw = ExecutorStep(
name=os.path.join("raw", "evals", ASR_OCR_NOISY_DATASET_ROOT),
description="Materialize paired ASR/OCR noisy-clean raw eval slices from Hugging Face.",
fn=remote(
materialize_noisy_asr_ocr_raw,
resources=ResourceConfig.with_cpu(cpu=4, ram="32g", disk="40g"),
pip_dependency_groups=["cpu"],
),
config=NoisyAsrOcrRawConfig(),
)


def noisy_asr_ocr_raw_validation_sets(
*,
noisy_asr_ocr_raw: ExecutorStep = noisy_asr_ocr_raw,
) -> dict[str, RawTextEvaluationDataset]:
"""Register clean and noisy variants for each ASR/OCR raw slice."""
datasets: dict[str, RawTextEvaluationDataset] = {}
for slice_ in ASR_OCR_NOISY_SLICES:
raw_pattern = os.path.join(slice_.registry_name, "data-*.jsonl.gz")
key_root = os.path.join(ASR_OCR_NOISY_DATASET_ROOT, slice_.registry_name)

datasets[os.path.join(key_root, "noisy")] = raw_text_dataset(
noisy_asr_ocr_raw.cd(raw_pattern),
text_key=NOISY_TEXT_FIELD,
tags=(*slice_.tags, "variant:noisy"),
)
datasets[os.path.join(key_root, "clean")] = raw_text_dataset(
noisy_asr_ocr_raw.cd(raw_pattern),
text_key=CLEAN_TEXT_FIELD,
tags=(*slice_.tags, "variant:clean"),
)

return datasets
79 changes: 79 additions & 0 deletions experiments/exp_model_perplexity_gap_asr_ocr_noisy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Run opt-in ASR/OCR noisy-text perplexity-gap reports for issue #5097."""

from fray.v2.types import ResourceConfig

from experiments.defaults import default_raw_validation_sets
from experiments.evals.asr_ocr_noisy_ppl import noisy_asr_ocr_raw_validation_sets
from marin.evaluation.perplexity_gap import GapFinderModelConfig, default_model_perplexity_gap
from marin.execution.executor import executor_main

RESOURCE_CONFIG = ResourceConfig.with_tpu("v5p-8", regions=["us-central1"])
MAX_DOCS_PER_DATASET = 256
MAX_DOC_BYTES = 32_768

DATASETS = {
**default_raw_validation_sets(),
**noisy_asr_ocr_raw_validation_sets(),
}

MARIN_MODEL = GapFinderModelConfig(
checkpoint_path="marin-community/marin-8b-base",
checkpoint_is_hf=True,
tokenizer="meta-llama/Llama-3.1-8B",
)

MARIN_VS_LLAMA = default_model_perplexity_gap(
name="asr-ocr-noisy-marin-8b-base-vs-llama-3.1-8b-base-doccap256",
model_a=MARIN_MODEL,
model_b=GapFinderModelConfig(
checkpoint_path="meta-llama/Llama-3.1-8B",
checkpoint_is_hf=True,
tokenizer="meta-llama/Llama-3.1-8B",
),
datasets=DATASETS,
resource_config=RESOURCE_CONFIG,
per_device_batch_size=4,
max_eval_length=4096,
max_docs_per_dataset=MAX_DOCS_PER_DATASET,
max_doc_bytes=MAX_DOC_BYTES,
wandb_tags=[
"eval=perplexity-gap",
"bundle=asr_ocr_noisy_ppl",
"model_a=marin-community/marin-8b-base",
"model_b=meta-llama/Llama-3.1-8B",
f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}",
],
)

MARIN_VS_QWEN3 = default_model_perplexity_gap(
name="asr-ocr-noisy-marin-8b-base-vs-qwen3-8b-base-doccap256",
model_a=MARIN_MODEL,
model_b=GapFinderModelConfig(
checkpoint_path="Qwen/Qwen3-8B-Base",
checkpoint_is_hf=True,
tokenizer="Qwen/Qwen3-8B",
),
datasets=DATASETS,
resource_config=RESOURCE_CONFIG,
per_device_batch_size=4,
max_eval_length=4096,
max_docs_per_dataset=MAX_DOCS_PER_DATASET,
max_doc_bytes=MAX_DOC_BYTES,
wandb_tags=[
"eval=perplexity-gap",
"bundle=asr_ocr_noisy_ppl",
"model_a=marin-community/marin-8b-base",
"model_b=Qwen/Qwen3-8B-Base",
f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}",
],
)


if __name__ == "__main__":
executor_main(
[MARIN_VS_LLAMA, MARIN_VS_QWEN3],
description="Run Marin perplexity-gap reports on opt-in ASR/OCR noisy-text slices.",
)
80 changes: 80 additions & 0 deletions tests/evals/test_asr_ocr_noisy_ppl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

import gzip
import json

from experiments.evals.asr_ocr_noisy_ppl import (
ASR_OCR_NOISY_DATASET_ROOT,
ASR_OCR_NOISY_SLICES,
NoisyAsrOcrRawConfig,
NoisyTextFamily,
NoisyTextSlice,
linearize_noisy_clean_row,
materialize_noisy_asr_ocr_raw,
noisy_asr_ocr_raw_validation_sets,
)
from marin.processing.tokenize import HfDatasetSpec


def test_linearize_noisy_clean_row_uses_first_hypothesis_and_preserves_reference():
row = {"hyps": ["THE CAT SAT", "THE CATS AT"], "ref": "the cat sat"}

linearized = linearize_noisy_clean_row(row, noisy_key="hyps", clean_key="ref")

assert linearized == {"noisy_text": "THE CAT SAT", "clean_text": "the cat sat"}


def test_noisy_asr_ocr_raw_validation_sets_registers_clean_and_noisy_slices():
class _SyntheticRawStep:
def cd(self, path: str) -> str:
return f"gs://synthetic/{path}"

datasets = noisy_asr_ocr_raw_validation_sets(noisy_asr_ocr_raw=_SyntheticRawStep())

first_slice = ASR_OCR_NOISY_SLICES[0]
noisy_key = f"{ASR_OCR_NOISY_DATASET_ROOT}/{first_slice.registry_name}/noisy"
clean_key = f"{ASR_OCR_NOISY_DATASET_ROOT}/{first_slice.registry_name}/clean"

assert datasets[noisy_key].text_key == "noisy_text"
assert datasets[clean_key].text_key == "clean_text"
assert isinstance(datasets[noisy_key].input_path, str)
assert f"/{first_slice.registry_name}/data-*.jsonl.gz" in datasets[noisy_key].input_path
assert datasets[noisy_key].tags[-1] == "variant:noisy"
assert datasets[clean_key].tags[-1] == "variant:clean"


def test_materialize_noisy_asr_ocr_raw_respects_per_slice_cap(tmp_path, monkeypatch):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Move test import to module top level

The local import from experiments.evals import asr_ocr_noisy_ppl violates the repository rule in /workspace/marin/AGENTS.md that imports must be at the top of the file except for circular-dependency or optional-dependency cases, and neither exception applies here because the same module is already imported at module load time. Keeping it local introduces unnecessary inconsistency and can delay import-time failures to test execution instead of collection.

Useful? React with 👍 / 👎.

from experiments.evals import asr_ocr_noisy_ppl
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This local import violates the AGENTS.md rule: "All imports at the top of the file. No local imports except to break circular dependencies or guard optional deps."

The module isn't imported at the top (the top-level import at L7-L16 pulls specific names, not the module object) and this usage — monkeypatch.setattr(asr_ocr_noisy_ppl, "load_dataset", ...) — is neither a circular-dep workaround nor an optional-dep guard.

Move the import to the top alongside the existing imports:

from experiments.evals import asr_ocr_noisy_ppl

Then drop the local import here.


rows = [
{"hyps": ["NOISY ONE"], "ref": "clean one"},
{"hyps": ["NOISY TWO"], "ref": "clean two"},
{"hyps": ["NOISY THREE"], "ref": "clean three"},
]

def _fake_load_dataset(*args, **kwargs):
del args, kwargs
return rows

monkeypatch.setattr(asr_ocr_noisy_ppl, "load_dataset", _fake_load_dataset)
slice_ = NoisyTextSlice(
registry_name="synthetic",
family=NoisyTextFamily.ASR,
source_url="https://example.com",
hf_dataset=HfDatasetSpec(id="synthetic/dataset"),
split="test",
noisy_key="hyps",
clean_key="ref",
max_rows=2,
)

materialize_noisy_asr_ocr_raw(NoisyAsrOcrRawConfig(output_path=str(tmp_path), slices=(slice_,)))

with gzip.open(tmp_path / "synthetic" / "data-00000-of-00001.jsonl.gz", "rt") as handle:
materialized = [json.loads(line) for line in handle]

assert materialized == [
{"noisy_text": "NOISY ONE", "clean_text": "clean one"},
{"noisy_text": "NOISY TWO", "clean_text": "clean two"},
]
Loading