Skip to content

Commit 645cf56

Browse files
gwarmstrongKipok
andauthored
fix(pipeline): honor keep_mounts=False to prevent sandbox mount leak (#1394)
Signed-off-by: gwarmstrong <gwarmstrong@users.noreply.github.com> Co-authored-by: gwarmstrong <gwarmstrong@users.noreply.github.com> Co-authored-by: Igor Gitman <igitman@nvidia.com>
1 parent 786edf9 commit 645cf56

2 files changed

Lines changed: 177 additions & 6 deletions

File tree

nemo_skills/pipeline/utils/declarative.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,13 @@ def prepare_for_execution(self, cluster_config: Dict) -> Tuple[run.Script, Dict]
298298
# For SandboxScript, keep_mounts=False (the safe default) maps to mounts=[]
299299
# so the sandbox container has no access to cluster filesystems.
300300
# keep_mounts=True maps to mounts=None, which inherits cluster mounts.
301+
# keep_mounts is propagated separately so Stage B (_create_executor) can
302+
# honor the isolation request even when Command.mounts is an explicit list
303+
# (in which case Stage A's resolved_mounts alone loses that signal).
304+
keep_mounts = getattr(self.script, "keep_mounts", True)
301305
if self.mounts is not None:
302306
resolved_mounts = self.mounts
303307
else:
304-
keep_mounts = getattr(self.script, "keep_mounts", True)
305308
resolved_mounts = None if keep_mounts else []
306309

307310
merged_env = dict(runtime_metadata.get("environment", {}))
@@ -311,6 +314,7 @@ def prepare_for_execution(self, cluster_config: Dict) -> Tuple[run.Script, Dict]
311314
"log_prefix": getattr(self.script, "log_prefix", "main"),
312315
"environment": merged_env,
313316
"mounts": resolved_mounts,
317+
"keep_mounts": keep_mounts,
314318
"container": self.container,
315319
}
316320

@@ -647,13 +651,28 @@ def _create_executor(
647651
else (hardware.num_tasks if hardware and hardware.num_tasks is not None else 1)
648652
)
649653

650-
# Allow per-command extra mounts without requiring editing the cluster YAML.
651-
# We treat exec_config["mounts"] as additive and merge it with mounts from cluster_config.
652-
mounts = None
653-
extra_mounts = exec_config["mounts"] or None
654-
if extra_mounts:
654+
# Resolve mounts based on Stage A output and the script's keep_mounts flag:
655+
# - mounts=None: inherit cluster mounts (Stage C default).
656+
# - keep_mounts=False: the script asked for filesystem isolation. Pass its
657+
# mounts list verbatim (even empty) so cluster mounts are NOT merged in.
658+
# - keep_mounts=True + non-empty extras: additive merge with cluster mounts.
659+
# - keep_mounts=True + empty extras: inherit cluster mounts.
660+
# Stage A invariant: mounts=None is only produced when keep_mounts=True
661+
# (keep_mounts=False with no explicit Command.mounts is normalized to []),
662+
# so the `extra_mounts is None` branch below is safe to take before
663+
# consulting keep_mounts. `.get(..., True)` defends against exec_configs
664+
# built by callers that bypass Stage A.
665+
extra_mounts = exec_config["mounts"]
666+
keep_mounts = exec_config.get("keep_mounts", True)
667+
if extra_mounts is None:
668+
mounts = None
669+
elif not keep_mounts:
670+
mounts = list(extra_mounts)
671+
elif extra_mounts:
655672
base_mounts = get_mounts_from_config(cluster_config)
656673
mounts = base_mounts + [m for m in extra_mounts if m not in base_mounts]
674+
else:
675+
mounts = None
657676

658677
# Sandbox-specific srun overrides: allow the sandbox to survive individual
659678
# worker crashes (e.g. SIGILL from libraries compiled for a different CPU).

tests/test_declarative_pipeline.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,5 +922,157 @@ def capture_env_update(cluster_config, updates):
922922
)
923923

924924

