88from typing import Dict , Generator , List , Tuple
99from contextlib import contextmanager
1010import warnings
11+ import shutil
1112
1213from voir .instruments .gpu import get_gpu_info
1314
1617from ..merge import merge
1718from ..utils import select_nodes
1819from .executors import execute_command
19- from ..system import option , DockerConfig
20+ from ..system import option , DockerConfig , SlurmConfig
2021
2122
2223def clone_with (cfg , new_cfg ):
@@ -330,6 +331,18 @@ def __init__(self, cmd: Command, **kwargs):
330331 super ().__init__ (cmd , * args )
331332
332333
334+ class Srun (WrapperCommand ):
335+ """Wrap a command to change the working directory"""
336+
337+ def __init__ (self , cmd : Command , node_count = 1 , task_per_node = 1 , ** kwargs ):
338+ args = [
339+ "srun" ,
340+ f"--natasks-per-node={ task_per_node } " ,
341+ f"--nodes={ node_count } "
342+ ]
343+ super ().__init__ (cmd , * args )
344+
345+
333346def is_inside_docker ():
334347 return os .environ .get ("MILABENCH_DOCKER" , None )
335348
@@ -638,6 +651,16 @@ def node_address(node):
638651 return ip or host
639652
640653
654+ def use_slurm_if_available ():
655+ enabled = SlurmConfig ().enabled
656+ available = shutil .which ("srun" ) is not None and shutil .which ("sbatch" ) is not None
657+
658+ if enabled and not available :
659+ raise RuntimeError ("Configuration asks for slurm but slurm is not available" )
660+
661+ return enabled and available
662+
663+
641664class ForeachNode (ListCommand ):
642665 def __init__ (self , executor : Command , ** kwargs ) -> None :
643666 super ().__init__ (None , ** kwargs )
@@ -665,6 +688,13 @@ def make_new_node_executor(self, rank, node, base):
665688
666689 def single_node (self ):
667690 return self .executor
691+
692+ def node_count (self ):
693+ config = self .executor .pack .config
694+ return len (config ["system" ]["nodes" ])
695+
696+ def task_per_node (self ):
697+ return 1
668698
669699 @property
670700 def executors (self ):
@@ -692,9 +722,7 @@ def executors(self):
692722 )
693723
694724 bench_cmd = self .make_new_node_executor (rank , node , self .executor )
695-
696725 docker_cmd = DockerRunCommand (bench_cmd , DockerConfig (** config ["system" ].get ("docker" , {})))
697-
698726 worker = SSHCommand (
699727 host = node_address (node ),
700728 user = node ["user" ],
@@ -703,6 +731,13 @@ def executors(self):
703731 executor = docker_cmd ,
704732 ** options
705733 )
734+
735+ #
736+ # When using slurm, slurm will launch all those job for us
737+ #
738+ if use_slurm_if_available ():
739+ return [Srun (docker_cmd , self .node_count (), self .task_per_node ())]
740+
706741 executors .append (worker )
707742 return executors
708743
0 commit comments