|
| 1 | +# Copyright The Marin Authors |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +"""Datakit Testbed ferry DAG builder. |
| 5 | +
|
| 6 | +Composes the canonical Datakit stages into one multi-source pipeline: |
| 7 | +
|
| 8 | + download[hf] ─┐ |
| 9 | + ├─► normalize[source] ─► sample[source] ─┐ |
| 10 | + download[hf] ─┘ ├─► noop_dedup (all sampled) |
| 11 | + │ |
| 12 | + sample[source] ────────────────── consolidate[source] |
| 13 | +
|
| 14 | +Key shape choices: |
| 15 | +* Downloads are grouped by ``(hf_dataset_id, revision)`` so Nemotron v2.1 |
| 16 | + subsets that share a family dump do not re-download. |
| 17 | +* **Sampling happens post-normalize.** Normalize produces uniform-size shards |
| 18 | + (``target_partition_bytes``), making "first K by filename" both byte-fair |
| 19 | + and content-fair. Downstream stages (dedup, consolidate, tokenize) pay |
| 20 | + O(sampled) per experiment; normalize is a one-time cost cached by |
| 21 | + ``override_output_path``. |
| 22 | +* Dedup is a single step with one dep per sampled source, emitting |
| 23 | + per-source attr directories. |
| 24 | +* Consolidate fans back out to one step per ``DatakitSource`` so each |
| 25 | + mixture component has its own filtered parquet output. |
| 26 | +
|
| 27 | +The ferry stops at ``consolidate`` on purpose. Tokenize runs in the training |
| 28 | +executor graph (see ``experiments/datakit_testbed/train.py``), not the ferry, |
| 29 | +because ``lm_mixture_data_config`` needs real ``ExecutorStep[TokenizeConfig]`` |
| 30 | +components — not the ``StepSpec`` instances this ferry produces. The training |
| 31 | +harness converts each consolidate ``StepSpec`` to an ``ExecutorStep`` via |
| 32 | +``.as_executor_step()`` and builds a proper ``TokenizerStep`` on top, which |
| 33 | +preserves cross-layer dep tracking for the executor. |
| 34 | +""" |
| 35 | + |
| 36 | +from __future__ import annotations |
| 37 | + |
| 38 | +import logging |
| 39 | +from collections.abc import Sequence |
| 40 | +from dataclasses import dataclass |
| 41 | + |
| 42 | +from fray import ResourceConfig |
| 43 | +from marin.datakit.download.huggingface import download_hf_step |
| 44 | +from marin.datakit.normalize import NormalizedData, normalize_step |
| 45 | +from marin.execution.artifact import Artifact |
| 46 | +from marin.execution.step_spec import StepSpec |
| 47 | +from marin.processing.classification.consolidate import ( |
| 48 | + FilterConfig, |
| 49 | + FilterType, |
| 50 | + consolidate, |
| 51 | +) |
| 52 | +from marin.processing.classification.deduplication.fuzzy_dups import FuzzyDupsAttrData |
| 53 | + |
| 54 | +from experiments.datakit_testbed.noop_dedup import compute_noop_dedup_attrs_step |
| 55 | +from experiments.datakit_testbed.sampler import ( |
| 56 | + proportional_sample_fractions, |
| 57 | + sample_normalized_shards_step, |
| 58 | +) |
| 59 | +from marin.datakit.sources import DatakitSource, pinned_sources |
| 60 | + |
| 61 | +from experiments.datakit_testbed.settings import RAW_TARGET_TOTAL_TOKENS_B |
| 62 | + |
| 63 | +logger = logging.getLogger(__name__) |
| 64 | + |
| 65 | + |
| 66 | +@dataclass(frozen=True) |
| 67 | +class TestbedDAG: |
| 68 | + """Handle to the built ferry DAG. |
| 69 | +
|
| 70 | + All steps must be passed to ``StepRunner().run(...)`` for execution. |
| 71 | + ``consolidated_by_source`` is exposed separately so the training harness |
| 72 | + can reference each source's consolidate step when building the tokenize |
| 73 | + ExecutorSteps that feed the mixture. |
| 74 | + """ |
| 75 | + |
| 76 | + all_steps: list[StepSpec] |
| 77 | + consolidated_by_source: dict[str, StepSpec] |
| 78 | + |
| 79 | + |
| 80 | +DownloadKey = tuple[str, str, str | None, tuple[str, ...] | None] |
| 81 | + |
| 82 | + |
| 83 | +def _download_key(src: DatakitSource) -> DownloadKey: |
| 84 | + """Key that uniquely identifies a download step. |
| 85 | +
|
| 86 | + Includes ``staged_path`` and ``hf_urls_glob`` because some HF repos (e.g. |
| 87 | + ``bigcode/StarCoder2-Extras``) are downloaded per-subset with distinct |
| 88 | + ``override_output_path`` values — so the same ``(hf_repo, revision)`` can |
| 89 | + correspond to multiple physical download steps. |
| 90 | + """ |
| 91 | + return (src.hf_dataset_id, src.revision or "", src.staged_path, src.hf_urls_glob) |
| 92 | + |
| 93 | + |
| 94 | +def _download_step_name(src: DatakitSource) -> str: |
| 95 | + """Stable download step name. Includes the last staged_path segment when |
| 96 | + multiple sources share a repo but stage separately (StarCoder2-Extras).""" |
| 97 | + base = src.hf_dataset_id.replace("/", "__") |
| 98 | + if src.staged_path: |
| 99 | + tail = src.staged_path.rstrip("/").rsplit("/", 1)[-1] |
| 100 | + if tail and tail != base: |
| 101 | + return f"datakit-testbed/download/{base}__{tail}" |
| 102 | + return f"datakit-testbed/download/{base}" |
| 103 | + |
| 104 | + |
| 105 | +def _build_downloads(sources: Sequence[DatakitSource]) -> dict[DownloadKey, StepSpec]: |
| 106 | + """One download step per unique ``(hf_repo, revision, staged_path, urls_glob)``.""" |
| 107 | + by_key: dict[DownloadKey, DatakitSource] = {} |
| 108 | + for src in sources: |
| 109 | + key = _download_key(src) |
| 110 | + if key in by_key: |
| 111 | + continue |
| 112 | + by_key[key] = src |
| 113 | + |
| 114 | + downloads: dict[DownloadKey, StepSpec] = {} |
| 115 | + seen_names: set[str] = set() |
| 116 | + for key, src in by_key.items(): |
| 117 | + step_name = _download_step_name(src) |
| 118 | + if step_name in seen_names: |
| 119 | + raise ValueError( |
| 120 | + f"Duplicate download step name {step_name!r} — extend " "_download_step_name to disambiguate" |
| 121 | + ) |
| 122 | + seen_names.add(step_name) |
| 123 | + assert src.revision is not None, f"{src.name}: cannot build download for unpinned revision" |
| 124 | + downloads[key] = download_hf_step( |
| 125 | + step_name, |
| 126 | + hf_dataset_id=src.hf_dataset_id, |
| 127 | + revision=src.revision, |
| 128 | + hf_urls_glob=list(src.hf_urls_glob) if src.hf_urls_glob else None, |
| 129 | + override_output_path=src.staged_path, |
| 130 | + ) |
| 131 | + return downloads |
| 132 | + |
| 133 | + |
| 134 | +def _normalize_step_for( |
| 135 | + src: DatakitSource, |
| 136 | + download: StepSpec, |
| 137 | +) -> StepSpec: |
| 138 | + """Per-source normalize. ``input_path`` points at ``data_subdir`` inside the download. |
| 139 | +
|
| 140 | + Output lands at ``$MARIN_PREFIX/normalized/<src.name>-<hash>/`` — a |
| 141 | + canonical, run-independent artifact that any downstream consumer (testbed |
| 142 | + or otherwise) can point at. Matches the convention used by |
| 143 | + ``marin.datakit.download.nemotron_v2.normalize_nemotron_v2_step``. |
| 144 | + """ |
| 145 | + input_path = f"{download.output_path}/{src.data_subdir}" if src.data_subdir else download.output_path |
| 146 | + return normalize_step( |
| 147 | + name=f"normalized/{src.name}", |
| 148 | + download=download, |
| 149 | + text_field=src.text_field, |
| 150 | + id_field=src.id_field, |
| 151 | + input_path=input_path, |
| 152 | + file_extensions=src.file_extensions, |
| 153 | + worker_resources=ResourceConfig(cpu=2, ram="16g", disk="20g"), |
| 154 | + ) |
| 155 | + |
| 156 | + |
| 157 | +def _sample_step_for( |
| 158 | + src: DatakitSource, |
| 159 | + normalized: StepSpec, |
| 160 | + sample_fraction: float, |
| 161 | + base: str, |
| 162 | +) -> StepSpec: |
| 163 | + """Per-source post-normalize sampler. Copies first ceil(N * fraction) shards.""" |
| 164 | + return sample_normalized_shards_step( |
| 165 | + name=f"datakit-testbed/sample/{src.name}", |
| 166 | + normalized=normalized, |
| 167 | + sample_fraction=sample_fraction, |
| 168 | + override_output_path=f"{base}/sample/{src.name}", |
| 169 | + ) |
| 170 | + |
| 171 | + |
| 172 | +def _consolidate_step_for( |
| 173 | + src: DatakitSource, |
| 174 | + normalized: StepSpec, |
| 175 | + deduped: StepSpec, |
| 176 | + base: str, |
| 177 | +) -> StepSpec: |
| 178 | + """Per-source consolidate. Resolves attr_dir at runtime via Artifact.load.""" |
| 179 | + return StepSpec( |
| 180 | + name=f"datakit-testbed/consolidate/{src.name}", |
| 181 | + deps=[normalized, deduped], |
| 182 | + fn=lambda output_path: consolidate( |
| 183 | + input_path=Artifact.load(normalized, NormalizedData).main_output_dir, |
| 184 | + output_path=output_path, |
| 185 | + filetype="parquet", |
| 186 | + filters=[ |
| 187 | + FilterConfig( |
| 188 | + type=FilterType.KEEP_DOC, |
| 189 | + attribute_path=Artifact.load(deduped, FuzzyDupsAttrData) |
| 190 | + .sources[Artifact.load(normalized, NormalizedData).main_output_dir] |
| 191 | + .attr_dir, |
| 192 | + name="is_cluster_canonical", |
| 193 | + attribute_filetype="parquet", |
| 194 | + keep_if_missing=True, |
| 195 | + ), |
| 196 | + ], |
| 197 | + worker_resources=ResourceConfig(cpu=1, ram="8g"), |
| 198 | + ), |
| 199 | + override_output_path=f"{base}/consolidate/{src.name}", |
| 200 | + ) |
| 201 | + |
| 202 | + |
| 203 | +def build_testbed_steps( |
| 204 | + run_id: str, |
| 205 | + sources: Sequence[DatakitSource] | None = None, |
| 206 | + target_total_tokens_b: float = RAW_TARGET_TOTAL_TOKENS_B, |
| 207 | +) -> TestbedDAG: |
| 208 | + """Build the full Datakit Testbed ferry DAG. |
| 209 | +
|
| 210 | + Args: |
| 211 | + run_id: Per-run identifier; output paths are ``datakit-testbed/{run_id}/...`` |
| 212 | + under ``MARIN_PREFIX`` so reruns are isolated. |
| 213 | + sources: DatakitSource list to ferry. ``None`` means the default set |
| 214 | + produced by :func:`pinned_sources` — every registry entry that |
| 215 | + has a pinned HF revision and a non-empty repo. Pass explicitly if |
| 216 | + you need to include unpinned or API-sourced entries with a custom |
| 217 | + download wiring. |
| 218 | + target_total_tokens_b: Target total token count (in billions) across |
| 219 | + the sampled set. Drives per-source sample fractions via |
| 220 | + :func:`proportional_sample_fractions`. Default is |
| 221 | + :data:`RAW_TARGET_TOTAL_TOKENS_B` (1000B = 1T per RFC). |
| 222 | +
|
| 223 | + Returns: |
| 224 | + A ``TestbedDAG`` whose ``all_steps`` list is safe to pass directly to |
| 225 | + ``StepRunner().run(...)``. |
| 226 | + """ |
| 227 | + if sources is None: |
| 228 | + sources = tuple(pinned_sources().values()) |
| 229 | + if not sources: |
| 230 | + raise ValueError("build_testbed_steps requires at least one source") |
| 231 | + |
| 232 | + base = f"datakit-testbed/{run_id}" |
| 233 | + |
| 234 | + fractions = proportional_sample_fractions(sources, target_total_tokens_b=target_total_tokens_b) |
| 235 | + |
| 236 | + downloads = _build_downloads(sources) |
| 237 | + normalized: dict[str, StepSpec] = {} |
| 238 | + sampled: dict[str, StepSpec] = {} |
| 239 | + for src in sources: |
| 240 | + normalized[src.name] = _normalize_step_for(src, downloads[_download_key(src)]) |
| 241 | + sampled[src.name] = _sample_step_for(src, normalized[src.name], fractions[src.name], base) |
| 242 | + |
| 243 | + deduped = compute_noop_dedup_attrs_step( |
| 244 | + name="datakit-testbed/noop_dedup", |
| 245 | + normalized_steps=list(sampled.values()), |
| 246 | + override_output_path=f"{base}/noop_dedup", |
| 247 | + ) |
| 248 | + |
| 249 | + consolidated: dict[str, StepSpec] = { |
| 250 | + src.name: _consolidate_step_for(src, sampled[src.name], deduped, base) for src in sources |
| 251 | + } |
| 252 | + |
| 253 | + all_steps: list[StepSpec] = [] |
| 254 | + all_steps.extend(downloads.values()) |
| 255 | + all_steps.extend(normalized.values()) |
| 256 | + all_steps.extend(sampled.values()) |
| 257 | + all_steps.append(deduped) |
| 258 | + all_steps.extend(consolidated.values()) |
| 259 | + |
| 260 | + logger.info( |
| 261 | + "Built testbed DAG: %d sources, %d downloads, %d samplers (target %.0fB tokens), " |
| 262 | + "1 noop_dedup, %d consolidated outputs", |
| 263 | + len(sources), |
| 264 | + len(downloads), |
| 265 | + len(sampled), |
| 266 | + target_total_tokens_b, |
| 267 | + len(consolidated), |
| 268 | + ) |
| 269 | + |
| 270 | + return TestbedDAG(all_steps=all_steps, consolidated_by_source=consolidated) |
0 commit comments