Skip to content

Commit 04dda35

Browse files
committed
StepRunner: walk transitive deps before scheduling
Add a post-order walk (deduped by output_path) so callers pass only terminal steps; runner resolves the full graph, and already-succeeded deps short-circuit via the existing cache check. The prior "Iterable exhausted" branch is unreachable under the new traversal and is removed. Fixes #5146
1 parent 30f6b6c commit 04dda35

2 files changed

Lines changed: 115 additions & 32 deletions

File tree

lib/marin/src/marin/execution/step_runner.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,34 @@ def _write_executor_info(step: StepSpec) -> None:
8888
# ---------------------------------------------------------------------------
8989

9090

91+
def _flatten_transitive_deps(steps: Iterable[StepSpec]) -> list[StepSpec]:
92+
"""Walk transitive deps in post-order, deduping by ``output_path``.
93+
94+
Callers can pass only terminal steps; the runner resolves the full graph.
95+
A cycle in the graph is a construction-time invariant violation and raises.
96+
"""
97+
seen: set[str] = set()
98+
in_stack: set[str] = set()
99+
ordered: list[StepSpec] = []
100+
101+
def visit(step: StepSpec) -> None:
102+
path = step.output_path
103+
if path in seen:
104+
return
105+
if path in in_stack:
106+
raise ValueError(f"Cycle detected in step graph involving {step.name_with_hash}")
107+
in_stack.add(path)
108+
for dep in step.deps:
109+
visit(dep)
110+
in_stack.remove(path)
111+
seen.add(path)
112+
ordered.append(step)
113+
114+
for step in steps:
115+
visit(step)
116+
return ordered
117+
118+
91119
class StepRunner:
92120
"""Runs ``StepSpec`` objects respecting their dependencies.
93121
@@ -106,10 +134,11 @@ def run(
106134
) -> None:
107135
"""Eagerly run steps, launching each as soon as its deps are satisfied.
108136
109-
Concurrency is bounded by the thread pool (``max_concurrent`` workers,
110-
default 8). If a step's dependencies haven't been seen yet, it is
111-
buffered until they complete. Raises if the iterable is exhausted and
112-
buffered steps still have unsatisfied dependencies.
137+
The input iterable may contain only the terminal steps you want to
138+
reach; the runner walks transitive deps (deduped by ``output_path``)
139+
before scheduling. Concurrency is bounded by the thread pool
140+
(``max_concurrent`` workers, default 8). Already-succeeded deps
141+
(``STATUS_SUCCESS`` on disk) resolve via the cache check.
113142
"""
114143
max_workers = max_concurrent or 8
115144
if max_workers < 1:
@@ -196,7 +225,8 @@ def _do_launch(step: StepSpec) -> None:
196225
else:
197226
completed.add(path)
198227

199-
for step in steps:
228+
flattened = _flatten_transitive_deps(steps)
229+
for step in flattened:
200230
path_to_name[step.output_path] = step.name_with_hash
201231

202232
_harvest()
@@ -212,22 +242,6 @@ def _do_launch(step: StepSpec) -> None:
212242

213243
# Drain remaining running and waiting steps
214244
while running or waiting:
215-
if not running and waiting:
216-
_flush_waiting()
217-
if not running and waiting:
218-
missing = []
219-
for s in waiting:
220-
unmet = [
221-
_display_name(d.output_path)
222-
for d in s.deps
223-
if d.output_path not in completed and d.output_path not in failed
224-
]
225-
missing.append(f" {s.name_with_hash}: needs {unmet}")
226-
raise RuntimeError(
227-
f"Iterable exhausted with {len(waiting)} step(s) with unsatisfied dependencies:\n"
228-
+ "\n".join(missing)
229-
)
230-
231245
_harvest(block=True)
232246
_flush_waiting()
233247

tests/execution/test_step_runner.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -372,29 +372,98 @@ def test_runner_max_concurrent(tmp_path: Path):
372372
assert train_artifact.tokens_seen > 0
373373

374374

375-
def test_runner_raises_clear_error_for_unmet_deps(tmp_path: Path):
376-
"""When the iterable omits a dependency, the runner must name the offending step and
377-
its unmet dep paths — not crash with ``TypeError: unhashable type: 'list'``."""
375+
def test_runner_walks_transitive_deps(tmp_path: Path):
376+
"""Passing only terminal steps should cause the runner to walk and run transitive deps."""
377+
executed: list[str] = []
378+
379+
def record(name: str):
380+
def _fn(output_path: str) -> PathMetadata:
381+
executed.append(name)
382+
return PathMetadata(path=output_path)
383+
384+
return _fn
378385

379386
dep = StepSpec(
380-
name="missing_upstream",
381-
override_output_path=(tmp_path / "missing_upstream").as_posix(),
387+
name="dep",
388+
override_output_path=(tmp_path / "dep").as_posix(),
389+
fn=record("dep"),
390+
)
391+
mid = StepSpec(
392+
name="mid",
393+
override_output_path=(tmp_path / "mid").as_posix(),
394+
deps=[dep],
395+
fn=record("mid"),
396+
)
397+
terminal = StepSpec(
398+
name="terminal",
399+
override_output_path=(tmp_path / "terminal").as_posix(),
400+
deps=[mid],
401+
fn=record("terminal"),
402+
)
403+
404+
StepRunner().run([terminal])
405+
406+
assert executed == ["dep", "mid", "terminal"]
407+
408+
409+
def test_runner_walks_transitive_deps_with_cache_hit(tmp_path: Path):
410+
"""Deps already succeeded on disk must be recognized via cache-hit during the walk."""
411+
dep = StepSpec(
412+
name="dep",
413+
override_output_path=(tmp_path / "dep").as_posix(),
382414
fn=lambda output_path: PathMetadata(path=output_path),
383415
)
416+
downstream_ran: list[str] = []
417+
418+
def run_downstream(output_path: str) -> PathMetadata:
419+
downstream_ran.append(output_path)
420+
return PathMetadata(path=output_path)
421+
384422
downstream = StepSpec(
385423
name="downstream",
386424
override_output_path=(tmp_path / "downstream").as_posix(),
387425
deps=[dep],
426+
fn=run_downstream,
427+
)
428+
429+
# Prime the cache for ``dep`` only.
430+
StepRunner().run([dep])
431+
assert downstream_ran == []
432+
433+
# Pass only ``downstream``; the runner walks deps and cache-hits ``dep``.
434+
StepRunner().run([downstream])
435+
assert downstream_ran == [(tmp_path / "downstream").as_posix()]
436+
437+
438+
def test_runner_dedups_shared_deps(tmp_path: Path):
439+
"""A dep shared by multiple terminals must be executed exactly once."""
440+
dep_runs: list[str] = []
441+
442+
def run_dep(output_path: str) -> PathMetadata:
443+
dep_runs.append(output_path)
444+
return PathMetadata(path=output_path)
445+
446+
dep = StepSpec(
447+
name="shared_dep",
448+
override_output_path=(tmp_path / "shared_dep").as_posix(),
449+
fn=run_dep,
450+
)
451+
a = StepSpec(
452+
name="a",
453+
override_output_path=(tmp_path / "a").as_posix(),
454+
deps=[dep],
455+
fn=lambda output_path: PathMetadata(path=output_path),
456+
)
457+
b = StepSpec(
458+
name="b",
459+
override_output_path=(tmp_path / "b").as_posix(),
460+
deps=[dep],
388461
fn=lambda output_path: PathMetadata(path=output_path),
389462
)
390463

391-
runner = StepRunner()
392-
with pytest.raises(RuntimeError, match=r"Iterable exhausted .* unsatisfied dependencies") as exc_info:
393-
runner.run([downstream])
464+
StepRunner().run([a, b])
394465

395-
message = str(exc_info.value)
396-
assert downstream.name_with_hash in message
397-
assert dep.output_path in message
466+
assert dep_runs == [(tmp_path / "shared_dep").as_posix()]
398467

399468

400469
def test_runner_preserves_underlying_step_exception(tmp_path: Path):

0 commit comments

Comments
 (0)