|
99 | 99 | DIRECT_PROVIDER_PROMOTION_RATE, |
100 | 100 | HeartbeatApplyRequest, |
101 | 101 | ReservationClaim, |
| 102 | + RunningTaskEntry, |
102 | 103 | SchedulingEvent, |
103 | 104 | TaskUpdate, |
104 | 105 | ) |
@@ -1024,6 +1025,13 @@ class Controller: |
1024 | 1025 | the controller will run it in a background thread. |
1025 | 1026 | """ |
1026 | 1027 |
|
| 1028 | + # Grace window during which a freshly-dispatched task (via StartTasks) is |
| 1029 | + # treated as expected when polling its worker. Bridges the StartTasks → |
| 1030 | + # PollTasks race: a poll's DB snapshot taken before the assignment commit |
| 1031 | + # would otherwise omit the task, and the worker would kill it as |
| 1032 | + # unexpected. 30s comfortably exceeds any normal StartTasks RPC latency. |
| 1033 | + _RECENT_DISPATCH_GRACE_SECONDS = 30.0 |
| 1034 | + |
1027 | 1035 | def __init__( |
1028 | 1036 | self, |
1029 | 1037 | config: ControllerConfig, |
@@ -1130,6 +1138,14 @@ def __init__( |
1130 | 1138 | self._poll_thread: ManagedThread | None = None |
1131 | 1139 | self._task_update_queue: queue_mod.Queue[HeartbeatApplyRequest] = queue_mod.Queue() |
1132 | 1140 |
|
| 1141 | + # Track tasks dispatched via StartTasks but whose assignment commit |
| 1142 | + # may not yet be visible to a concurrent PollTasks snapshot. Merged |
| 1143 | + # into expected_tasks in _poll_all_workers so the worker does not |
| 1144 | + # treat a freshly-submitted task as "unexpected" and kill it. See |
| 1145 | + # _RECENT_DISPATCH_GRACE_SECONDS and _dispatch_assignments_direct. |
| 1146 | + self._recent_dispatches: dict[WorkerId, dict[tuple[str, int], float]] = {} |
| 1147 | + self._recent_dispatches_lock = threading.Lock() |
| 1148 | + |
1133 | 1149 | self._autoscaler: Autoscaler | None = autoscaler |
1134 | 1150 |
|
1135 | 1151 | self._heartbeat_iteration = 0 |
@@ -2168,6 +2184,14 @@ def _dispatch_assignments_direct( |
2168 | 2184 | command = [Assignment(task_id=task_id, worker_id=worker_id) for task_id, worker_id in assignments] |
2169 | 2185 | result = self._transitions.queue_assignments(command, direct_dispatch=True) |
2170 | 2186 |
|
| 2187 | + # Register dispatches before issuing StartTasks so a concurrent poll |
| 2188 | + # whose DB snapshot predates the assignment commit still sees these |
| 2189 | + # tasks as expected. Any poll that observes the entries won't kill |
| 2190 | + # the task; any poll that misses them can't yet collide with the |
| 2191 | + # StartTasks RPC (not sent until after this point) on the same |
| 2192 | + # gRPC channel. |
| 2193 | + self._record_recent_dispatches(result.start_requests) |
| 2194 | + |
2171 | 2195 | # Group StartTasks payloads by (worker_id, address) |
2172 | 2196 | by_worker: dict[tuple[WorkerId, str], list[job_pb2.RunTaskRequest]] = {} |
2173 | 2197 | for worker_id, address, run_request in result.start_requests: |
@@ -2318,24 +2342,88 @@ def _run_poll_loop(self, stop_event: threading.Event) -> None: |
2318 | 2342 | except Exception: |
2319 | 2343 | logger.exception("Poll loop iteration failed") |
2320 | 2344 |
|
| 2345 | + def _record_recent_dispatches( |
| 2346 | + self, |
| 2347 | + start_requests: list[tuple[WorkerId, str, job_pb2.RunTaskRequest]], |
| 2348 | + ) -> None: |
| 2349 | + """Record (worker, task, attempt) dispatches with a monotonic timestamp.""" |
| 2350 | + if not start_requests: |
| 2351 | + return |
| 2352 | + now = time.monotonic() |
| 2353 | + with self._recent_dispatches_lock: |
| 2354 | + for worker_id, _address, run_request in start_requests: |
| 2355 | + self._recent_dispatches.setdefault(worker_id, {})[(run_request.task_id, run_request.attempt_id)] = now |
| 2356 | + |
| 2357 | + def _prune_and_snapshot_recent_dispatches(self) -> dict[WorkerId, set[tuple[str, int]]]: |
| 2358 | + """Drop stale entries and return a per-worker snapshot of keys within the grace window.""" |
| 2359 | + cutoff = time.monotonic() - self._RECENT_DISPATCH_GRACE_SECONDS |
| 2360 | + snapshot: dict[WorkerId, set[tuple[str, int]]] = {} |
| 2361 | + with self._recent_dispatches_lock: |
| 2362 | + for wid in list(self._recent_dispatches.keys()): |
| 2363 | + entries = self._recent_dispatches[wid] |
| 2364 | + for key in [k for k, ts in entries.items() if ts < cutoff]: |
| 2365 | + del entries[key] |
| 2366 | + if not entries: |
| 2367 | + del self._recent_dispatches[wid] |
| 2368 | + else: |
| 2369 | + snapshot[wid] = set(entries.keys()) |
| 2370 | + return snapshot |
| 2371 | + |
2321 | 2372 | def _poll_all_workers(self) -> None: |
2322 | | - """Poll all workers for task state and feed results into the updater queue.""" |
| 2373 | + """Poll all workers for task state and feed results into the updater queue. |
| 2374 | +
|
| 2375 | + Merges recently-dispatched (worker, task, attempt) keys into each |
| 2376 | + worker's expected_tasks so a stale DB snapshot can't cause the worker |
| 2377 | + to kill a just-submitted task. Suppresses "Task not found on worker" |
| 2378 | + updates for entries still within the grace window: the StartTasks RPC |
| 2379 | + may not yet have landed when PollTasks arrived. |
| 2380 | + """ |
2323 | 2381 | if self._config.dry_run: |
2324 | 2382 | return |
2325 | 2383 | running, addresses = self._transitions.get_running_tasks_for_poll() |
| 2384 | + recent = self._prune_and_snapshot_recent_dispatches() |
| 2385 | + |
| 2386 | + # Merge recent dispatches into the per-worker expected list. Ensures |
| 2387 | + # a worker with only freshly-dispatched tasks (not yet visible in the |
| 2388 | + # DB snapshot) is still polled with the correct expected set. |
| 2389 | + for wid, keys in recent.items(): |
| 2390 | + if wid not in addresses: |
| 2391 | + continue |
| 2392 | + entries = running.setdefault(wid, []) |
| 2393 | + existing = {(e.task_id.to_wire(), e.attempt_id) for e in entries} |
| 2394 | + for task_wire, attempt_id in keys: |
| 2395 | + if (task_wire, attempt_id) in existing: |
| 2396 | + continue |
| 2397 | + entries.append( |
| 2398 | + RunningTaskEntry( |
| 2399 | + task_id=JobName.from_wire(task_wire), |
| 2400 | + attempt_id=attempt_id, |
| 2401 | + ) |
| 2402 | + ) |
| 2403 | + |
2326 | 2404 | if not running: |
2327 | 2405 | return |
2328 | 2406 | poll_results = self._provider.poll_workers(running, addresses) |
2329 | 2407 | for worker_id, updates, error in poll_results: |
2330 | 2408 | if error is not None: |
2331 | 2409 | logger.warning("PollTasks failed for worker %s: %s", worker_id, error) |
2332 | 2410 | continue |
2333 | | - if updates: |
| 2411 | + if not updates: |
| 2412 | + continue |
| 2413 | + # Drop every update for a task still inside the dispatch grace |
| 2414 | + # window. The worker's view of a freshly-submitted task is |
| 2415 | + # unreliable during this race (either it already has it and would |
| 2416 | + # report running, or it doesn't yet and would report "not found" |
| 2417 | + # as WORKER_FAILED). Let the next poll settle this once the grace |
| 2418 | + # window closes. |
| 2419 | + grace_keys = recent.get(worker_id, set()) |
| 2420 | + filtered = [u for u in updates if (u.task_id.to_wire(), u.attempt_id) not in grace_keys] |
| 2421 | + if filtered: |
2334 | 2422 | self._task_update_queue.put( |
2335 | 2423 | HeartbeatApplyRequest( |
2336 | 2424 | worker_id=worker_id, |
2337 | 2425 | worker_resource_snapshot=None, |
2338 | | - updates=updates, |
| 2426 | + updates=filtered, |
2339 | 2427 | ) |
2340 | 2428 | ) |
2341 | 2429 |
|
|
0 commit comments