Skip to content

Commit d80e439

Browse files
Merge pull request #609 from NVIDIA/andrei-bug
Fix to Buggy Implemention of PR 589
2 parents 520c79f + 2389c40 commit d80e439

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/cloudai/_core/base_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ async def monitor_jobs(self) -> int:
259259
f"Job {job.id} for test {job.test_run.name} failed: {job_status_result.error_message}"
260260
)
261261
logging.error(error_message)
262+
await self.handle_job_completion(job)
262263
await self.shutdown()
263264
raise JobFailureError(job.test_run.name, error_message, job_status_result.error_message)
264265
else:

src/cloudai/configurator/cloudai_gym.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, test_run: TestRun, runner: Runner):
4242
runner (Runner): The runner object to execute jobs.
4343
"""
4444
self.test_run = test_run
45+
self.original_test_run = copy.deepcopy(test_run) # Preserve clean state for DSE
4546
self.runner = runner
4647
self.max_steps = test_run.test.test_definition.agent_steps
4748
self.reward_function = Registry().get_reward_function(test_run.test.test_definition.agent_reward_function)
@@ -105,9 +106,24 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
105106

106107
logging.info(f"Running step {self.test_run.step} with action {action}")
107108
new_tr = copy.deepcopy(self.test_run)
109+
new_tr.output_path = self.runner.runner.get_job_output_path(new_tr)
108110
self.runner.runner.test_scenario.test_runs = [new_tr]
111+
112+
self.runner.runner.shutting_down = False
113+
self.runner.runner.jobs.clear()
114+
self.runner.runner.testrun_to_job_map.clear()
115+
109116
asyncio.run(self.runner.run())
110-
self.test_run = self.runner.runner.test_scenario.test_runs[0]
117+
118+
if (
119+
self.runner.runner.test_scenario.test_runs
120+
and self.runner.runner.test_scenario.test_runs[0].output_path.exists()
121+
):
122+
self.test_run = self.runner.runner.test_scenario.test_runs[0]
123+
else:
124+
self.test_run = copy.deepcopy(self.original_test_run)
125+
self.test_run.step = new_tr.step
126+
self.test_run.output_path = new_tr.output_path
111127

112128
observation = self.get_observation(action)
113129
reward = self.compute_reward(observation)

0 commit comments

Comments
 (0)