Skip to content

Commit d6679bf

Browse files
authored
[executor] Add mirrored() for region-agnostic data references (#3990)
Add MirroredValue/mirrored() config wrapper that marks input paths for cross-region mirroring with per-path transfer budgets. The executor resolves mirrored paths to the local marin prefix and copies data from whichever regional bucket has it before step execution. Also fixes MirrorFileSystem.ls to union local and remote prefixes so glob() discovers files across regions, and makes the transfer budget configurable via MARIN_MIRROR_BUDGET_GB env var. Converts hardcoded gs:// input paths in exp934_hq_vs_pt.py (stackexchange, wikipedia, ar5iv, medu QA) to use mirrored() so the initial ferry can run from any region. Fixes #3989
1 parent 9406cc0 commit d6679bf

6 files changed

Lines changed: 326 additions & 25 deletions

File tree

docs/explanations/executor.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,42 @@ Coordination between multiple pipelines is handled via lease files. This
3535
prevents duplicate execution if, for example, 2 Executor pipelines share common
3636
ancestor steps.
3737

38+
## Mirrored inputs
39+
40+
Some datasets live in a specific regional bucket (e.g.
41+
`gs://marin-us-central2/documents/stackexchange/...`) but experiments may run
42+
from any region. The `mirrored()` wrapper marks an input path for
43+
**cross-region mirroring** so that the executor copies the data to the local
44+
marin prefix before the step runs.
45+
46+
```python
47+
from marin.execution.executor import mirrored, versioned
48+
49+
step = ExecutorStep(
50+
name="train",
51+
fn=my_training_fn,
52+
config=TrainConfig(
53+
dataset=mirrored(versioned("documents/stackexchange/v1"), budget_gb=50),
54+
),
55+
)
56+
```
57+
58+
At config instantiation time, `mirrored()` rewrites the path to use the
59+
`mirror://` protocol. When the step's function opens the path via `fsspec`,
60+
the `MirrorFileSystem` transparently copies data from whichever regional bucket
61+
has it into the local marin prefix, respecting the per-path transfer budget.
62+
63+
**Key details:**
64+
65+
- `budget_gb` (default 10) caps how much data (in GB) a single step may copy
66+
cross-region. The budget is enforced via the `mirror_budget` context manager
67+
from `iris.marin_fs`.
68+
- Paths that already exist in the local prefix are not re-copied.
69+
- `mirrored()` can wrap plain strings or `VersionedValue` / `InputName`
70+
references.
71+
- To adjust the global mirror budget default, set the `MARIN_MIRROR_BUDGET_GB`
72+
environment variable before the process starts.
73+
3874
## Ray
3975

4076
Recall that a step's function can either be a normal Python function or in most realistic cases,

experiments/exp934_hq_vs_pt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
datasets used by various training experiments.
99
"""
1010

11-
from marin.execution.executor import ExecutorStep, this_output_path, versioned
11+
from marin.execution.executor import ExecutorStep, mirrored, this_output_path, versioned
1212
from marin.schemas.web.convert import HtmlToMarkdownConfig, ResiliparseConfig
1313
from marin.schemas.web.selectors import ARXIV_BLACKLISTED_SELECTORS, WIKI_BLACKLISTED_SELECTORS
1414
from marin.transform.ar5iv.transform_ar5iv import Ar5ivExtractionConfig, process_ar5iv_dump
@@ -29,7 +29,7 @@
2929
name="documents/stackexchange-resiliparse-custom-fork",
3030
fn=process_stackexchange_dump,
3131
config=StackExchangeExtractionConfig(
32-
input_path=versioned("gs://marin-us-central2/documents/stackexchange/v2024-04-02/md-complete"),
32+
input_path=mirrored(versioned("documents/stackexchange/v2024-04-02/md-complete"), budget_gb=50),
3333
output_path=this_output_path(),
3434
extract_method="resiliparse",
3535
extract_config=ResiliparseConfig(
@@ -48,7 +48,7 @@
4848
name="documents/wikipedia-resiliparse-custom-fork",
4949
fn=process_wiki_dump,
5050
config=WikiExtractionConfig(
51-
input_path="gs://marin-us-central2/raw/wikipedia-a7dad0/20241201",
51+
input_path=mirrored("raw/wikipedia-a7dad0/20241201", budget_gb=1),
5252
revision=versioned("20241201"),
5353
output_path=this_output_path(),
5454
extract_method="resiliparse",
@@ -72,7 +72,7 @@
7272
name="documents/ar5iv/ar5iv-04-2024-no-problem",
7373
fn=process_ar5iv_dump,
7474
config=Ar5ivExtractionConfig(
75-
input_path="gs://marin-us-central2/raw/ar5iv/ar5iv-04-2024-no-problem-49c4e3/202404",
75+
input_path=mirrored("raw/ar5iv/ar5iv-04-2024-no-problem-49c4e3/202404", budget_gb=1),
7676
revision="042024",
7777
output_path=this_output_path("resiliparse-custom-fork"),
7878
extract_method=versioned("resiliparse"),
@@ -88,7 +88,7 @@
8888
# MMLU Science QA tokenization
8989
medu_mmlu_science_qa_tokenized = default_tokenize(
9090
name="medu-mmlu-science-qa",
91-
dataset="gs://marin-us-east1/documents/medu-mmlu-science-llama8b-qa-whole-1a419d",
91+
dataset=mirrored("documents/medu-mmlu-science-llama8b-qa-whole-1a419d", budget_gb=30),
9292
tokenizer=llama3_tokenizer,
9393
).with_output_path("tokenized/medu-mmlu-science-qa-c64fda")
9494

lib/iris/src/iris/marin_fs.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
the ``MARIN_I_WILL_PAY_FOR_ALL_FEES`` env var to override the guard.
2424
"""
2525

26+
import contextlib
27+
import contextvars
2628
import dataclasses
2729
import functools
2830
import logging
@@ -33,7 +35,7 @@
3335
import time
3436
import urllib.error
3537
import urllib.request
36-
from collections.abc import Callable, Sequence
38+
from collections.abc import Callable, Generator, Sequence
3739
from pathlib import PurePath
3840
from typing import Any
3941

@@ -375,7 +377,18 @@ def _normalize_path_like(path: str | os.PathLike) -> str:
375377
# ---------------------------------------------------------------------------
376378

377379
MARIN_CROSS_REGION_OVERRIDE_ENV: str = "MARIN_I_WILL_PAY_FOR_ALL_FEES"
378-
CROSS_REGION_TRANSFER_LIMIT_BYTES: int = 10 * 1024 * 1024 * 1024 # 10 GB
380+
MARIN_MIRROR_BUDGET_ENV: str = "MARIN_MIRROR_BUDGET_GB"
381+
_DEFAULT_TRANSFER_LIMIT_GB: int = 10
382+
383+
384+
def _transfer_limit_bytes() -> int:
385+
raw = os.environ.get(MARIN_MIRROR_BUDGET_ENV, "")
386+
if raw:
387+
return int(float(raw) * 1024 * 1024 * 1024)
388+
return _DEFAULT_TRANSFER_LIMIT_GB * 1024 * 1024 * 1024
389+
390+
391+
CROSS_REGION_TRANSFER_LIMIT_BYTES: int = _transfer_limit_bytes()
379392

380393
# GCS multi-region bucket locations are returned as "us", "eu", or "asia"
381394
# rather than a specific region like "us-central1". European regions use the
@@ -450,6 +463,34 @@ def reset(self, limit_bytes: int | None = None) -> None:
450463

451464
_global_transfer_budget = TransferBudget()
452465

466+
_mirror_budget_ctx: contextvars.ContextVar[TransferBudget | None] = contextvars.ContextVar(
467+
"_mirror_budget_ctx", default=None
468+
)
469+
470+
471+
def set_mirror_budget(budget_gb: float) -> contextvars.Token:
472+
"""Set the MirrorFileSystem transfer budget for the current context.
473+
474+
Returns a token that can be used to reset the budget.
475+
"""
476+
budget = TransferBudget(limit_bytes=int(budget_gb * 1024 * 1024 * 1024))
477+
return _mirror_budget_ctx.set(budget)
478+
479+
480+
def reset_mirror_budget(token: contextvars.Token) -> None:
481+
"""Reset the MirrorFileSystem transfer budget to its previous value."""
482+
_mirror_budget_ctx.reset(token)
483+
484+
485+
@contextlib.contextmanager
486+
def mirror_budget(budget_gb: float) -> Generator[None, None, None]:
487+
"""Context manager to scope a MirrorFileSystem transfer budget."""
488+
token = set_mirror_budget(budget_gb)
489+
try:
490+
yield
491+
finally:
492+
reset_mirror_budget(token)
493+
453494

454495
@functools.lru_cache(maxsize=1)
455496
def _cached_marin_region() -> str | None:
@@ -684,6 +725,15 @@ def __init__(
684725
self._budget = budget if budget is not None else _global_transfer_budget
685726
self._worker_id = default_worker_id()
686727

728+
# -- budget resolution ----------------------------------------------------
729+
730+
def _active_budget(self) -> TransferBudget:
731+
"""Return the contextvar budget if set, otherwise the instance budget."""
732+
ctx_budget = _mirror_budget_ctx.get()
733+
if ctx_budget is not None:
734+
return ctx_budget
735+
return self._budget
736+
687737
# -- underlying fs helpers ------------------------------------------------
688738

689739
def _get_fs_and_path(self, url: str) -> tuple[Any, str]:
@@ -755,7 +805,7 @@ def _copy_to_local(self, source_prefix: str, path: str) -> None:
755805

756806
size = self._fs_size(remote_url)
757807
if size is not None:
758-
self._budget.record(size, remote_url)
808+
self._active_budget().record(size, remote_url)
759809

760810
logger.info("Mirror: copying %s → %s", remote_url, local_url)
761811
self._fs_copy(remote_url, local_url)
@@ -785,27 +835,38 @@ def _info(self, path: str, **kwargs: Any) -> dict[str, Any]:
785835
info["name"] = path
786836
return info
787837

838+
@staticmethod
839+
def _stripped_prefix(bucket_prefix: str) -> str:
840+
"""Return the bucket prefix without scheme, with trailing slash."""
841+
return bucket_prefix.rstrip("/").replace("gs://", "").replace("file://", "") + "/"
842+
788843
def ls(self, path: str, detail: bool = True, **kwargs: Any) -> list[Any]:
789844
path = self._strip_protocol(path)
790-
local_url = self._local_url(path)
791-
fs, fspath = self._get_fs_and_path(local_url)
792-
try:
793-
results = fs.ls(fspath, detail=detail, **kwargs)
794-
except FileNotFoundError:
795-
results = []
796-
797-
prefix = self._local_prefix.rstrip("/") + "/"
798-
# For GCS, fsspec strips the scheme, so the prefix in results won't have gs://
799-
stripped_prefix = prefix.replace("gs://", "").replace("file://", "")
800-
845+
# Union listings from local + all remote prefixes so that glob()
846+
# discovers files that only exist in other regions. Local entries
847+
# take precedence when a relative path appears in multiple buckets.
848+
seen: dict[str, dict[str, Any]] = {}
849+
850+
for prefix in [self._local_prefix, *self._remote_prefixes]:
851+
url = f"{prefix}/{path}"
852+
fs, fspath = self._get_fs_and_path(url)
853+
try:
854+
entries = fs.ls(fspath, detail=True, **kwargs)
855+
except FileNotFoundError:
856+
continue
857+
858+
stripped = self._stripped_prefix(prefix)
859+
for entry in entries:
860+
rel_name = entry["name"]
861+
if rel_name.startswith(stripped):
862+
rel_name = rel_name[len(stripped) :]
863+
if rel_name not in seen:
864+
seen[rel_name] = {**entry, "name": rel_name}
865+
866+
results = list(seen.values())
801867
if detail:
802-
for entry in results:
803-
name = entry["name"]
804-
if name.startswith(stripped_prefix):
805-
entry["name"] = name[len(stripped_prefix) :]
806868
return results
807-
else:
808-
return [r[len(stripped_prefix) :] if r.startswith(stripped_prefix) else r for r in results]
869+
return [e["name"] for e in results]
809870

810871
def exists(self, path: str, **kwargs: Any) -> bool:
811872
path = self._strip_protocol(path)

lib/iris/tests/test_mirror_fs.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,88 @@ def test_second_read_uses_local_cache(mirror_fs, mirror_env):
104104
def test_read_finds_file_in_second_remote(mirror_fs, mirror_env):
105105
_write_file(mirror_env["remote2"], "data/remote2.txt", b"from-remote2")
106106
assert mirror_fs.cat_file("data/remote2.txt") == b"from-remote2"
107+
108+
109+
# ---------------------------------------------------------------------------
110+
# ls / glob tests
111+
# ---------------------------------------------------------------------------
112+
113+
114+
def test_ls_returns_local_entries(mirror_fs, mirror_env):
115+
_write_file(mirror_env["local_dir"], "data/a.jsonl", b"a")
116+
_write_file(mirror_env["local_dir"], "data/b.jsonl", b"b")
117+
118+
entries = mirror_fs.ls("data", detail=False)
119+
assert sorted(entries) == ["data/a.jsonl", "data/b.jsonl"]
120+
121+
122+
def test_ls_discovers_remote_only_entries(mirror_fs, mirror_env):
123+
_write_file(mirror_env["remote1"], "data/remote.jsonl", b"r")
124+
125+
entries = mirror_fs.ls("data", detail=False)
126+
assert "data/remote.jsonl" in entries
127+
128+
129+
def test_ls_unions_local_and_remote(mirror_fs, mirror_env):
130+
_write_file(mirror_env["local_dir"], "data/local.jsonl", b"l")
131+
_write_file(mirror_env["remote1"], "data/remote.jsonl", b"r")
132+
133+
entries = mirror_fs.ls("data", detail=False)
134+
assert sorted(entries) == ["data/local.jsonl", "data/remote.jsonl"]
135+
136+
137+
def test_ls_local_takes_precedence(mirror_fs, mirror_env):
138+
_write_file(mirror_env["local_dir"], "data/file.jsonl", b"local-version")
139+
_write_file(mirror_env["remote1"], "data/file.jsonl", b"remote-version")
140+
141+
entries = mirror_fs.ls("data", detail=True)
142+
assert len(entries) == 1
143+
assert entries[0]["name"] == "data/file.jsonl"
144+
145+
146+
def test_ls_unions_multiple_remotes(mirror_fs, mirror_env):
147+
_write_file(mirror_env["remote1"], "data/from_r1.jsonl", b"r1")
148+
_write_file(mirror_env["remote2"], "data/from_r2.jsonl", b"r2")
149+
150+
entries = mirror_fs.ls("data", detail=False)
151+
assert sorted(entries) == ["data/from_r1.jsonl", "data/from_r2.jsonl"]
152+
153+
154+
def test_ls_empty_when_path_missing_everywhere(mirror_fs):
155+
entries = mirror_fs.ls("nonexistent/path", detail=False)
156+
assert entries == []
157+
158+
159+
def test_glob_discovers_remote_files(mirror_fs, mirror_env):
160+
_write_file(mirror_env["remote1"], "docs/a.jsonl.gz", b"a")
161+
_write_file(mirror_env["remote1"], "docs/b.jsonl.gz", b"b")
162+
_write_file(mirror_env["remote1"], "docs/skip.txt", b"skip")
163+
164+
matched = mirror_fs.glob("docs/*.jsonl.gz")
165+
assert sorted(matched) == ["docs/a.jsonl.gz", "docs/b.jsonl.gz"]
166+
167+
168+
# ---------------------------------------------------------------------------
169+
# mirror_budget context manager tests
170+
# ---------------------------------------------------------------------------
171+
172+
173+
def test_mirror_budget_context_manager(mirror_fs, mirror_env):
174+
"""Transfer budget set via context manager is used for copies."""
175+
from iris.marin_fs import mirror_budget
176+
177+
_write_file(mirror_env["remote1"], "data/big.bin", b"x" * 1000)
178+
179+
with mirror_budget(budget_gb=0.001): # ~1MB
180+
assert mirror_fs.cat_file("data/big.bin") == b"x" * 1000
181+
182+
183+
def test_mirror_budget_context_manager_blocks_over_budget(mirror_fs, mirror_env):
184+
from iris.marin_fs import mirror_budget
185+
186+
mirror_fs._budget.reset(limit_bytes=10 * 1024 * 1024 * 1024) # high instance budget
187+
_write_file(mirror_env["remote1"], "data/big.bin", b"x" * 2000)
188+
189+
with mirror_budget(budget_gb=0.000001): # ~1KB — too small
190+
with pytest.raises(TransferBudgetExceeded):
191+
mirror_fs.cat_file("data/big.bin")

0 commit comments

Comments
 (0)