Skip to content

Commit 3075fac

Browse files
Merge pull request #438 from NVIDIA/am/multi-dse
Allow multiple DSE cases in a scenario
2 parents cf0d7dc + 47222ef commit 3075fac

File tree

5 files changed

+56
-33
lines changed

5 files changed

+56
-33
lines changed

conf/common/test_scenario/ucc_test.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2-
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,11 +19,13 @@ name = "ucc_test"
1919
[[Tests]]
2020
id = "Tests.1"
2121
test_name = "ucc_test_alltoall"
22+
time_limit = "00:20:00"
2223
num_nodes = "2"
2324

2425
[[Tests]]
2526
id = "Tests.2"
2627
test_name = "ucc_test_alltoall"
28+
time_limit = "00:20:00"
2729
num_nodes = "2"
2830
[[Tests.dependencies]]
2931
type = "start_post_comp"
@@ -32,6 +34,7 @@ num_nodes = "2"
3234
[[Tests]]
3335
id = "Tests.3"
3436
test_name = "ucc_test_alltoall"
37+
time_limit = "00:20:00"
3538
num_nodes = "2"
3639
[[Tests.dependencies]]
3740
type = "start_post_comp"
@@ -40,6 +43,7 @@ num_nodes = "2"
4043
[[Tests]]
4144
id = "Tests.4"
4245
test_name = "ucc_test_alltoall"
46+
time_limit = "00:20:00"
4347
num_nodes = "2"
4448
[[Tests.dependencies]]
4549
type = "start_post_comp"
@@ -48,6 +52,7 @@ num_nodes = "2"
4852
[[Tests]]
4953
id = "Tests.5"
5054
test_name = "ucc_test_alltoall"
55+
time_limit = "00:20:00"
5156
num_nodes = "2"
5257
[[Tests.dependencies]]
5358
type = "start_post_comp"

src/cloudai/_core/base_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def get_job_output_path(self, tr: TestRun) -> Path:
199199
self.scenario_root.mkdir()
200200

201201
job_output_path = self.scenario_root / tr.name / str(tr.current_iteration)
202+
# here it is required to check DSE as step number because test_definition object is not a DSE object anymore
202203
if tr.step > 0:
203204
job_output_path = job_output_path / str(tr.step)
204205

@@ -272,14 +273,17 @@ async def handle_job_completion(self, completed_job: BaseJob):
272273
Args:
273274
completed_job (BaseJob): The job that has just been completed.
274275
"""
275-
logging.info(f"Job completed: {completed_job.test_run.name}")
276+
logging.info(
277+
f"Job completed: {completed_job.test_run.name} "
278+
f"(iteration {completed_job.test_run.current_iteration+1} of {completed_job.test_run.iterations})"
279+
)
276280

277281
self.jobs.remove(completed_job)
278282
del self.testrun_to_job_map[completed_job.test_run]
279283

280284
if completed_job.test_run.step <= 0:
281-
completed_job.test_run.current_iteration += 1
282285
if not completed_job.terminated_by_dependency and completed_job.test_run.has_more_iterations():
286+
completed_job.test_run.current_iteration += 1
283287
msg = f"Re-running job for iteration {completed_job.test_run.current_iteration}"
284288
logging.info(msg)
285289
await self.submit_test(completed_job.test_run)

src/cloudai/_core/configurator/cloudai_gym.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import asyncio
18+
import copy
1819
import csv
1920
import logging
2021
from typing import Any, Dict, Optional, Tuple
@@ -43,7 +44,6 @@ def __init__(self, test_run: TestRun, runner: Runner):
4344
"""
4445
self.test_run = test_run
4546
self.runner = runner
46-
self.test_scenario = runner.runner.test_scenario
4747
self.max_steps = test_run.test.test_definition.agent_steps
4848
super().__init__()
4949

@@ -134,15 +134,17 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
134134
if not self.test_run.test.test_definition.constraint_check:
135135
logging.info("Constraint check failed. Skipping step.")
136136
return [-1.0], -1.0, True, {}
137-
logging.info(f"Running step {self.test_run.current_iteration} with action {action}")
137+
138+
logging.info(f"Running step {self.test_run.step} with action {action}")
139+
self.runner.runner.test_scenario.test_runs = [copy.deepcopy(self.test_run)]
138140
asyncio.run(self.runner.run())
139141

140142
observation = self.get_observation(action)
141143
reward = self.compute_reward(observation)
142144
done = False
143145
info = {}
144146

145-
self.write_trajectory(self.test_run.current_iteration, action, reward, observation)
147+
self.write_trajectory(self.test_run.step, action, reward, observation)
146148

