Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions src/aiida/engine/processes/calcjobs/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ def __init__(self, authinfo: AuthInfo, transport_queue: 'TransportQueue', last_u
self._loop = transport_queue.loop
self._logger = logging.getLogger(__name__)

self._jobs_cache: Dict[Hashable, 'JobInfo'] = {}
self._job_update_requests: Dict[Hashable, asyncio.Future] = {} # Mapping: {job_id: Future}
self._jobs_cache: Dict[str, 'JobInfo'] = {}
self._job_update_requests: Dict[str, asyncio.Future] = {} # Mapping: {job_id: Future}
Comment on lines +60 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional:

Suggested change
self._jobs_cache: Dict[str, 'JobInfo'] = {}
self._job_update_requests: Dict[str, asyncio.Future] = {} # Mapping: {job_id: Future}
self._jobs_cache: dict[str, 'JobInfo'] = {}
self._job_update_requests: dict[str, asyncio.Future] = {} # Mapping: {job_id: Future}

self._last_updated = last_updated
self._update_handle: Optional[asyncio.TimerHandle] = None
self._polling_jobs: List[str] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the list builtin

Suggested change
self._polling_jobs: List[str] = []
self._polling_jobs: list[str] = []


@property
def logger(self) -> logging.Logger:
Expand All @@ -87,7 +88,7 @@ def last_updated(self) -> Optional[float]:
"""
return self._last_updated

async def _get_jobs_from_scheduler(self) -> Dict[Hashable, 'JobInfo']:
async def _get_jobs_from_scheduler(self) -> Dict[str, 'JobInfo']:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional, since you're already touching the type:

Suggested change
async def _get_jobs_from_scheduler(self) -> Dict[str, 'JobInfo']:
async def _get_jobs_from_scheduler(self) -> dict[str, 'JobInfo']:

"""Get the current jobs list from the scheduler.

:return: a mapping of job ids to :py:class:`~aiida.schedulers.datastructures.JobInfo` instances
Expand All @@ -100,11 +101,13 @@ async def _get_jobs_from_scheduler(self) -> Dict[Hashable, 'JobInfo']:
scheduler = self._authinfo.computer.get_scheduler()
scheduler.set_transport(transport)

self._polling_jobs = [str(job_id) for job_id, _ in self._job_update_requests.items()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be slightly simpler

Suggested change
self._polling_jobs = [str(job_id) for job_id, _ in self._job_update_requests.items()]
self._polling_jobs = [str(job_id) for job_id in self._job_update_requests]


kwargs: Dict[str, Any] = {'as_dict': True}
if scheduler.get_feature('can_query_by_user'):
kwargs['user'] = '$USER'
else:
kwargs['jobs'] = self._get_jobs_with_scheduler()
kwargs['jobs'] = self._polling_jobs

scheduler_response = scheduler.get_jobs(**kwargs)

Expand All @@ -119,11 +122,14 @@ async def _get_jobs_from_scheduler(self) -> Dict[Hashable, 'JobInfo']:
return jobs_cache

async def _update_job_info(self) -> None:
"""Update all of the job information objects.
"""Update job information and resolve pending requests.

This will set the futures for all pending update requests where the corresponding job has a new status compared
to the last update.
Note, _job_update_requests is dynamic, and might get new entries while polling from scheduler.
Therefore we only update the jobs actually polled, and the new entries will be handled in the next update.
"""

try:
if not self._update_requests_outstanding():
return
Expand All @@ -141,14 +147,15 @@ async def _update_job_info(self) -> None:
# `_ensure_updating` will falsely conclude we are still updating, since the handle is not `None` and so it
# will not schedule the next update, causing the job update futures to never be resolved.
self._update_handle = None
self._job_update_requests = {}

raise
else:
for job_id, future in self._job_update_requests.items():
if not future.done():
future.set_result(self._jobs_cache.get(job_id, None))
finally:
self._job_update_requests = {}
for job_id in self._polling_jobs:
future = self._job_update_requests.pop(job_id)
if future.done():
Comment on lines +155 to +156
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Despite the discussion below, I'd still do this a bit more defensively

Suggested change
future = self._job_update_requests.pop(job_id)
if future.done():
future = self._job_update_requests.pop(job_id, None)
if future is None or future.done():

continue
future.set_result(self._jobs_cache.get(job_id, None))

@contextlib.contextmanager
def request_job_info_update(self, authinfo: AuthInfo, job_id: Hashable) -> Iterator['asyncio.Future[JobInfo]']:
Expand All @@ -161,7 +168,7 @@ def request_job_info_update(self, authinfo: AuthInfo, job_id: Hashable) -> Itera
"""
self._authinfo = authinfo
# Get or create the future
request = self._job_update_requests.setdefault(job_id, asyncio.Future())
request = self._job_update_requests.setdefault(str(job_id), asyncio.Future())
assert not request.done(), 'Expected pending job info future, found in done state.'

try:
Expand Down Expand Up @@ -235,14 +242,6 @@ def _get_next_update_delay(self) -> float:
def _update_requests_outstanding(self) -> bool:
return any(not request.done() for request in self._job_update_requests.values())

def _get_jobs_with_scheduler(self) -> List[str]:
"""Get all the jobs that are currently with scheduler.

:return: the list of jobs with the scheduler
:rtype: list
"""
return [str(job_id) for job_id, _ in self._job_update_requests.items()]


class JobManager:
"""A manager for :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` submitted to ``Computer`` instances.
Expand Down
64 changes: 64 additions & 0 deletions tests/engine/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def test_request_job_info_update(self):
with self.manager.request_job_info_update(self.auth_info, job_id=1) as request:
assert isinstance(request, asyncio.Future)

# Check if the job_id is properly converted to str in JobsList
self.manager._job_lists[self.auth_info.pk]._job_update_requests[str(1)] == request


class TestJobsList:
"""Test the `aiida.engine.processes.calcjobs.manager.JobsList` class."""
Expand Down Expand Up @@ -71,3 +74,64 @@ def test_last_updated(self):
last_updated = time.time()
jobs_list = JobsList(self.auth_info, self.transport_queue, last_updated=last_updated)
assert jobs_list.last_updated == last_updated

def test_prevent_racing_condition(self):
"""Test that the `JobsList` prevents racing condition when updating job info.

This test simulates a race condition where:
1. Job job_id_a requests an update
2. During the scheduler query, a new job job_id_b also requests an update
3. JobList must only update about job_id_a
4. job_id_b future should be kept pending for the next update cycle
5. In the next update cycle, job_id_b should be resolved correctly
"""
from unittest.mock import patch

from aiida.schedulers.datastructures import JobInfo, JobState

jobs_list = self.jobs_list

mock_job_info_a = JobInfo()
job_id_a = 'A'
mock_job_info_a.job_id = job_id_a
mock_job_info_a.job_state = JobState.RUNNING

mock_job_info_b = JobInfo()
job_id_b = 'B'
mock_job_info_b.job_id = job_id_b

def mock_get_jobs(**kwargs):
# Simulate the race: job_id_b is added to _job_update_requests while we're querying the scheduler
jobs_list._job_update_requests.setdefault(job_id_b, asyncio.Future())

# Return only job_id_a (scheduler was queried with only job_id_a)
return {job_id_a: mock_job_info_a}

# Request update for job_id_a
future1 = jobs_list._job_update_requests.setdefault(str(job_id_a), asyncio.Future())

# Patch the scheduler's get_jobs
scheduler = self.auth_info.computer.get_scheduler()
with patch.object(scheduler.__class__, 'get_jobs', side_effect=mock_get_jobs):
self.loop.run_until_complete(jobs_list._update_job_info())

# Verify job_id_a was resolved correctly
assert future1.done(), 'job_id_a future should be resolved'
assert future1.result() == mock_job_info_a, 'job_id_a should have the correct JobInfo'

# Verify job_id_b was NOT resolved and it has remained in _job_update_requests for the next cycle
assert job_id_b in jobs_list._job_update_requests, 'job_id_b should still be in update requests'
future2 = jobs_list._job_update_requests[job_id_b]
assert not future2.done(), 'job_id_b future should NOT be resolved yet (prevented racing bug)'
assert len(jobs_list._job_update_requests) == 1, 'Only job_id_b should remain in update requests'

# Verify that in the next update cycle, job_id_b is now resolved correctly
def mock_get_jobs(**kwargs):
# Intentionally return empty dict to simulate job_id_b as finished
# but not being in the scheduler anymore.
return {}

with patch.object(scheduler.__class__, 'get_jobs', side_effect=mock_get_jobs):
self.loop.run_until_complete(jobs_list._update_job_info())

assert future2.done(), 'job_id_b future should be resolved'