Skip to content

Commit ad8a9eb

Browse files
authored
Merge pull request #671 from NVIDIA/am/single-sbatch-trajectory
Write trajectory file for DSE jobs in single-sbatch mode
2 parents 50a8ef3 + e9a295c commit ad8a9eb

File tree

4 files changed

+44
-29
lines changed

4 files changed

+44
-29
lines changed

src/cloudai/cli/handlers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace):
121121

122122
for tr in runner.runner.test_scenario.test_runs:
123123
test_run = copy.deepcopy(tr)
124-
env = CloudAIGymEnv(test_run=test_run, runner=runner)
124+
env = CloudAIGymEnv(test_run=test_run, runner=runner.runner)
125125
agent_type = test_run.test.test_definition.agent
126126

127127
agent_class = registry.agents_map.get(agent_type)

src/cloudai/configurator/cloudai_gym.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import logging
2121
from typing import Any, Dict, Optional, Tuple
2222

23-
from cloudai.core import METRIC_ERROR, Registry, Runner, TestRun
23+
from cloudai.core import METRIC_ERROR, BaseRunner, Registry, TestRun
2424
from cloudai.util.lazy_imports import lazy
2525

2626
from .base_gym import BaseGym
@@ -33,13 +33,13 @@ class CloudAIGymEnv(BaseGym):
3333
Uses the TestRun object and actual runner methods to execute jobs.
3434
"""
3535

36-
def __init__(self, test_run: TestRun, runner: Runner):
36+
def __init__(self, test_run: TestRun, runner: BaseRunner):
3737
"""
3838
Initialize the Gym environment using the TestRun object.
3939
4040
Args:
4141
test_run (TestRun): A test run object that encapsulates cmd_args, extra_cmd_args, etc.
42-
runner (Runner): The runner object to execute jobs.
42+
runner (BaseRunner): The runner object to execute jobs.
4343
"""
4444
self.test_run = test_run
4545
self.original_test_run = copy.deepcopy(test_run) # Preserve clean state for DSE
@@ -106,20 +106,20 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
106106

107107
logging.info(f"Running step {self.test_run.step} with action {action}")
108108
new_tr = copy.deepcopy(self.test_run)
109-
new_tr.output_path = self.runner.runner.get_job_output_path(new_tr)
110-
self.runner.runner.test_scenario.test_runs = [new_tr]
109+
new_tr.output_path = self.runner.get_job_output_path(new_tr)
110+
self.runner.test_scenario.test_runs = [new_tr]
111111

112-
self.runner.runner.shutting_down = False
113-
self.runner.runner.jobs.clear()
114-
self.runner.runner.testrun_to_job_map.clear()
112+
self.runner.shutting_down = False
113+
self.runner.jobs.clear()
114+
self.runner.testrun_to_job_map.clear()
115115

116-
asyncio.run(self.runner.run())
116+
try:
117+
asyncio.run(self.runner.run())
118+
except Exception as e:
119+
logging.error(f"Error running step {self.test_run.step}: {e}")
117120

118-
if (
119-
self.runner.runner.test_scenario.test_runs
120-
and self.runner.runner.test_scenario.test_runs[0].output_path.exists()
121-
):
122-
self.test_run = self.runner.runner.test_scenario.test_runs[0]
121+
if self.runner.test_scenario.test_runs and self.runner.test_scenario.test_runs[0].output_path.exists():
122+
self.test_run = self.runner.test_scenario.test_runs[0]
123123
else:
124124
self.test_run = copy.deepcopy(self.original_test_run)
125125
self.test_run.step = new_tr.step
@@ -179,7 +179,7 @@ def get_observation(self, action: Any) -> list:
179179

180180
observation = []
181181
for metric in all_metrics:
182-
v = self.test_run.get_metric_value(self.runner.runner.system, metric)
182+
v = self.test_run.get_metric_value(self.runner.system, metric)
183183
if v == METRIC_ERROR:
184184
v = -1.0
185185
observation.append(v)
@@ -196,10 +196,7 @@ def write_trajectory(self, step: int, action: Any, reward: float, observation: l
196196
observation (list): The observation after taking the action.
197197
"""
198198
trajectory_file_path = (
199-
self.runner.runner.scenario_root
200-
/ self.test_run.name
201-
/ f"{self.test_run.current_iteration}"
202-
/ "trajectory.csv"
199+
self.runner.scenario_root / self.test_run.name / f"{self.test_run.current_iteration}" / "trajectory.csv"
203200
)
204201

205202
file_exists = trajectory_file_path.exists()

src/cloudai/systems/slurm/single_sbatch_runner.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
from pathlib import Path
2222
from typing import Generator, Optional, cast
2323

24+
from cloudai.configurator.cloudai_gym import CloudAIGymEnv
2425
from cloudai.core import JobIdRetrievalError, System, TestRun, TestScenario
25-
from cloudai.systems.slurm.slurm_metadata import SlurmJobMetadata, SlurmStepMetadata
2626
from cloudai.util import CommandShell, format_time_limit, parse_time_limit
2727

