2727from cloudai .schema .test_template .jax_toolbox .slurm_command_gen_strategy import JaxToolboxSlurmCommandGenStrategy
2828from cloudai .schema .test_template .jax_toolbox .template import JaxToolbox
2929from cloudai .schema .test_template .nccl_test .slurm_command_gen_strategy import NcclTestSlurmCommandGenStrategy
30+ from cloudai .schema .test_template .nemo_launcher .slurm_command_gen_strategy import NeMoLauncherSlurmCommandGenStrategy
31+ from cloudai .schema .test_template .nemo_launcher .template import NeMoLauncher
3032from cloudai .schema .test_template .sleep .slurm_command_gen_strategy import SleepSlurmCommandGenStrategy
3133from cloudai .schema .test_template .sleep .template import Sleep
3234from cloudai .schema .test_template .ucc_test .slurm_command_gen_strategy import UCCTestSlurmCommandGenStrategy
3335from cloudai .systems import SlurmSystem
3436from cloudai .test_definitions .gpt import GPTCmdArgs , GPTTestDefinition
3537from cloudai .test_definitions .grok import GrokCmdArgs , GrokTestDefinition
3638from cloudai .test_definitions .nccl import NCCLCmdArgs , NCCLTestDefinition
39+ from cloudai .test_definitions .nemo_launcher import NeMoLauncherCmdArgs , NeMoLauncherTestDefinition
3740from cloudai .test_definitions .sleep import SleepCmdArgs , SleepTestDefinition
3841from cloudai .test_definitions .ucc import UCCCmdArgs , UCCTestDefinition
3942
@@ -91,7 +94,9 @@ def partial_tr(slurm_system: SlurmSystem) -> partial[TestRun]:
9194 return partial (TestRun , num_nodes = 1 , nodes = [], output_path = slurm_system .output_path )
9295
9396
94- @pytest .fixture (params = ["ucc" , "nccl" , "sleep" , "gpt-pre-test" , "gpt-no-hook" , "grok-pre-test" , "grok-no-hook" ])
97+ @pytest .fixture (
98+ params = ["ucc" , "nccl" , "sleep" , "gpt-pre-test" , "gpt-no-hook" , "grok-pre-test" , "grok-no-hook" , "nemo-launcher" ]
99+ )
95100def test_req (request , slurm_system : SlurmSystem , partial_tr : partial [TestRun ]) -> tuple [TestRun , str , Optional [str ]]:
96101 if request .param == "ucc" :
97102 tr = partial_tr (
@@ -211,6 +216,25 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -
211216 tr .pre_test = TestScenario (name = f"{ pre_test_tr .name } NCCL pre-test" , test_runs = [pre_test_tr ])
212217
213218 return (tr , f"{ request .param } .sbatch" , "grok.run" )
219+ elif request .param == "nemo-launcher" :
220+ tr = partial_tr (
221+ name = "nemo-launcher" ,
222+ test = Test (
223+ test_definition = NeMoLauncherTestDefinition (
224+ name = "nemo-launcher" ,
225+ description = "nemo-launcher" ,
226+ test_template_name = "nemo-launcher" ,
227+ cmd_args = NeMoLauncherCmdArgs (),
228+ ),
229+ test_template = NeMoLauncher (slurm_system , name = "nemo-launcher" ),
230+ ),
231+ )
232+ tr .test .test_template .command_gen_strategy = NeMoLauncherSlurmCommandGenStrategy (
233+ slurm_system , tr .test .test_definition .cmd_args_dict
234+ )
235+ tr .test .test_template .command_gen_strategy .job_name = Mock (return_value = "job_name" )
236+
237+ return (tr , "nemo-launcher.sbatch" , None )
214238
215239 raise ValueError (f"Unknown test: { request .param } " )
216240
@@ -221,10 +245,14 @@ def test_sbatch_generation(slurm_system: SlurmSystem, test_req: tuple[TestRun, s
221245 tr = test_req [0 ]
222246
223247 sbatch_script = tr .test .test_template .gen_exec_command (tr ).split ()[- 1 ]
248+ ref = (Path (__file__ ).parent / "ref_data" / test_req [1 ]).read_text ().strip ()
249+ if "nemo-launcher" in test_req [1 ]:
250+ sbatch_script = slurm_system .output_path / "generated_command.sh"
251+ ref = ref .replace ("__OUTPUT_DIR__" , str (slurm_system .output_path .parent ))
252+ else :
253+ ref = ref .replace ("__OUTPUT_DIR__" , str (slurm_system .output_path )).replace ("__JOB_NAME__" , "job_name" )
224254
225255 curr = Path (sbatch_script ).read_text ().strip ()
226- ref = (Path (__file__ ).parent / "ref_data" / test_req [1 ]).read_text ().strip ()
227- ref = ref .replace ("__OUTPUT_DIR__" , str (slurm_system .output_path )).replace ("__JOB_NAME__" , "job_name" )
228256
229257 assert curr == ref
230258
0 commit comments