@@ -42,7 +42,7 @@ def write_job_script(config, machine, target_cores, min_cores, work_dir,
4242 cores = np .sqrt (target_cores * min_cores )
4343 nodes = int (np .ceil (cores / cores_per_node ))
4444
45- partition , qos , constraint , wall_time = get_slurm_options (
45+ partition , qos , constraint , gpus_per_node , wall_time = get_slurm_options (
4646 config , machine , nodes )
4747
4848 job_name = config .get ('job' , 'job_name' )
@@ -58,7 +58,7 @@ def write_job_script(config, machine, target_cores, min_cores, work_dir,
5858 text = template .render (job_name = job_name , account = account ,
5959 nodes = f'{ nodes } ' , wall_time = wall_time , qos = qos ,
6060 partition = partition , constraint = constraint ,
61- suite = suite )
61+ gpus_per_node = gpus_per_node , suite = suite )
6262 text = clean_up_whitespace (text )
6363 if suite == '' :
6464 script_filename = 'job_script.sh'
@@ -95,6 +95,9 @@ def get_slurm_options(config, machine, nodes):
9595 constraint : str
9696 Slurm constraint
9797
98+ gpus_per_node : str
99+ The numer of GPUs per node (if any)
100+
98101 wall_time : str
99102 Slurm wall time
100103 """
@@ -131,9 +134,14 @@ def get_slurm_options(config, machine, nodes):
131134 else :
132135 constraint = ''
133136
137+ if config .has_option ('parallel' , 'gpus_per_node' ):
138+ gpus_per_node = config .get ('parallel' , 'gpus_per_node' )
139+ else :
140+ gpus_per_node = ''
141+
134142 wall_time = config .get ('job' , 'wall_time' )
135143
136- return partition , qos , constraint , wall_time
144+ return partition , qos , constraint , gpus_per_node , wall_time
137145
138146
139147def clean_up_whitespace (text ):
0 commit comments