Skip to content

Commit d70d1db

Browse files
committed
Srun support for milabench
1 parent 33d3401 commit d70d1db

File tree

7 files changed

+53
-13
lines changed

7 files changed

+53
-13
lines changed

benchmarks/fno_benchmark/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ all:
1313
install prepare single gpus nodes
1414

1515
install:
16-
milabench install $(MILABENCH_ARGS) --force
16+
milabench install $(MILABENCH_ARGS) --use-current-env
1717

1818
prepare:
1919
milabench prepare $(MILABENCH_ARGS)

benchmarks/fno_benchmark/dev.yaml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
_fno_benchmark:
32
inherits: _defaults
43
definition: .
@@ -20,25 +19,25 @@ _fno_benchmark:
2019
pic1d:
2120
inherits: _fno_benchmark
2221
argv:
23-
--configf: "{benchmark_folder}/src/config/electrostatic_plasma1d_pic.yaml"
22+
--config: "{benchmark_folder}/src/config/electrostatic_plasma1d_pic.yaml"
2423
--dataFile: "{milabench_data}/FNOBenchmark/PIC1D_electrostatic.h5"
2524

2625
pic2d:
2726
inherits: _fno_benchmark
2827
argv:
29-
--configf: "{benchmark_folder}/src/config/electrostatic_plasma2d_pic.yaml"
28+
--config: "{benchmark_folder}/src/config/electrostatic_plasma2d_pic.yaml"
3029
--dataFile: "{milabench_data}/FNOBenchmark/PIC2D_electrostatic.h5"
3130

3231

3332
rbc2d:
3433
inherits: _fno_benchmark
3534
argv:
36-
--configf: "{benchmark_folder}/src/config/rbc2d_dedalus.yaml"
35+
--config: "{benchmark_folder}/src/config/rbc2d_dedalus.yaml"
3736
--dataFile: "{milabench_data}/FNOBenchmark/RBC2D_256x64_Ra1e7_dt1e-3_update.h5"
3837

3938
rbc3d:
4039
inherits: _fno_benchmark
4140
argv:
42-
--configf: "{benchmark_folder}/src/config/rbc3d_pysdc.yaml"
41+
--config: "{benchmark_folder}/src/config/rbc3d_pysdc.yaml"
4342
--dataFile: "{milabench_data}/FNOBenchmark/RBC3D_64x64x32_Ra1e5_dt0_5_solution.h5"
4443

benchmarks/fno_benchmark/milabench-14181232.out

Lines changed: 0 additions & 1 deletion
This file was deleted.

benchmarks/llm/benchfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def make_new_node_executor(self, rank, node, base):
3737
return executor
3838

3939

40+
4041
class Llm(Package):
4142
# Requirements file installed by install(). It can be empty or absent.
4243
base_requirements = "requirements.in"

milabench/_version.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""This file is generated, do not modify"""
22

3-
__tag__ = "v1.2.2-8-gaab9090"
4-
__commit__ = "aab909089a077526e8658a8843ac62bfbf2a4c22"
5-
__date__ = "2025-09-19 17:36:10 +0000"
3+
__tag__ = "33d3401"
4+
__commit__ = "33d34013ba58882ed17ae89aba8210a2ad3872f3"
5+
__date__ = "2025-11-04 18:16:26 +0100"

milabench/commands/__init__.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Dict, Generator, List, Tuple
99
from contextlib import contextmanager
1010
import warnings
11+
import shutil
1112

1213
from voir.instruments.gpu import get_gpu_info
1314

@@ -16,7 +17,7 @@
1617
from ..merge import merge
1718
from ..utils import select_nodes
1819
from .executors import execute_command
19-
from ..system import option, DockerConfig
20+
from ..system import option, DockerConfig, SlurmConfig
2021

2122

2223
def clone_with(cfg, new_cfg):
@@ -330,6 +331,18 @@ def __init__(self, cmd: Command, **kwargs):
330331
super().__init__(cmd, *args)
331332

332333

334+
class Srun(WrapperCommand):
335+
"""Wrap a command to change the working directory"""
336+
337+
def __init__(self, cmd: Command, node_count=1, task_per_node=1, **kwargs):
338+
args = [
339+
"srun",
340+
f"--natasks-per-node={task_per_node}",
341+
f"--nodes={node_count}"
342+
]
343+
super().__init__(cmd, *args)
344+
345+
333346
def is_inside_docker():
334347
return os.environ.get("MILABENCH_DOCKER", None)
335348

@@ -638,6 +651,16 @@ def node_address(node):
638651
return ip or host
639652

640653

654+
def use_slurm_if_available():
655+
enabled = SlurmConfig().enabled
656+
available = shutil.which("srun") is not None and shutil.which("sbatch") is not None
657+
658+
if enabled and not available:
659+
raise RuntimeError("Configuration asks for slurm but slurm is not available")
660+
661+
return enabled and available
662+
663+
641664
class ForeachNode(ListCommand):
642665
def __init__(self, executor: Command, **kwargs) -> None:
643666
super().__init__(None, **kwargs)
@@ -665,6 +688,13 @@ def make_new_node_executor(self, rank, node, base):
665688

666689
def single_node(self):
667690
return self.executor
691+
692+
def node_count(self):
693+
config = self.executor.pack.config
694+
return len(config["system"]["nodes"])
695+
696+
def task_per_node(self):
697+
return 1
668698

669699
@property
670700
def executors(self):
@@ -692,9 +722,7 @@ def executors(self):
692722
)
693723

694724
bench_cmd = self.make_new_node_executor(rank, node, self.executor)
695-
696725
docker_cmd = DockerRunCommand(bench_cmd, DockerConfig(**config["system"].get("docker", {})))
697-
698726
worker = SSHCommand(
699727
host=node_address(node),
700728
user=node["user"],
@@ -703,6 +731,13 @@ def executors(self):
703731
executor=docker_cmd,
704732
**options
705733
)
734+
735+
#
736+
# When using slurm, slurm will launch all those job for us
737+
#
738+
if use_slurm_if_available():
739+
return [Srun(docker_cmd, self.node_count(), self.task_per_node())]
740+
706741
executors.append(worker)
707742
return executors
708743

milabench/system.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,11 @@ def default_device():
409409
return "cpu"
410410

411411

412+
@dataclass
413+
class SlurmConfig:
414+
enabled: int = defaultfield("slurm.enabled", bool, 0)
415+
416+
412417
@dataclass
413418
class SystemConfig:
414419
"""This is meant to be an exhaustive list of all the environment overrides"""
@@ -425,6 +430,7 @@ class SystemConfig:
425430
dash: bool = defaultfield("dash", bool, 1)
426431
noterm: bool = defaultfield("noterm", bool, 0)
427432
github: Github = field(default_factory=Github)
433+
slurm: SlurmConfig = field(default_factory=SlurmConfig)
428434

429435
use_uv: bool = defaultfield("use_uv", bool, 0)
430436

0 commit comments

Comments
 (0)