Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions experiments/midtraining_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from marin.datakit.download.huggingface import DownloadConfig, download_hf
from marin.datakit.download.openmathinstruct2 import openmathinstruct2_normalize_steps
from marin.execution import versioned
from marin.execution.types import ExecutorStep, this_output_path
from marin.processing.tokenize import lm_mixture_data_config
Expand Down Expand Up @@ -54,6 +55,13 @@
tokenizer=llama3_tokenizer,
)

openmathinstruct2_full = openmathinstruct2_normalize_steps()[-1].as_executor_step()
openmathinstruct2_full_tokenized = default_tokenize(
name="openmathinstruct2_full",
dataset=openmathinstruct2_full,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Tokenize only normalized main shards

When this normalized StepSpec is passed as the dataset, default_tokenize treats it as a directory and expand_tokenize_paths expands directories to recursive **/*.parquet globs; the normalize step writes both outputs/main and outputs/dups parquet shards by default. For any OpenMathInstruct-2 duplicate rows, this tokenization step will read the duplicate side-output too and put data that normalization intentionally removed back into the training cache; point the dataset at openmathinstruct2_full / "outputs/main/*.parquet" instead.

Useful? React with 👍 / 👎.

tokenizer=llama3_tokenizer,
)

# Define MegaMath dataset source
megamath_source = default_download(
name="raw/llm360/megamath",
Expand Down
132 changes: 132 additions & 0 deletions lib/marin/src/marin/datakit/download/openmathinstruct2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""nvidia/OpenMathInstruct-2 dataset download and transform.

OpenMathInstruct-2 is a synthetic math reasoning corpus derived from GSM8K,
MATH, and augmented variants. This transform materializes the full train split
as tagged transcript documents and preserves source metadata for downstream
contamination and mixture analysis.
"""

import hashlib

from fray import ResourceConfig
from zephyr import Dataset, ZephyrContext, counters, load_parquet

from marin.datakit.download.huggingface import download_hf_step
from marin.datakit.normalize import normalize_step
from marin.execution.step_spec import StepSpec

HF_DATASET_ID = "nvidia/OpenMathInstruct-2"
HF_REVISION = "469216e"
TRAIN_PARQUET_GLOB = "data/train-*.parquet"
OPENMATHINSTRUCT2_ROUGH_TOKENS_B = 4.0
EXPECTED_PROBLEM_SOURCES = frozenset({"math", "gsm8k", "augmented_math", "augmented_gsm8k"})
LONG_PROBLEM_CHARS = 1_376
LONG_SOLUTION_CHARS = 5_237


def _clean_text(row: dict, key: str) -> str | None:
value = row.get(key)
if not isinstance(value, str):
return None

text = value.strip()
if not text:
return None

return text


def _optional_text(row: dict, key: str) -> str:
value = row.get(key)
if not isinstance(value, str):
return ""

return value.strip()


def row_to_doc(row: dict) -> list[dict]:
problem = _clean_text(row, "problem")
if problem is None:
counters.increment("openmathinstruct2/dropped_empty_problem")
return []

solution = _clean_text(row, "generated_solution")
if solution is None:
counters.increment("openmathinstruct2/dropped_empty_solution")
return []

problem_source = _optional_text(row, "problem_source")
if problem_source not in EXPECTED_PROBLEM_SOURCES:
counters.increment("openmathinstruct2/dropped_unknown_problem_source")
return []

expected_answer = _optional_text(row, "expected_answer")
if not expected_answer:
counters.increment("openmathinstruct2/empty_expected_answer")

if len(problem) > LONG_PROBLEM_CHARS:
counters.increment("openmathinstruct2/long_problem")
if len(solution) > LONG_SOLUTION_CHARS:
counters.increment("openmathinstruct2/long_solution")

text = f"<user>\n{problem}\n</user>\n\n<assistant>\n{solution}\n</assistant>"

counters.increment("openmathinstruct2/kept")
counters.increment(f"openmathinstruct2/source/{problem_source}")
return [
{
"id": hashlib.sha256(text.encode("utf-8")).hexdigest(),
"problem_hash": hashlib.sha256(problem.encode("utf-8")).hexdigest(),
"text": text,
"source": HF_DATASET_ID,
"problem_source": problem_source,
"expected_answer": expected_answer,
"synthetic": True,
"benchmark_adjacent": True,
"hf_revision": HF_REVISION,
"split": "train",
}
]


def transform(input_path: str, output_path: str) -> None:
pipeline = (
Dataset.from_files(f"{input_path}/**/*.parquet")
.flat_map(load_parquet)
.flat_map(row_to_doc)
.write_parquet(f"{output_path}/data-{{shard:05d}}-of-{{total:05d}}.parquet", skip_existing=True)
)
ctx = ZephyrContext(name="openmathinstruct2-transform", resources=ResourceConfig(cpu=1, ram="8g"))
ctx.execute(pipeline)


def download_openmathinstruct2_step() -> StepSpec:
"""Download and transform the full OpenMathInstruct-2 train split."""
dl = download_hf_step(
"raw/openmathinstruct2",
hf_dataset_id=HF_DATASET_ID,
revision=HF_REVISION,
hf_urls_glob=[TRAIN_PARQUET_GLOB],
)

return StepSpec(
name="processed/openmathinstruct2",
deps=[dl],
fn=lambda output_path: transform(
input_path=dl.output_path,
output_path=output_path,
),
hash_attrs={"version": "v1", "split": "train"},
)


