Skip to content

Commit 2cb7f0c

Browse files
ravwojdylaclaude
andcommitted
Add datakit testbed ferry + training harness
Builds the per-source ferry DAG and training wrapper on top of the merged marin.datakit.sources registry (#5105): * experiments/datakit_testbed/settings.py — testbed-wide constants (TESTBED_TOKENIZER, TESTBED_SEQ_LEN, TESTBED_STAGING_REGION, RAW_TARGET_TOTAL_TOKENS_B) * experiments/datakit_testbed/noop_dedup.py — metadata-only stand-in for fuzzy-dup marking; emits empty attr parquets so consolidate's 1:1 attr-file invariant holds without reading data * experiments/datakit_testbed/sampler.py — post-normalize by-provenance sampler (first K shards by filename; normalize's uniform partitioning makes this byte-fair and content-fair) * experiments/datakit_testbed/dag.py — wires download -> normalize -> sample -> noop_dedup -> consolidate with downloads grouped by (hf_id, revision, staged_path, urls_glob) * experiments/datakit_testbed/mixture.py — proportional mixture builder over tokenized caches, weighting by rough_token_count_b * experiments/datakit_testbed/train.py — Grug-MoE harness with simulated epoching (target_budget / experiment_budget on LmDataConfig) * experiments/ferries/datakit_testbed_ferry.py — entry point with us-central1 region guard 42 offline tests across DAG shape, sampler behavior, noop dedup end-to-end with consolidate, mixture arithmetic, and simulated-epoching budget math. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent fdc9349 commit 2cb7f0c

15 files changed

Lines changed: 1651 additions & 0 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0

experiments/datakit_testbed/dag.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Datakit Testbed ferry DAG builder.
5+
6+
Composes the canonical Datakit stages into one multi-source pipeline:
7+
8+
download[hf] ─┐
9+
├─► normalize[source] ─► sample[source] ─┐
10+
download[hf] ─┘ ├─► noop_dedup (all sampled)
11+
12+
sample[source] ────────────────── consolidate[source]
13+
14+
Key shape choices:
15+
* Downloads are grouped by ``(hf_dataset_id, revision)`` so Nemotron v2.1
16+
subsets that share a family dump do not re-download.
17+
* **Sampling happens post-normalize.** Normalize produces uniform-size shards
18+
(``target_partition_bytes``), making "first K by filename" both byte-fair
19+
and content-fair. Downstream stages (dedup, consolidate, tokenize) pay
20+
O(sampled) per experiment; normalize is a one-time cost cached by
21+
``override_output_path``.
22+
* Dedup is a single step with one dep per sampled source, emitting
23+
per-source attr directories.
24+
* Consolidate fans back out to one step per ``DatakitSource`` so each
25+
mixture component has its own filtered parquet output.
26+
27+
The ferry stops at ``consolidate`` on purpose. Tokenize runs in the training
28+
executor graph (see ``experiments/datakit_testbed/train.py``), not the ferry,
29+
because ``lm_mixture_data_config`` needs real ``ExecutorStep[TokenizeConfig]``
30+
components — not the ``StepSpec`` instances this ferry produces. The training
31+
harness converts each consolidate ``StepSpec`` to an ``ExecutorStep`` via
32+
``.as_executor_step()`` and builds a proper ``TokenizerStep`` on top, which
33+
preserves cross-layer dep tracking for the executor.
34+
"""
35+
36+
from __future__ import annotations
37+
38+
import logging
39+
from collections.abc import Sequence
40+
from dataclasses import dataclass
41+
42+
from fray import ResourceConfig
43+
from marin.datakit.download.huggingface import download_hf_step
44+
from marin.datakit.normalize import NormalizedData, normalize_step
45+
from marin.execution.artifact import Artifact
46+
from marin.execution.step_spec import StepSpec
47+
from marin.processing.classification.consolidate import (
48+
FilterConfig,
49+
FilterType,
50+
consolidate,
51+
)
52+
from marin.processing.classification.deduplication.fuzzy_dups import FuzzyDupsAttrData
53+
54+
from experiments.datakit_testbed.noop_dedup import compute_noop_dedup_attrs_step
55+
from experiments.datakit_testbed.sampler import (
56+
proportional_sample_fractions,
57+
sample_normalized_shards_step,
58+
)
59+
from marin.datakit.sources import DatakitSource, pinned_sources
60+
61+
from experiments.datakit_testbed.settings import RAW_TARGET_TOTAL_TOKENS_B
62+
63+
logger = logging.getLogger(__name__)
64+
65+
66+
@dataclass(frozen=True)
67+
class TestbedDAG:
68+
"""Handle to the built ferry DAG.
69+
70+
All steps must be passed to ``StepRunner().run(...)`` for execution.
71+
``consolidated_by_source`` is exposed separately so the training harness
72+
can reference each source's consolidate step when building the tokenize
73+
ExecutorSteps that feed the mixture.
74+
"""
75+
76+
all_steps: list[StepSpec]
77+
consolidated_by_source: dict[str, StepSpec]
78+
79+
80+
DownloadKey = tuple[str, str, str | None, tuple[str, ...] | None]
81+
82+
83+
def _download_key(src: DatakitSource) -> DownloadKey:
84+
"""Key that uniquely identifies a download step.
85+
86+
Includes ``staged_path`` and ``hf_urls_glob`` because some HF repos (e.g.
87+
``bigcode/StarCoder2-Extras``) are downloaded per-subset with distinct
88+
``override_output_path`` values — so the same ``(hf_repo, revision)`` can
89+
correspond to multiple physical download steps.
90+
"""
91+
return (src.hf_dataset_id, src.revision or "", src.staged_path, src.hf_urls_glob)
92+
93+
94+
def _download_step_name(src: DatakitSource) -> str:
95+
"""Stable download step name. Includes the last staged_path segment when
96+
multiple sources share a repo but stage separately (StarCoder2-Extras)."""
97+
base = src.hf_dataset_id.replace("/", "__")
98+
if src.staged_path:
99+
tail = src.staged_path.rstrip("/").rsplit("/", 1)[-1]
100+
if tail and tail != base:
101+
return f"datakit-testbed/download/{base}__{tail}"
102+
return f"datakit-testbed/download/{base}"
103+
104+
105+
def _build_downloads(sources: Sequence[DatakitSource]) -> dict[DownloadKey, StepSpec]:
106+
"""One download step per unique ``(hf_repo, revision, staged_path, urls_glob)``."""
107+
by_key: dict[DownloadKey, DatakitSource] = {}
108+
for src in sources:
109+
key = _download_key(src)
110+
if key in by_key:
111+
continue
112+
by_key[key] = src
113+
114+
downloads: dict[DownloadKey, StepSpec] = {}
115+
seen_names: set[str] = set()
116+
for key, src in by_key.items():
117+
step_name = _download_step_name(src)
118+
if step_name in seen_names:
119+
raise ValueError(
120+
f"Duplicate download step name {step_name!r} — extend " "_download_step_name to disambiguate"
121+
)
122+
seen_names.add(step_name)
123+
assert src.revision is not None, f"{src.name}: cannot build download for unpinned revision"
124+
downloads[key] = download_hf_step(
125+
step_name,
126+
hf_dataset_id=src.hf_dataset_id,
127+
revision=src.revision,
128+
hf_urls_glob=list(src.hf_urls_glob) if src.hf_urls_glob else None,
129+
override_output_path=src.staged_path,
130+
)
131+
return downloads
132+
133+
134+
def _normalize_step_for(
135+
src: DatakitSource,
136+
download: StepSpec,
137+
) -> StepSpec:
138+
"""Per-source normalize. ``input_path`` points at ``data_subdir`` inside the download.
139+
140+
Output lands at ``$MARIN_PREFIX/normalized/<src.name>-<hash>/`` — a
141+
canonical, run-independent artifact that any downstream consumer (testbed
142+
or otherwise) can point at. Matches the convention used by
143+
``marin.datakit.download.nemotron_v2.normalize_nemotron_v2_step``.
144+
"""
145+
input_path = f"{download.output_path}/{src.data_subdir}" if src.data_subdir else download.output_path
146+
return normalize_step(
147+
name=f"normalized/{src.name}",
148+
download=download,
149+
text_field=src.text_field,
150+
id_field=src.id_field,
151+
input_path=input_path,
152+
file_extensions=src.file_extensions,
153+
worker_resources=ResourceConfig(cpu=2, ram="16g", disk="20g"),
154+
)
155+
156+
157+
def _sample_step_for(
158+
src: DatakitSource,
159+
normalized: StepSpec,
160+
sample_fraction: float,
161+
base: str,
162+
) -> StepSpec:
163+
"""Per-source post-normalize sampler. Copies first ceil(N * fraction) shards."""
164+
return sample_normalized_shards_step(
165+
name=f"datakit-testbed/sample/{src.name}",
166+
normalized=normalized,
167+
sample_fraction=sample_fraction,
168+
override_output_path=f"{base}/sample/{src.name}",
169+
)
170+
171+
172+
def _consolidate_step_for(
173+
src: DatakitSource,
174+
normalized: StepSpec,
175+
deduped: StepSpec,
176+
base: str,
177+
) -> StepSpec:
178+
"""Per-source consolidate. Resolves attr_dir at runtime via Artifact.load."""
179+
return StepSpec(
180+
name=f"datakit-testbed/consolidate/{src.name}",
181+
deps=[normalized, deduped],
182+
fn=lambda output_path: consolidate(
183+
input_path=Artifact.load(normalized, NormalizedData).main_output_dir,
184+
output_path=output_path,
185+
filetype="parquet",
186+
filters=[
187+
FilterConfig(
188+
type=FilterType.KEEP_DOC,
189+
attribute_path=Artifact.load(deduped, FuzzyDupsAttrData)
190+
.sources[Artifact.load(normalized, NormalizedData).main_output_dir]
191+
.attr_dir,
192+
name="is_cluster_canonical",
193+
attribute_filetype="parquet",
194+
keep_if_missing=True,
195+
),
196+
],
197+
worker_resources=ResourceConfig(cpu=1, ram="8g"),
198+
),
199+
override_output_path=f"{base}/consolidate/{src.name}",
200+
)
201+
202+
203+
def build_testbed_steps(
204+
run_id: str,
205+
sources: Sequence[DatakitSource] | None = None,
206+
target_total_tokens_b: float = RAW_TARGET_TOTAL_TOKENS_B,
207+
) -> TestbedDAG:
208+
"""Build the full Datakit Testbed ferry DAG.
209+
210+
Args:
211+
run_id: Per-run identifier; output paths are ``datakit-testbed/{run_id}/...``
212+
under ``MARIN_PREFIX`` so reruns are isolated.
213+
sources: DatakitSource list to ferry. ``None`` means the default set
214+
produced by :func:`pinned_sources` — every registry entry that
215+
has a pinned HF revision and a non-empty repo. Pass explicitly if
216+
you need to include unpinned or API-sourced entries with a custom
217+
download wiring.
218+
target_total_tokens_b: Target total token count (in billions) across
219+
the sampled set. Drives per-source sample fractions via
220+
:func:`proportional_sample_fractions`. Default is
221+
:data:`RAW_TARGET_TOTAL_TOKENS_B` (1000B = 1T per RFC).
222+
223+
Returns:
224+
A ``TestbedDAG`` whose ``all_steps`` list is safe to pass directly to
225+
``StepRunner().run(...)``.
226+
"""
227+
if sources is None:
228+
sources = tuple(pinned_sources().values())
229+
if not sources:
230+
raise ValueError("build_testbed_steps requires at least one source")
231+
232+
base = f"datakit-testbed/{run_id}"
233+
234+
fractions = proportional_sample_fractions(sources, target_total_tokens_b=target_total_tokens_b)
235+
236+
downloads = _build_downloads(sources)
237+
normalized: dict[str, StepSpec] = {}
238+
sampled: dict[str, StepSpec] = {}
239+
for src in sources:
240+
normalized[src.name] = _normalize_step_for(src, downloads[_download_key(src)])
241+
sampled[src.name] = _sample_step_for(src, normalized[src.name], fractions[src.name], base)
242+
243+
deduped = compute_noop_dedup_attrs_step(
244+
name="datakit-testbed/noop_dedup",
245+
normalized_steps=list(sampled.values()),
246+
override_output_path=f"{base}/noop_dedup",
247+
)
248+
249+
consolidated: dict[str, StepSpec] = {
250+
src.name: _consolidate_step_for(src, sampled[src.name], deduped, base) for src in sources
251+
}
252+
253+
all_steps: list[StepSpec] = []
254+
all_steps.extend(downloads.values())
255+
all_steps.extend(normalized.values())
256+
all_steps.extend(sampled.values())
257+
all_steps.append(deduped)
258+
all_steps.extend(consolidated.values())
259+
260+
logger.info(
261+
"Built testbed DAG: %d sources, %d downloads, %d samplers (target %.0fB tokens), "
262+
"1 noop_dedup, %d consolidated outputs",
263+
len(sources),
264+
len(downloads),
265+
len(sampled),
266+
target_total_tokens_b,
267+
len(consolidated),
268+
)
269+
270+
return TestbedDAG(all_steps=all_steps, consolidated_by_source=consolidated)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Proportional-mixing builder for Datakit Testbed training runs.
5+
6+
Maps each testbed ``DatakitSource`` to a mixture weight proportional to its
7+
``rough_token_count_b`` and wraps the resulting
8+
``dict[source_name -> TokenizerStep]`` in an ``LmDataConfig`` via
9+
``lm_mixture_data_config``.
10+
11+
Weights are raw; ``MixtureDataset`` normalizes them at sampling time (see
12+
``levanter/data/text/datasets.py:701``), so there's no need to pre-normalize.
13+
14+
For simulated epoching, training callers set ``target_budget`` and
15+
``experiment_budget`` on the returned config via ``dataclasses.replace``
16+
(see ``experiments/defaults.py:321``). That slicing preserves the per-source
17+
proportions over the shortened horizon.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import logging
23+
24+
from levanter.data.text import LMMixtureDatasetConfig
25+
from marin.processing.tokenize import lm_mixture_data_config
26+
from marin.processing.tokenize.data_configs import TokenizerStep
27+
28+
from marin.datakit.sources import DatakitSource, all_sources
29+
30+
logger = logging.getLogger(__name__)
31+
32+
# Applied to any source whose ``rough_token_count_b`` is ``None``. Picking 1.0
33+
# — roughly a billion tokens — avoids biasing the mixture toward tiny sources
34+
# when the known-count sources are in the hundreds or thousands of billions.
35+
_UNKNOWN_WEIGHT_FALLBACK_B: float = 1.0
36+
37+
38+
def weights_from_rough_counts(sources: list[DatakitSource]) -> dict[str, float]:
39+
"""Use each source's ``rough_token_count_b`` as its mixture weight.
40+
41+
Sources with ``rough_token_count_b is None`` get ``_UNKNOWN_WEIGHT_FALLBACK_B``
42+
and are logged. Measured counts from the tokenize step's ``stats.json``
43+
can replace this later.
44+
"""
45+
weights: dict[str, float] = {}
46+
unknown: list[str] = []
47+
for src in sources:
48+
if src.rough_token_count_b is None:
49+
weights[src.name] = _UNKNOWN_WEIGHT_FALLBACK_B
50+
unknown.append(src.name)
51+
else:
52+
weights[src.name] = src.rough_token_count_b
53+
if unknown:
54+
logger.warning(
55+
"testbed mixture: %d source(s) have rough_token_count_b=None; using fallback %s: %s",
56+
len(unknown),
57+
_UNKNOWN_WEIGHT_FALLBACK_B,
58+
sorted(unknown),
59+
)
60+
return weights
61+
62+
63+
def build_testbed_mixture(
64+
tokenized_by_source: dict[str, TokenizerStep],
65+
*,
66+
weights: dict[str, float] | None = None,
67+
sources: list[DatakitSource] | None = None,
68+
) -> LMMixtureDatasetConfig:
69+
"""Build the proportional mixture over a set of tokenized caches.
70+
71+
Args:
72+
tokenized_by_source: Mapping from ``DatakitSource.name`` to its
73+
``tokenize`` step. Typically ``TestbedDAG.tokenized_by_source``.
74+
weights: Optional explicit mixture weights. Keys must match
75+
``tokenized_by_source``. Raw values — not pre-normalized.
76+
sources: Optional source list to pull ``rough_token_count_b`` from
77+
when ``weights`` is not provided. Defaults to the full 102-entry
78+
set from :func:`all_sources`.
79+
80+
Returns:
81+
An ``LMMixtureDatasetConfig`` ready to hand to ``default_train`` or
82+
``simulated_epoching_train`` (after setting ``target_budget`` and
83+
``experiment_budget``).
84+
85+
Raises:
86+
ValueError: If ``tokenized_by_source`` is empty, or ``weights`` keys
87+
don't match ``tokenized_by_source`` keys exactly.
88+
"""
89+
if not tokenized_by_source:
90+
raise ValueError("tokenized_by_source must be non-empty")
91+
92+
if weights is None:
93+
resolved_sources = sources if sources is not None else list(all_sources().values())
94+
known = {s.name: s for s in resolved_sources}
95+
missing = set(tokenized_by_source) - set(known)
96+
if missing:
97+
raise ValueError(
98+
f"No DatakitSource metadata for tokenized components: {sorted(missing)}. "
99+
"Pass weights=... explicitly or extend sources=..."
100+
)
101+
selected = [known[name] for name in tokenized_by_source]
102+
weights = weights_from_rough_counts(selected)
103+
else:
104+
if set(weights) != set(tokenized_by_source):
105+
raise ValueError(
106+
f"weights keys {sorted(weights)} must match tokenized_by_source keys " f"{sorted(tokenized_by_source)}"
107+
)
108+
109+
logger.info(
110+
"testbed mixture: %d components, total raw weight %.1fB",
111+
len(tokenized_by_source),
112+
sum(weights.values()),
113+
)
114+
115+
return lm_mixture_data_config(
116+
components=tokenized_by_source,
117+
weights=weights,
118+
)

0 commit comments

Comments
 (0)