Skip to content

Commit 9f8ddd1

Browse files
committed
non sleep solution, with test
1 parent 37b4ade commit 9f8ddd1

File tree

5 files changed

+59
-18
lines changed

5 files changed

+59
-18
lines changed

src/aiida/engine/processes/calcjobs/manager.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(self, authinfo: AuthInfo, transport_queue: 'TransportQueue', last_u
6161
self._job_update_requests: Dict[Hashable, asyncio.Future] = {} # Mapping: {job_id: Future}
6262
self._last_updated = last_updated
6363
self._update_handle: Optional[asyncio.TimerHandle] = None
64+
self._inspecting_jobs: List[str] = []
6465

6566
@property
6667
def logger(self) -> logging.Logger:
@@ -101,10 +102,11 @@ async def _get_jobs_from_scheduler(self) -> Dict[Hashable, 'JobInfo']:
101102
scheduler.set_transport(transport)
102103

103104
kwargs: Dict[str, Any] = {'as_dict': True}
105+
self._inspecting_jobs = self._get_jobs_with_scheduler()
104106
if scheduler.get_feature('can_query_by_user'):
105107
kwargs['user'] = '$USER'
106108
else:
107-
kwargs['jobs'] = self._get_jobs_with_scheduler()
109+
kwargs['jobs'] = self._inspecting_jobs
108110

109111
scheduler_response = scheduler.get_jobs(**kwargs)
110112

@@ -124,6 +126,7 @@ async def _update_job_info(self) -> None:
124126
This will set the futures for all pending update requests where the corresponding job has a new status compared
125127
to the last update.
126128
"""
129+
racing_requests = {}
127130
try:
128131
if not self._update_requests_outstanding():
129132
return
@@ -146,9 +149,12 @@ async def _update_job_info(self) -> None:
146149
else:
147150
for job_id, future in self._job_update_requests.items():
148151
if not future.done():
149-
future.set_result(self._jobs_cache.get(job_id, None))
152+
if str(job_id) in self._inspecting_jobs:
153+
future.set_result(self._jobs_cache.get(job_id, None))
154+
else:
155+
racing_requests[job_id] = future
150156
finally:
151-
self._job_update_requests = {}
157+
self._job_update_requests = racing_requests
152158

153159
@contextlib.contextmanager
154160
def request_job_info_update(self, authinfo: AuthInfo, job_id: Hashable) -> Iterator['asyncio.Future[JobInfo]']:

src/aiida/engine/processes/calcjobs/tasks.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,6 @@ async def do_update():
209209
try:
210210
logger.info(f'scheduled request to update CalcJob<{node.pk}>')
211211
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
212-
213-
if not node.get_last_job_info():
214-
# One can pass this value from the scheduler,
215-
# so we sleep only for slurm. And only for the first time?
216-
# scheduler = node.computer.get_scheduler()
217-
# sleeptim e= scheduler._FIRST_FETCH_SLEEP
218-
# Some Schedulers take some time to show the job status
219-
await asyncio.sleep(5)
220-
221212
job_done = await utils.exponential_backoff_retry(
222213
do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
223214
)

tests/engine/processes/calcjobs/test_calc_job.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ def test_monitor_result_parse(get_calcjob_builder, entry_points):
12251225
entry_points.add(monitor_skip_parse, group='aiida.calculations.monitors', name='core.skip_parse')
12261226

12271227
builder = get_calcjob_builder()
1228-
builder.metadata.options.sleep = 8
1228+
builder.metadata.options.sleep = 3
12291229
builder.monitors = {'monitor': orm.Dict({'entry_point': 'core.skip_parse'})}
12301230
_, node = launch.run_get_node(builder)
12311231
assert sorted(node.outputs) == ['remote_folder', 'retrieved']
@@ -1245,7 +1245,7 @@ def test_monitor_result_retrieve(get_calcjob_builder, entry_points):
12451245
entry_points.add(monitor_skip_retrieve, group='aiida.calculations.monitors', name='core.skip_retrieval')
12461246

12471247
builder = get_calcjob_builder()
1248-
builder.metadata.options.sleep = 8
1248+
builder.metadata.options.sleep = 3
12491249
builder.monitors = {'monitor': orm.Dict({'entry_point': 'core.skip_retrieval'})}
12501250
_, node = launch.run_get_node(builder)
12511251
assert 'retrieved' not in node.outputs
@@ -1265,7 +1265,7 @@ def test_monitor_result_override_exit_code(get_calcjob_builder, entry_points):
12651265
entry_points.add(monitor_override_exit_code, group='aiida.calculations.monitors', name='core.override_exit_code')
12661266

12671267
builder = get_calcjob_builder()
1268-
builder.metadata.options.sleep = 8
1268+
builder.metadata.options.sleep = 3
12691269
builder.monitors = {'monitor': orm.Dict({'entry_point': 'core.override_exit_code'})}
12701270
_, node = launch.run_get_node(builder)
12711271
assert sorted(node.outputs) == ['remote_folder', 'retrieved']
@@ -1315,7 +1315,7 @@ def test_monitor_result_action_disable_self(get_calcjob_builder, entry_points, c
13151315
entry_points.add(monitor_disable_self, group='aiida.calculations.monitors', name='core.disable_self')
13161316

13171317
builder = get_calcjob_builder()
1318-
builder.metadata.options.sleep = 6
1318+
builder.metadata.options.sleep = 1
13191319
builder.monitors = {'disable_self': orm.Dict({'entry_point': 'core.disable_self'})}
13201320
_, node = launch.run_get_node(builder)
13211321
assert node.is_finished_ok

tests/engine/processes/calcjobs/test_monitors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def test_calc_job_monitors_process_poll_interval_integrated(entry_points, aiida_
198198
builder.x = Int(1)
199199
builder.y = Int(1)
200200
builder.monitors = {'always_kill': Dict({'entry_point': 'core.emit_warning', 'minimum_poll_interval': 5})}
201-
builder.metadata = {'options': {'sleep': 6, 'resources': {'num_machines': 1}}}
201+
builder.metadata = {'options': {'sleep': 1, 'resources': {'num_machines': 1}}}
202202

203203
_, node = run_get_node(builder)
204204
assert node.is_finished_ok
@@ -218,7 +218,7 @@ def test_calc_job_monitors_outputs(entry_points, aiida_code_installed):
218218
builder.x = Int(1)
219219
builder.y = Int(1)
220220
builder.monitors = {'store_message': Dict({'entry_point': 'core.store_message', 'minimum_poll_interval': 1})}
221-
builder.metadata = {'options': {'sleep': 8, 'resources': {'num_machines': 1}}}
221+
builder.metadata = {'options': {'sleep': 3, 'resources': {'num_machines': 1}}}
222222

223223
_, node = run_get_node(builder)
224224
assert node.is_finished_ok

tests/engine/test_manager.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,47 @@ def test_last_updated(self):
7171
last_updated = time.time()
7272
jobs_list = JobsList(self.auth_info, self.transport_queue, last_updated=last_updated)
7373
assert jobs_list.last_updated == last_updated
74+
75+
def test_prevent_racing_condition(self):
76+
"""Test that the `JobsList` prevents racing condition when updating job info.
77+
78+
This test simulates a race condition where:
79+
1. Job 'job1' requests an update
80+
2. During the scheduler query, a new job 'job2' also requests an update
81+
3. JobList must only update about 'job1'
82+
4. 'job2' future should be kept pending for the next update cycle
83+
"""
84+
from unittest.mock import patch
85+
86+
from aiida.schedulers.datastructures import JobInfo, JobState
87+
88+
jobs_list = self.jobs_list
89+
90+
mock_job_info = JobInfo()
91+
mock_job_info.job_id = 'job1'
92+
mock_job_info.job_state = JobState.RUNNING
93+
94+
def mock_get_jobs(**kwargs):
95+
# Simulate the race: job2 is added to _job_update_requests while we're querying the scheduler
96+
jobs_list._job_update_requests.setdefault('job2', asyncio.Future())
97+
98+
# Return only job1 (scheduler was queried with only job1)
99+
return {'job1': mock_job_info}
100+
101+
# Request update for job1
102+
future1 = jobs_list._job_update_requests.setdefault('job1', asyncio.Future())
103+
104+
# Patch the scheduler's get_jobs
105+
scheduler = self.auth_info.computer.get_scheduler()
106+
with patch.object(scheduler.__class__, 'get_jobs', side_effect=mock_get_jobs):
107+
self.loop.run_until_complete(jobs_list._update_job_info())
108+
109+
# Verify job1 was resolved correctly
110+
assert future1.done(), 'job1 future should be resolved'
111+
assert future1.result() == mock_job_info, 'job1 should have the correct JobInfo'
112+
113+
# Verify job2 was NOT resolved and it has remained in _job_update_requests for the next cycle
114+
assert 'job2' in jobs_list._job_update_requests, 'job2 should still be in update requests'
115+
future2 = jobs_list._job_update_requests['job2']
116+
assert not future2.done(), 'job2 future should NOT be resolved yet (prevented racing bug)'
117+
assert len(jobs_list._job_update_requests) == 1, 'Only job2 should remain in update requests'

0 commit comments

Comments
 (0)