925+
class TestMountsResolution:
926+
"""Regression tests for the Command/Pipeline mounts resolution flow.
927+
928+
Covers the full (Command.mounts x script.keep_mounts) matrix described in
929+
the sandbox-mount-leak bug analysis. The three bug rows share keep_mounts=False
930+
and must NOT receive cluster mounts back via the Stage B additive merge.
931+
"""
932+
933+
CLUSTER_MOUNTS = ["/cluster/a:/cluster/a", "/cluster/b:/cluster/b"]
934+
935+
def _make_script(self, *, keep_mounts=None):
936+
"""Return a DummyScript with an optional keep_mounts attribute."""
937+
script = DummyScript(inline="echo test")
938+
if keep_mounts is not None:
939+
script.keep_mounts = keep_mounts
940+
return script
941+
942+
# -------------------- Stage A: Command.prepare_for_execution --------------------
943+
944+
@pytest.mark.parametrize(
945+
"command_mounts, keep_mounts_attr, expected_mounts, expected_keep_mounts",
946+
[
947+
# Command.mounts=None
948+
(None, None, None, True), # non-sandbox (keep_mounts attr absent -> defaults True)
949+
(None, True, None, True), # sandbox opt-in: inherit cluster mounts
950+
(None, False, [], False), # sandbox default: empty list, flag propagated
951+
# Command.mounts=[]
952+
([], None, [], True),
953+
([], True, [], True),
954+
([], False, [], False),
955+
# Command.mounts=[/a:/b]
956+
(["/a:/b"], None, ["/a:/b"], True),
957+
(["/a:/b"], True, ["/a:/b"], True),
958+
(["/a:/b"], False, ["/a:/b"], False),
959+
],
960+
)
961+
def test_stage_a_resolved_mounts_and_keep_mounts(
962+
self, command_mounts, keep_mounts_attr, expected_mounts, expected_keep_mounts
963+
):
964+
"""Stage A must store mounts and the keep_mounts flag in execution_config."""
965+
script = self._make_script(keep_mounts=keep_mounts_attr)
966+
cmd = Command(script=script, name="c", mounts=command_mounts)
967+
cluster_config = {"executor": "local", "containers": {}}
968+
969+
_, exec_config = cmd.prepare_for_execution(cluster_config)
970+
971+
assert exec_config["mounts"] == expected_mounts
972+
assert exec_config["keep_mounts"] is expected_keep_mounts
973+
974+
# -------------------- Stage B/C: end-to-end mounts passed to get_executor --------------------
975+
976+
def _run_pipeline_and_capture_mounts(self, command_mounts, keep_mounts_attr):
977+
"""Run a one-command Pipeline with mocks and return the mounts kwarg passed to get_executor."""
978+
captured = {}
979+
980+
def mock_get_executor(**kwargs):
981+
captured["mounts"] = kwargs.get("mounts")
982+
executor = MagicMock()
983+
executor.packager = MagicMock()
984+
return executor
985+
986+
cluster_config = {
987+
"executor": "slurm",
988+
"containers": {"nemo-skills": "test/container"},
989+
"account": "test",
990+
"env_vars": {"HF_HOME": "/hf"},
991+
"mounts": self.CLUSTER_MOUNTS,
992+
}
993+
994+
script = self._make_script(keep_mounts=keep_mounts_attr)
995+
cmd = Command(script=script, name="c", mounts=command_mounts)
996+
group = CommandGroup(commands=[cmd], name="g", log_dir="/logs")
997+
998+
with (
999+
patch("nemo_skills.pipeline.utils.declarative.get_executor", side_effect=mock_get_executor),
1000+
patch(
1001+
"nemo_skills.pipeline.utils.declarative.get_mounts_from_config",
1002+
return_value=list(self.CLUSTER_MOUNTS),
1003+
),
1004+
patch("nemo_skills.pipeline.utils.declarative.get_env_variables", return_value={"HF_HOME": "/hf"}),
1005+
patch("nemo_skills.pipeline.utils.declarative.get_exp") as mock_get_exp,
1006+
patch("nemo_skills.pipeline.utils.declarative.run_exp"),
1007+
):
1008+
mock_exp = MagicMock()
1009+
mock_exp.__enter__ = MagicMock(return_value=mock_exp)
1010+
mock_exp.__exit__ = MagicMock(return_value=False)
1011+
mock_exp.add = MagicMock(return_value="handle")
1012+
mock_get_exp.return_value = mock_exp
1013+
1014+
Pipeline(
1015+
name="test",
1016+
cluster_config=cluster_config,
1017+
jobs=[{"name": "j", "group": group}],
1018+
skip_hf_home_check=True,
1019+
reuse_code=False,
1020+
).run(dry_run=True)
1021+
1022+
assert "mounts" in captured, "get_executor was not called"
1023+
return captured["mounts"]
1024+
1025+
# ---- Non-bug rows: expected pre-fix behavior is preserved ----
1026+
1027+
def test_mounts_none_no_keep_mounts_attr_inherits_cluster(self):
1028+
"""Non-sandbox script with no explicit mounts inherits cluster mounts."""
1029+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=None, keep_mounts_attr=None)
1030+
# Stage C falls back to cluster mounts when mounts kwarg is None
1031+
assert mounts is None
1032+
1033+
def test_mounts_none_keep_mounts_true_inherits_cluster(self):
1034+
"""keep_mounts=True with no explicit list inherits cluster mounts."""
1035+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=None, keep_mounts_attr=True)
1036+
assert mounts is None
1037+
1038+
def test_mounts_empty_no_keep_mounts_attr_inherits_cluster(self):
1039+
"""Empty Command.mounts on a non-sandbox script is treated as 'no extras' -> inherit."""
1040+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=[], keep_mounts_attr=None)
1041+
assert mounts is None
1042+
1043+
def test_mounts_empty_keep_mounts_true_inherits_cluster(self):
1044+
"""Empty Command.mounts with keep_mounts=True also inherits cluster mounts."""
1045+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=[], keep_mounts_attr=True)
1046+
assert mounts is None
1047+
1048+
def test_mounts_extra_no_keep_mounts_attr_additive_merge(self):
1049+
"""Non-sandbox extras are additively merged with cluster mounts."""
1050+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=["/a:/b"], keep_mounts_attr=None)
1051+
assert mounts == self.CLUSTER_MOUNTS + ["/a:/b"]
1052+
1053+
def test_mounts_extra_keep_mounts_true_additive_merge(self):
1054+
"""keep_mounts=True with extras: additive merge (opt-in inherit + extras)."""
1055+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=["/a:/b"], keep_mounts_attr=True)
1056+
assert mounts == self.CLUSTER_MOUNTS + ["/a:/b"]
1057+
1058+
# ---- Bug rows: keep_mounts=False must isolate from cluster mounts ----
1059+
1060+
def test_bug_row_1_mounts_none_keep_mounts_false_no_cluster_leak(self):
1061+
"""Sandbox default (Command.mounts=None, keep_mounts=False): no cluster mounts leak through."""
1062+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=None, keep_mounts_attr=False)
1063+
# Must be an empty list passed to get_executor so Stage C does NOT fall back to cluster mounts
1064+
assert mounts == [], f"keep_mounts=False leaked cluster mounts: {mounts}"
1065+
1066+
def test_bug_row_2_mounts_empty_keep_mounts_false_no_cluster_leak(self):
1067+
"""Sandbox with explicit empty list (Command.mounts=[], keep_mounts=False): no cluster mounts leak."""
1068+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=[], keep_mounts_attr=False)
1069+
assert mounts == [], f"keep_mounts=False leaked cluster mounts: {mounts}"
1070+
1071+
def test_bug_row_3_mounts_extra_keep_mounts_false_no_cluster_merge(self):
1072+
"""Sandbox with explicit extras (Command.mounts=[/a:/b], keep_mounts=False): extras verbatim, no cluster merge."""
1073+
mounts = self._run_pipeline_and_capture_mounts(command_mounts=["/a:/b"], keep_mounts_attr=False)
1074+
assert mounts == ["/a:/b"], f"keep_mounts=False merged cluster mounts into sandbox: {mounts}"
1075+
1076+
9251077
if __name__ == "__main__":
9261078
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)