Skip to content

Commit 1e301c8

Browse files
authored
[iris] Fix scheduling loop: filter reservation jobs at SQL level (#4179)
Add has_reservation column to jobs table so _claim_workers_for_reservations can filter at SQL level instead of deserializing all active job protobufs. On production DB (1,340 active jobs, 218MB of request_proto): p50 drops from 156ms to 0.0ms, p95 from 2,669ms to 0.1ms. Adds before/after comparison to the benchmark script.
1 parent 9499121 commit 1e301c8

5 files changed

Lines changed: 81 additions & 7 deletions

File tree

lib/iris/scripts/benchmark_db_queries.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from iris.cluster.controller.controller import (
3131
_building_counts,
3232
_jobs_by_id,
33+
_jobs_with_reservations,
3334
_schedulable_tasks,
3435
)
3536
from iris.cluster.controller.db import (
@@ -112,6 +113,30 @@ def benchmark_scheduling(db: ControllerDB, iterations: int) -> list[tuple[str, f
112113
else:
113114
print(" _jobs_by_id (skipped, no pending jobs)")
114115

116+
# Reservation queries: compare old (fetch all + filter) vs new (SQL filter)
117+
reservable_states = (
118+
cluster_pb2.JOB_STATE_PENDING,
119+
cluster_pb2.JOB_STATE_BUILDING,
120+
cluster_pb2.JOB_STATE_RUNNING,
121+
)
122+
123+
def _reservation_jobs_old():
124+
with db.snapshot() as snapshot:
125+
all_jobs = snapshot.select(JOBS, where=JOBS.c.state.in_(list(reservable_states)))
126+
return [j for j in all_jobs if j.request.HasField("reservation")]
127+
128+
p50, p95 = bench("reservation_jobs (old: full scan)", _reservation_jobs_old, iterations=iterations)
129+
results.append(("reservation_jobs (old: full scan)", p50, p95))
130+
print_result("reservation_jobs (old: full scan)", p50, p95)
131+
132+
p50, p95 = bench(
133+
"reservation_jobs (new: has_reservation)",
134+
lambda: _jobs_with_reservations(db, reservable_states),
135+
iterations=iterations,
136+
)
137+
results.append(("reservation_jobs (new: has_reservation)", p50, p95))
138+
print_result("reservation_jobs (new: has_reservation)", p50, p95)
139+
115140
return results
116141

117142

lib/iris/src/iris/cluster/controller/controller.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
Job,
4343
Task,
4444
Worker,
45+
_decode_row,
4546
_tasks_with_attempts,
4647
healthy_active_workers_with_attributes,
4748
insert_task_profile,
@@ -284,6 +285,21 @@ def _jobs_by_id(queries: ControllerDB, job_ids: set[JobName]) -> dict[JobName, J
284285
return {job.job_id: job for job in jobs}
285286

286287

288+
def _jobs_with_reservations(queries: ControllerDB, states: tuple[int, ...]) -> list[Job]:
289+
"""Fetch only jobs that have reservations, filtering at the SQL level.
290+
291+
Uses the denormalized has_reservation column to avoid deserializing
292+
request_proto for all active jobs.
293+
"""
294+
placeholders = ",".join("?" for _ in states)
295+
with queries.snapshot() as snapshot:
296+
rows = snapshot._fetchall(
297+
f"SELECT * FROM jobs WHERE state IN ({placeholders}) AND has_reservation = 1",
298+
list(states),
299+
)
300+
return [_decode_row(Job, row) for row in rows]
301+
302+
287303
def _schedulable_tasks(queries: ControllerDB) -> list[Task]:
288304
# Only PENDING tasks can pass can_be_scheduled(); no need to fetch ASSIGNED/BUILDING/RUNNING.
289305
SCHEDULABLE_STATES = (cluster_pb2.TASK_STATE_PENDING,)
@@ -1181,11 +1197,8 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai
11811197
cluster_pb2.JOB_STATE_BUILDING,
11821198
cluster_pb2.JOB_STATE_RUNNING,
11831199
)
1184-
with self._db.snapshot() as snapshot:
1185-
reservable_jobs = snapshot.select(JOBS, where=JOBS.c.state.in_(list(reservable_states)))
1186-
for job in reservable_jobs:
1187-
if not job.request.HasField("reservation"):
1188-
continue
1200+
reservation_jobs = _jobs_with_reservations(self._db, reservable_states)
1201+
for job in reservation_jobs:
11891202

11901203
job_wire = job.job_id.to_wire()
11911204
for idx, res_entry in enumerate(job.request.reservation.entries):

lib/iris/src/iris/cluster/controller/db.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ class Job:
579579
exit_code: int | None = db_field("exit_code", _nullable(_decode_int))
580580
num_tasks: int = db_field("num_tasks", _decode_int)
581581
is_reservation_holder: bool = db_field("is_reservation_holder", _decode_bool_int)
582+
has_reservation: bool = db_field("has_reservation", _decode_bool_int, default=False)
582583
name: str = db_field("name", _decode_str, default="")
583584
depth: int = db_field("depth", _decode_int, default=0)
584585

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import sqlite3
5+
6+
7+
def _has_column(conn: sqlite3.Connection, table: str, column: str) -> bool:
8+
columns = {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()}
9+
return column in columns
10+
11+
12+
def migrate(conn: sqlite3.Connection) -> None:
13+
if not _has_column(conn, "jobs", "has_reservation"):
14+
conn.execute("ALTER TABLE jobs ADD COLUMN has_reservation INTEGER NOT NULL DEFAULT 0")
15+
conn.execute(
16+
"CREATE INDEX IF NOT EXISTS idx_jobs_has_reservation "
17+
"ON jobs(has_reservation, state) WHERE has_reservation = 1"
18+
)
19+
20+
# Backfill: scan all rows once and mark only those with reservations.
21+
# Skip if any rows already have has_reservation=1 (migration already ran).
22+
already_backfilled = conn.execute("SELECT 1 FROM jobs WHERE has_reservation = 1 LIMIT 1").fetchone()
23+
if already_backfilled:
24+
return
25+
26+
from iris.rpc import cluster_pb2
27+
28+
rows = conn.execute("SELECT job_id, request_proto FROM jobs WHERE request_proto IS NOT NULL").fetchall()
29+
for row in rows:
30+
proto = cluster_pb2.Controller.LaunchJobRequest()
31+
proto.ParseFromString(row[1])
32+
if proto.HasField("reservation") and proto.reservation.entries:
33+
conn.execute("UPDATE jobs SET has_reservation = 1 WHERE job_id = ?", (row[0],))

lib/iris/src/iris/cluster/controller/transitions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,12 +596,13 @@ def submit_job(
596596

597597
state = cluster_pb2.JOB_STATE_PENDING if validation_error is None else cluster_pb2.JOB_STATE_FAILED
598598
finished_ms = None if validation_error is None else effective_submission_ms
599+
has_reservation = 1 if request.HasField("reservation") and request.reservation.entries else 0
599600
cur.execute(
600601
"INSERT INTO jobs("
601602
"job_id, user_id, parent_job_id, root_job_id, depth, request_proto, state, submitted_at_ms, "
602603
"root_submitted_at_ms, started_at_ms, finished_at_ms, scheduling_deadline_epoch_ms, "
603-
"error, exit_code, num_tasks, is_reservation_holder, name"
604-
") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, ?, ?, ?, NULL, ?, 0, ?)",
604+
"error, exit_code, num_tasks, is_reservation_holder, has_reservation, name"
605+
") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, ?, ?, ?, NULL, ?, 0, ?, ?)",
605606
(
606607
job_id.to_wire(),
607608
job_id.user,
@@ -616,6 +617,7 @@ def submit_job(
616617
deadline_epoch_ms,
617618
validation_error,
618619
replicas,
620+
has_reservation,
619621
request.name,
620622
),
621623
)

0 commit comments

Comments
 (0)