@@ -906,6 +906,12 @@ class ControllerConfig:
906906 heartbeat_interval : Duration = field (default_factory = lambda : Duration .from_seconds (5.0 ))
907907 """How often to send heartbeats to workers."""
908908
909+ poll_interval : Duration = field (default_factory = lambda : Duration .from_seconds (60.0 ))
910+ """How often to reconcile worker task state via PollTasks. Reconciliation runs
911+ inline at the end of each scheduling iteration so it observes a post-commit DB
912+ view, eliminating the StartTasks/PollTasks race that arose when poll ran in a
913+ separate thread (issue #5041)."""
914+
909915 max_dispatch_parallelism : int = 32
910916 """Maximum number of concurrent RPC dispatch operations."""
911917
@@ -1125,7 +1131,6 @@ def __init__(
11251131 self ._prune_thread : ManagedThread | None = None
11261132 self ._task_updater_thread : ManagedThread | None = None
11271133 self ._ping_thread : ManagedThread | None = None
1128- self ._poll_thread : ManagedThread | None = None
11291134 self ._task_update_queue : queue_mod .Queue [HeartbeatApplyRequest ] = queue_mod .Queue ()
11301135
11311136 self ._autoscaler : Autoscaler | None = autoscaler
@@ -1227,7 +1232,6 @@ def start(self) -> None:
12271232 self ._scheduling_thread = self ._threads .spawn (self ._run_scheduling_loop , name = "scheduling-loop" )
12281233 self ._ping_thread = self ._threads .spawn (self ._run_ping_loop , name = "ping-loop" )
12291234 self ._task_updater_thread = self ._threads .spawn (self ._run_task_updater_loop , name = "task-updater-loop" )
1230- self ._poll_thread = self ._threads .spawn (self ._run_poll_loop , name = "poll-loop" )
12311235 if not self ._config .dry_run :
12321236 self ._profile_thread = self ._threads .spawn (self ._run_profile_loop , name = "profile-loop" )
12331237 self ._prune_thread = self ._threads .spawn (self ._run_prune_loop , name = "prune-loop" )
@@ -1304,9 +1308,6 @@ def stop(self) -> None:
13041308 if self ._task_updater_thread :
13051309 self ._task_updater_thread .stop ()
13061310 self ._task_updater_thread .join (timeout = join_timeout )
1307- if self ._poll_thread :
1308- self ._poll_thread .stop ()
1309- self ._poll_thread .join (timeout = join_timeout )
13101311 if self ._prune_thread :
13111312 self ._prune_thread .stop ()
13121313 self ._prune_thread .join (timeout = join_timeout )
@@ -1347,13 +1348,19 @@ def _run_scheduling_loop(self, stop_event: threading.Event) -> None:
13471348 Backs off from min to max interval when idle (no pending tasks or no
13481349 assignments possible). Resets to min interval when woken by a new job
13491350 submission or when assignments are made.
1351+
1352+ Reconciliation (PollTasks) runs inline at the end of each iteration,
1353+ gated by a rate limiter. Sharing this thread with scheduling guarantees
1354+ the poll's expected_tasks snapshot is taken after the same iteration's
1355+ StartTasks commits — see issue #5041 for the race that motivated this.
13501356 """
13511357 backoff = ExponentialBackoff (
13521358 initial = self ._config .scheduler_min_interval .to_seconds (),
13531359 maximum = self ._config .scheduler_max_interval .to_seconds (),
13541360 factor = 2.0 ,
13551361 jitter = 0.1 ,
13561362 )
1363+ poll_limiter = RateLimiter (interval_seconds = self ._config .poll_interval .to_seconds ())
13571364 while not stop_event .is_set ():
13581365 interval = backoff .next_interval ()
13591366 woken = self ._wake_event .wait (timeout = interval )
@@ -1371,6 +1378,12 @@ def _run_scheduling_loop(self, stop_event: threading.Event) -> None:
13711378 if outcome == SchedulingOutcome .ASSIGNMENTS_MADE :
13721379 backoff .reset ()
13731380
1381+ if self ._config .use_split_heartbeat and poll_limiter .should_run ():
1382+ try :
1383+ self ._poll_all_workers ()
1384+ except Exception :
1385+ logger .exception ("Inline poll reconciliation failed" )
1386+
13741387 def _run_prune_loop (self , stop_event : threading .Event ) -> None :
13751388 """Background pruning loop: history cleanup every 60s, full data prune on the configured interval."""
13761389 last_full_prune = 0.0
@@ -2321,21 +2334,6 @@ def _run_ping_loop(self, stop_event: threading.Event) -> None:
23212334 except Exception :
23222335 logger .exception ("Ping loop iteration failed" )
23232336
2324- def _run_poll_loop (self , stop_event : threading .Event ) -> None :
2325- """Periodic full-state reconciliation for split heartbeat mode.
2326-
2327- Polls all workers via PollTasks every 60s and feeds results into the
2328- task-updater queue for batched application.
2329- """
2330- limiter = RateLimiter (interval_seconds = 60.0 )
2331- while not stop_event .is_set ():
2332- if not limiter .wait (cancel = stop_event ):
2333- break
2334- try :
2335- self ._poll_all_workers ()
2336- except Exception :
2337- logger .exception ("Poll loop iteration failed" )
2338-
23392337 def _poll_all_workers (self ) -> None :
23402338 """Poll all workers for task state and feed results into the updater queue."""
23412339 if self ._config .dry_run :
0 commit comments