Skip to content
Open
8 changes: 4 additions & 4 deletions experiments/references/reference_hyperparameter_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def _build_update_step(
study_resource_name=study_resource_name,
input_db_path=input_db_path,
suggestions_path=suggestions_path,
run_paths=[step.as_input_name() for step in training_steps],
run_paths=[step.as_mirrored_value() for step in training_steps],
metric_file=SWEEP.metric_file,
metric_key=SWEEP.metric_key,
mode=SWEEP.metric_mode,
Expand Down Expand Up @@ -648,7 +648,7 @@ def _build_optimal_step(
base_launch_config = _build_base_launch_config()

for loop_index in range(num_loops):
input_db_path = previous_update_step / VIZIER_DB_FILENAME if previous_update_step else None
input_db_path = previous_update_step.as_mirrored_value() / VIZIER_DB_FILENAME if previous_update_step else None
suggest_step = _build_suggest_step(loop_index=loop_index, input_db_path=input_db_path)

suggestions_path = suggest_step / SUGGESTIONS_FILENAME
Expand All @@ -665,14 +665,14 @@ def _build_optimal_step(
update_step = _build_update_step(
loop_index=loop_index,
study_resource_name=SWEEP.study_resource_name,
input_db_path=suggest_step / VIZIER_DB_FILENAME,
input_db_path=suggest_step.as_mirrored_value() / VIZIER_DB_FILENAME,
suggestions_path=suggestions_path,
training_steps=training_steps,
)
previous_update_step = update_step

optimal_step = _build_optimal_step(
input_db_path=previous_update_step / VIZIER_DB_FILENAME,
input_db_path=previous_update_step.as_mirrored_value() / VIZIER_DB_FILENAME,
study_resource_name=SWEEP.study_resource_name,
)
executor_main(steps=[optimal_step])
49 changes: 43 additions & 6 deletions lib/marin/src/marin/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,16 @@ def with_output_path(self, output_path: str) -> "ExecutorStep":
def as_input_name(self) -> "InputName":
return InputName(step=self, name=None)

def as_mirrored_value(self, budget_gb: float = 10) -> "MirroredValue[InputName]":
"""Return a ``MirroredValue`` wrapping this step's output as an ``InputName``.

This is the step-reference analogue of the free ``mirrored()`` helper
used for raw paths. Usage::

default_eval(step=training_run.as_mirrored_value())
"""
return MirroredValue(value=self.as_input_name(), budget_gb=budget_gb)


@dataclass(frozen=True)
class InputName:
Expand All @@ -805,7 +815,11 @@ class InputName:
"""

def cd(self, name: str) -> "InputName":
return InputName(self.step, name=os.path.join(self.name, name) if self.name else name)
return InputName(
self.step,
name=os.path.join(self.name, name) if self.name else name,
block_on_step=self.block_on_step,
)

def __truediv__(self, other: str) -> "InputName":
"""Alias for `cd` that looks more Pythonic."""
Expand All @@ -827,6 +841,10 @@ def nonblocking(self) -> "InputName":
"""
return dataclasses.replace(self, block_on_step=False)

def as_mirrored_value(self, budget_gb: float = 10) -> "MirroredValue[InputName]":
"""Wrap this input in a ``MirroredValue`` for cross-region mirroring."""
return MirroredValue(value=self, budget_gb=budget_gb)


def get_executor_step(run: ExecutorStep | InputName) -> ExecutorStep:
"""
Expand Down Expand Up @@ -943,11 +961,27 @@ class MirroredValue(Generic[T_co]):
value: T_co
budget_gb: float = 10

def cd(self, name: str) -> "MirroredValue":
"""Navigate into a subdirectory, keeping the mirror wrapper."""
inner = self.value
if isinstance(inner, (ExecutorStep, InputName)):
return MirroredValue(value=inner.cd(name), budget_gb=self.budget_gb)
if isinstance(inner, str):
return MirroredValue(value=os.path.join(inner, name), budget_gb=self.budget_gb)
raise TypeError(f"cd() not supported on MirroredValue wrapping {type(inner)}")

def __truediv__(self, other: str) -> "MirroredValue":
"""Alias for ``cd`` that looks more Pythonic."""
return self.cd(other)


def mirrored(value: str | VersionedValue[str], budget_gb: float = 10) -> MirroredValue:
def mirrored(value: str | VersionedValue[str] | InputName, budget_gb: float = 10) -> MirroredValue:
"""Mark a path for cross-region mirroring with a transfer budget.

Usage: input_path=mirrored(versioned("documents/stackexchange/..."), budget_gb=50)
Usage::

input_path=mirrored(versioned("documents/stackexchange/..."), budget_gb=50)
model_path=mirrored(training_step.as_input_name() / "hf", budget_gb=25)
"""
if isinstance(value, MirroredValue):
raise ValueError("Can't nest MirroredValue")
Expand Down Expand Up @@ -1117,7 +1151,9 @@ def recurse(obj: Any) -> None:
if isinstance(obj, VersionedValue):
recurse(obj.value)
return
if isinstance(obj, InputName | ExecutorStep):
if isinstance(obj, InputName):
return
if isinstance(obj, ExecutorStep):
return
if is_dataclass(obj):
for field in fields(obj):
Expand Down Expand Up @@ -1163,9 +1199,10 @@ def recurse(obj: Any):

if isinstance(obj, InputName):
if obj.step is None:
return _make_prefix_absolute_path(prefix, obj.name)
resolved = _make_prefix_absolute_path(prefix, obj.name)
else:
return join_path(output_paths[obj.step], obj.name)
resolved = join_path(output_paths[obj.step], obj.name)
return resolved
elif isinstance(obj, OutputName):
return join_path(output_path, obj.name)
elif isinstance(obj, VersionedValue):
Expand Down
58 changes: 32 additions & 26 deletions tests/execution/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
_get_info_path,
collect_dependencies_and_version,
instantiate_config,
mirrored,
output_path_of,
this_output_path,
versioned,
Expand Down Expand Up @@ -615,47 +614,54 @@ def test_parent_will_run_if_some_child_is_not_skippable():
assert os.path.exists(os.path.join(executor.output_paths[parent], "dummy", "done.txt"))


def test_mirrored_versioning():
"""MirroredValue wrapping VersionedValue should version the inner value."""
def _dummy_fn(config):
pass


def test_mirrored_input_name_instantiate_config():
"""MirroredValue wrapping InputName resolves to mirror:// path."""

@dataclass(frozen=True)
class Cfg:
input_path: str
model_path: str
output_path: str

deps = collect_dependencies_and_version(
Cfg(input_path=mirrored(versioned("some/path"), budget_gb=50), output_path="out")
)
assert deps.version == {"input_path": "some/path"}
step = ExecutorStep(name="train", fn=_dummy_fn, config={})
cfg = Cfg(model_path=step.as_mirrored_value(), output_path="out")
output_paths = {step: "/bucket/train/abc123"}
resolved = instantiate_config(cfg, output_path="/out", output_paths=output_paths, prefix="/bucket")
assert resolved.model_path == "mirror:///bucket/train/abc123"


def test_mirrored_instantiate_config():
"""MirroredValue should resolve to mirror:// path."""
def test_mirrored_input_name_does_not_affect_version():
"""Wrapping InputName in MirroredValue should not change the version hash."""

@dataclass(frozen=True)
class Cfg:
input_path: str
model_path: str
output_path: str

cfg = Cfg(input_path=mirrored(versioned("documents/data"), budget_gb=10), output_path="out")
resolved = instantiate_config(cfg, output_path="/out", output_paths={}, prefix="/bucket")
assert resolved.input_path == "mirror://documents/data"
step = ExecutorStep(name="train", fn=_dummy_fn, config={})
deps_plain = collect_dependencies_and_version(Cfg(model_path=output_path_of(step, "hf"), output_path="out"))
deps_mirrored = collect_dependencies_and_version(
Cfg(model_path=(step.as_input_name() / "hf").as_mirrored_value(budget_gb=50), output_path="out")
)
assert deps_plain.version == deps_mirrored.version


def test_mirrored_nesting_raises():
with pytest.raises(ValueError, match="nest"):
mirrored(mirrored("x"))
def test_mirrored_value_truediv_instantiate():
"""MirroredValue with / subdirs resolves correctly via instantiate_config."""

@dataclass(frozen=True)
class Cfg:
model_path: str
output_path: str

def test_mirrored_changes_version():
"""Changing the path inside mirrored() should change the version hash."""
deps1 = collect_dependencies_and_version(
MyConfig(input_path=mirrored(versioned("data/v1")), output_path="out", n=versioned(1), m=1)
)
deps2 = collect_dependencies_and_version(
MyConfig(input_path=mirrored(versioned("data/v2")), output_path="out", n=versioned(1), m=1)
)
assert deps1.version != deps2.version
step = ExecutorStep(name="train", fn=_dummy_fn, config={})
cfg = Cfg(model_path=step.as_mirrored_value(budget_gb=5) / "hf", output_path="out")
output_paths = {step: "/bucket/train/abc123"}
resolved = instantiate_config(cfg, output_path="/out", output_paths=output_paths, prefix="/bucket")
assert resolved.model_path == "mirror:///bucket/train/abc123/hf"


def test_status_file_takeover_stale_lock_then_refresh(tmp_path):
Expand Down
Loading