def openmathinstruct2_normalize_steps() -> tuple[StepSpec, ...]:
"""Return the full ``(download+transform, normalize)`` chain for OpenMathInstruct-2."""
processed = download_openmathinstruct2_step()
return (
processed,
normalize_step(name="normalized/openmathinstruct2", download=processed),
)
5 changes: 5 additions & 0 deletions lib/marin/src/marin/datakit/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
from marin.datakit.download.nsf_awards import nsf_awards_normalize_steps
from marin.datakit.download.numinamath_tir import numinamath_tir_normalize_steps
from marin.datakit.download.numinamath_v1_5 import numinamath_v1_5_normalize_steps
from marin.datakit.download.openmathinstruct2 import (
OPENMATHINSTRUCT2_ROUGH_TOKENS_B,
openmathinstruct2_normalize_steps,
)
from marin.datakit.download.starcoder2_extras import starcoder2_extras_normalize_steps
from marin.datakit.download.superior_reasoning import superior_reasoning_normalize_steps
from marin.datakit.download.svgfind import svgfind_creativecommons_normalize_steps
Expand Down Expand Up @@ -158,6 +162,7 @@ def all_sources() -> dict[str, DatakitSource]:
("nsf_awards", nsf_awards_normalize_steps, 0.17),
("numinamath-1.5", numinamath_v1_5_normalize_steps, 0.40),
("numinamath-tir", numinamath_tir_normalize_steps, 0.08),
("openmathinstruct2", openmathinstruct2_normalize_steps, OPENMATHINSTRUCT2_ROUGH_TOKENS_B),
("superior-reasoning", superior_reasoning_normalize_steps, 7.08),
("svg", svgfind_creativecommons_normalize_steps, 8.95),
("swe-rebench-openhands", swe_rebench_openhands_normalize_steps, 2.47),
Expand Down
123 changes: 123 additions & 0 deletions tests/datakit/download/test_openmathinstruct2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

import hashlib
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from marin.datakit.download.openmathinstruct2 import (
HF_DATASET_ID,
HF_REVISION,
download_openmathinstruct2_step,
row_to_doc,
transform,
)


def _valid_row(**overrides) -> dict:
row = {
"problem": "Solve for $x$: $x + 2 = 5$.",
"generated_solution": "Subtracting 2 from both sides gives $x = 3$.",
"expected_answer": "3",
"problem_source": "augmented_math",
}
row.update(overrides)
return row


def test_row_to_doc_renders_problem_solution_pair():
expected_text = (
"<user>\n"
"Solve for $x$: $x + 2 = 5$.\n"
"</user>\n\n"
"<assistant>\n"
"Subtracting 2 from both sides gives $x = 3$.\n"
"</assistant>"
)

[doc] = row_to_doc(_valid_row())

assert doc == {
"id": hashlib.sha256(expected_text.encode("utf-8")).hexdigest(),
"problem_hash": hashlib.sha256(b"Solve for $x$: $x + 2 = 5$.").hexdigest(),
"text": expected_text,
"source": HF_DATASET_ID,
"problem_source": "augmented_math",
"expected_answer": "3",
"synthetic": True,
"benchmark_adjacent": True,
"hf_revision": HF_REVISION,
"split": "train",
}


def test_problem_hash_is_stable_across_solution_variants():
first = row_to_doc(_valid_row(generated_solution="Solution A."))[0]
second = row_to_doc(_valid_row(generated_solution="Solution B."))[0]

assert first["problem_hash"] == second["problem_hash"]
assert first["id"] != second["id"]


@pytest.mark.parametrize(
"problem_source",
["augmented_gsm8k", "augmented_math", "gsm8k", "math"],
)
def test_row_to_doc_accepts_expected_problem_sources(problem_source):
[doc] = row_to_doc(_valid_row(problem_source=problem_source))

assert doc["problem_source"] == problem_source


@pytest.mark.parametrize(
"overrides",
[
{"problem": ""},
{"problem": " "},
{"problem": None},
{"generated_solution": ""},
{"generated_solution": " "},
{"generated_solution": None},
{"problem_source": ""},
{"problem_source": None},
{"problem_source": "other"},
],
)
def test_row_to_doc_drops_invalid_or_empty_rows(overrides):
assert row_to_doc(_valid_row(**overrides)) == []


def test_row_to_doc_preserves_empty_expected_answer():
[doc] = row_to_doc(_valid_row(expected_answer=None))

assert doc["expected_answer"] == ""


def test_download_step_uses_full_train_split():
processed = download_openmathinstruct2_step()
[download] = processed.deps

assert download.hash_attrs["hf_dataset_id"] == HF_DATASET_ID
assert download.hash_attrs["revision"] == HF_REVISION
assert download.hash_attrs["hf_urls_glob"] == ["data/train-*.parquet"]
assert processed.hash_attrs["split"] == "train"


def test_transform_reads_parquet_and_writes_valid_docs(tmp_path: Path):
raw_dir = tmp_path / "raw" / "data"
raw_dir.mkdir(parents=True)
table = pa.Table.from_pylist(
[
_valid_row(),
_valid_row(problem_source="other"),
]
)
pq.write_table(table, raw_dir / "train-00000-of-00001.parquet")

output_dir = tmp_path / "processed"
transform(str(tmp_path / "raw"), str(output_dir))

rows = [row for path in output_dir.rglob("*.parquet") for row in pq.read_table(path).to_pylist()]
assert rows == row_to_doc(_valid_row())
Loading