Skip to content

Commit 641d625

Browse files
authored
perf(iris): reduce drain_dispatch_all lock hold time from 80ms to <5ms (#4222)
## Summary - `drain_dispatch_all`: move the running-tasks 3-way JOIN out of the write lock into a `read_snapshot()`, drop the `JOIN jobs` (filter `is_reservation_holder` in Python instead). Write lock now only covers the dispatch_queue SELECT + DELETE. - `prune_old_data`: add a read-snapshot pre-check to skip the write lock when nothing needs pruning. - `db.py`: add missing `decode_task` method to `ControllerDB`. Benchmark on production checkpoint (4208 jobs, 148K tasks, 225 workers): - `drain_dispatch_all` lock hold: **86ms → <5ms** - `prune_old_data` (0 deletions): **427ms → <1ms** Fixes #4220
1 parent 58149ad commit 641d625

File tree

4 files changed

+193
-65
lines changed

4 files changed

+193
-65
lines changed

lib/iris/scripts/benchmark_db_queries.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,14 @@ def benchmark_heartbeat(db: ControllerDB, iterations: int) -> list[tuple[str, fl
373373
sample_worker_id = str(workers[0].worker_id)
374374
active_states = tuple(ACTIVE_TASK_STATES)
375375

376-
# Single-worker running tasks query (simulates drain_dispatch inner query)
376+
# Single-worker running tasks query (simulates drain_dispatch inner query, 2-way JOIN)
377377
def _single_worker_running_tasks():
378378
with db.read_snapshot() as q:
379379
q.raw(
380-
"SELECT t.task_id, t.current_attempt_id "
380+
"SELECT t.task_id, t.current_attempt_id, t.job_id "
381381
"FROM tasks t "
382382
"JOIN task_attempts ta ON t.task_id = ta.task_id AND t.current_attempt_id = ta.attempt_id "
383-
"JOIN jobs j ON j.job_id = t.job_id "
384-
"WHERE ta.worker_id = ? AND t.state IN (?, ?, ?) AND j.is_reservation_holder = 0 "
383+
"WHERE ta.worker_id = ? AND t.state IN (?, ?, ?) "
385384
"ORDER BY t.task_id ASC",
386385
(sample_worker_id, *active_states),
387386
)
@@ -390,16 +389,15 @@ def _single_worker_running_tasks():
390389
results.append(("drain_dispatch running_tasks (1 worker)", p50, p95))
391390
print_result("drain_dispatch running_tasks (1 worker)", p50, p95)
392391

393-
# Full loop: running tasks for ALL workers (simulates phase 1)
392+
# Full loop: running tasks for ALL workers (simulates phase 1, 2-way JOIN)
394393
def _all_workers_running_tasks():
395394
for w in workers:
396395
with db.read_snapshot() as q:
397396
q.raw(
398-
"SELECT t.task_id, t.current_attempt_id "
397+
"SELECT t.task_id, t.current_attempt_id, t.job_id "
399398
"FROM tasks t "
400399
"JOIN task_attempts ta ON t.task_id = ta.task_id AND t.current_attempt_id = ta.attempt_id "
401-
"JOIN jobs j ON j.job_id = t.job_id "
402-
"WHERE ta.worker_id = ? AND t.state IN (?, ?, ?) AND j.is_reservation_holder = 0 "
400+
"WHERE ta.worker_id = ? AND t.state IN (?, ?, ?) "
403401
"ORDER BY t.task_id ASC",
404402
(str(w.worker_id), *active_states),
405403
)

lib/iris/src/iris/cluster/controller/db.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,10 @@ def read_snapshot(self) -> Iterator[QuerySnapshot]:
706706
logging.getLogger(__name__).warning("read_snapshot rollback failed", exc_info=True)
707707
self._read_pool.put(conn)
708708

709+
@staticmethod
710+
def decode_task(row: sqlite3.Row) -> Task:
711+
return _decode_row(Task, row)
712+
709713
def apply_migrations(self) -> None:
710714
"""Apply pending migrations from the migrations/ directory.
711715

lib/iris/src/iris/cluster/controller/transitions.py

Lines changed: 115 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,15 +1629,25 @@ def drain_dispatch(self, worker_id: WorkerId) -> DispatchBatch | None:
16291629
).fetchall()
16301630
if dispatch_rows:
16311631
cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (str(worker_id),))
1632-
running_rows = cur.execute(
1633-
"SELECT t.task_id, t.current_attempt_id "
1632+
running_rows_raw = cur.execute(
1633+
"SELECT t.task_id, t.current_attempt_id, t.job_id "
16341634
"FROM tasks t "
16351635
"JOIN task_attempts ta ON t.task_id = ta.task_id AND t.current_attempt_id = ta.attempt_id "
1636-
"JOIN jobs j ON j.job_id = t.job_id "
1637-
"WHERE ta.worker_id = ? AND t.state IN (?, ?, ?) AND j.is_reservation_holder = 0 "
1636+
"WHERE ta.worker_id = ? AND t.state IN (?, ?, ?) "
16381637
"ORDER BY t.task_id ASC",
16391638
(str(worker_id), *ACTIVE_TASK_STATES),
16401639
).fetchall()
1640+
running_job_ids = {str(row["job_id"]) for row in running_rows_raw}
1641+
if running_job_ids:
1642+
holder_placeholders = ",".join("?" for _ in running_job_ids)
1643+
holder_rows = cur.execute(
1644+
f"SELECT job_id FROM jobs WHERE job_id IN ({holder_placeholders}) AND is_reservation_holder = 1",
1645+
tuple(running_job_ids),
1646+
).fetchall()
1647+
holder_ids = {str(r["job_id"]) for r in holder_rows}
1648+
else:
1649+
holder_ids = set()
1650+
running_rows = [r for r in running_rows_raw if str(r["job_id"]) not in holder_ids]
16411651
tasks_to_run: list[cluster_pb2.Worker.RunTaskRequest] = []
16421652
tasks_to_kill: list[str] = []
16431653
for row in dispatch_rows:
@@ -1662,16 +1672,47 @@ def drain_dispatch(self, worker_id: WorkerId) -> DispatchBatch | None:
16621672
)
16631673

16641674
def drain_dispatch_all(self) -> list[DispatchBatch]:
1665-
"""Drain buffered dispatches and snapshot running tasks for all healthy active workers in one transaction."""
1666-
with self._db.transaction() as cur:
1667-
worker_rows = cur.execute(
1675+
"""Drain buffered dispatches and snapshot running tasks for all healthy active workers.
1676+
1677+
Reads (workers, running tasks, reservation filter) use a read snapshot
1678+
to avoid holding the write lock. The write lock is only held for the
1679+
dispatch_queue SELECT + DELETE.
1680+
"""
1681+
# -- Phase 1: read-only queries (no write lock) --
1682+
with self._db.read_snapshot() as snap:
1683+
worker_rows = snap.fetchall(
16681684
"SELECT worker_id, address, metadata_proto FROM workers WHERE active = 1 AND healthy = 1"
1669-
).fetchall()
1685+
)
16701686
if not worker_rows:
16711687
return []
16721688

16731689
worker_id_set = {str(row["worker_id"]) for row in worker_rows}
1674-
placeholders = ",".join("?" for _ in worker_id_set)
1690+
1691+
running_rows = snap.fetchall(
1692+
"SELECT ta.worker_id, t.task_id, t.current_attempt_id, t.job_id "
1693+
"FROM tasks t "
1694+
"JOIN task_attempts ta ON t.task_id = ta.task_id AND t.current_attempt_id = ta.attempt_id "
1695+
"WHERE t.state IN (?, ?, ?) "
1696+
"ORDER BY t.task_id ASC",
1697+
tuple(ACTIVE_TASK_STATES),
1698+
)
1699+
1700+
# Batch-check reservation holders instead of joining the jobs table
1701+
running_job_ids = {str(row["job_id"]) for row in running_rows}
1702+
reservation_holder_ids: set[str] = set()
1703+
if running_job_ids:
1704+
job_placeholders = ",".join("?" for _ in running_job_ids)
1705+
res_rows = snap.fetchall(
1706+
f"SELECT job_id FROM jobs WHERE job_id IN ({job_placeholders}) AND is_reservation_holder = 1",
1707+
tuple(running_job_ids),
1708+
)
1709+
reservation_holder_ids = {str(row["job_id"]) for row in res_rows}
1710+
1711+
running_rows = [row for row in running_rows if str(row["job_id"]) not in reservation_holder_ids]
1712+
1713+
# -- Phase 2: write lock only for dispatch_queue drain --
1714+
placeholders = ",".join("?" for _ in worker_id_set)
1715+
with self._db.transaction() as cur:
16751716
dispatch_rows = cur.execute(
16761717
f"SELECT worker_id, id, kind, payload_proto, task_id FROM dispatch_queue "
16771718
f"WHERE worker_id IN ({placeholders}) ORDER BY id ASC",
@@ -1683,57 +1724,48 @@ def drain_dispatch_all(self) -> list[DispatchBatch]:
16831724
tuple(worker_id_set),
16841725
)
16851726

1686-
running_rows = cur.execute(
1687-
"SELECT ta.worker_id, t.task_id, t.current_attempt_id "
1688-
"FROM tasks t "
1689-
"JOIN task_attempts ta ON t.task_id = ta.task_id AND t.current_attempt_id = ta.attempt_id "
1690-
"JOIN jobs j ON j.job_id = t.job_id "
1691-
"WHERE t.state IN (?, ?, ?) AND j.is_reservation_holder = 0 "
1692-
"ORDER BY t.task_id ASC",
1693-
(*ACTIVE_TASK_STATES,),
1694-
).fetchall()
1727+
# -- Phase 3: build results (pure Python, no lock) --
1728+
dispatch_by_worker: dict[str, list[Any]] = defaultdict(list)
1729+
for row in dispatch_rows:
1730+
dispatch_by_worker[str(row["worker_id"])].append(row)
16951731

1696-
dispatch_by_worker: dict[str, list[Any]] = defaultdict(list)
1697-
for row in dispatch_rows:
1698-
dispatch_by_worker[str(row["worker_id"])].append(row)
1732+
running_by_worker: dict[str, list[Any]] = defaultdict(list)
1733+
for row in running_rows:
1734+
running_by_worker[str(row["worker_id"])].append(row)
16991735

1700-
running_by_worker: dict[str, list[Any]] = defaultdict(list)
1701-
for row in running_rows:
1702-
running_by_worker[str(row["worker_id"])].append(row)
1703-
1704-
batches: list[DispatchBatch] = []
1705-
for worker_row in worker_rows:
1706-
wid = str(worker_row["worker_id"])
1707-
w_dispatch = dispatch_by_worker.get(wid, [])
1708-
w_running = running_by_worker.get(wid, [])
1709-
1710-
tasks_to_run: list[cluster_pb2.Worker.RunTaskRequest] = []
1711-
tasks_to_kill: list[str] = []
1712-
for row in w_dispatch:
1713-
if str(row["kind"]) == "run" and row["payload_proto"] is not None:
1714-
req = cluster_pb2.Worker.RunTaskRequest()
1715-
req.ParseFromString(bytes(row["payload_proto"]))
1716-
tasks_to_run.append(req)
1717-
elif row["task_id"] is not None:
1718-
tasks_to_kill.append(str(row["task_id"]))
1719-
1720-
batches.append(
1721-
DispatchBatch(
1722-
worker_id=WorkerId(wid),
1723-
worker_address=str(worker_row["address"]),
1724-
running_tasks=[
1725-
RunningTaskEntry(
1726-
task_id=JobName.from_wire(str(row["task_id"])),
1727-
attempt_id=int(row["current_attempt_id"]),
1728-
)
1729-
for row in w_running
1730-
],
1731-
tasks_to_run=tasks_to_run,
1732-
tasks_to_kill=tasks_to_kill,
1733-
)
1736+
batches: list[DispatchBatch] = []
1737+
for worker_row in worker_rows:
1738+
wid = str(worker_row["worker_id"])
1739+
w_dispatch = dispatch_by_worker.get(wid, [])
1740+
w_running = running_by_worker.get(wid, [])
1741+
1742+
tasks_to_run: list[cluster_pb2.Worker.RunTaskRequest] = []
1743+
tasks_to_kill: list[str] = []
1744+
for row in w_dispatch:
1745+
if str(row["kind"]) == "run" and row["payload_proto"] is not None:
1746+
req = cluster_pb2.Worker.RunTaskRequest()
1747+
req.ParseFromString(bytes(row["payload_proto"]))
1748+
tasks_to_run.append(req)
1749+
elif row["task_id"] is not None:
1750+
tasks_to_kill.append(str(row["task_id"]))
1751+
1752+
batches.append(
1753+
DispatchBatch(
1754+
worker_id=WorkerId(wid),
1755+
worker_address=str(worker_row["address"]),
1756+
running_tasks=[
1757+
RunningTaskEntry(
1758+
task_id=JobName.from_wire(str(row["task_id"])),
1759+
attempt_id=int(row["current_attempt_id"]),
1760+
)
1761+
for row in w_running
1762+
],
1763+
tasks_to_run=tasks_to_run,
1764+
tasks_to_kill=tasks_to_kill,
17341765
)
1766+
)
17351767

1736-
return batches
1768+
return batches
17371769

17381770
def requeue_dispatch(self, batch: DispatchBatch) -> None:
17391771
"""Re-queue drained dispatch payloads for later delivery."""
@@ -1819,11 +1851,37 @@ def prune_old_data(
18191851
txn_cutoff_ms = now_ms - txn_action_retention.to_ms()
18201852

18211853
terminal_states = tuple(TERMINAL_JOB_STATES)
1854+
placeholders = ",".join("?" * len(terminal_states))
1855+
1856+
# Cheap pre-check via read snapshot: skip the write lock when nothing is old enough
1857+
with self._db.read_snapshot() as snap:
1858+
has_work = (
1859+
snap.fetchone(
1860+
f"SELECT 1 FROM jobs WHERE state IN ({placeholders})"
1861+
" AND finished_at_ms IS NOT NULL AND finished_at_ms < ? LIMIT 1",
1862+
(*terminal_states, job_cutoff_ms),
1863+
)
1864+
or snap.fetchone(
1865+
"SELECT 1 FROM workers WHERE (active = 0 OR healthy = 0) AND last_heartbeat_ms < ? LIMIT 1",
1866+
(worker_cutoff_ms,),
1867+
)
1868+
or snap.fetchone(
1869+
"SELECT 1 FROM logs WHERE epoch_ms < ? LIMIT 1",
1870+
(log_cutoff_ms,),
1871+
)
1872+
or snap.fetchone(
1873+
"SELECT 1 FROM txn_actions WHERE created_at_ms < ? LIMIT 1",
1874+
(txn_cutoff_ms,),
1875+
)
1876+
)
1877+
1878+
if not has_work:
1879+
return PruneResult(jobs_deleted=0, workers_deleted=0, logs_deleted=0, txn_actions_deleted=0)
1880+
18221881
actions: list[tuple[str, str, dict[str, object]]] = []
18231882

18241883
with self._db.transaction() as cur:
18251884
# 1. Terminal jobs finished before the cutoff
1826-
placeholders = ",".join("?" * len(terminal_states))
18271885
job_rows = cur.execute(
18281886
f"SELECT job_id FROM jobs WHERE state IN ({placeholders})"
18291887
" AND finished_at_ms IS NOT NULL AND finished_at_ms < ?",

lib/iris/tests/cluster/controller/test_transitions.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3144,6 +3144,74 @@ def test_prune_noop_when_nothing_old(state):
31443144
assert result.total == 0
31453145

31463146

3147+
# =============================================================================
3148+
# drain_dispatch_all Tests
3149+
# =============================================================================
3150+
3151+
3152+
def test_drain_dispatch_all_excludes_reservation_holders(state):
3153+
"""drain_dispatch_all returns running tasks but filters out reservation-holder tasks."""
3154+
wid = register_worker(state, "w1", "host:8080", make_worker_metadata())
3155+
3156+
normal_req = make_job_request("normal-job")
3157+
normal_tasks = submit_job(state, "normal-job", normal_req)
3158+
dispatch_task(state, normal_tasks[0], wid)
3159+
3160+
holder_req = make_job_request("holder-job")
3161+
holder_tasks = submit_job(state, "holder-job", holder_req)
3162+
holder_job_id = JobName.root("test-user", "holder-job")
3163+
state._db.execute(
3164+
"UPDATE jobs SET is_reservation_holder = 1 WHERE job_id = ?",
3165+
(holder_job_id.to_wire(),),
3166+
)
3167+
dispatch_task(state, holder_tasks[0], wid)
3168+
3169+
batches = state.drain_dispatch_all()
3170+
assert len(batches) == 1
3171+
batch = batches[0]
3172+
running_task_ids = {entry.task_id for entry in batch.running_tasks}
3173+
3174+
assert normal_tasks[0].task_id in running_task_ids
3175+
assert holder_tasks[0].task_id not in running_task_ids
3176+
3177+
3178+
def test_drain_dispatch_all_drains_dispatch_queue(state):
3179+
"""drain_dispatch_all drains queued dispatches and deletes them from the queue."""
3180+
wid = register_worker(state, "w1", "host:8080", make_worker_metadata())
3181+
3182+
req = make_job_request("j1")
3183+
tasks = submit_job(state, "j1", req)
3184+
state.queue_assignments([Assignment(task_id=tasks[0].task_id, worker_id=wid)])
3185+
3186+
rows_before = state._db.fetchall("SELECT * FROM dispatch_queue WHERE worker_id = ?", (str(wid),))
3187+
assert len(rows_before) > 0
3188+
3189+
batches = state.drain_dispatch_all()
3190+
assert len(batches) == 1
3191+
assert len(batches[0].tasks_to_run) > 0
3192+
3193+
rows_after = state._db.fetchall("SELECT * FROM dispatch_queue WHERE worker_id = ?", (str(wid),))
3194+
assert len(rows_after) == 0
3195+
3196+
3197+
def test_prune_old_data_short_circuits_when_nothing_prunable(state):
3198+
"""prune_old_data skips the write lock when a read_snapshot shows nothing to prune."""
3199+
wid = register_worker(state, "w1", "host:8080", make_worker_metadata())
3200+
req = make_job_request("active-job")
3201+
tasks = submit_job(state, "active-job", req)
3202+
dispatch_task(state, tasks[0], wid)
3203+
3204+
result = state.prune_old_data(
3205+
job_retention=Duration.from_seconds(86400),
3206+
worker_retention=Duration.from_seconds(86400),
3207+
log_retention=Duration.from_seconds(86400),
3208+
txn_action_retention=Duration.from_seconds(86400),
3209+
)
3210+
3211+
assert result == PruneResult()
3212+
assert result.total == 0
3213+
3214+
31473215
# =============================================================================
31483216
# Direct Provider Transition Tests
31493217
# =============================================================================

0 commit comments

Comments
 (0)