Skip to content

Commit 49cef1d

Browse files
committed
SlurmSweep should spawn as many jobs as workers; output folder of SlurmJobs and SlurmSweeps should be collected in distinct folders
1 parent 36c78f7 commit 49cef1d

1 file changed

Lines changed: 24 additions & 6 deletions

File tree

ml_project_template/runs.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
import subprocess
32
import sys
43
import tempfile
@@ -70,7 +69,14 @@ def python_command(self) -> str:
7069

7170
def run(self) -> None:
7271
"""Run the job on the cluster."""
73-
command = ["python", *self.filter_args(sys.argv), "cfg/wandb=base"]
72+
hydra_run_dir = "./outputs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f}"
73+
74+
command = [
75+
"python",
76+
*self.filter_args(sys.argv),
77+
"cfg/wandb=base",
78+
f"hydra.run.dir={hydra_run_dir}",
79+
]
7480

7581
function = CommandFunction(command)
7682
executor = AutoExecutor(
@@ -127,7 +133,20 @@ def run(self) -> None:
127133
parameters = {cfg_key: {"values": list(values)} for cfg_key, values in self.parameters.items()}
128134
metric = {"goal": self.metric_goal, "name": self.metric_name}
129135
program, args = sys.argv[0], self.filter_args(sys.argv[1:])
130-
command = ["${env}", "${interpreter}", "${program}", *args, "cfg/wandb=base", "${args_no_hyphens}"]
136+
137+
folder_path = get_hydra_output_dir()
138+
dummy_sweep_id = "sweep_started_" + Path(folder_path).parts[-2] + "_" + Path(folder_path).parts[-1]
139+
hydra_run_dir = "./outputs/sweeps/" + dummy_sweep_id + "/${now:%H-%M-%S-%f}"
140+
141+
command = [
142+
"${env}",
143+
"${interpreter}",
144+
"${program}",
145+
*args,
146+
"cfg/wandb=base",
147+
f"hydra.run.dir={hydra_run_dir}",
148+
"${args_no_hyphens}",
149+
]
131150

132151
sweep_config = {
133152
"program": program,
@@ -141,7 +160,7 @@ def run(self) -> None:
141160

142161
function = CommandFunction(["wandb", "agent"])
143162
executor = AutoExecutor(
144-
folder=get_hydra_output_dir(),
163+
folder=folder_path,
145164
cluster=self.cluster,
146165
slurm_python=self.python_command,
147166
)
@@ -150,8 +169,7 @@ def run(self) -> None:
150169
**self.slurm_params.to_submitit_params(),
151170
)
152171

153-
num_jobs = math.prod([len(v) for v in self.parameters.values()])
154-
jobs = executor.map_array(function, [sweep_id] * num_jobs)
172+
jobs = executor.map_array(function, [sweep_id] * self.num_workers)
155173

156174
for job in jobs:
157175
logger.info(f"Submitted job {job.job_id}")

0 commit comments

Comments
 (0)