147149
return observation, reward, done, info
148150

@@ -222,11 +224,15 @@ def write_trajectory(self, step: int, action: Any, reward: float, observation: l
222224
reward (float): The reward received for the action.
223225
observation (list): The observation after taking the action.
224226
"""
225-
output_path = self.runner.runner.scenario_root
226-
subdir = next(output_path.iterdir())
227-
trajectory_file_path = subdir / f"{self.test_run.current_iteration}" / "trajectory.csv"
227+
trajectory_file_path = (
228+
self.runner.runner.scenario_root
229+
/ self.test_run.name
230+
/ f"{self.test_run.current_iteration}"
231+
/ "trajectory.csv"
232+
)
228233

229234
file_exists = trajectory_file_path.exists()
235+
logging.debug(f"Writing trajectory into {trajectory_file_path} (exists: {file_exists})")
230236

231237
with open(trajectory_file_path, mode="a", newline="") as file:
232238
writer = csv.writer(file)

src/cloudai/_core/test_scenario.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def has_more_iterations(self) -> bool:
7979
Returns
8080
bool: True if more iterations are pending, False otherwise.
8181
"""
82-
return self.current_iteration < self.iterations
82+
return self.current_iteration + 1 < self.iterations
8383

8484
@property
8585
def metric_reporter(self) -> Optional[Type["ReportGenerationStrategy"]]:

src/cloudai/cli/handlers.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import argparse
1818
import asyncio
19+
import copy
1920
import logging
2021
import signal
2122
from pathlib import Path
@@ -86,31 +87,32 @@ def handle_install_and_uninstall(args: argparse.Namespace) -> int:
8687

8788

8889
def handle_dse_job(runner: Runner, args: argparse.Namespace):
89-
test_run = next(iter(runner.runner.test_scenario.test_runs))
90-
env = CloudAIGymEnv(test_run=test_run, runner=runner)
9190
registry = Registry()
9291

93-
agent_type = test_run.test.test_definition.agent
94-
95-
agent_class = registry.agents_map.get(agent_type)
96-
if agent_class is None:
97-
logging.error(
98-
f"No agent available for type: {agent_type}. Please make sure {agent_type} "
99-
f"is a valid agent type. Available agents: {registry.agents_map.keys()}"
100-
)
101-
exit(1)
92+
for tr in runner.runner.test_scenario.test_runs:
93+
test_run = copy.deepcopy(tr)
94+
env = CloudAIGymEnv(test_run=test_run, runner=runner)
95+
agent_type = test_run.test.test_definition.agent
96+
97+
agent_class = registry.agents_map.get(agent_type)
98+
if agent_class is None:
99+
logging.error(
100+
f"No agent available for type: {agent_type}. Please make sure {agent_type} "
101+
f"is a valid agent type. Available agents: {registry.agents_map.keys()}"
102+
)
103+
continue
102104

103-
agent = agent_class(env)
104-
for step in range(agent.max_steps):
105-
result = agent.select_action()
106-
if result is None:
107-
break
108-
step, action = result
109-
test_run.step = step
110-
observation, reward, done, info = env.step(action)
111-
feedback = {"trial_index": step, "value": reward}
112-
agent.update_policy(feedback)
113-
logging.info(f"Step {step}: Observation: {observation}, Reward: {reward}")
105+
agent = agent_class(env)
106+
for step in range(agent.max_steps):
107+
result = agent.select_action()
108+
if result is None:
109+
break
110+
step, action = result
111+
test_run.step = step
112+
observation, reward, done, info = env.step(action)
113+
feedback = {"trial_index": step, "value": reward}
114+
agent.update_policy(feedback)
115+
logging.info(f"Step {step}: Observation: {observation}, Reward: {reward}")
114116

115117

116118
def handle_non_dse_job(runner: Runner, args: argparse.Namespace) -> None:
@@ -187,8 +189,14 @@ def handle_dry_run_and_run(args: argparse.Namespace) -> int:
187189
runner = Runner(args.mode, system, test_scenario)
188190
register_signal_handlers(runner.cancel_on_signal)
189191

192+
all_dse = all(tr.test.test_definition.is_dse_job for tr in test_scenario.test_runs)
193+
190194
if any(tr.test.test_definition.is_dse_job for tr in test_scenario.test_runs):
191-
handle_dse_job(runner, args)
195+
if all_dse:
196+
handle_dse_job(runner, args)
197+
else:
198+
logging.error("Mixing DSE and non-DSE jobs is not allowed.")
199+
return 1
192200
else:
193201
handle_non_dse_job(runner, args)
194202

0 commit comments

Comments
 (0)