2828
from .slurm_command_gen_strategy import SlurmCommandGenStrategy
29+
from .slurm_metadata import SlurmJobMetadata, SlurmStepMetadata
2930
from .slurm_runner import SlurmJob, SlurmRunner
3031
from .slurm_system import SlurmSystem
3132

@@ -194,8 +195,25 @@ async def run(self):
194195
is_completed = True if self.mode == "dry-run" else self.system.is_job_completed(job)
195196
await asyncio.sleep(self.system.monitor_interval)
196197

198+
self.handle_dse()
199+
197200
self.on_job_completion(job)
198201

202+
def handle_dse(self):
203+
for tr in self.test_scenario.test_runs:
204+
if not tr.is_dse_job:
205+
continue
206+
207+
for idx, combination in enumerate(tr.all_combinations):
208+
next_tr = tr.apply_params_set(combination)
209+
next_tr.step = idx + 1
210+
next_tr.output_path = self.get_job_output_path(next_tr)
211+
212+
gym = CloudAIGymEnv(next_tr, self)
213+
observation = gym.get_observation({})
214+
reward = gym.compute_reward(observation)
215+
gym.write_trajectory(idx, combination, reward, observation)
216+
199217
def _submit_test(self, tr: TestRun) -> SlurmJob:
200218
with open(self.scenario_root / "cloudai_sbatch_script.sh", "w") as f:
201219
f.write(self.gen_sbatch_content())

tests/test_cloudaigym.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytest
2222

2323
from cloudai.configurator import CloudAIGymEnv, GridSearchAgent
24-
from cloudai.core import Runner, Test, TestRun, TestScenario, TestTemplateStrategy
24+
from cloudai.core import BaseRunner, Runner, Test, TestRun, TestScenario, TestTemplateStrategy
2525
from cloudai.systems.slurm import SlurmSystem
2626
from cloudai.workloads.nemo_run import (
2727
Data,
@@ -45,7 +45,7 @@ def nemorun() -> NeMoRunTestDefinition:
4545

4646

4747
@pytest.fixture
48-
def setup_env(slurm_system: SlurmSystem, nemorun: NeMoRunTestDefinition) -> tuple[TestRun, Runner]:
48+
def setup_env(slurm_system: SlurmSystem, nemorun: NeMoRunTestDefinition) -> tuple[TestRun, BaseRunner]:
4949
tdef = nemorun.model_copy(deep=True)
5050
tdef.cmd_args.trainer = Trainer(
5151
max_steps=[1000, 2000],
@@ -81,10 +81,10 @@ def setup_env(slurm_system: SlurmSystem, nemorun: NeMoRunTestDefinition) -> tupl
8181

8282
runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario)
8383

84-
return test_run, runner
84+
return test_run, runner.runner
8585

8686

87-
def test_observation_space(setup_env):
87+
def test_observation_space(setup_env: tuple[TestRun, BaseRunner]):
8888
test_run, runner = setup_env
8989
env = CloudAIGymEnv(test_run=test_run, runner=runner)
9090
observation_space = env.define_observation_space()
@@ -147,7 +147,7 @@ def test_compute_reward_invalid():
147147
assert "Available functions: ['inverse', 'negative', 'identity']" in str(exc_info.value)
148148

149149

150-
def test_tr_output_path(setup_env: tuple[TestRun, Runner]):
150+
def test_tr_output_path(setup_env: tuple[TestRun, BaseRunner]):
151151
test_run, runner = setup_env
152152
test_run.test.test_definition.cmd_args.data.global_batch_size = 8 # avoid constraint check failure
153153
env = CloudAIGymEnv(test_run=test_run, runner=runner)
@@ -160,7 +160,7 @@ def test_tr_output_path(setup_env: tuple[TestRun, Runner]):
160160
assert env.test_run.output_path.name == "42"
161161

162162

163-
def test_action_space(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestRun, Runner]):
163+
def test_action_space(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestRun, BaseRunner]):
164164
tr, _ = setup_env
165165
nemorun.cmd_args.trainer = Trainer(
166166
max_steps=[1000, 2000], strategy=TrainerStrategy(tensor_model_parallel_size=[1, 2])
@@ -185,7 +185,7 @@ def test_action_space(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestRun,
185185

186186

187187
@pytest.mark.parametrize("num_nodes", (1, [1, 2], [3]))
188-
def test_all_combinations(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestRun, Runner], num_nodes: int):
188+
def test_all_combinations(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestRun, BaseRunner], num_nodes: int):
189189
tr, _ = setup_env
190190
nemorun.cmd_args.trainer = Trainer(max_steps=[1000], strategy=TrainerStrategy(tensor_model_parallel_size=[1, 2]))
191191
nemorun.extra_env_vars["DSE_VAR"] = ["1", "2", "3"]
@@ -224,7 +224,7 @@ def test_all_combinations(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestR
224224
assert expected in real_combinations, f"Expected {expected} in all_combinations"
225225

226226

227-
def test_all_combinations_non_dse(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestRun, Runner]):
227+
def test_all_combinations_non_dse(nemorun: NeMoRunTestDefinition, setup_env: tuple[TestRun, BaseRunner]):
228228
tr, _ = setup_env
229229
tr.test.test_definition = nemorun
230230
assert len(tr.all_combinations) == 0

0 commit comments

Comments
 (0)