3232 write_checkpoint ,
3333)
3434from iris .cluster .controller .db import (
35- ATTEMPTS ,
36- JOBS ,
37- RESERVATION_CLAIMS ,
38- TASKS ,
39- WORKERS ,
35+ Attempt ,
4036 ControllerDB ,
41- Join ,
4237 Job ,
4338 Task ,
4439 Worker ,
4540 _decode_row ,
4641 _tasks_with_attempts ,
42+ decode_rows ,
4743 healthy_active_workers_with_attributes ,
4844 insert_task_profile ,
4945 running_tasks_by_worker ,
@@ -260,13 +256,9 @@ def compute_demand_entries(
260256def _read_reservation_claims (db : ControllerDB ) -> dict [WorkerId , ReservationClaim ]:
261257 """Read reservation claims from the canonical DB table."""
262258 with db .snapshot () as snapshot :
263- rows = snapshot .select (
264- RESERVATION_CLAIMS ,
265- columns = (
266- RESERVATION_CLAIMS .c .worker_id ,
267- RESERVATION_CLAIMS .c .job_id ,
268- RESERVATION_CLAIMS .c .entry_idx ,
269- ),
259+ rows = snapshot .raw (
260+ "SELECT rc.worker_id, rc.job_id, rc.entry_idx FROM reservation_claims rc" ,
261+ decoders = {"worker_id" : WorkerId },
270262 )
271263 return {
272264 row .worker_id : ReservationClaim (
@@ -280,8 +272,12 @@ def _read_reservation_claims(db: ControllerDB) -> dict[WorkerId, ReservationClai
280272def _jobs_by_id (queries : ControllerDB , job_ids : set [JobName ]) -> dict [JobName , Job ]:
281273 if not job_ids :
282274 return {}
275+ wires = [job_id .to_wire () for job_id in job_ids ]
276+ placeholders = "," .join ("?" for _ in wires )
283277 with queries .snapshot () as snapshot :
284- jobs = snapshot .select (JOBS , where = JOBS .c .job_id .in_ ([job_id .to_wire () for job_id in job_ids ]))
278+ jobs = decode_rows (
279+ Job , snapshot .fetchall (f"SELECT * FROM jobs j WHERE j.job_id IN ({ placeholders } )" , tuple (wires ))
280+ )
285281 return {job .job_id : job for job in jobs }
286282
287283
@@ -302,16 +298,14 @@ def _jobs_with_reservations(queries: ControllerDB, states: tuple[int, ...]) -> l
302298
303299def _schedulable_tasks (queries : ControllerDB ) -> list [Task ]:
304300 # Only PENDING tasks can pass can_be_scheduled(); no need to fetch ASSIGNED/BUILDING/RUNNING.
305- SCHEDULABLE_STATES = (cluster_pb2 .TASK_STATE_PENDING ,)
306301 with queries .snapshot () as snapshot :
307- tasks = snapshot .select (
308- TASKS ,
309- where = TASKS .c .state .in_ (list (SCHEDULABLE_STATES )),
310- order_by = (
311- TASKS .c .priority_neg_depth .asc (),
312- TASKS .c .priority_root_submitted_ms .asc (),
313- TASKS .c .submitted_at_ms .asc (),
314- TASKS .c .task_id .asc (),
302+ tasks = decode_rows (
303+ Task ,
304+ snapshot .fetchall (
305+ "SELECT * FROM tasks t WHERE t.state = ? "
306+ "ORDER BY t.priority_neg_depth ASC, t.priority_root_submitted_ms ASC, "
307+ "t.submitted_at_ms ASC, t.task_id ASC" ,
308+ (cluster_pb2 .TASK_STATE_PENDING ,),
315309 ),
316310 )
317311 return [task for task in tasks if task .can_be_scheduled ()]
@@ -321,16 +315,22 @@ def _tasks_by_ids_with_attempts(queries: ControllerDB, task_ids: set[JobName]) -
321315 if not task_ids :
322316 return {}
323317 task_wires = [task_id .to_wire () for task_id in task_ids ]
318+ placeholders = "," .join ("?" for _ in task_wires )
324319 with queries .snapshot () as snapshot :
325- tasks = snapshot .select (
326- TASKS ,
327- where = TASKS .c .task_id .in_ (task_wires ),
328- order_by = (TASKS .c .task_id .asc (),),
320+ tasks = decode_rows (
321+ Task ,
322+ snapshot .fetchall (
323+ f"SELECT * FROM tasks t WHERE t.task_id IN ({ placeholders } ) ORDER BY t.task_id ASC" ,
324+ tuple (task_wires ),
325+ ),
329326 )
330- attempts = snapshot .select (
331- ATTEMPTS ,
332- where = ATTEMPTS .c .task_id .in_ (task_wires ),
333- order_by = (ATTEMPTS .c .task_id .asc (), ATTEMPTS .c .attempt_id .asc ()),
327+ attempts = decode_rows (
328+ Attempt ,
329+ snapshot .fetchall (
330+ f"SELECT * FROM task_attempts a WHERE a.task_id IN ({ placeholders } ) "
331+ "ORDER BY a.task_id ASC, a.attempt_id ASC" ,
332+ tuple (task_wires ),
333+ ),
334334 )
335335 return {task .task_id : task for task in _tasks_with_attempts (tasks , attempts )}
336336
@@ -362,25 +362,27 @@ def _building_counts(queries: ControllerDB, workers: list[Worker]) -> dict[Worke
362362def _workers_by_id (queries : ControllerDB , worker_ids : set [WorkerId ]) -> dict [WorkerId , Worker ]:
363363 if not worker_ids :
364364 return {}
365+ wires = [str (wid ) for wid in worker_ids ]
366+ placeholders = "," .join ("?" for _ in wires )
365367 with queries .snapshot () as snapshot :
366- workers = snapshot .select (
367- WORKERS ,
368- where = WORKERS .c .worker_id .in_ ([str (worker_id ) for worker_id in worker_ids ]),
368+ workers = decode_rows (
369+ Worker , snapshot .fetchall (f"SELECT * FROM workers w WHERE w.worker_id IN ({ placeholders } )" , tuple (wires ))
369370 )
370371 return {worker .worker_id : worker for worker in workers }
371372
372373
373374def _task_worker_mapping (queries : ControllerDB , task_ids : set [JobName ]) -> dict [JobName , WorkerId ]:
374375 if not task_ids :
375376 return {}
377+ task_wires = [task_id .to_wire () for task_id in task_ids ]
378+ placeholders = "," .join ("?" for _ in task_wires )
376379 with queries .snapshot () as snapshot :
377- rows = snapshot .select (
378- TASKS ,
379- columns = (TASKS .c .task_id , ATTEMPTS .c .worker_id ),
380- joins = (Join (table = ATTEMPTS , on = TASKS .c .task_id == ATTEMPTS .c .task_id ),),
381- where = TASKS .c .task_id .in_ ([task_id .to_wire () for task_id in task_ids ])
382- & (TASKS .c .current_attempt_id == ATTEMPTS .c .attempt_id )
383- & ATTEMPTS .c .worker_id .not_null (),
380+ rows = snapshot .raw (
381+ f"SELECT t.task_id, a.worker_id FROM tasks t "
382+ f"JOIN task_attempts a ON t.task_id = a.task_id AND t.current_attempt_id = a.attempt_id "
383+ f"WHERE t.task_id IN ({ placeholders } ) AND a.worker_id IS NOT NULL" ,
384+ tuple (task_wires ),
385+ decoders = {"task_id" : JobName .from_wire , "worker_id" : WorkerId },
384386 )
385387 return {row .task_id : row .worker_id for row in rows }
386388
@@ -1178,12 +1180,8 @@ def _cleanup_stale_claims(self, claims: dict[WorkerId, ReservationClaim] | None
11781180 persisted = True
11791181 with self ._db .snapshot () as snapshot :
11801182 active_worker_ids = {
1181- row .worker_id
1182- for row in snapshot .select (
1183- WORKERS ,
1184- columns = (WORKERS .c .worker_id ,),
1185- where = WORKERS .c .active == 1 ,
1186- )
1183+ WorkerId (str (row [0 ]))
1184+ for row in snapshot .fetchall ("SELECT w.worker_id FROM workers w WHERE w.active = 1" )
11871185 }
11881186 claimed_job_ids = {JobName .from_wire (claim .job_id ) for claim in claims .values ()}
11891187 claimed_jobs = list (_jobs_by_id (self ._db , claimed_job_ids ).values ()) if claimed_job_ids else []
@@ -1224,7 +1222,6 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai
12241222 )
12251223 reservation_jobs = _jobs_with_reservations (self ._db , reservable_states )
12261224 for job in reservation_jobs :
1227-
12281225 job_wire = job .job_id .to_wire ()
12291226 for idx , res_entry in enumerate (job .request .reservation .entries ):
12301227 if (job_wire , idx ) in claimed_entries :
@@ -1610,7 +1607,11 @@ def _sync_all_execution_units(self) -> None:
16101607 if _HEALTH_SUMMARY_INTERVAL .should_run ():
16111608 workers = healthy_active_workers_with_attributes (self ._db )
16121609 with self ._db .snapshot () as snap :
1613- active = snap .count (JOBS , where = JOBS .c .state == cluster_pb2 .JOB_STATE_RUNNING )
1610+ active = snap .fetchone (
1611+ "SELECT COUNT(*) FROM jobs j WHERE j.state = ?" , (cluster_pb2 .JOB_STATE_RUNNING ,)
1612+ )[
1613+ 0
1614+ ] # type: ignore[index]
16141615 pending = len (_schedulable_tasks (self ._db ))
16151616 logger .info (
16161617 "Controller status: %d workers (%d failed), %d active jobs, %d pending tasks" ,
@@ -1647,7 +1648,7 @@ def _build_worker_status_map(self) -> WorkerStatusMap:
16471648 """Build a map of worker_id to worker status for autoscaler idle tracking."""
16481649 result : WorkerStatusMap = {}
16491650 with self ._db .snapshot () as snapshot :
1650- workers = snapshot .select ( WORKERS , where = WORKERS . c . active == 1 )
1651+ workers = decode_rows ( Worker , snapshot .fetchall ( "SELECT * FROM workers w WHERE w. active = 1" ) )
16511652 running_by_worker = running_tasks_by_worker (self ._db , {worker .worker_id for worker in workers })
16521653 for worker in workers :
16531654 result [worker .worker_id ] = WorkerStatus (
0 commit comments