Skip to content

Commit d3763fe

Browse files
authored
Merge pull request #725 from NVIDIA/am/bug-4744288
Fixed and issue when using dependencies could result in an infinite loop
2 parents dd70526 + 3e88391 commit d3763fe

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

src/cloudai/_core/base_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,15 +348,15 @@ async def handle_dependencies(self, completed_job: BaseJob) -> List[Task]:
348348
for tr in self.test_scenario.test_runs:
349349
if tr not in self.testrun_to_job_map:
350350
for dep_type, dep in tr.dependencies.items():
351-
if dep_type == "start_post_comp" and dep.test_run.test == completed_job.test_run.test:
351+
if dep_type == "start_post_comp" and dep.test_run == completed_job.test_run:
352352
task = await self.delayed_submit_test(tr)
353353
if task:
354354
tasks.append(task)
355355

356356
# Handling end_post_comp dependencies
357357
for test, dependent_job in self.testrun_to_job_map.items():
358358
for dep_type, dep in test.dependencies.items():
359-
if dep_type == "end_post_comp" and dep.test_run.test == completed_job.test_run.test:
359+
if dep_type == "end_post_comp" and dep.test_run == completed_job.test_run:
360360
task = await self.delayed_kill_job(dependent_job)
361361
tasks.append(task)
362362

tests/test_base_runner.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,24 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import asyncio
18+
from copy import deepcopy
1719
from pathlib import Path
1820
from typing import cast
1921

2022
import pytest
2123
from pydantic import ConfigDict
2224

23-
from cloudai._core.system import System
24-
from cloudai.core import BaseJob, BaseRunner, JobStatusResult, TestDefinition, TestRun, TestScenario
25+
from cloudai.core import (
26+
BaseJob,
27+
BaseRunner,
28+
JobStatusResult,
29+
System,
30+
TestDefinition,
31+
TestDependency,
32+
TestRun,
33+
TestScenario,
34+
)
2535
from cloudai.models.workload import CmdArgs
2636
from cloudai.systems.slurm import SlurmSystem
2737

@@ -30,13 +40,23 @@ class MyRunner(BaseRunner):
3040
def __init__(self, mode: str, system: System, test_scenario: TestScenario, output_path: Path):
3141
super().__init__(mode, system, test_scenario, output_path)
3242
self.runner_job_status_result = JobStatusResult(is_successful=True)
43+
self.submitted_trs: list[TestRun] = []
44+
self.killed_by_dependency: list[BaseJob] = []
3345

3446
def get_runner_job_status(self, job: BaseJob) -> JobStatusResult:
3547
return self.runner_job_status_result
3648

3749
def _submit_test(self, tr: TestRun) -> BaseJob:
50+
self.submitted_trs.append(tr)
3851
return BaseJob(tr, 0)
3952

53+
async def delayed_submit_test(self, tr: TestRun, delay: int = 0):
54+
await super().delayed_submit_test(tr, 0)
55+
56+
async def delayed_kill_job(self, job: BaseJob, delay: int = 0):
57+
self.killed_by_dependency.append(job)
58+
await asyncio.sleep(0)
59+
4060

4161
class MyWorkload(TestDefinition):
4262
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -93,3 +113,54 @@ def test_both_failed_runner_status_reported(self, runner: MyRunner):
93113
runner.runner_job_status_result = JobStatusResult(is_successful=False, error_message="runner job failed")
94114
res = runner.get_job_status(job)
95115
assert res == runner.runner_job_status_result
116+
117+
118+
class TestHandleDependencies:
119+
"""
120+
Tests for BaseRunner.handle_dependencies method.
121+
122+
Both main and dependent TestRuns use the same MyWorkload test definition to reproduce an issue when jobs
123+
comparison was done incorrectly.
124+
"""
125+
126+
@pytest.fixture
127+
def tr_main(self, runner: MyRunner) -> TestRun:
128+
return runner.test_scenario.test_runs[0]
129+
130+
@pytest.mark.asyncio
131+
async def test_no_dependencies(self, runner: MyRunner, tr_main: TestRun):
132+
await runner.handle_dependencies(BaseJob(tr_main, 0))
133+
assert len(runner.submitted_trs) == 0
134+
135+
@pytest.mark.asyncio
136+
async def test_start_post_comp(self, runner: MyRunner, tr_main: TestRun):
137+
tr_dep = deepcopy(tr_main)
138+
tr_dep.dependencies = {"start_post_comp": TestDependency(tr_main)}
139+
runner.test_scenario.test_runs.append(tr_dep)
140+
141+
await runner.handle_dependencies(BaseJob(tr_dep, 0)) # self, should not trigger anything
142+
assert len(runner.submitted_trs) == 0
143+
144+
await runner.handle_dependencies(BaseJob(tr_main, 0))
145+
assert len(runner.submitted_trs) == 1
146+
assert runner.submitted_trs[0] == tr_dep
147+
148+
@pytest.mark.asyncio
149+
async def test_end_post_comp(self, runner: MyRunner, tr_main: TestRun):
150+
tr_dep = deepcopy(tr_main)
151+
tr_dep.dependencies = {"end_post_comp": TestDependency(tr_main)}
152+
runner.test_scenario.test_runs.append(tr_dep)
153+
154+
# self not running, main completed -> nothing to kill
155+
await runner.handle_dependencies(BaseJob(tr_main, 0))
156+
assert len(runner.killed_by_dependency) == 0
157+
158+
# self is running, main completed -> should kill
159+
await runner.submit_test(tr_dep)
160+
161+
await runner.handle_dependencies(BaseJob(tr_dep, 0)) # self, should not kill
162+
assert len(runner.killed_by_dependency) == 0
163+
164+
await runner.handle_dependencies(BaseJob(tr_main, 0))
165+
assert len(runner.killed_by_dependency) == 1
166+
assert runner.killed_by_dependency[0].test_run == tr_dep

0 commit comments

Comments
 (0)