diff --git a/experiments/midtraining_datasets.py b/experiments/midtraining_datasets.py index 5f01705b8c..b50ea5957a 100644 --- a/experiments/midtraining_datasets.py +++ b/experiments/midtraining_datasets.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from marin.datakit.download.huggingface import DownloadConfig, download_hf +from marin.datakit.download.openmathinstruct2 import openmathinstruct2_normalize_steps from marin.execution import versioned from marin.execution.types import ExecutorStep, this_output_path from marin.processing.tokenize import lm_mixture_data_config @@ -54,6 +55,13 @@ tokenizer=llama3_tokenizer, ) +openmathinstruct2_full = openmathinstruct2_normalize_steps()[-1].as_executor_step() +openmathinstruct2_full_tokenized = default_tokenize( + name="openmathinstruct2_full", + dataset=openmathinstruct2_full, + tokenizer=llama3_tokenizer, +) + # Define MegaMath dataset source megamath_source = default_download( name="raw/llm360/megamath", diff --git a/lib/marin/src/marin/datakit/download/openmathinstruct2.py b/lib/marin/src/marin/datakit/download/openmathinstruct2.py new file mode 100644 index 0000000000..cab21728da --- /dev/null +++ b/lib/marin/src/marin/datakit/download/openmathinstruct2.py @@ -0,0 +1,132 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""nvidia/OpenMathInstruct-2 dataset download and transform. + +OpenMathInstruct-2 is a synthetic math reasoning corpus derived from GSM8K, +MATH, and augmented variants. This transform materializes the full train split +as tagged transcript documents and preserves source metadata for downstream +contamination and mixture analysis. +""" + +import hashlib + +from fray import ResourceConfig +from zephyr import Dataset, ZephyrContext, counters, load_parquet + +from marin.datakit.download.huggingface import download_hf_step +from marin.datakit.normalize import normalize_step +from marin.execution.step_spec import StepSpec + +HF_DATASET_ID = "nvidia/OpenMathInstruct-2" +HF_REVISION = "469216e" +TRAIN_PARQUET_GLOB = "data/train-*.parquet" +OPENMATHINSTRUCT2_ROUGH_TOKENS_B = 4.0 +EXPECTED_PROBLEM_SOURCES = frozenset({"math", "gsm8k", "augmented_math", "augmented_gsm8k"}) +LONG_PROBLEM_CHARS = 1_376 +LONG_SOLUTION_CHARS = 5_237 + + +def _clean_text(row: dict, key: str) -> str | None: + value = row.get(key) + if not isinstance(value, str): + return None + + text = value.strip() + if not text: + return None + + return text + + +def _optional_text(row: dict, key: str) -> str: + value = row.get(key) + if not isinstance(value, str): + return "" + + return value.strip() + + +def row_to_doc(row: dict) -> list[dict]: + problem = _clean_text(row, "problem") + if problem is None: + counters.increment("openmathinstruct2/dropped_empty_problem") + return [] + + solution = _clean_text(row, "generated_solution") + if solution is None: + counters.increment("openmathinstruct2/dropped_empty_solution") + return [] + + problem_source = _optional_text(row, "problem_source") + if problem_source not in EXPECTED_PROBLEM_SOURCES: + counters.increment("openmathinstruct2/dropped_unknown_problem_source") + return [] + + expected_answer = _optional_text(row, "expected_answer") + if not expected_answer: + counters.increment("openmathinstruct2/empty_expected_answer") + + if len(problem) > LONG_PROBLEM_CHARS: + counters.increment("openmathinstruct2/long_problem") + if len(solution) > LONG_SOLUTION_CHARS: + counters.increment("openmathinstruct2/long_solution") + + text = f"\n{problem}\n\n\n\n{solution}\n" + + counters.increment("openmathinstruct2/kept") + counters.increment(f"openmathinstruct2/source/{problem_source}") + return [ + { + "id": hashlib.sha256(text.encode("utf-8")).hexdigest(), + "problem_hash": hashlib.sha256(problem.encode("utf-8")).hexdigest(), + "text": text, + "source": HF_DATASET_ID, + "problem_source": problem_source, + "expected_answer": expected_answer, + "synthetic": True, + "benchmark_adjacent": True, + "hf_revision": HF_REVISION, + "split": "train", + } + ] + + +def transform(input_path: str, output_path: str) -> None: + pipeline = ( + Dataset.from_files(f"{input_path}/**/*.parquet") + .flat_map(load_parquet) + .flat_map(row_to_doc) + .write_parquet(f"{output_path}/data-{{shard:05d}}-of-{{total:05d}}.parquet", skip_existing=True) + ) + ctx = ZephyrContext(name="openmathinstruct2-transform", resources=ResourceConfig(cpu=1, ram="8g")) + ctx.execute(pipeline) + + +def download_openmathinstruct2_step() -> StepSpec: + """Download and transform the full OpenMathInstruct-2 train split.""" + dl = download_hf_step( + "raw/openmathinstruct2", + hf_dataset_id=HF_DATASET_ID, + revision=HF_REVISION, + hf_urls_glob=[TRAIN_PARQUET_GLOB], + ) + + return StepSpec( + name="processed/openmathinstruct2", + deps=[dl], + fn=lambda output_path: transform( + input_path=dl.output_path, + output_path=output_path, + ), + hash_attrs={"version": "v1", "split": "train"}, + ) + + +def openmathinstruct2_normalize_steps() -> tuple[StepSpec, ...]: + """Return the full ``(download+transform, normalize)`` chain for OpenMathInstruct-2.""" + processed = download_openmathinstruct2_step() + return ( + processed, + normalize_step(name="normalized/openmathinstruct2", download=processed), + ) diff --git a/lib/marin/src/marin/datakit/sources.py b/lib/marin/src/marin/datakit/sources.py index 629f3871f9..454fc60f2f 100644 --- a/lib/marin/src/marin/datakit/sources.py +++ b/lib/marin/src/marin/datakit/sources.py @@ -38,6 +38,10 @@ from marin.datakit.download.nsf_awards import nsf_awards_normalize_steps from marin.datakit.download.numinamath_tir import numinamath_tir_normalize_steps from marin.datakit.download.numinamath_v1_5 import numinamath_v1_5_normalize_steps +from marin.datakit.download.openmathinstruct2 import ( + OPENMATHINSTRUCT2_ROUGH_TOKENS_B, + openmathinstruct2_normalize_steps, +) from marin.datakit.download.starcoder2_extras import starcoder2_extras_normalize_steps from marin.datakit.download.superior_reasoning import superior_reasoning_normalize_steps from marin.datakit.download.svgfind import svgfind_creativecommons_normalize_steps @@ -158,6 +162,7 @@ def all_sources() -> dict[str, DatakitSource]: ("nsf_awards", nsf_awards_normalize_steps, 0.17), ("numinamath-1.5", numinamath_v1_5_normalize_steps, 0.40), ("numinamath-tir", numinamath_tir_normalize_steps, 0.08), + ("openmathinstruct2", openmathinstruct2_normalize_steps, OPENMATHINSTRUCT2_ROUGH_TOKENS_B), ("superior-reasoning", superior_reasoning_normalize_steps, 7.08), ("svg", svgfind_creativecommons_normalize_steps, 8.95), ("swe-rebench-openhands", swe_rebench_openhands_normalize_steps, 2.47), diff --git a/tests/datakit/download/test_openmathinstruct2.py b/tests/datakit/download/test_openmathinstruct2.py new file mode 100644 index 0000000000..e281ff52ff --- /dev/null +++ b/tests/datakit/download/test_openmathinstruct2.py @@ -0,0 +1,123 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from marin.datakit.download.openmathinstruct2 import ( + HF_DATASET_ID, + HF_REVISION, + download_openmathinstruct2_step, + row_to_doc, + transform, +) + + +def _valid_row(**overrides) -> dict: + row = { + "problem": "Solve for $x$: $x + 2 = 5$.", + "generated_solution": "Subtracting 2 from both sides gives $x = 3$.", + "expected_answer": "3", + "problem_source": "augmented_math", + } + row.update(overrides) + return row + + +def test_row_to_doc_renders_problem_solution_pair(): + expected_text = ( + "\n" + "Solve for $x$: $x + 2 = 5$.\n" + "\n\n" + "\n" + "Subtracting 2 from both sides gives $x = 3$.\n" + "" + ) + + [doc] = row_to_doc(_valid_row()) + + assert doc == { + "id": hashlib.sha256(expected_text.encode("utf-8")).hexdigest(), + "problem_hash": hashlib.sha256(b"Solve for $x$: $x + 2 = 5$.").hexdigest(), + "text": expected_text, + "source": HF_DATASET_ID, + "problem_source": "augmented_math", + "expected_answer": "3", + "synthetic": True, + "benchmark_adjacent": True, + "hf_revision": HF_REVISION, + "split": "train", + } + + +def test_problem_hash_is_stable_across_solution_variants(): + first = row_to_doc(_valid_row(generated_solution="Solution A."))[0] + second = row_to_doc(_valid_row(generated_solution="Solution B."))[0] + + assert first["problem_hash"] == second["problem_hash"] + assert first["id"] != second["id"] + + +@pytest.mark.parametrize( + "problem_source", + ["augmented_gsm8k", "augmented_math", "gsm8k", "math"], +) +def test_row_to_doc_accepts_expected_problem_sources(problem_source): + [doc] = row_to_doc(_valid_row(problem_source=problem_source)) + + assert doc["problem_source"] == problem_source + + +@pytest.mark.parametrize( + "overrides", + [ + {"problem": ""}, + {"problem": " "}, + {"problem": None}, + {"generated_solution": ""}, + {"generated_solution": " "}, + {"generated_solution": None}, + {"problem_source": ""}, + {"problem_source": None}, + {"problem_source": "other"}, + ], +) +def test_row_to_doc_drops_invalid_or_empty_rows(overrides): + assert row_to_doc(_valid_row(**overrides)) == [] + + +def test_row_to_doc_preserves_empty_expected_answer(): + [doc] = row_to_doc(_valid_row(expected_answer=None)) + + assert doc["expected_answer"] == "" + + +def test_download_step_uses_full_train_split(): + processed = download_openmathinstruct2_step() + [download] = processed.deps + + assert download.hash_attrs["hf_dataset_id"] == HF_DATASET_ID + assert download.hash_attrs["revision"] == HF_REVISION + assert download.hash_attrs["hf_urls_glob"] == ["data/train-*.parquet"] + assert processed.hash_attrs["split"] == "train" + + +def test_transform_reads_parquet_and_writes_valid_docs(tmp_path: Path): + raw_dir = tmp_path / "raw" / "data" + raw_dir.mkdir(parents=True) + table = pa.Table.from_pylist( + [ + _valid_row(), + _valid_row(problem_source="other"), + ] + ) + pq.write_table(table, raw_dir / "train-00000-of-00001.parquet") + + output_dir = tmp_path / "processed" + transform(str(tmp_path / "raw"), str(output_dir)) + + rows = [row for path in output_dir.rglob("*.parquet") for row in pq.read_table(path).to_pylist()] + assert rows == row_to_doc(_valid_row())