@@ -1629,15 +1629,25 @@ def drain_dispatch(self, worker_id: WorkerId) -> DispatchBatch | None:
16291629 ).fetchall ()
16301630 if dispatch_rows :
16311631 cur .execute ("DELETE FROM dispatch_queue WHERE worker_id = ?" , (str (worker_id ),))
1632- running_rows = cur .execute (
1633- "SELECT t.task_id, t.current_attempt_id "
1632+ running_rows_raw = cur .execute (
1633+ "SELECT t.task_id, t.current_attempt_id, t.job_id "
16341634 "FROM tasks t "
16351635 "JOIN task_attempts ta ON t.task_id = ta.task_id AND t.current_attempt_id = ta.attempt_id "
1636- "JOIN jobs j ON j.job_id = t.job_id "
1637- "WHERE ta.worker_id = ? AND t.state IN (?, ?, ?) AND j.is_reservation_holder = 0 "
1636+ "WHERE ta.worker_id = ? AND t.state IN (?, ?, ?) "
16381637 "ORDER BY t.task_id ASC" ,
16391638 (str (worker_id ), * ACTIVE_TASK_STATES ),
16401639 ).fetchall ()
1640+ running_job_ids = {str (row ["job_id" ]) for row in running_rows_raw }
1641+ if running_job_ids :
1642+ holder_placeholders = "," .join ("?" for _ in running_job_ids )
1643+ holder_rows = cur .execute (
1644+ f"SELECT job_id FROM jobs WHERE job_id IN ({ holder_placeholders } ) AND is_reservation_holder = 1" ,
1645+ tuple (running_job_ids ),
1646+ ).fetchall ()
1647+ holder_ids = {str (r ["job_id" ]) for r in holder_rows }
1648+ else :
1649+ holder_ids = set ()
1650+ running_rows = [r for r in running_rows_raw if str (r ["job_id" ]) not in holder_ids ]
16411651 tasks_to_run : list [cluster_pb2 .Worker .RunTaskRequest ] = []
16421652 tasks_to_kill : list [str ] = []
16431653 for row in dispatch_rows :
@@ -1662,16 +1672,47 @@ def drain_dispatch(self, worker_id: WorkerId) -> DispatchBatch | None:
16621672 )
16631673
16641674 def drain_dispatch_all (self ) -> list [DispatchBatch ]:
1665- """Drain buffered dispatches and snapshot running tasks for all healthy active workers in one transaction."""
1666- with self ._db .transaction () as cur :
1667- worker_rows = cur .execute (
1675+ """Drain buffered dispatches and snapshot running tasks for all healthy active workers.
1676+
1677+ Reads (workers, running tasks, reservation filter) use a read snapshot
1678+ to avoid holding the write lock. The write lock is only held for the
1679+ dispatch_queue SELECT + DELETE.
1680+ """
1681+ # -- Phase 1: read-only queries (no write lock) --
1682+ with self ._db .read_snapshot () as snap :
1683+ worker_rows = snap .fetchall (
16681684 "SELECT worker_id, address, metadata_proto FROM workers WHERE active = 1 AND healthy = 1"
1669- ). fetchall ()
1685+ )
16701686 if not worker_rows :
16711687 return []
16721688
16731689 worker_id_set = {str (row ["worker_id" ]) for row in worker_rows }
1674- placeholders = "," .join ("?" for _ in worker_id_set )
1690+
1691+ running_rows = snap .fetchall (
1692+ "SELECT ta.worker_id, t.task_id, t.current_attempt_id, t.job_id "
1693+ "FROM tasks t "
1694+ "JOIN task_attempts ta ON t.task_id = ta.task_id AND t.current_attempt_id = ta.attempt_id "
1695+ "WHERE t.state IN (?, ?, ?) "
1696+ "ORDER BY t.task_id ASC" ,
1697+ tuple (ACTIVE_TASK_STATES ),
1698+ )
1699+
1700+ # Batch-check reservation holders instead of joining the jobs table
1701+ running_job_ids = {str (row ["job_id" ]) for row in running_rows }
1702+ reservation_holder_ids : set [str ] = set ()
1703+ if running_job_ids :
1704+ job_placeholders = "," .join ("?" for _ in running_job_ids )
1705+ res_rows = snap .fetchall (
1706+ f"SELECT job_id FROM jobs WHERE job_id IN ({ job_placeholders } ) AND is_reservation_holder = 1" ,
1707+ tuple (running_job_ids ),
1708+ )
1709+ reservation_holder_ids = {str (row ["job_id" ]) for row in res_rows }
1710+
1711+ running_rows = [row for row in running_rows if str (row ["job_id" ]) not in reservation_holder_ids ]
1712+
1713+ # -- Phase 2: write lock only for dispatch_queue drain --
1714+ placeholders = "," .join ("?" for _ in worker_id_set )
1715+ with self ._db .transaction () as cur :
16751716 dispatch_rows = cur .execute (
16761717 f"SELECT worker_id, id, kind, payload_proto, task_id FROM dispatch_queue "
16771718 f"WHERE worker_id IN ({ placeholders } ) ORDER BY id ASC" ,
@@ -1683,57 +1724,48 @@ def drain_dispatch_all(self) -> list[DispatchBatch]:
16831724 tuple (worker_id_set ),
16841725 )
16851726
1686- running_rows = cur .execute (
1687- "SELECT ta.worker_id, t.task_id, t.current_attempt_id "
1688- "FROM tasks t "
1689- "JOIN task_attempts ta ON t.task_id = ta.task_id AND t.current_attempt_id = ta.attempt_id "
1690- "JOIN jobs j ON j.job_id = t.job_id "
1691- "WHERE t.state IN (?, ?, ?) AND j.is_reservation_holder = 0 "
1692- "ORDER BY t.task_id ASC" ,
1693- (* ACTIVE_TASK_STATES ,),
1694- ).fetchall ()
1727+ # -- Phase 3: build results (pure Python, no lock) --
1728+ dispatch_by_worker : dict [str , list [Any ]] = defaultdict (list )
1729+ for row in dispatch_rows :
1730+ dispatch_by_worker [str (row ["worker_id" ])].append (row )
16951731
1696- dispatch_by_worker : dict [str , list [Any ]] = defaultdict (list )
1697- for row in dispatch_rows :
1698- dispatch_by_worker [str (row ["worker_id" ])].append (row )
1732+ running_by_worker : dict [str , list [Any ]] = defaultdict (list )
1733+ for row in running_rows :
1734+ running_by_worker [str (row ["worker_id" ])].append (row )
16991735
1700- running_by_worker : dict [str , list [Any ]] = defaultdict (list )
1701- for row in running_rows :
1702- running_by_worker [str (row ["worker_id" ])].append (row )
1703-
1704- batches : list [DispatchBatch ] = []
1705- for worker_row in worker_rows :
1706- wid = str (worker_row ["worker_id" ])
1707- w_dispatch = dispatch_by_worker .get (wid , [])
1708- w_running = running_by_worker .get (wid , [])
1709-
1710- tasks_to_run : list [cluster_pb2 .Worker .RunTaskRequest ] = []
1711- tasks_to_kill : list [str ] = []
1712- for row in w_dispatch :
1713- if str (row ["kind" ]) == "run" and row ["payload_proto" ] is not None :
1714- req = cluster_pb2 .Worker .RunTaskRequest ()
1715- req .ParseFromString (bytes (row ["payload_proto" ]))
1716- tasks_to_run .append (req )
1717- elif row ["task_id" ] is not None :
1718- tasks_to_kill .append (str (row ["task_id" ]))
1719-
1720- batches .append (
1721- DispatchBatch (
1722- worker_id = WorkerId (wid ),
1723- worker_address = str (worker_row ["address" ]),
1724- running_tasks = [
1725- RunningTaskEntry (
1726- task_id = JobName .from_wire (str (row ["task_id" ])),
1727- attempt_id = int (row ["current_attempt_id" ]),
1728- )
1729- for row in w_running
1730- ],
1731- tasks_to_run = tasks_to_run ,
1732- tasks_to_kill = tasks_to_kill ,
1733- )
1736+ batches : list [DispatchBatch ] = []
1737+ for worker_row in worker_rows :
1738+ wid = str (worker_row ["worker_id" ])
1739+ w_dispatch = dispatch_by_worker .get (wid , [])
1740+ w_running = running_by_worker .get (wid , [])
1741+
1742+ tasks_to_run : list [cluster_pb2 .Worker .RunTaskRequest ] = []
1743+ tasks_to_kill : list [str ] = []
1744+ for row in w_dispatch :
1745+ if str (row ["kind" ]) == "run" and row ["payload_proto" ] is not None :
1746+ req = cluster_pb2 .Worker .RunTaskRequest ()
1747+ req .ParseFromString (bytes (row ["payload_proto" ]))
1748+ tasks_to_run .append (req )
1749+ elif row ["task_id" ] is not None :
1750+ tasks_to_kill .append (str (row ["task_id" ]))
1751+
1752+ batches .append (
1753+ DispatchBatch (
1754+ worker_id = WorkerId (wid ),
1755+ worker_address = str (worker_row ["address" ]),
1756+ running_tasks = [
1757+ RunningTaskEntry (
1758+ task_id = JobName .from_wire (str (row ["task_id" ])),
1759+ attempt_id = int (row ["current_attempt_id" ]),
1760+ )
1761+ for row in w_running
1762+ ],
1763+ tasks_to_run = tasks_to_run ,
1764+ tasks_to_kill = tasks_to_kill ,
17341765 )
1766+ )
17351767
1736- return batches
1768+ return batches
17371769
17381770 def requeue_dispatch (self , batch : DispatchBatch ) -> None :
17391771 """Re-queue drained dispatch payloads for later delivery."""
@@ -1819,11 +1851,37 @@ def prune_old_data(
18191851 txn_cutoff_ms = now_ms - txn_action_retention .to_ms ()
18201852
18211853 terminal_states = tuple (TERMINAL_JOB_STATES )
1854+ placeholders = "," .join ("?" * len (terminal_states ))
1855+
1856+ # Cheap pre-check via read snapshot: skip the write lock when nothing is old enough
1857+ with self ._db .read_snapshot () as snap :
1858+ has_work = (
1859+ snap .fetchone (
1860+ f"SELECT 1 FROM jobs WHERE state IN ({ placeholders } )"
1861+ " AND finished_at_ms IS NOT NULL AND finished_at_ms < ? LIMIT 1" ,
1862+ (* terminal_states , job_cutoff_ms ),
1863+ )
1864+ or snap .fetchone (
1865+ "SELECT 1 FROM workers WHERE (active = 0 OR healthy = 0) AND last_heartbeat_ms < ? LIMIT 1" ,
1866+ (worker_cutoff_ms ,),
1867+ )
1868+ or snap .fetchone (
1869+ "SELECT 1 FROM logs WHERE epoch_ms < ? LIMIT 1" ,
1870+ (log_cutoff_ms ,),
1871+ )
1872+ or snap .fetchone (
1873+ "SELECT 1 FROM txn_actions WHERE created_at_ms < ? LIMIT 1" ,
1874+ (txn_cutoff_ms ,),
1875+ )
1876+ )
1877+
1878+ if not has_work :
1879+ return PruneResult (jobs_deleted = 0 , workers_deleted = 0 , logs_deleted = 0 , txn_actions_deleted = 0 )
1880+
18221881 actions : list [tuple [str , str , dict [str , object ]]] = []
18231882
18241883 with self ._db .transaction () as cur :
18251884 # 1. Terminal jobs finished before the cutoff
1826- placeholders = "," .join ("?" * len (terminal_states ))
18271885 job_rows = cur .execute (
18281886 f"SELECT job_id FROM jobs WHERE state IN ({ placeholders } )"
18291887 " AND finished_at_ms IS NOT NULL AND finished_at_ms < ?" ,
0 commit comments