Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mache/parallel/pbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,20 @@ def _get_parallel_args(
if nodes is None:
raise ValueError('Node count is not set for the pbs system.')

tasks_per_node = _ceil_division(ntasks, nodes)
max_mpi_tasks_per_node = self.get_config_int('max_mpi_tasks_per_node')
if max_mpi_tasks_per_node is None:
raise ValueError(
'max_mpi_tasks_per_node must be set in the config for the pbs '
'system.'
)
tasks_per_node = _ceil_division(ntasks, nodes)
if tasks_per_node > max_mpi_tasks_per_node:
raise ValueError(
f'Calculated tasks_per_node ({tasks_per_node}) exceeds the '
f'max_mpi_tasks_per_node ({max_mpi_tasks_per_node}). You '
f'likely need to allocate more nodes.'
)
tasks_per_node = min(ntasks, max_mpi_tasks_per_node)

parallel_args = [
'-n',
Expand Down
3 changes: 2 additions & 1 deletion mache/parallel/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,13 @@ def _get_parallel_args(
f'max_mpi_tasks_per_node ({max_mpi_tasks_per_node}). You '
f'likely need to allocate more nodes.'
)
launch_nodes = _ceil_division(ntasks, max_mpi_tasks_per_node)

parallel_args = [
'-c',
f'{cpus_per_task}',
'-N',
f'{nodes}',
f'{launch_nodes}',
'-n',
f'{ntasks}',
]
Expand Down
94 changes: 94 additions & 0 deletions tests/test_parallel_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,52 @@ def test_slurm_custom_gpus_per_task_flag(monkeypatch):
assert args[index + 1] == '1'


def test_slurm_uses_minimum_nodes_for_single_task(monkeypatch):
config = _get_config(
{
'parallel_executable': 'srun --label',
'cores_per_node': '32',
'max_mpi_tasks_per_node': '4',
}
)

monkeypatch.setenv('SLURM_JOB_ID', '12345')
monkeypatch.setattr(
'mache.parallel.slurm._get_subprocess_int', lambda args: 6
)

system = SlurmSystem(config)
args = system._get_parallel_args(
cpus_per_task=1, gpus_per_task=1, ntasks=1
)

index = args.index('-N')
assert args[index + 1] == '1'


def test_slurm_uses_minimum_nodes_for_task_count(monkeypatch):
config = _get_config(
{
'parallel_executable': 'srun --label',
'cores_per_node': '32',
'max_mpi_tasks_per_node': '4',
}
)

monkeypatch.setenv('SLURM_JOB_ID', '12345')
monkeypatch.setattr(
'mache.parallel.slurm._get_subprocess_int', lambda args: 6
)

system = SlurmSystem(config)
args = system._get_parallel_args(
cpus_per_task=1, gpus_per_task=0, ntasks=17
)

index = args.index('-N')
assert args[index + 1] == '5'


def test_pbs_skips_gpu_flag_when_not_configured(monkeypatch):
config = _get_config(
{
Expand Down Expand Up @@ -110,3 +156,51 @@ def test_pbs_uses_configured_gpu_flag(monkeypatch):
assert '--gpus-per-task' in args
index = args.index('--gpus-per-task')
assert args[index + 1] == '1'


def test_pbs_uses_minimum_nodes_for_single_task(monkeypatch):
config = _get_config(
{
'parallel_executable': 'mpiexec --label',
'cores_per_node': '32',
'max_mpi_tasks_per_node': '4',
'cpus_per_task_flag': '--depth',
}
)

monkeypatch.setenv('PBS_JOBID', '12345.server')
monkeypatch.setattr(
PbsSystem, '_get_node_count_from_qstat', lambda self: 6
)

system = PbsSystem(config)
args = system._get_parallel_args(
cpus_per_task=1, gpus_per_task=0, ntasks=1
)

index = args.index('--ppn')
assert args[index + 1] == '1'


def test_pbs_uses_minimum_nodes_for_task_count(monkeypatch):
config = _get_config(
{
'parallel_executable': 'mpiexec --label',
'cores_per_node': '32',
'max_mpi_tasks_per_node': '4',
'cpus_per_task_flag': '--depth',
}
)

monkeypatch.setenv('PBS_JOBID', '12345.server')
monkeypatch.setattr(
PbsSystem, '_get_node_count_from_qstat', lambda self: 6
)

system = PbsSystem(config)
args = system._get_parallel_args(
cpus_per_task=1, gpus_per_task=0, ntasks=17
)

index = args.index('--ppn')
assert args[index + 1] == '4'
Loading