Skip to content

Commit 2af8104

Browse files
fix to buggy implemention of PR 589
1 parent 6deb7df commit 2af8104

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/cloudai/_core/base_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ async def monitor_jobs(self) -> int:
259259
f"Job {job.id} for test {job.test_run.name} failed: {job_status_result.error_message}"
260260
)
261261
logging.error(error_message)
262+
await self.handle_job_completion(job)
262263
await self.shutdown()
263264
raise JobFailureError(job.test_run.name, error_message, job_status_result.error_message)
264265
else:

src/cloudai/configurator/cloudai_gym.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from cloudai.core import METRIC_ERROR, Registry, Runner, TestRun
2424
from cloudai.util.lazy_imports import lazy
25+
from cloudai._core.exceptions import JobFailureError
2526

2627
from .base_gym import BaseGym
2728

@@ -42,6 +43,7 @@ def __init__(self, test_run: TestRun, runner: Runner):
4243
runner (Runner): The runner object to execute jobs.
4344
"""
4445
self.test_run = test_run
46+
self.original_test_run = copy.deepcopy(test_run) # Preserve clean state for DSE
4547
self.runner = runner
4648
self.max_steps = test_run.test.test_definition.agent_steps
4749
self.reward_function = Registry().get_reward_function(test_run.test.test_definition.agent_reward_function)
@@ -105,9 +107,21 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
105107

106108
logging.info(f"Running step {self.test_run.step} with action {action}")
107109
new_tr = copy.deepcopy(self.test_run)
110+
new_tr.output_path = self.runner.runner.get_job_output_path(new_tr)
108111
self.runner.runner.test_scenario.test_runs = [new_tr]
112+
113+
self.runner.runner.shutting_down = False
114+
self.runner.runner.jobs.clear()
115+
self.runner.runner.testrun_to_job_map.clear()
116+
109117
asyncio.run(self.runner.run())
110-
self.test_run = self.runner.runner.test_scenario.test_runs[0]
118+
119+
if self.runner.runner.test_scenario.test_runs and self.runner.runner.test_scenario.test_runs[0].output_path.exists():
120+
self.test_run = self.runner.runner.test_scenario.test_runs[0]
121+
else:
122+
self.test_run = copy.deepcopy(self.original_test_run)
123+
self.test_run.step = new_tr.step
124+
self.test_run.output_path = new_tr.output_path
111125

112126
observation = self.get_observation(action)
113127
reward = self.compute_reward(observation)

0 commit comments

Comments
 (0)