Skip to content

Commit ca4286e

Browse files
committed
Merge branch 'main' into am/cmd-gen-use-tr-member
2 parents 3553cd5 + b6491d3 commit ca4286e

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

src/cloudai/_core/base_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ async def submit_test(self, tr: TestRun):
115115
exit(1)
116116

117117
def on_job_submit(self, tr: TestRun) -> None:
118-
cmd_gen = self.get_cmd_gen_strategy(self.system, tr)
119-
cmd_gen.store_test_run()
118+
return
120119

121120
async def delayed_submit_test(self, tr: TestRun, delay: int = 5):
122121
"""

src/cloudai/systems/slurm/slurm_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def _submit_test(self, tr: TestRun) -> SlurmJob:
7373
logging.info(f"Submitted slurm job: {job_id}")
7474
return SlurmJob(tr, id=job_id)
7575

76+
def on_job_submit(self, tr: TestRun) -> None:
77+
cmd_gen = self.get_cmd_gen_strategy(self.system, tr)
78+
cmd_gen.store_test_run()
79+
7680
def on_job_completion(self, job: BaseJob) -> None:
7781
logging.debug(f"Job completion callback for job {job.id}")
7882
self.system.complete_job(cast(SlurmJob, job))

tests/test_acceptance.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def build_special_test_run(
188188
),
189189
extra_env_vars={"COMBINE_THRESHOLD": "1"},
190190
),
191-
# JaxToolboxSlurmCommandGenStrategy,
192191
)
193192
elif "grok" in param:
194193
test_type = "grok"
@@ -205,7 +204,6 @@ def build_special_test_run(
205204
),
206205
extra_env_vars={"COMBINE_THRESHOLD": "1"},
207206
),
208-
# JaxToolboxSlurmCommandGenStrategy,
209207
)
210208
elif "nemo-run" in param:
211209
test_type = "nemo-run"
@@ -221,7 +219,6 @@ def build_special_test_run(
221219
docker_image_url="nvcr.io/nvidia/nemo:24.09", task="pretrain", recipe_name="llama_3b"
222220
),
223221
),
224-
# NeMoRunSlurmCommandGenStrategy,
225222
)
226223
elif "nemo-launcher" in param:
227224
test_type = "nemo-launcher"
@@ -280,21 +277,18 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -
280277
test_template_name="ucc",
281278
cmd_args=UCCCmdArgs(docker_image_url="nvcr.io/nvidia/pytorch:24.02-py3"),
282279
),
283-
# UCCTestSlurmCommandGenStrategy,
284280
),
285281
"nccl": lambda: create_test_run(
286282
partial_tr,
287283
slurm_system,
288284
"nccl",
289285
NCCLTestDefinition(name="nccl", description="nccl", test_template_name="nccl", cmd_args=NCCLCmdArgs()),
290-
# NcclTestSlurmCommandGenStrategy,
291286
),
292287
"sleep": lambda: create_test_run(
293288
partial_tr,
294289
slurm_system,
295290
"sleep",
296291
SleepTestDefinition(name="sleep", description="sleep", test_template_name="sleep", cmd_args=SleepCmdArgs()),
297-
# SleepSlurmCommandGenStrategy,
298292
),
299293
"slurm_container": lambda: create_test_run(
300294
partial_tr,
@@ -306,7 +300,6 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -
306300
test_template_name="slurm_container",
307301
cmd_args=SlurmContainerCmdArgs(docker_image_url="https://docker/url", cmd="pwd ; ls"),
308302
),
309-
# SlurmContainerCommandGenStrategy,
310303
),
311304
"megatron-run": lambda: create_test_run(
312305
partial_tr,
@@ -325,7 +318,6 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -
325318
),
326319
extra_container_mounts=["$PWD"],
327320
),
328-
# MegatronRunSlurmCommandGenStrategy,
329321
),
330322
"nemo-run": lambda: create_test_run(
331323
partial_tr,
@@ -341,7 +333,6 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -
341333
recipe_name="llama_3b",
342334
),
343335
),
344-
# NeMoRunSlurmCommandGenStrategy,
345336
),
346337
"triton-inference": lambda: create_test_run(
347338
partial_tr,
@@ -358,7 +349,6 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -
358349
tokenizer="tok",
359350
),
360351
),
361-
# TritonInferenceSlurmCommandGenStrategy,
362352
),
363353
"nixl_bench": lambda: create_test_run(
364354
partial_tr,
@@ -375,7 +365,6 @@ def test_req(request, slurm_system: SlurmSystem, partial_tr: partial[TestRun]) -
375365
path_to_benchmark="./nixlbench",
376366
),
377367
),
378-
# NIXLBenchSlurmCommandGenStrategy,
379368
),
380369
"ai-dynamo": lambda: create_test_run(
381370
partial_tr,

0 commit comments

Comments
 (0)