1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17+ import asyncio
18+ from copy import deepcopy
1719from pathlib import Path
1820from typing import cast
1921
2022import pytest
2123from pydantic import ConfigDict
2224
23- from cloudai ._core .system import System
24- from cloudai .core import BaseJob , BaseRunner , JobStatusResult , TestDefinition , TestRun , TestScenario
25+ from cloudai .core import (
26+ BaseJob ,
27+ BaseRunner ,
28+ JobStatusResult ,
29+ System ,
30+ TestDefinition ,
31+ TestDependency ,
32+ TestRun ,
33+ TestScenario ,
34+ )
2535from cloudai .models .workload import CmdArgs
2636from cloudai .systems .slurm import SlurmSystem
2737
@@ -30,13 +40,23 @@ class MyRunner(BaseRunner):
3040 def __init__ (self , mode : str , system : System , test_scenario : TestScenario , output_path : Path ):
3141 super ().__init__ (mode , system , test_scenario , output_path )
3242 self .runner_job_status_result = JobStatusResult (is_successful = True )
43+ self .submitted_trs : list [TestRun ] = []
44+ self .killed_by_dependency : list [BaseJob ] = []
3345
3446 def get_runner_job_status (self , job : BaseJob ) -> JobStatusResult :
3547 return self .runner_job_status_result
3648
3749 def _submit_test (self , tr : TestRun ) -> BaseJob :
50+ self .submitted_trs .append (tr )
3851 return BaseJob (tr , 0 )
3952
53+ async def delayed_submit_test (self , tr : TestRun , delay : int = 0 ):
54+ await super ().delayed_submit_test (tr , 0 )
55+
56+ async def delayed_kill_job (self , job : BaseJob , delay : int = 0 ):
57+ self .killed_by_dependency .append (job )
58+ await asyncio .sleep (0 )
59+
4060
4161class MyWorkload (TestDefinition ):
4262 model_config = ConfigDict (arbitrary_types_allowed = True )
@@ -93,3 +113,54 @@ def test_both_failed_runner_status_reported(self, runner: MyRunner):
93113 runner .runner_job_status_result = JobStatusResult (is_successful = False , error_message = "runner job failed" )
94114 res = runner .get_job_status (job )
95115 assert res == runner .runner_job_status_result
116+
117+
118+ class TestHandleDependencies :
119+ """
120+ Tests for BaseRunner.handle_dependencies method.
121+
122+ Both main and dependent TestRuns use the same MyWorkload test definition to reproduce an issue when jobs
123+ comparison was done incorrectly.
124+ """
125+
126+ @pytest .fixture
127+ def tr_main (self , runner : MyRunner ) -> TestRun :
128+ return runner .test_scenario .test_runs [0 ]
129+
130+ @pytest .mark .asyncio
131+ async def test_no_dependencies (self , runner : MyRunner , tr_main : TestRun ):
132+ await runner .handle_dependencies (BaseJob (tr_main , 0 ))
133+ assert len (runner .submitted_trs ) == 0
134+
135+ @pytest .mark .asyncio
136+ async def test_start_post_comp (self , runner : MyRunner , tr_main : TestRun ):
137+ tr_dep = deepcopy (tr_main )
138+ tr_dep .dependencies = {"start_post_comp" : TestDependency (tr_main )}
139+ runner .test_scenario .test_runs .append (tr_dep )
140+
141+ await runner .handle_dependencies (BaseJob (tr_dep , 0 )) # self, should not trigger anything
142+ assert len (runner .submitted_trs ) == 0
143+
144+ await runner .handle_dependencies (BaseJob (tr_main , 0 ))
145+ assert len (runner .submitted_trs ) == 1
146+ assert runner .submitted_trs [0 ] == tr_dep
147+
148+ @pytest .mark .asyncio
149+ async def test_end_post_comp (self , runner : MyRunner , tr_main : TestRun ):
150+ tr_dep = deepcopy (tr_main )
151+ tr_dep .dependencies = {"end_post_comp" : TestDependency (tr_main )}
152+ runner .test_scenario .test_runs .append (tr_dep )
153+
154+ # self not running, main completed -> nothing to kill
155+ await runner .handle_dependencies (BaseJob (tr_main , 0 ))
156+ assert len (runner .killed_by_dependency ) == 0
157+
158+ # self is running, main completed -> should kill
159+ await runner .submit_test (tr_dep )
160+
161+ await runner .handle_dependencies (BaseJob (tr_dep , 0 )) # self, should not kill
162+ assert len (runner .killed_by_dependency ) == 0
163+
164+ await runner .handle_dependencies (BaseJob (tr_main , 0 ))
165+ assert len (runner .killed_by_dependency ) == 1
166+ assert runner .killed_by_dependency [0 ].test_run == tr_dep
0 commit comments