@@ -922,5 +922,157 @@ def capture_env_update(cluster_config, updates):
922922 )
923923
924924
925+ class TestMountsResolution :
926+ """Regression tests for the Command/Pipeline mounts resolution flow.
927+
928+ Covers the full (Command.mounts x script.keep_mounts) matrix described in
929+ the sandbox-mount-leak bug analysis. The three bug rows share keep_mounts=False
930+ and must NOT receive cluster mounts back via the Stage B additive merge.
931+ """
932+
933+ CLUSTER_MOUNTS = ["/cluster/a:/cluster/a" , "/cluster/b:/cluster/b" ]
934+
935+ def _make_script (self , * , keep_mounts = None ):
936+ """Return a DummyScript with an optional keep_mounts attribute."""
937+ script = DummyScript (inline = "echo test" )
938+ if keep_mounts is not None :
939+ script .keep_mounts = keep_mounts
940+ return script
941+
942+ # -------------------- Stage A: Command.prepare_for_execution --------------------
943+
944+ @pytest .mark .parametrize (
945+ "command_mounts, keep_mounts_attr, expected_mounts, expected_keep_mounts" ,
946+ [
947+ # Command.mounts=None
948+ (None , None , None , True ), # non-sandbox (keep_mounts attr absent -> defaults True)
949+ (None , True , None , True ), # sandbox opt-in: inherit cluster mounts
950+ (None , False , [], False ), # sandbox default: empty list, flag propagated
951+ # Command.mounts=[]
952+ ([], None , [], True ),
953+ ([], True , [], True ),
954+ ([], False , [], False ),
955+ # Command.mounts=[/a:/b]
956+ (["/a:/b" ], None , ["/a:/b" ], True ),
957+ (["/a:/b" ], True , ["/a:/b" ], True ),
958+ (["/a:/b" ], False , ["/a:/b" ], False ),
959+ ],
960+ )
961+ def test_stage_a_resolved_mounts_and_keep_mounts (
962+ self , command_mounts , keep_mounts_attr , expected_mounts , expected_keep_mounts
963+ ):
964+ """Stage A must store mounts and the keep_mounts flag in execution_config."""
965+ script = self ._make_script (keep_mounts = keep_mounts_attr )
966+ cmd = Command (script = script , name = "c" , mounts = command_mounts )
967+ cluster_config = {"executor" : "local" , "containers" : {}}
968+
969+ _ , exec_config = cmd .prepare_for_execution (cluster_config )
970+
971+ assert exec_config ["mounts" ] == expected_mounts
972+ assert exec_config ["keep_mounts" ] is expected_keep_mounts
973+
974+ # -------------------- Stage B/C: end-to-end mounts passed to get_executor --------------------
975+
976+ def _run_pipeline_and_capture_mounts (self , command_mounts , keep_mounts_attr ):
977+ """Run a one-command Pipeline with mocks and return the mounts kwarg passed to get_executor."""
978+ captured = {}
979+
980+ def mock_get_executor (** kwargs ):
981+ captured ["mounts" ] = kwargs .get ("mounts" )
982+ executor = MagicMock ()
983+ executor .packager = MagicMock ()
984+ return executor
985+
986+ cluster_config = {
987+ "executor" : "slurm" ,
988+ "containers" : {"nemo-skills" : "test/container" },
989+ "account" : "test" ,
990+ "env_vars" : {"HF_HOME" : "/hf" },
991+ "mounts" : self .CLUSTER_MOUNTS ,
992+ }
993+
994+ script = self ._make_script (keep_mounts = keep_mounts_attr )
995+ cmd = Command (script = script , name = "c" , mounts = command_mounts )
996+ group = CommandGroup (commands = [cmd ], name = "g" , log_dir = "/logs" )
997+
998+ with (
999+ patch ("nemo_skills.pipeline.utils.declarative.get_executor" , side_effect = mock_get_executor ),
1000+ patch (
1001+ "nemo_skills.pipeline.utils.declarative.get_mounts_from_config" ,
1002+ return_value = list (self .CLUSTER_MOUNTS ),
1003+ ),
1004+ patch ("nemo_skills.pipeline.utils.declarative.get_env_variables" , return_value = {"HF_HOME" : "/hf" }),
1005+ patch ("nemo_skills.pipeline.utils.declarative.get_exp" ) as mock_get_exp ,
1006+ patch ("nemo_skills.pipeline.utils.declarative.run_exp" ),
1007+ ):
1008+ mock_exp = MagicMock ()
1009+ mock_exp .__enter__ = MagicMock (return_value = mock_exp )
1010+ mock_exp .__exit__ = MagicMock (return_value = False )
1011+ mock_exp .add = MagicMock (return_value = "handle" )
1012+ mock_get_exp .return_value = mock_exp
1013+
1014+ Pipeline (
1015+ name = "test" ,
1016+ cluster_config = cluster_config ,
1017+ jobs = [{"name" : "j" , "group" : group }],
1018+ skip_hf_home_check = True ,
1019+ reuse_code = False ,
1020+ ).run (dry_run = True )
1021+
1022+ assert "mounts" in captured , "get_executor was not called"
1023+ return captured ["mounts" ]
1024+
1025+ # ---- Non-bug rows: expected pre-fix behavior is preserved ----
1026+
1027+ def test_mounts_none_no_keep_mounts_attr_inherits_cluster (self ):
1028+ """Non-sandbox script with no explicit mounts inherits cluster mounts."""
1029+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = None , keep_mounts_attr = None )
1030+ # Stage C falls back to cluster mounts when mounts kwarg is None
1031+ assert mounts is None
1032+
1033+ def test_mounts_none_keep_mounts_true_inherits_cluster (self ):
1034+ """keep_mounts=True with no explicit list inherits cluster mounts."""
1035+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = None , keep_mounts_attr = True )
1036+ assert mounts is None
1037+
1038+ def test_mounts_empty_no_keep_mounts_attr_inherits_cluster (self ):
1039+ """Empty Command.mounts on a non-sandbox script is treated as 'no extras' -> inherit."""
1040+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = [], keep_mounts_attr = None )
1041+ assert mounts is None
1042+
1043+ def test_mounts_empty_keep_mounts_true_inherits_cluster (self ):
1044+ """Empty Command.mounts with keep_mounts=True also inherits cluster mounts."""
1045+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = [], keep_mounts_attr = True )
1046+ assert mounts is None
1047+
1048+ def test_mounts_extra_no_keep_mounts_attr_additive_merge (self ):
1049+ """Non-sandbox extras are additively merged with cluster mounts."""
1050+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = ["/a:/b" ], keep_mounts_attr = None )
1051+ assert mounts == self .CLUSTER_MOUNTS + ["/a:/b" ]
1052+
1053+ def test_mounts_extra_keep_mounts_true_additive_merge (self ):
1054+ """keep_mounts=True with extras: additive merge (opt-in inherit + extras)."""
1055+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = ["/a:/b" ], keep_mounts_attr = True )
1056+ assert mounts == self .CLUSTER_MOUNTS + ["/a:/b" ]
1057+
1058+ # ---- Bug rows: keep_mounts=False must isolate from cluster mounts ----
1059+
1060+ def test_bug_row_1_mounts_none_keep_mounts_false_no_cluster_leak (self ):
1061+ """Sandbox default (Command.mounts=None, keep_mounts=False): no cluster mounts leak through."""
1062+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = None , keep_mounts_attr = False )
1063+ # Must be an empty list passed to get_executor so Stage C does NOT fall back to cluster mounts
1064+ assert mounts == [], f"keep_mounts=False leaked cluster mounts: { mounts } "
1065+
1066+ def test_bug_row_2_mounts_empty_keep_mounts_false_no_cluster_leak (self ):
1067+ """Sandbox with explicit empty list (Command.mounts=[], keep_mounts=False): no cluster mounts leak."""
1068+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = [], keep_mounts_attr = False )
1069+ assert mounts == [], f"keep_mounts=False leaked cluster mounts: { mounts } "
1070+
1071+ def test_bug_row_3_mounts_extra_keep_mounts_false_no_cluster_merge (self ):
1072+ """Sandbox with explicit extras (Command.mounts=[/a:/b], keep_mounts=False): extras verbatim, no cluster merge."""
1073+ mounts = self ._run_pipeline_and_capture_mounts (command_mounts = ["/a:/b" ], keep_mounts_attr = False )
1074+ assert mounts == ["/a:/b" ], f"keep_mounts=False merged cluster mounts into sandbox: { mounts } "
1075+
1076+
9251077if __name__ == "__main__" :
9261078 pytest .main ([__file__ , "-v" ])
0 commit comments