From 48b3f0d7c994e28d66f2dbf2a16058bf567287a3 Mon Sep 17 00:00:00 2001
From: taivu1998 <46636857+taivu1998@users.noreply.github.com>
Date: Sun, 7 Jun 2026 06:52:31 -0700
Subject: [PATCH] [datakit] Add OpenMathInstruct-2 midtraining dataset
Register the full OpenMathInstruct-2 train split as a Datakit source and midtraining tokenization input. The transform preserves source metadata and renders problem-solution rows as tagged transcripts so the synthetic math corpus can be mixed intentionally.
---
experiments/midtraining_datasets.py | 8 ++
.../datakit/download/openmathinstruct2.py | 132 ++++++++++++++++++
lib/marin/src/marin/datakit/sources.py | 5 +
.../download/test_openmathinstruct2.py | 123 ++++++++++++++++
4 files changed, 268 insertions(+)
create mode 100644 lib/marin/src/marin/datakit/download/openmathinstruct2.py
create mode 100644 tests/datakit/download/test_openmathinstruct2.py
diff --git a/experiments/midtraining_datasets.py b/experiments/midtraining_datasets.py
index 5f01705b8c..b50ea5957a 100644
--- a/experiments/midtraining_datasets.py
+++ b/experiments/midtraining_datasets.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from marin.datakit.download.huggingface import DownloadConfig, download_hf
+from marin.datakit.download.openmathinstruct2 import openmathinstruct2_normalize_steps
from marin.execution import versioned
from marin.execution.types import ExecutorStep, this_output_path
from marin.processing.tokenize import lm_mixture_data_config
@@ -54,6 +55,13 @@
tokenizer=llama3_tokenizer,
)
+openmathinstruct2_full = openmathinstruct2_normalize_steps()[-1].as_executor_step()
+openmathinstruct2_full_tokenized = default_tokenize(
+ name="openmathinstruct2_full",
+ dataset=openmathinstruct2_full,
+ tokenizer=llama3_tokenizer,
+)
+
# Define MegaMath dataset source
megamath_source = default_download(
name="raw/llm360/megamath",
diff --git a/lib/marin/src/marin/datakit/download/openmathinstruct2.py b/lib/marin/src/marin/datakit/download/openmathinstruct2.py
new file mode 100644
index 0000000000..cab21728da
--- /dev/null
+++ b/lib/marin/src/marin/datakit/download/openmathinstruct2.py
@@ -0,0 +1,132 @@
+# Copyright The Marin Authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""nvidia/OpenMathInstruct-2 dataset download and transform.
+
+OpenMathInstruct-2 is a synthetic math reasoning corpus derived from GSM8K,
+MATH, and augmented variants. This transform materializes the full train split
+as tagged transcript documents and preserves source metadata for downstream
+contamination and mixture analysis.
+"""
+
+import hashlib
+
+from fray import ResourceConfig
+from zephyr import Dataset, ZephyrContext, counters, load_parquet
+
+from marin.datakit.download.huggingface import download_hf_step
+from marin.datakit.normalize import normalize_step
+from marin.execution.step_spec import StepSpec
+
+HF_DATASET_ID = "nvidia/OpenMathInstruct-2"
+HF_REVISION = "469216e"
+TRAIN_PARQUET_GLOB = "data/train-*.parquet"
+OPENMATHINSTRUCT2_ROUGH_TOKENS_B = 4.0
+EXPECTED_PROBLEM_SOURCES = frozenset({"math", "gsm8k", "augmented_math", "augmented_gsm8k"})
+LONG_PROBLEM_CHARS = 1_376
+LONG_SOLUTION_CHARS = 5_237
+
+
+def _clean_text(row: dict, key: str) -> str | None:
+ value = row.get(key)
+ if not isinstance(value, str):
+ return None
+
+ text = value.strip()
+ if not text:
+ return None
+
+ return text
+
+
+def _optional_text(row: dict, key: str) -> str:
+ value = row.get(key)
+ if not isinstance(value, str):
+ return ""
+
+ return value.strip()
+
+
+def row_to_doc(row: dict) -> list[dict]:
+ problem = _clean_text(row, "problem")
+ if problem is None:
+ counters.increment("openmathinstruct2/dropped_empty_problem")
+ return []
+
+ solution = _clean_text(row, "generated_solution")
+ if solution is None:
+ counters.increment("openmathinstruct2/dropped_empty_solution")
+ return []
+
+ problem_source = _optional_text(row, "problem_source")
+ if problem_source not in EXPECTED_PROBLEM_SOURCES:
+ counters.increment("openmathinstruct2/dropped_unknown_problem_source")
+ return []
+
+ expected_answer = _optional_text(row, "expected_answer")
+ if not expected_answer:
+ counters.increment("openmathinstruct2/empty_expected_answer")
+
+ if len(problem) > LONG_PROBLEM_CHARS:
+ counters.increment("openmathinstruct2/long_problem")
+ if len(solution) > LONG_SOLUTION_CHARS:
+ counters.increment("openmathinstruct2/long_solution")
+
+ text = f"\n{problem}\n\n\n\n{solution}\n"
+
+ counters.increment("openmathinstruct2/kept")
+ counters.increment(f"openmathinstruct2/source/{problem_source}")
+ return [
+ {
+ "id": hashlib.sha256(text.encode("utf-8")).hexdigest(),
+ "problem_hash": hashlib.sha256(problem.encode("utf-8")).hexdigest(),
+ "text": text,
+ "source": HF_DATASET_ID,
+ "problem_source": problem_source,
+ "expected_answer": expected_answer,
+ "synthetic": True,
+ "benchmark_adjacent": True,
+ "hf_revision": HF_REVISION,
+ "split": "train",
+ }
+ ]
+
+
+def transform(input_path: str, output_path: str) -> None:
+ pipeline = (
+ Dataset.from_files(f"{input_path}/**/*.parquet")
+ .flat_map(load_parquet)
+ .flat_map(row_to_doc)
+ .write_parquet(f"{output_path}/data-{{shard:05d}}-of-{{total:05d}}.parquet", skip_existing=True)
+ )
+ ctx = ZephyrContext(name="openmathinstruct2-transform", resources=ResourceConfig(cpu=1, ram="8g"))
+ ctx.execute(pipeline)
+
+
+def download_openmathinstruct2_step() -> StepSpec:
+ """Download and transform the full OpenMathInstruct-2 train split."""
+ dl = download_hf_step(
+ "raw/openmathinstruct2",
+ hf_dataset_id=HF_DATASET_ID,
+ revision=HF_REVISION,
+ hf_urls_glob=[TRAIN_PARQUET_GLOB],
+ )
+
+ return StepSpec(
+ name="processed/openmathinstruct2",
+ deps=[dl],
+ fn=lambda output_path: transform(
+ input_path=dl.output_path,
+ output_path=output_path,
+ ),
+ hash_attrs={"version": "v1", "split": "train"},
+ )
+
+
+def openmathinstruct2_normalize_steps() -> tuple[StepSpec, ...]:
+ """Return the full ``(download+transform, normalize)`` chain for OpenMathInstruct-2."""
+ processed = download_openmathinstruct2_step()
+ return (
+ processed,
+ normalize_step(name="normalized/openmathinstruct2", download=processed),
+ )
diff --git a/lib/marin/src/marin/datakit/sources.py b/lib/marin/src/marin/datakit/sources.py
index 629f3871f9..454fc60f2f 100644
--- a/lib/marin/src/marin/datakit/sources.py
+++ b/lib/marin/src/marin/datakit/sources.py
@@ -38,6 +38,10 @@
from marin.datakit.download.nsf_awards import nsf_awards_normalize_steps
from marin.datakit.download.numinamath_tir import numinamath_tir_normalize_steps
from marin.datakit.download.numinamath_v1_5 import numinamath_v1_5_normalize_steps
+from marin.datakit.download.openmathinstruct2 import (
+ OPENMATHINSTRUCT2_ROUGH_TOKENS_B,
+ openmathinstruct2_normalize_steps,
+)
from marin.datakit.download.starcoder2_extras import starcoder2_extras_normalize_steps
from marin.datakit.download.superior_reasoning import superior_reasoning_normalize_steps
from marin.datakit.download.svgfind import svgfind_creativecommons_normalize_steps
@@ -158,6 +162,7 @@ def all_sources() -> dict[str, DatakitSource]:
("nsf_awards", nsf_awards_normalize_steps, 0.17),
("numinamath-1.5", numinamath_v1_5_normalize_steps, 0.40),
("numinamath-tir", numinamath_tir_normalize_steps, 0.08),
+ ("openmathinstruct2", openmathinstruct2_normalize_steps, OPENMATHINSTRUCT2_ROUGH_TOKENS_B),
("superior-reasoning", superior_reasoning_normalize_steps, 7.08),
("svg", svgfind_creativecommons_normalize_steps, 8.95),
("swe-rebench-openhands", swe_rebench_openhands_normalize_steps, 2.47),
diff --git a/tests/datakit/download/test_openmathinstruct2.py b/tests/datakit/download/test_openmathinstruct2.py
new file mode 100644
index 0000000000..e281ff52ff
--- /dev/null
+++ b/tests/datakit/download/test_openmathinstruct2.py
@@ -0,0 +1,123 @@
+# Copyright The Marin Authors
+# SPDX-License-Identifier: Apache-2.0
+
+import hashlib
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.parquet as pq
+import pytest
+from marin.datakit.download.openmathinstruct2 import (
+ HF_DATASET_ID,
+ HF_REVISION,
+ download_openmathinstruct2_step,
+ row_to_doc,
+ transform,
+)
+
+
+def _valid_row(**overrides) -> dict:
+ row = {
+ "problem": "Solve for $x$: $x + 2 = 5$.",
+ "generated_solution": "Subtracting 2 from both sides gives $x = 3$.",
+ "expected_answer": "3",
+ "problem_source": "augmented_math",
+ }
+ row.update(overrides)
+ return row
+
+
+def test_row_to_doc_renders_problem_solution_pair():
+ expected_text = (
+ "\n"
+ "Solve for $x$: $x + 2 = 5$.\n"
+ "\n\n"
+ "\n"
+ "Subtracting 2 from both sides gives $x = 3$.\n"
+ ""
+ )
+
+ [doc] = row_to_doc(_valid_row())
+
+ assert doc == {
+ "id": hashlib.sha256(expected_text.encode("utf-8")).hexdigest(),
+ "problem_hash": hashlib.sha256(b"Solve for $x$: $x + 2 = 5$.").hexdigest(),
+ "text": expected_text,
+ "source": HF_DATASET_ID,
+ "problem_source": "augmented_math",
+ "expected_answer": "3",
+ "synthetic": True,
+ "benchmark_adjacent": True,
+ "hf_revision": HF_REVISION,
+ "split": "train",
+ }
+
+
+def test_problem_hash_is_stable_across_solution_variants():
+ first = row_to_doc(_valid_row(generated_solution="Solution A."))[0]
+ second = row_to_doc(_valid_row(generated_solution="Solution B."))[0]
+
+ assert first["problem_hash"] == second["problem_hash"]
+ assert first["id"] != second["id"]
+
+
+@pytest.mark.parametrize(
+ "problem_source",
+ ["augmented_gsm8k", "augmented_math", "gsm8k", "math"],
+)
+def test_row_to_doc_accepts_expected_problem_sources(problem_source):
+ [doc] = row_to_doc(_valid_row(problem_source=problem_source))
+
+ assert doc["problem_source"] == problem_source
+
+
+@pytest.mark.parametrize(
+ "overrides",
+ [
+ {"problem": ""},
+ {"problem": " "},
+ {"problem": None},
+ {"generated_solution": ""},
+ {"generated_solution": " "},
+ {"generated_solution": None},
+ {"problem_source": ""},
+ {"problem_source": None},
+ {"problem_source": "other"},
+ ],
+)
+def test_row_to_doc_drops_invalid_or_empty_rows(overrides):
+ assert row_to_doc(_valid_row(**overrides)) == []
+
+
+def test_row_to_doc_preserves_empty_expected_answer():
+ [doc] = row_to_doc(_valid_row(expected_answer=None))
+
+ assert doc["expected_answer"] == ""
+
+
+def test_download_step_uses_full_train_split():
+ processed = download_openmathinstruct2_step()
+ [download] = processed.deps
+
+ assert download.hash_attrs["hf_dataset_id"] == HF_DATASET_ID
+ assert download.hash_attrs["revision"] == HF_REVISION
+ assert download.hash_attrs["hf_urls_glob"] == ["data/train-*.parquet"]
+ assert processed.hash_attrs["split"] == "train"
+
+
+def test_transform_reads_parquet_and_writes_valid_docs(tmp_path: Path):
+ raw_dir = tmp_path / "raw" / "data"
+ raw_dir.mkdir(parents=True)
+ table = pa.Table.from_pylist(
+ [
+ _valid_row(),
+ _valid_row(problem_source="other"),
+ ]
+ )
+ pq.write_table(table, raw_dir / "train-00000-of-00001.parquet")
+
+ output_dir = tmp_path / "processed"
+ transform(str(tmp_path / "raw"), str(output_dir))
+
+ rows = [row for path in output_dir.rglob("*.parquet") for row in pq.read_table(path).to_pylist()]
+ assert rows == row_to_doc(_valid_row())