114114 TaskUpdate ,
115115 log_event ,
116116)
117- from iris .cluster .controller .worker_health import WorkerHealthTracker
117+ from iris .cluster .controller .worker_health import WorkerCommitTracker , WorkerHealthTracker
118118from iris .cluster .log_store_helpers import CONTROLLER_LOG_KEY
119119from iris .cluster .providers .k8s .tasks import K8sTaskProvider
120120from iris .cluster .providers .types import find_free_port , resolve_external_host
@@ -881,6 +881,8 @@ def _reservation_region_constraints(
881881 job_id_wire : str ,
882882 claims : dict [WorkerId , ReservationClaim ],
883883 queries : ControllerDB ,
884+ health : WorkerHealthTracker ,
885+ committed : WorkerCommitTracker ,
884886 existing_constraints : list [Constraint ],
885887) -> list [Constraint ]:
886888 """Derive region constraints from claimed reservation workers.
@@ -897,7 +899,7 @@ def _reservation_region_constraints(
897899 claimed_worker_ids = {worker_id for worker_id , claim in claims .items () if claim .job_id == job_id_wire }
898900 workers_by_id = {
899901 worker .worker_id : worker
900- for worker in healthy_active_workers_with_attributes (queries )
902+ for worker in healthy_active_workers_with_attributes (queries , health , committed )
901903 if worker .worker_id in claimed_worker_ids
902904 }
903905 regions : set [str ] = set ()
@@ -1153,7 +1155,8 @@ def __init__(
11531155 self ._db = db
11541156 else :
11551157 self ._db = ControllerDB (db_dir = config .local_state_dir / "db" )
1156- self ._store = ControllerStore (self ._db )
1158+ self ._health = WorkerHealthTracker ()
1159+ self ._store = ControllerStore (self ._db , health = self ._health )
11571160
11581161 # ThreadContainer must be initialized before the log service setup
11591162 # because _start_local_log_server spawns a uvicorn thread.
@@ -1194,7 +1197,6 @@ def __init__(
11941197 self ._log_handler .setFormatter (logging .Formatter ("%(asctime)s %(name)s %(message)s" ))
11951198 logging .getLogger ("iris" ).addHandler (self ._log_handler )
11961199
1197- self ._health = WorkerHealthTracker ()
11981200 self ._transitions = ControllerTransitions (
11991201 store = self ._store ,
12001202 health = self ._health ,
@@ -1630,7 +1632,7 @@ def _profile_all_running_tasks(self) -> None:
16301632 Memory profiling via memray is currently disabled because memray attach
16311633 has been triggering segfaults in target processes.
16321634 """
1633- workers = healthy_active_workers_with_attributes (self ._db )
1635+ workers = healthy_active_workers_with_attributes (self ._db , self . _health , self . _store . committed )
16341636 if not workers :
16351637 return
16361638 workers_by_id = {w .worker_id : w for w in workers }
@@ -1742,11 +1744,7 @@ def _cleanup_stale_claims(self, claims: dict[WorkerId, ReservationClaim] | None
17421744 if claims is None :
17431745 claims = _read_reservation_claims (self ._db )
17441746 persisted = True
1745- with self ._db .read_snapshot () as snapshot :
1746- active_worker_ids = {
1747- WorkerId (str (row [0 ]))
1748- for row in snapshot .fetchall ("SELECT w.worker_id FROM workers w WHERE w.active = 1" )
1749- }
1747+ active_worker_ids = {wid for wid , l in self ._health .all ().items () if l .active }
17501748 claimed_job_ids = {JobName .from_wire (claim .job_id ) for claim in claims .values ()}
17511749 claimed_jobs = list (_jobs_by_id (self ._db , claimed_job_ids ).values ()) if claimed_job_ids else []
17521750 jobs_by_id = {job .job_id .to_wire (): job for job in claimed_jobs }
@@ -1778,7 +1776,7 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai
17781776 persisted = True
17791777 claimed_entries : set [tuple [str , int ]] = {(c .job_id , c .entry_idx ) for c in claims .values ()}
17801778 claimed_worker_ids : set [WorkerId ] = set (claims .keys ())
1781- all_workers = healthy_active_workers_with_attributes (self ._db )
1779+ all_workers = healthy_active_workers_with_attributes (self ._db , self . _health , self . _store . committed )
17821780 changed = False
17831781
17841782 reservable_states = (
@@ -1916,7 +1914,7 @@ def _read_scheduling_state(self) -> _SchedulingStateRead:
19161914 timer = Timer ()
19171915 with slow_log (logger , "scheduling state reads" , threshold_ms = 50 ):
19181916 pending_tasks = _schedulable_tasks (self ._db )
1919- workers = healthy_active_workers_with_attributes (self ._db )
1917+ workers = healthy_active_workers_with_attributes (self ._db , self . _health , self . _store . committed )
19201918 return _SchedulingStateRead (
19211919 pending_tasks = pending_tasks ,
19221920 workers = workers ,
@@ -2378,7 +2376,7 @@ def _stop_tasks_direct(
23782376
23792377 def _get_active_worker_addresses (self ) -> list [tuple [WorkerId , str | None ]]:
23802378 """Get healthy active workers as (worker_id, address) tuples for ping."""
2381- workers = healthy_active_workers_with_attributes (self ._db )
2379+ workers = healthy_active_workers_with_attributes (self ._db , self . _health , self . _store . committed )
23822380 return [(w .worker_id , w .address ) for w in workers ]
23832381
23842382 def _run_ping_loop (self , stop_event : threading .Event ) -> None :
@@ -2406,8 +2404,7 @@ def _run_ping_loop(self, stop_event: threading.Event) -> None:
24062404 self ._health .ping (result .worker_id , healthy = True )
24072405 live_worker_ids .append (result .worker_id )
24082406
2409- with self ._store .transaction () as cur :
2410- self ._transitions .update_worker_pings (cur , live_worker_ids )
2407+ self ._transitions .update_worker_pings (live_worker_ids )
24112408
24122409 unhealthy = self ._health .workers_over_threshold ()
24132410 if unhealthy :
@@ -2534,7 +2531,7 @@ def _run_autoscaler_once(self) -> None:
25342531
25352532 worker_status_map = self ._build_worker_status_map ()
25362533 self ._autoscaler .refresh (worker_status_map )
2537- workers = healthy_active_workers_with_attributes (self ._db )
2534+ workers = healthy_active_workers_with_attributes (self ._db , self . _health , self . _store . committed )
25382535 demand_entries = compute_demand_entries (
25392536 self ._db ,
25402537 self ._scheduler ,
@@ -2546,12 +2543,7 @@ def _run_autoscaler_once(self) -> None:
25462543 def _build_worker_status_map (self ) -> WorkerStatusMap :
25472544 """Build a map of worker_id to worker status for autoscaler idle tracking."""
25482545 result : WorkerStatusMap = {}
2549- with self ._db .read_snapshot () as snapshot :
2550- rows = snapshot .raw (
2551- "SELECT worker_id FROM workers WHERE active = 1" ,
2552- decoders = {"worker_id" : WorkerId },
2553- )
2554- worker_ids = {row .worker_id for row in rows }
2546+ worker_ids = {wid for wid , l in self ._health .all ().items () if l .active }
25552547 running_by_worker = running_tasks_by_worker (self ._db , worker_ids )
25562548 for wid in worker_ids :
25572549 result [wid ] = WorkerStatus (
0 commit comments