Skip to content

Commit bbaef73

Browse files
rjpowergithub-actions[bot]claude
authored
[iris] Remove ORM query builder, replace with raw SQL (#4181)
Delete Table, Column, Predicate, SelectExpr, Order, Join, JoinedQuery and all 16 module-level table constants from db.py. Replace ~100 callsites across 8 production files and 12 test files with raw SQL + decode_rows/decode_one helpers. Every query is now visible as plain SQL at its callsite. Net -556 lines. Follows #4179 which fixed the immediate scheduling perf issue. --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Russell Power <rjpower@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4820a5b commit bbaef73

19 files changed

Lines changed: 430 additions & 987 deletions

lib/iris/scripts/benchmark_db_queries.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
)
3636
from iris.cluster.controller.db import (
3737
ACTIVE_TASK_STATES,
38-
JOBS,
3938
ControllerDB,
4039
EndpointQuery,
40+
Job,
41+
decode_rows,
4142
healthy_active_workers_with_attributes,
4243
running_tasks_by_worker,
4344
tasks_for_job_with_attempts,
@@ -121,8 +122,11 @@ def benchmark_scheduling(db: ControllerDB, iterations: int) -> list[tuple[str, f
121122
)
122123

123124
def _reservation_jobs_old():
125+
placeholders = ",".join("?" for _ in reservable_states)
124126
with db.snapshot() as snapshot:
125-
all_jobs = snapshot.select(JOBS, where=JOBS.c.state.in_(list(reservable_states)))
127+
all_jobs = decode_rows(
128+
Job, snapshot.fetchall(f"SELECT * FROM jobs WHERE state IN ({placeholders})", reservable_states)
129+
)
126130
return [j for j in all_jobs if j.request.HasField("reservation")]
127131

128132
p50, p95 = bench("reservation_jobs (old: full scan)", _reservation_jobs_old, iterations=iterations)
@@ -145,8 +149,15 @@ def benchmark_dashboard(db: ControllerDB, iterations: int) -> list[tuple[str, fl
145149
results: list[tuple[str, float, float]] = []
146150

147151
def _bench_jobs_in_states(db):
152+
placeholders = ",".join("?" for _ in USER_JOB_STATES)
148153
with db.read_snapshot() as q:
149-
return q.select(JOBS, where=JOBS.c.state.in_(list(USER_JOB_STATES)) & (JOBS.c.depth == 1))
154+
return decode_rows(
155+
Job,
156+
q.fetchall(
157+
f"SELECT * FROM jobs WHERE state IN ({placeholders}) AND depth = 1",
158+
(*USER_JOB_STATES,),
159+
),
160+
)
150161

151162
p50, p95 = bench("jobs_in_states (top-level)", lambda: _bench_jobs_in_states(db), iterations=iterations)
152163
results.append(("jobs_in_states (top-level)", p50, p95))

lib/iris/src/iris/cluster/controller/actor_proxy.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from starlette.requests import Request
2121
from starlette.responses import JSONResponse, Response
2222

23-
from iris.cluster.controller.db import ControllerDB, EndpointQuery, endpoint_query_predicate
23+
from iris.cluster.controller.db import ControllerDB, Endpoint, EndpointQuery, decode_rows, endpoint_query_sql
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -99,11 +99,9 @@ async def handle(self, request: Request) -> Response:
9999
def _resolve_endpoint(self, name: str) -> str | None:
100100
"""Resolve an endpoint name to an address via the controller DB."""
101101
query = EndpointQuery(exact_name=name)
102-
joins, where = endpoint_query_predicate(query)
103-
from iris.cluster.controller.service import ENDPOINTS
104-
102+
sql, params = endpoint_query_sql(query)
105103
with self._db.read_snapshot() as q:
106-
endpoints = q.select(ENDPOINTS, where=where, joins=joins)
104+
endpoints = decode_rows(Endpoint, q.fetchall(sql, tuple(params)))
107105
if not endpoints:
108106
return None
109107
return endpoints[0].address

lib/iris/src/iris/cluster/controller/auth.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import jwt
1919

20-
from iris.cluster.controller.db import API_KEYS, ApiKey, ControllerDB
20+
from iris.cluster.controller.db import ApiKey, ControllerDB, decode_one, decode_rows
2121
from iris.rpc import config_pb2
2222
from iris.rpc.auth import (
2323
GcpAccessTokenVerifier,
@@ -59,9 +59,10 @@ def create_api_key(
5959

6060
def lookup_api_key_by_hash(db: ControllerDB, key_hash: str) -> ApiKey | None:
6161
"""Find an API key by its SHA-256 hash."""
62-
table = dataclasses.replace(API_KEYS, sql_name=db.api_keys_table)
6362
with db.snapshot() as q:
64-
return q.one(table, where=table.c.key_hash == key_hash)
63+
return decode_one(
64+
ApiKey, q.fetchall(f"SELECT * FROM {db.api_keys_table} WHERE key_hash = ? LIMIT 1", (key_hash,))
65+
)
6566

6667

6768
def touch_api_key(db: ControllerDB, key_id: str, now: Timestamp) -> None:
@@ -84,11 +85,10 @@ def revoke_api_key(db: ControllerDB, key_id: str, now: Timestamp) -> bool:
8485

8586
def list_api_keys(db: ControllerDB, user_id: str | None = None) -> list[ApiKey]:
8687
"""List API keys, optionally filtered by user."""
87-
table = dataclasses.replace(API_KEYS, sql_name=db.api_keys_table)
8888
with db.snapshot() as q:
8989
if user_id:
90-
return q.select(table, where=table.c.user_id == user_id)
91-
return q.select(table)
90+
return decode_rows(ApiKey, q.fetchall(f"SELECT * FROM {db.api_keys_table} WHERE user_id = ?", (user_id,)))
91+
return decode_rows(ApiKey, q.fetchall(f"SELECT * FROM {db.api_keys_table}"))
9292

9393

9494
def revoke_login_keys_for_user(db: ControllerDB, user_id: str, now: Timestamp) -> list[str]:

lib/iris/src/iris/cluster/controller/autoscaler.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
soft_constraint_score,
4848
split_hard_soft,
4949
)
50-
from iris.cluster.controller.db import SCALING_GROUPS, SLICES, TRACKED_WORKERS, ControllerDB
50+
from iris.cluster.controller.db import ControllerDB, _decode_json_list, _decode_timestamp_ms
5151
from iris.cluster.types import WorkerStatusMap
5252
from iris.cluster.controller.scaling_group import (
5353
GroupAvailability,
@@ -1260,38 +1260,30 @@ def restore_from_db(self, db: ControllerDB, platform: WorkerInfraProvider) -> No
12601260
tracked workers. Call at startup before loops begin.
12611261
"""
12621262
with db.snapshot() as snapshot:
1263-
scaling_rows = snapshot.select(
1264-
SCALING_GROUPS,
1265-
columns=(
1266-
SCALING_GROUPS.c.name,
1267-
SCALING_GROUPS.c.consecutive_failures,
1268-
SCALING_GROUPS.c.backoff_until_ms,
1269-
SCALING_GROUPS.c.last_scale_up_ms,
1270-
SCALING_GROUPS.c.last_scale_down_ms,
1271-
SCALING_GROUPS.c.quota_exceeded_until_ms,
1272-
SCALING_GROUPS.c.quota_reason,
1273-
),
1263+
scaling_rows = snapshot.raw(
1264+
"SELECT name, consecutive_failures, backoff_until_ms, last_scale_up_ms, "
1265+
"last_scale_down_ms, quota_exceeded_until_ms, quota_reason "
1266+
"FROM scaling_groups",
1267+
decoders={
1268+
"consecutive_failures": int,
1269+
"backoff_until_ms": _decode_timestamp_ms,
1270+
"last_scale_up_ms": _decode_timestamp_ms,
1271+
"last_scale_down_ms": _decode_timestamp_ms,
1272+
"quota_exceeded_until_ms": _decode_timestamp_ms,
1273+
},
12741274
)
1275-
slice_rows = snapshot.select(
1276-
SLICES,
1277-
columns=(
1278-
SLICES.c.slice_id,
1279-
SLICES.c.scale_group,
1280-
SLICES.c.lifecycle,
1281-
SLICES.c.worker_ids,
1282-
SLICES.c.created_at_ms,
1283-
SLICES.c.last_active_ms,
1284-
SLICES.c.error_message,
1285-
),
1275+
slice_rows = snapshot.raw(
1276+
"SELECT slice_id, scale_group, lifecycle, worker_ids, "
1277+
"created_at_ms, last_active_ms, error_message "
1278+
"FROM slices",
1279+
decoders={
1280+
"worker_ids": _decode_json_list,
1281+
"created_at_ms": _decode_timestamp_ms,
1282+
"last_active_ms": _decode_timestamp_ms,
1283+
},
12861284
)
1287-
tracked_rows = snapshot.select(
1288-
TRACKED_WORKERS,
1289-
columns=(
1290-
TRACKED_WORKERS.c.worker_id,
1291-
TRACKED_WORKERS.c.slice_id,
1292-
TRACKED_WORKERS.c.scale_group,
1293-
TRACKED_WORKERS.c.internal_address,
1294-
),
1285+
tracked_rows = snapshot.raw(
1286+
"SELECT worker_id, slice_id, scale_group, internal_address FROM tracked_workers",
12951287
)
12961288

12971289
# Build GroupSnapshot objects from DB rows

lib/iris/src/iris/cluster/controller/checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import fsspec.core
3434
import zstandard
3535

36-
from iris.cluster.controller.db import JOBS, TASKS, WORKERS, ControllerDB
36+
from iris.cluster.controller.db import ControllerDB
3737
from iris.time_utils import Duration, Timestamp
3838

3939
logger = logging.getLogger(__name__)
@@ -158,9 +158,9 @@ def write_checkpoint(
158158
tmp_zst2.unlink(missing_ok=True)
159159

160160
with db.snapshot() as snapshot:
161-
job_count = snapshot.count(JOBS)
162-
task_count = snapshot.count(TASKS)
163-
worker_count = snapshot.count(WORKERS)
161+
job_count = snapshot.fetchone("SELECT COUNT(*) FROM jobs")[0] # type: ignore[index]
162+
task_count = snapshot.fetchone("SELECT COUNT(*) FROM tasks")[0] # type: ignore[index]
163+
worker_count = snapshot.fetchone("SELECT COUNT(*) FROM workers")[0] # type: ignore[index]
164164
result = CheckpointResult(
165165
created_at=created_at,
166166
job_count=job_count,

lib/iris/src/iris/cluster/controller/controller.py

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,14 @@
3232
write_checkpoint,
3333
)
3434
from iris.cluster.controller.db import (
35-
ATTEMPTS,
36-
JOBS,
37-
RESERVATION_CLAIMS,
38-
TASKS,
39-
WORKERS,
35+
Attempt,
4036
ControllerDB,
41-
Join,
4237
Job,
4338
Task,
4439
Worker,
4540
_decode_row,
4641
_tasks_with_attempts,
42+
decode_rows,
4743
healthy_active_workers_with_attributes,
4844
insert_task_profile,
4945
running_tasks_by_worker,
@@ -260,13 +256,9 @@ def compute_demand_entries(
260256
def _read_reservation_claims(db: ControllerDB) -> dict[WorkerId, ReservationClaim]:
261257
"""Read reservation claims from the canonical DB table."""
262258
with db.snapshot() as snapshot:
263-
rows = snapshot.select(
264-
RESERVATION_CLAIMS,
265-
columns=(
266-
RESERVATION_CLAIMS.c.worker_id,
267-
RESERVATION_CLAIMS.c.job_id,
268-
RESERVATION_CLAIMS.c.entry_idx,
269-
),
259+
rows = snapshot.raw(
260+
"SELECT rc.worker_id, rc.job_id, rc.entry_idx FROM reservation_claims rc",
261+
decoders={"worker_id": WorkerId},
270262
)
271263
return {
272264
row.worker_id: ReservationClaim(
@@ -280,8 +272,12 @@ def _read_reservation_claims(db: ControllerDB) -> dict[WorkerId, ReservationClai
280272
def _jobs_by_id(queries: ControllerDB, job_ids: set[JobName]) -> dict[JobName, Job]:
281273
if not job_ids:
282274
return {}
275+
wires = [job_id.to_wire() for job_id in job_ids]
276+
placeholders = ",".join("?" for _ in wires)
283277
with queries.snapshot() as snapshot:
284-
jobs = snapshot.select(JOBS, where=JOBS.c.job_id.in_([job_id.to_wire() for job_id in job_ids]))
278+
jobs = decode_rows(
279+
Job, snapshot.fetchall(f"SELECT * FROM jobs j WHERE j.job_id IN ({placeholders})", tuple(wires))
280+
)
285281
return {job.job_id: job for job in jobs}
286282

287283

@@ -302,16 +298,14 @@ def _jobs_with_reservations(queries: ControllerDB, states: tuple[int, ...]) -> l
302298

303299
def _schedulable_tasks(queries: ControllerDB) -> list[Task]:
304300
# Only PENDING tasks can pass can_be_scheduled(); no need to fetch ASSIGNED/BUILDING/RUNNING.
305-
SCHEDULABLE_STATES = (cluster_pb2.TASK_STATE_PENDING,)
306301
with queries.snapshot() as snapshot:
307-
tasks = snapshot.select(
308-
TASKS,
309-
where=TASKS.c.state.in_(list(SCHEDULABLE_STATES)),
310-
order_by=(
311-
TASKS.c.priority_neg_depth.asc(),
312-
TASKS.c.priority_root_submitted_ms.asc(),
313-
TASKS.c.submitted_at_ms.asc(),
314-
TASKS.c.task_id.asc(),
302+
tasks = decode_rows(
303+
Task,
304+
snapshot.fetchall(
305+
"SELECT * FROM tasks t WHERE t.state = ? "
306+
"ORDER BY t.priority_neg_depth ASC, t.priority_root_submitted_ms ASC, "
307+
"t.submitted_at_ms ASC, t.task_id ASC",
308+
(cluster_pb2.TASK_STATE_PENDING,),
315309
),
316310
)
317311
return [task for task in tasks if task.can_be_scheduled()]
@@ -321,16 +315,22 @@ def _tasks_by_ids_with_attempts(queries: ControllerDB, task_ids: set[JobName]) -
321315
if not task_ids:
322316
return {}
323317
task_wires = [task_id.to_wire() for task_id in task_ids]
318+
placeholders = ",".join("?" for _ in task_wires)
324319
with queries.snapshot() as snapshot:
325-
tasks = snapshot.select(
326-
TASKS,
327-
where=TASKS.c.task_id.in_(task_wires),
328-
order_by=(TASKS.c.task_id.asc(),),
320+
tasks = decode_rows(
321+
Task,
322+
snapshot.fetchall(
323+
f"SELECT * FROM tasks t WHERE t.task_id IN ({placeholders}) ORDER BY t.task_id ASC",
324+
tuple(task_wires),
325+
),
329326
)
330-
attempts = snapshot.select(
331-
ATTEMPTS,
332-
where=ATTEMPTS.c.task_id.in_(task_wires),
333-
order_by=(ATTEMPTS.c.task_id.asc(), ATTEMPTS.c.attempt_id.asc()),
327+
attempts = decode_rows(
328+
Attempt,
329+
snapshot.fetchall(
330+
f"SELECT * FROM task_attempts a WHERE a.task_id IN ({placeholders}) "
331+
"ORDER BY a.task_id ASC, a.attempt_id ASC",
332+
tuple(task_wires),
333+
),
334334
)
335335
return {task.task_id: task for task in _tasks_with_attempts(tasks, attempts)}
336336

@@ -362,25 +362,27 @@ def _building_counts(queries: ControllerDB, workers: list[Worker]) -> dict[Worke
362362
def _workers_by_id(queries: ControllerDB, worker_ids: set[WorkerId]) -> dict[WorkerId, Worker]:
363363
if not worker_ids:
364364
return {}
365+
wires = [str(wid) for wid in worker_ids]
366+
placeholders = ",".join("?" for _ in wires)
365367
with queries.snapshot() as snapshot:
366-
workers = snapshot.select(
367-
WORKERS,
368-
where=WORKERS.c.worker_id.in_([str(worker_id) for worker_id in worker_ids]),
368+
workers = decode_rows(
369+
Worker, snapshot.fetchall(f"SELECT * FROM workers w WHERE w.worker_id IN ({placeholders})", tuple(wires))
369370
)
370371
return {worker.worker_id: worker for worker in workers}
371372

372373

373374
def _task_worker_mapping(queries: ControllerDB, task_ids: set[JobName]) -> dict[JobName, WorkerId]:
374375
if not task_ids:
375376
return {}
377+
task_wires = [task_id.to_wire() for task_id in task_ids]
378+
placeholders = ",".join("?" for _ in task_wires)
376379
with queries.snapshot() as snapshot:
377-
rows = snapshot.select(
378-
TASKS,
379-
columns=(TASKS.c.task_id, ATTEMPTS.c.worker_id),
380-
joins=(Join(table=ATTEMPTS, on=TASKS.c.task_id == ATTEMPTS.c.task_id),),
381-
where=TASKS.c.task_id.in_([task_id.to_wire() for task_id in task_ids])
382-
& (TASKS.c.current_attempt_id == ATTEMPTS.c.attempt_id)
383-
& ATTEMPTS.c.worker_id.not_null(),
380+
rows = snapshot.raw(
381+
f"SELECT t.task_id, a.worker_id FROM tasks t "
382+
f"JOIN task_attempts a ON t.task_id = a.task_id AND t.current_attempt_id = a.attempt_id "
383+
f"WHERE t.task_id IN ({placeholders}) AND a.worker_id IS NOT NULL",
384+
tuple(task_wires),
385+
decoders={"task_id": JobName.from_wire, "worker_id": WorkerId},
384386
)
385387
return {row.task_id: row.worker_id for row in rows}
386388

@@ -1178,12 +1180,8 @@ def _cleanup_stale_claims(self, claims: dict[WorkerId, ReservationClaim] | None
11781180
persisted = True
11791181
with self._db.snapshot() as snapshot:
11801182
active_worker_ids = {
1181-
row.worker_id
1182-
for row in snapshot.select(
1183-
WORKERS,
1184-
columns=(WORKERS.c.worker_id,),
1185-
where=WORKERS.c.active == 1,
1186-
)
1183+
WorkerId(str(row[0]))
1184+
for row in snapshot.fetchall("SELECT w.worker_id FROM workers w WHERE w.active = 1")
11871185
}
11881186
claimed_job_ids = {JobName.from_wire(claim.job_id) for claim in claims.values()}
11891187
claimed_jobs = list(_jobs_by_id(self._db, claimed_job_ids).values()) if claimed_job_ids else []
@@ -1224,7 +1222,6 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai
12241222
)
12251223
reservation_jobs = _jobs_with_reservations(self._db, reservable_states)
12261224
for job in reservation_jobs:
1227-
12281225
job_wire = job.job_id.to_wire()
12291226
for idx, res_entry in enumerate(job.request.reservation.entries):
12301227
if (job_wire, idx) in claimed_entries:
@@ -1610,7 +1607,11 @@ def _sync_all_execution_units(self) -> None:
16101607
if _HEALTH_SUMMARY_INTERVAL.should_run():
16111608
workers = healthy_active_workers_with_attributes(self._db)
16121609
with self._db.snapshot() as snap:
1613-
active = snap.count(JOBS, where=JOBS.c.state == cluster_pb2.JOB_STATE_RUNNING)
1610+
active = snap.fetchone(
1611+
"SELECT COUNT(*) FROM jobs j WHERE j.state = ?", (cluster_pb2.JOB_STATE_RUNNING,)
1612+
)[
1613+
0
1614+
] # type: ignore[index]
16141615
pending = len(_schedulable_tasks(self._db))
16151616
logger.info(
16161617
"Controller status: %d workers (%d failed), %d active jobs, %d pending tasks",
@@ -1647,7 +1648,7 @@ def _build_worker_status_map(self) -> WorkerStatusMap:
16471648
"""Build a map of worker_id to worker status for autoscaler idle tracking."""
16481649
result: WorkerStatusMap = {}
16491650
with self._db.snapshot() as snapshot:
1650-
workers = snapshot.select(WORKERS, where=WORKERS.c.active == 1)
1651+
workers = decode_rows(Worker, snapshot.fetchall("SELECT * FROM workers w WHERE w.active = 1"))
16511652
running_by_worker = running_tasks_by_worker(self._db, {worker.worker_id for worker in workers})
16521653
for worker in workers:
16531654
result[worker.worker_id] = WorkerStatus(

0 commit comments

Comments
 (0)