2222
2323from cloudai .core import METRIC_ERROR , Registry , Runner , TestRun
2424from cloudai .util .lazy_imports import lazy
25+ from cloudai ._core .exceptions import JobFailureError
2526
2627from .base_gym import BaseGym
2728
@@ -42,6 +43,7 @@ def __init__(self, test_run: TestRun, runner: Runner):
4243 runner (Runner): The runner object to execute jobs.
4344 """
4445 self .test_run = test_run
46+ self .original_test_run = copy .deepcopy (test_run ) # Preserve clean state for DSE
4547 self .runner = runner
4648 self .max_steps = test_run .test .test_definition .agent_steps
4749 self .reward_function = Registry ().get_reward_function (test_run .test .test_definition .agent_reward_function )
@@ -105,9 +107,21 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
105107
106108 logging .info (f"Running step { self .test_run .step } with action { action } " )
107109 new_tr = copy .deepcopy (self .test_run )
110+ new_tr .output_path = self .runner .runner .get_job_output_path (new_tr )
108111 self .runner .runner .test_scenario .test_runs = [new_tr ]
112+
113+ self .runner .runner .shutting_down = False
114+ self .runner .runner .jobs .clear ()
115+ self .runner .runner .testrun_to_job_map .clear ()
116+
109117 asyncio .run (self .runner .run ())
110- self .test_run = self .runner .runner .test_scenario .test_runs [0 ]
118+
119+ if self .runner .runner .test_scenario .test_runs and self .runner .runner .test_scenario .test_runs [0 ].output_path .exists ():
120+ self .test_run = self .runner .runner .test_scenario .test_runs [0 ]
121+ else :
122+ self .test_run = copy .deepcopy (self .original_test_run )
123+ self .test_run .step = new_tr .step
124+ self .test_run .output_path = new_tr .output_path
111125
112126 observation = self .get_observation (action )
113127 reward = self .compute_reward (observation )
0 commit comments