diff --git a/lib/iris/scripts/benchmark_db_queries.py b/lib/iris/scripts/benchmark_db_queries.py index 450d80cad0..50d2f9e73f 100644 --- a/lib/iris/scripts/benchmark_db_queries.py +++ b/lib/iris/scripts/benchmark_db_queries.py @@ -72,6 +72,7 @@ _worker_roster, ) from iris.cluster.controller.schema import EndpointRow +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ( Assignment, ControllerTransitions, @@ -252,7 +253,8 @@ def benchmark_scheduling(db: ControllerDB) -> None: # --- Write-path benchmarks (use a lightweight clone) --- write_db = clone_db(db) - write_txns = ControllerTransitions(write_db) + write_store = ControllerStore(write_db) + write_txns = ControllerTransitions(store=write_store) try: # queue_assignments: the main write-lock holder in scheduling. @@ -542,7 +544,7 @@ def _all_workers_running_tasks(): bench("running_tasks_by_worker", lambda: running_tasks_by_worker(db, worker_ids)) - transitions = ControllerTransitions(db) + transitions = ControllerTransitions(store=ControllerStore(db)) bench( f"get_running_tasks_for_poll ({len(workers)} workers)", lambda: transitions.get_running_tasks_for_poll(), @@ -594,7 +596,7 @@ def _all_workers_running_tasks(): ) hb_db = clone_db(db) - hb_transitions = ControllerTransitions(hb_db) + hb_transitions = ControllerTransitions(store=ControllerStore(hb_db)) try: bench( @@ -739,14 +741,16 @@ def benchmark_endpoints(db: ControllerDB) -> None: contention (matches the production Register p95 of 3-4s) """ # Read-path queries run against the source DB (cheap, no clone needed). - bench("endpoint_registry.query (all)", lambda: db.endpoints.query()) + read_store = ControllerStore(db) + bench("endpoint_store.query (all)", lambda: read_store.endpoints.query()) bench( - "endpoint_registry.query (prefix)", - lambda: db.endpoints.query(EndpointQuery(name_prefix="test")), + "endpoint_store.query (prefix)", + lambda: read_store.endpoints.query(EndpointQuery(name_prefix="test")), ) write_db = clone_db(db) - write_txns = ControllerTransitions(write_db) + write_store = ControllerStore(write_db) + write_txns = ControllerTransitions(store=write_store) try: sample = _active_task_sample(write_db, limit=300) @@ -762,7 +766,7 @@ def _do_single(): def _reset_single(): write_db.execute("DELETE FROM endpoints WHERE name LIKE '/bench/endpoint/%'") - write_db.endpoints._load_all() + write_store.endpoints._load_all() bench("add_endpoint (1 write)", _do_single, reset=_reset_single) @@ -793,7 +797,7 @@ def _do_burst_per_txn(tasks=tasks_for_burst): def _do_burst_one_txn(tasks=tasks_for_burst): with write_db.transaction() as cur: for t in tasks: - write_db.endpoints.add(cur, _make_endpoint(t)) + write_store.endpoints.add(cur, _make_endpoint(t)) bench( f"add_endpoint burst x{burst_n} (1 txn)", @@ -1365,7 +1369,7 @@ def benchmark_apply_contention(db: ControllerDB) -> None: ] write_db = clone_db(db) - write_txns = ControllerTransitions(write_db) + write_txns = ControllerTransitions(store=ControllerStore(write_db)) try: for scenario in scenarios: _run_apply_under_contention( diff --git a/lib/iris/src/iris/cluster/controller/actor_proxy.py b/lib/iris/src/iris/cluster/controller/actor_proxy.py index 4fdae2cfe9..3a2cab997f 100644 --- a/lib/iris/src/iris/cluster/controller/actor_proxy.py +++ b/lib/iris/src/iris/cluster/controller/actor_proxy.py @@ -20,7 +20,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.stores import ControllerStore logger = logging.getLogger(__name__) @@ -46,10 +46,10 @@ class ActorProxy: - """Forwards ActorService RPCs to actors resolved from the endpoint registry.""" + """Forwards ActorService RPCs to actors resolved from the endpoint store.""" - def __init__(self, db: ControllerDB): - self._db = db + def __init__(self, store: ControllerStore): + self._store = store self._client = httpx.AsyncClient(timeout=PROXY_TIMEOUT_SECONDS) async def close(self) -> None: @@ -97,8 +97,8 @@ async def handle(self, request: Request) -> Response: ) def _resolve_endpoint(self, name: str) -> str | None: - """Resolve an endpoint name to an address via the in-memory registry.""" - row = self._db.endpoints.resolve(name) + """Resolve an endpoint name to an address via the in-memory store.""" + row = self._store.endpoints.resolve(name) if row is None: return None return row.address diff --git a/lib/iris/src/iris/cluster/controller/controller.py b/lib/iris/src/iris/cluster/controller/controller.py index a8172eefaa..4082876710 100644 --- a/lib/iris/src/iris/cluster/controller/controller.py +++ b/lib/iris/src/iris/cluster/controller/controller.py @@ -91,6 +91,7 @@ ) from iris.cluster.controller.auth import ControllerAuth from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ( RESERVATION_HOLDER_JOB_NAME, Assignment, @@ -1038,6 +1039,7 @@ def __init__( self._db = db else: self._db = ControllerDB(db_dir=config.local_state_dir / "db") + self._store = ControllerStore(self._db) # ThreadContainer must be initialized before the log service setup # because _start_local_log_server spawns a uvicorn thread. @@ -1075,7 +1077,7 @@ def __init__( self._health = WorkerHealthTracker() self._transitions = ControllerTransitions( - db=self._db, + store=self._store, health=self._health, ) self._scheduler = Scheduler() @@ -1084,7 +1086,7 @@ def __init__( self._service = ControllerServiceImpl( self._transitions, - self._db, + self._store, controller=self, bundle_store=self._bundle_store, log_service=self._remote_log_service, diff --git a/lib/iris/src/iris/cluster/controller/dashboard.py b/lib/iris/src/iris/cluster/controller/dashboard.py index a217dd66b6..d09576a250 100644 --- a/lib/iris/src/iris/cluster/controller/dashboard.py +++ b/lib/iris/src/iris/cluster/controller/dashboard.py @@ -299,7 +299,7 @@ def _create_app(self) -> ASGIApp: rpc_app = WSGIMiddleware(rpc_wsgi_app) - self._actor_proxy = ActorProxy(self._service._db) + self._actor_proxy = ActorProxy(self._service._store) @requires_auth async def _proxy_actor_rpc(request: Request) -> Response: diff --git a/lib/iris/src/iris/cluster/controller/db.py b/lib/iris/src/iris/cluster/controller/db.py index af26d0df27..445c7760b9 100644 --- a/lib/iris/src/iris/cluster/controller/db.py +++ b/lib/iris/src/iris/cluster/controller/db.py @@ -237,7 +237,7 @@ class TransactionCursor: Post-commit hooks registered via :meth:`on_commit` run after the wrapping ``ControllerDB.transaction()`` block commits successfully. They are used - by caches (e.g. ``EndpointRegistry``) to update in-memory state atomically + by caches (e.g. ``EndpointStore``) to update in-memory state atomically with the DB write: rollback suppresses the hook so memory never drifts from disk. """ @@ -258,6 +258,14 @@ def executescript(self, sql: str) -> sqlite3.Cursor: """Raw SQL script escape hatch.""" return self._cursor.executescript(sql) + def fetchall(self, sql: str, params: tuple = ()) -> list[sqlite3.Row]: + """Execute ``sql`` and return all rows. Mirrors :meth:`QuerySnapshot.fetchall`.""" + return list(self._cursor.execute(sql, params).fetchall()) + + def fetchone(self, sql: str, params: tuple = ()) -> sqlite3.Row | None: + """Execute ``sql`` and return the first row, or None. Mirrors :meth:`QuerySnapshot.fetchone`.""" + return self._cursor.execute(sql, params).fetchone() + def on_commit(self, hook: Callable[[], None]) -> None: """Register ``hook`` to run after the transaction commits successfully.""" self._commit_hooks.append(hook) @@ -321,19 +329,14 @@ def __init__(self, db_dir: Path): self._attr_cache: dict[WorkerId, dict[str, AttributeValue]] | None = None self._attr_cache_lock = Lock() - # Write-through in-memory cache over the ``endpoints`` table. Imported - # locally to break the ``db -> endpoint_registry -> db`` import cycle; - # this is the single exception to "no local imports" (see AGENTS.md). - from iris.cluster.controller.endpoint_registry import EndpointRegistry + # Callables invoked at the end of ``replace_from`` so callers with + # caches over DB contents (e.g. ``ControllerStore``) can reload them + # after a checkpoint restore. Registered via ``register_reopen_hook``. + self._reopen_hooks: list[Callable[[], None]] = [] - t0 = time.monotonic() - self._endpoint_registry = EndpointRegistry(self) - logger.info("EndpointRegistry initialized in %.2fs", time.monotonic() - t0) - - @property - def endpoints(self) -> EndpointRegistry: # noqa: F821 - """Process-local cache for the ``endpoints`` table; authoritative for reads.""" - return self._endpoint_registry + def register_reopen_hook(self, hook: Callable[[], None]) -> None: + """Register a no-arg callable to run at the end of ``replace_from``.""" + self._reopen_hooks.append(hook) def _populate_attr_cache(self) -> dict[WorkerId, dict[str, AttributeValue]]: """Load all worker attributes from the DB into the cache. @@ -454,7 +457,7 @@ def transaction(self): On successful commit, any hooks registered via ``TransactionCursor.on_commit`` fire while the write lock is still held — keeping in-memory caches - (e.g. ``EndpointRegistry``) in sync with the DB without exposing a + (e.g. ``EndpointStore``) in sync with the DB without exposing a torn snapshot to concurrent readers. """ with self._lock: @@ -751,7 +754,8 @@ def replace_from(self, source_dir: str | Path) -> None: self._conn.execute("ATTACH DATABASE ? AS profiles", (str(self._profiles_db_path),)) self._init_read_pool() self.apply_migrations() - self._endpoint_registry._load_all() + for hook in self._reopen_hooks: + hook() # SQL-canonical read access is exposed through ``snapshot()`` and typed table # metadata at module scope. Legacy list/get/count helper methods were removed diff --git a/lib/iris/src/iris/cluster/controller/endpoint_registry.py b/lib/iris/src/iris/cluster/controller/endpoint_registry.py deleted file mode 100644 index 095126f235..0000000000 --- a/lib/iris/src/iris/cluster/controller/endpoint_registry.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Process-local in-memory cache for the ``endpoints`` table. - -Profiling showed that ``ListEndpoints`` dominated controller CPU — not because -the SQL was slow per se, but because every call serialized through the -read-connection pool and walked a large WAL to build a snapshot. The endpoints -table is tiny (hundreds of rows) and only changes on explicit register / -unregister, so it is a natural fit for a write-through in-memory cache. - -Design invariants: - -* Reads never touch the DB. All lookups are served from in-memory maps - guarded by an ``RLock`` — readers observe a consistent snapshot of the - indexes, never a torn state mid-update. -* Writes execute the SQL inside the caller's transaction. The in-memory - update is scheduled as a post-commit hook on the cursor so memory only - changes after the DB has committed. If the transaction rolls back, the - hook never fires. -* N is small enough (≈ hundreds) that linear scans for prefix / task / id - lookups are simpler and plenty fast. Extra indexes (by name, by task_id) - speed the two common cases. - -The registry is the sole source of truth for endpoint reads; nothing else in -the controller tree should SELECT from ``endpoints``. -""" - -from __future__ import annotations - -import json -import logging -from collections.abc import Iterable, Sequence -from threading import RLock - -from iris.cluster.controller.db import EndpointQuery, TransactionCursor -from iris.cluster.controller.schema import ENDPOINT_PROJECTION, EndpointRow -from iris.cluster.types import TERMINAL_TASK_STATES, JobName - -logger = logging.getLogger(__name__) - - -class EndpointRegistry: - """In-memory index of endpoint rows, kept in sync with the DB. - - Construct with a ``ControllerDB``; the registry loads all existing rows at - init time. Callers mutate through ``add`` / ``remove*`` methods that take - the open ``TransactionCursor`` so the SQL lands inside the caller's - transaction. Memory is only updated after a successful commit via a - cursor post-commit hook. - """ - - def __init__(self, db): - self._db = db - self._lock = RLock() - self._by_id: dict[str, EndpointRow] = {} - # One name can map to multiple endpoint_ids — the schema does not enforce - # uniqueness on ``name``, and ``INSERT OR REPLACE`` keys off endpoint_id. - self._by_name: dict[str, set[str]] = {} - self._by_task: dict[JobName, set[str]] = {} - self._load_all() - - # -- Loading -------------------------------------------------------------- - - def _load_all(self) -> None: - with self._db.read_snapshot() as q: - rows = ENDPOINT_PROJECTION.decode( - q.fetchall(f"SELECT {ENDPOINT_PROJECTION.select_clause()} FROM endpoints e"), - ) - with self._lock: - self._by_id.clear() - self._by_name.clear() - self._by_task.clear() - for row in rows: - self._index(row) - logger.info("EndpointRegistry loaded %d endpoint(s) from DB", len(rows)) - - def _index(self, row: EndpointRow) -> None: - self._by_id[row.endpoint_id] = row - self._by_name.setdefault(row.name, set()).add(row.endpoint_id) - self._by_task.setdefault(row.task_id, set()).add(row.endpoint_id) - - def _unindex(self, endpoint_id: str) -> EndpointRow | None: - row = self._by_id.pop(endpoint_id, None) - if row is None: - return None - name_ids = self._by_name.get(row.name) - if name_ids is not None: - name_ids.discard(endpoint_id) - if not name_ids: - self._by_name.pop(row.name, None) - task_ids = self._by_task.get(row.task_id) - if task_ids is not None: - task_ids.discard(endpoint_id) - if not task_ids: - self._by_task.pop(row.task_id, None) - return row - - # -- Reads ---------------------------------------------------------------- - - def query(self, query: EndpointQuery = EndpointQuery()) -> list[EndpointRow]: - """Return endpoint rows matching ``query``. - - All filters AND together, matching the semantics of the original SQL - in :func:`iris.cluster.controller.db.endpoint_query_sql`. - """ - with self._lock: - # Narrow the candidate set using the most selective index available. - if query.endpoint_ids: - candidates: Iterable[EndpointRow] = ( - self._by_id[eid] for eid in query.endpoint_ids if eid in self._by_id - ) - elif query.task_ids: - task_set = set(query.task_ids) - candidates = (self._by_id[eid] for task_id in task_set for eid in self._by_task.get(task_id, ())) - elif query.exact_name is not None: - candidates = (self._by_id[eid] for eid in self._by_name.get(query.exact_name, ())) - else: - candidates = self._by_id.values() - - results: list[EndpointRow] = [] - for row in candidates: - if query.name_prefix is not None and not row.name.startswith(query.name_prefix): - continue - if query.exact_name is not None and row.name != query.exact_name: - continue - if query.task_ids and row.task_id not in query.task_ids: - continue - if query.endpoint_ids and row.endpoint_id not in query.endpoint_ids: - continue - results.append(row) - if query.limit is not None and len(results) >= query.limit: - break - return results - - def resolve(self, name: str) -> EndpointRow | None: - """Return any endpoint with exact ``name``, or None. Used by the actor proxy.""" - with self._lock: - ids = self._by_name.get(name) - if not ids: - return None - # Arbitrary but stable pick — the original SQL did not specify ORDER BY. - return self._by_id[next(iter(ids))] - - def get(self, endpoint_id: str) -> EndpointRow | None: - with self._lock: - return self._by_id.get(endpoint_id) - - def all(self) -> list[EndpointRow]: - with self._lock: - return list(self._by_id.values()) - - # -- Writes --------------------------------------------------------------- - - def add(self, cur: TransactionCursor, endpoint: EndpointRow) -> bool: - """Insert ``endpoint`` into the DB and schedule the memory update. - - Returns False (and writes nothing) if the owning task is already - terminal. Otherwise inserts / replaces and schedules a post-commit - hook that updates the in-memory indexes. - """ - task_id = endpoint.task_id - job_id, _ = task_id.require_task() - row = cur.execute("SELECT state FROM tasks WHERE task_id = ?", (task_id.to_wire(),)).fetchone() - if row is not None and int(row["state"]) in TERMINAL_TASK_STATES: - return False - - cur.execute( - "INSERT OR REPLACE INTO endpoints(" - "endpoint_id, name, address, job_id, task_id, metadata_json, registered_at_ms" - ") VALUES (?, ?, ?, ?, ?, ?, ?)", - ( - endpoint.endpoint_id, - endpoint.name, - endpoint.address, - job_id.to_wire(), - task_id.to_wire(), - json.dumps(endpoint.metadata), - endpoint.registered_at.epoch_ms(), - ), - ) - - def apply() -> None: - with self._lock: - # Replace: drop any previous row with this id first so the - # name/task indexes stay consistent on overwrite. - self._unindex(endpoint.endpoint_id) - self._index(endpoint) - - cur.on_commit(apply) - return True - - def remove(self, cur: TransactionCursor, endpoint_id: str) -> EndpointRow | None: - """Remove a single endpoint by id. Returns the removed row snapshot, if any.""" - existing = self.get(endpoint_id) - if existing is None: - return None - cur.execute("DELETE FROM endpoints WHERE endpoint_id = ?", (endpoint_id,)) - - def apply() -> None: - with self._lock: - self._unindex(endpoint_id) - - cur.on_commit(apply) - return existing - - def remove_by_task(self, cur: TransactionCursor, task_id: JobName) -> list[str]: - """Remove all endpoints owned by a task. Returns the removed endpoint_ids.""" - with self._lock: - ids = list(self._by_task.get(task_id, ())) - if not ids: - # Still issue the DELETE to stay consistent with any rows the - # registry might not have observed yet (belt-and-suspenders for - # the unlikely race of an in-flight concurrent writer). This - # costs nothing on the common path. - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) - return [] - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) - - def apply() -> None: - with self._lock: - for eid in ids: - self._unindex(eid) - - cur.on_commit(apply) - return ids - - def remove_by_job_ids(self, cur: TransactionCursor, job_ids: Sequence[JobName]) -> list[str]: - """Remove all endpoints owned by any of ``job_ids``. Used by cancel_job and prune.""" - if not job_ids: - return [] - wire_ids = [jid.to_wire() for jid in job_ids] - with self._lock: - to_remove: list[str] = [] - for row in self._by_id.values(): - owning_job, _ = row.task_id.require_task() - if owning_job.to_wire() in wire_ids: - to_remove.append(row.endpoint_id) - placeholders = ",".join("?" for _ in wire_ids) - cur.execute( - f"DELETE FROM endpoints WHERE job_id IN ({placeholders})", - tuple(wire_ids), - ) - if not to_remove: - return [] - - def apply() -> None: - with self._lock: - for eid in to_remove: - self._unindex(eid) - - cur.on_commit(apply) - return to_remove diff --git a/lib/iris/src/iris/cluster/controller/service.py b/lib/iris/src/iris/cluster/controller/service.py index d2719d03cb..777f632aaf 100644 --- a/lib/iris/src/iris/cluster/controller/service.py +++ b/lib/iris/src/iris/cluster/controller/service.py @@ -90,6 +90,7 @@ from iris.cluster.controller.query import execute_raw_query from iris.rpc import query_pb2 from iris.cluster.controller.scheduler import SchedulingContext +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ( TASK_RESOURCE_HISTORY_RETENTION, ControllerTransitions, @@ -989,7 +990,7 @@ class ControllerServiceImpl: Args: transitions: State machine for DB mutations (submit, cancel, register, etc.) - db: Query interface for direct DB reads + store: Controller store bundle (per-entity stores + transaction / read_snapshot). controller: Controller runtime for scheduling and worker management bundle_store: Bundle store for zip storage. log_service: LogService for fetching logs (in-process or remote proxy). @@ -998,7 +999,7 @@ class ControllerServiceImpl: def __init__( self, transitions: ControllerTransitions, - db: ControllerDB, + store: ControllerStore, controller: ControllerProtocol, bundle_store: BundleStore, log_service: LogServiceImpl | LogServiceProxy, @@ -1007,7 +1008,8 @@ def __init__( user_budget_defaults: UserBudgetDefaults | None = None, ): self._transitions = transitions - self._db = db + self._store = store + self._db = store._db self._controller = controller self._bundle_store = bundle_store self._log_service = log_service @@ -1742,7 +1744,7 @@ def list_endpoints( if prefix.startswith("/system/"): return self._list_system_endpoints(prefix, exact=request.exact) - endpoints = self._db.endpoints.query( + endpoints = self._store.endpoints.query( EndpointQuery( exact_name=prefix if request.exact else None, name_prefix=None if request.exact else prefix, diff --git a/lib/iris/src/iris/cluster/controller/stores.py b/lib/iris/src/iris/cluster/controller/stores.py new file mode 100644 index 0000000000..bb9e68108c --- /dev/null +++ b/lib/iris/src/iris/cluster/controller/stores.py @@ -0,0 +1,726 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Typed store layer over :mod:`iris.cluster.controller.db`. + +Stores group related SQL against a single entity (jobs, tasks, workers, +endpoints, ...) and expose a typed API that callers invoke inside an open +transaction (read or write). :class:`ControllerStore` bundles every per-entity +store and forwards ``transaction()`` / ``read_snapshot()`` to the underlying +:class:`ControllerDB`. + +Dependency chain (target state):: + + db.py — connections, migrations, transaction context managers + schema.py — table DDL, row dataclasses, projections + stores.py — depends on { db, schema }; per-entity stores + transitions.py — depends on stores; stores own the SQL + +The layer is introduced incrementally. The current state is mid-migration: +``EndpointStore`` and ``JobStore`` are populated, while ``TaskStore``, +``TaskAttemptStore``, ``WorkerStore``, ``DispatchQueueStore`` and +``ReservationStore`` are still empty skeletons. ``ControllerTransitions`` +keeps a temporary ``self._db`` backdoor for SQL that has not yet been +moved (tasks, workers, dispatch queue, reservations, the ``meta`` table, +worker-attribute cache). That backdoor is removed in a later phase once +every entity has a store. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from threading import RLock + +from iris.cluster.controller.db import ControllerDB, EndpointQuery, QuerySnapshot, TransactionCursor +from iris.cluster.controller.schema import ( + ENDPOINT_PROJECTION, + JOB_CONFIG_JOIN, + JOB_DETAIL_PROJECTION, + EndpointRow, + JobDetailRow, +) +from iris.cluster.types import TERMINAL_JOB_STATES, TERMINAL_TASK_STATES, JobName +from iris.rpc import job_pb2 + +logger = logging.getLogger(__name__) + + +# Store read methods accept either a write cursor or a read snapshot. Writes +# require ``TransactionCursor`` explicitly so a ``QuerySnapshot`` can't be +# accidentally passed to a mutating API. (This alias does *not* prevent a store +# read method from issuing writes internally — it just polices the caller-side +# direction. A read-only ``Protocol`` would be stricter; not yet worth the +# plumbing.) +Tx = TransactionCursor | QuerySnapshot + + +# ============================================================================= +# EndpointStore +# ============================================================================= + + +class EndpointStore: + """Process-local write-through cache over the ``endpoints`` table. + + Profiling showed ``ListEndpoints`` dominated controller CPU — not because + the SQL was slow per se, but because every call serialized through the + read-connection pool and walked a large WAL to build a snapshot. The + endpoints table is tiny (hundreds of rows) and only changes on explicit + register / unregister, so it is a natural fit for a write-through + in-memory cache. + + Design invariants: + + * Reads never touch the DB. All lookups are served from in-memory maps + guarded by an ``RLock`` — readers observe a consistent snapshot of the + indexes, never a torn state mid-update. + * Writes execute the SQL inside the caller's transaction. The in-memory + update is scheduled as a post-commit hook on the cursor so memory only + changes after the DB has committed. If the transaction rolls back, the + hook never fires. + * N is small enough (≈ hundreds) that linear scans for prefix / task / id + lookups are simpler and plenty fast. Extra indexes (by name, by task_id) + speed the two common cases. + + The store is the sole source of truth for endpoint reads; nothing else in + the controller tree should SELECT from ``endpoints``. + """ + + def __init__(self, db: ControllerDB) -> None: + self._db = db + self._lock = RLock() + self._by_id: dict[str, EndpointRow] = {} + # One name can map to multiple endpoint_ids — the schema does not enforce + # uniqueness on ``name``, and ``INSERT OR REPLACE`` keys off endpoint_id. + self._by_name: dict[str, set[str]] = {} + self._by_task: dict[JobName, set[str]] = {} + self._load_all() + + # -- Loading -------------------------------------------------------------- + + def _load_all(self) -> None: + with self._db.read_snapshot() as q: + rows = ENDPOINT_PROJECTION.decode( + q.fetchall(f"SELECT {ENDPOINT_PROJECTION.select_clause()} FROM endpoints e"), + ) + with self._lock: + self._by_id.clear() + self._by_name.clear() + self._by_task.clear() + for row in rows: + self._index(row) + logger.info("EndpointStore loaded %d endpoint(s) from DB", len(rows)) + + def _index(self, row: EndpointRow) -> None: + self._by_id[row.endpoint_id] = row + self._by_name.setdefault(row.name, set()).add(row.endpoint_id) + self._by_task.setdefault(row.task_id, set()).add(row.endpoint_id) + + def _unindex(self, endpoint_id: str) -> EndpointRow | None: + row = self._by_id.pop(endpoint_id, None) + if row is None: + return None + name_ids = self._by_name.get(row.name) + if name_ids is not None: + name_ids.discard(endpoint_id) + if not name_ids: + self._by_name.pop(row.name, None) + task_ids = self._by_task.get(row.task_id) + if task_ids is not None: + task_ids.discard(endpoint_id) + if not task_ids: + self._by_task.pop(row.task_id, None) + return row + + # -- Reads ---------------------------------------------------------------- + + def query(self, query: EndpointQuery = EndpointQuery()) -> list[EndpointRow]: + """Return endpoint rows matching ``query``; all filters AND together.""" + with self._lock: + # Narrow the candidate set using the most selective index available. + if query.endpoint_ids: + candidates: Iterable[EndpointRow] = ( + self._by_id[eid] for eid in query.endpoint_ids if eid in self._by_id + ) + elif query.task_ids: + task_set = set(query.task_ids) + candidates = (self._by_id[eid] for task_id in task_set for eid in self._by_task.get(task_id, ())) + elif query.exact_name is not None: + candidates = (self._by_id[eid] for eid in self._by_name.get(query.exact_name, ())) + else: + candidates = self._by_id.values() + + results: list[EndpointRow] = [] + for row in candidates: + if query.name_prefix is not None and not row.name.startswith(query.name_prefix): + continue + if query.exact_name is not None and row.name != query.exact_name: + continue + if query.task_ids and row.task_id not in query.task_ids: + continue + if query.endpoint_ids and row.endpoint_id not in query.endpoint_ids: + continue + results.append(row) + if query.limit is not None and len(results) >= query.limit: + break + return results + + def resolve(self, name: str) -> EndpointRow | None: + """Return any endpoint with exact ``name``, or None. Used by the actor proxy.""" + with self._lock: + ids = self._by_name.get(name) + if not ids: + return None + # Arbitrary but stable pick — the original SQL did not specify ORDER BY. + return self._by_id[next(iter(ids))] + + def get(self, endpoint_id: str) -> EndpointRow | None: + with self._lock: + return self._by_id.get(endpoint_id) + + def all(self) -> list[EndpointRow]: + with self._lock: + return list(self._by_id.values()) + + # -- Writes --------------------------------------------------------------- + + def add(self, cur: TransactionCursor, endpoint: EndpointRow) -> bool: + """Insert ``endpoint`` into the DB and schedule the memory update. + + Returns False (and writes nothing) if the owning task is already + terminal. Otherwise inserts / replaces and schedules a post-commit + hook that updates the in-memory indexes. + """ + task_id = endpoint.task_id + job_id, _ = task_id.require_task() + row = cur.execute("SELECT state FROM tasks WHERE task_id = ?", (task_id.to_wire(),)).fetchone() + if row is not None and int(row["state"]) in TERMINAL_TASK_STATES: + return False + + cur.execute( + "INSERT OR REPLACE INTO endpoints(" + "endpoint_id, name, address, job_id, task_id, metadata_json, registered_at_ms" + ") VALUES (?, ?, ?, ?, ?, ?, ?)", + ( + endpoint.endpoint_id, + endpoint.name, + endpoint.address, + job_id.to_wire(), + task_id.to_wire(), + json.dumps(endpoint.metadata), + endpoint.registered_at.epoch_ms(), + ), + ) + + def apply() -> None: + with self._lock: + # Replace: drop any previous row with this id first so the + # name/task indexes stay consistent on overwrite. + self._unindex(endpoint.endpoint_id) + self._index(endpoint) + + cur.on_commit(apply) + return True + + def remove(self, cur: TransactionCursor, endpoint_id: str) -> EndpointRow | None: + """Remove a single endpoint by id. Returns the removed row snapshot, if any.""" + existing = self.get(endpoint_id) + if existing is None: + return None + cur.execute("DELETE FROM endpoints WHERE endpoint_id = ?", (endpoint_id,)) + + def apply() -> None: + with self._lock: + self._unindex(endpoint_id) + + cur.on_commit(apply) + return existing + + def remove_by_task(self, cur: TransactionCursor, task_id: JobName) -> list[str]: + """Remove all endpoints owned by a task. Returns the removed endpoint_ids.""" + with self._lock: + ids = list(self._by_task.get(task_id, ())) + if not ids: + # Still issue the DELETE to stay consistent with any rows the + # store might not have observed yet (belt-and-suspenders for + # the unlikely race of an in-flight concurrent writer). This + # costs nothing on the common path. + cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) + return [] + cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) + + def apply() -> None: + with self._lock: + for eid in ids: + self._unindex(eid) + + cur.on_commit(apply) + return ids + + def remove_by_job_ids(self, cur: TransactionCursor, job_ids: Sequence[JobName]) -> list[str]: + """Remove all endpoints owned by any of ``job_ids``. Used by cancel_job and prune.""" + if not job_ids: + return [] + wire_ids = [jid.to_wire() for jid in job_ids] + with self._lock: + to_remove: list[str] = [] + for row in self._by_id.values(): + owning_job, _ = row.task_id.require_task() + if owning_job.to_wire() in wire_ids: + to_remove.append(row.endpoint_id) + placeholders = ",".join("?" for _ in wire_ids) + cur.execute( + f"DELETE FROM endpoints WHERE job_id IN ({placeholders})", + tuple(wire_ids), + ) + if not to_remove: + return [] + + def apply() -> None: + with self._lock: + for eid in to_remove: + self._unindex(eid) + + cur.on_commit(apply) + return to_remove + + +# ============================================================================= +# Phase-1 skeletons for the remaining per-entity stores. +# +# These exist so callers can already reference ``store.jobs`` etc. and so that +# subsequent phases (moving SQL out of transitions.py) land as additive +# changes to these classes rather than needing new plumbing each time. +# Methods are added as the corresponding SQL migrates out of transitions.py. +# ============================================================================= + + +@dataclass(frozen=True, slots=True) +class JobInsertParams: + """Fields needed to insert one row into the ``jobs`` table. + + Holder jobs set ``is_reservation_holder=True`` and leave ``error`` / + ``exit_code`` / ``finished_at_ms`` / ``scheduling_deadline_epoch_ms`` None; + the regular path passes the corresponding submit-time values. + """ + + job_id: JobName + user_id: str + parent_job_id: str | None + root_job_id: str + depth: int + state: int + submitted_at_ms: int + root_submitted_at_ms: int + started_at_ms: int | None + finished_at_ms: int | None + scheduling_deadline_epoch_ms: int | None + error: str | None + exit_code: int | None + num_tasks: int + is_reservation_holder: bool + name: str + has_reservation: bool + + +@dataclass(frozen=True, slots=True) +class JobConfigInsertParams: + """Fields needed to insert one row into the ``job_config`` table. + + Holder jobs do not set ``submit_argv`` / ``reservation`` / ``fail_if_exists``; + those have defaults so the holder path can omit them. + """ + + job_id: JobName + name: str + has_reservation: bool + res_cpu_millicores: int + res_memory_bytes: int + res_disk_bytes: int + res_device_json: str | None + constraints_json: str + has_coscheduling: bool + coscheduling_group_by: str + scheduling_timeout_ms: int | None + max_task_failures: int + entrypoint_json: str + environment_json: str + bundle_id: str + ports_json: str + max_retries_failure: int + max_retries_preemption: int + timeout_ms: int | None + preemption_policy: int + existing_job_policy: int + priority_band: int + task_image: str + submit_argv_json: str = "[]" + reservation_json: str | None = None + fail_if_exists: bool = False + + +@dataclass(frozen=True, slots=True) +class JobRecomputeBasis: + state: int + started_at_ms: int | None + max_task_failures: int + + +class JobStore: + """Jobs, job_config, users, user_budgets. + + Holds the SQL for the four tables the controller uses to track a submitted + job's lifecycle. Reads take a ``Tx`` (read snapshot or write cursor); + writes require a ``TransactionCursor`` so static typing rules out + mutations through a read-only snapshot. + """ + + def __init__(self, db: ControllerDB) -> None: + self._db = db + + # -- Reads --------------------------------------------------------------- + + def get_state(self, tx: Tx, job_id: JobName) -> int | None: + row = tx.fetchone("SELECT state FROM jobs WHERE job_id = ?", (job_id.to_wire(),)) + return int(row["state"]) if row is not None else None + + def get_root_submitted_at_ms(self, tx: Tx, job_id: JobName) -> int | None: + row = tx.fetchone("SELECT root_submitted_at_ms FROM jobs WHERE job_id = ?", (job_id.to_wire(),)) + return int(row["root_submitted_at_ms"]) if row is not None else None + + def get_preemption_info(self, tx: Tx, job_id: JobName) -> tuple[int, int] | None: + """Return ``(preemption_policy, num_tasks)`` or None if the job is gone.""" + row = tx.fetchone( + f"SELECT jc.preemption_policy, j.num_tasks FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", + (job_id.to_wire(),), + ) + if row is None: + return None + return int(row["preemption_policy"]), int(row["num_tasks"]) + + def get_recompute_basis(self, tx: Tx, job_id: JobName) -> JobRecomputeBasis | None: + row = tx.fetchone( + f"SELECT j.state, j.started_at_ms, jc.max_task_failures " + f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", + (job_id.to_wire(),), + ) + if row is None: + return None + return JobRecomputeBasis( + state=int(row["state"]), + started_at_ms=int(row["started_at_ms"]) if row["started_at_ms"] is not None else None, + max_task_failures=int(row["max_task_failures"]), + ) + + def get_detail(self, tx: Tx, job_id: JobName) -> JobDetailRow | None: + row = tx.fetchone( + f"SELECT {JOB_DETAIL_PROJECTION.select_clause()} " f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", + (job_id.to_wire(),), + ) + if row is None: + return None + return JOB_DETAIL_PROJECTION.decode_one([row]) + + def get_config(self, tx: Tx, job_id: JobName) -> dict | None: + """Return the raw ``job_config`` row as a dict, or None. + + Callers currently access fields by string key (e.g. ``jc["res_cpu_millicores"]``); + returning a dict keeps the existing consumers working while SQL moves + behind the store. + """ + row = tx.fetchone("SELECT * FROM job_config WHERE job_id = ?", (job_id.to_wire(),)) + return dict(row) if row is not None else None + + def list_descendants( + self, + tx: Tx, + parent_id: JobName, + *, + exclude_reservation_holders: bool = False, + ) -> list[JobName]: + """Return all transitive descendants of ``parent_id`` (not ``parent_id`` itself). + + When ``exclude_reservation_holders`` is True, reservation-holder jobs and + anything below them are skipped — used during preemption retry, where the + parent goes back to PENDING and needs its reservation subtree preserved. + """ + if exclude_reservation_holders: + rows = tx.fetchall( + "WITH RECURSIVE subtree(job_id) AS (" + " SELECT job_id FROM jobs WHERE parent_job_id = ? AND is_reservation_holder = 0 " + " UNION ALL " + " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" + " WHERE j.is_reservation_holder = 0" + ") SELECT job_id FROM subtree", + (parent_id.to_wire(),), + ) + else: + rows = tx.fetchall( + "WITH RECURSIVE subtree(job_id) AS (" + " SELECT job_id FROM jobs WHERE parent_job_id = ? " + " UNION ALL " + " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" + ") SELECT job_id FROM subtree", + (parent_id.to_wire(),), + ) + return [JobName.from_wire(str(row["job_id"])) for row in rows] + + def list_subtree(self, tx: Tx, root_id: JobName) -> list[JobName]: + """Return ``root_id`` and all its transitive descendants.""" + rows = tx.fetchall( + "WITH RECURSIVE subtree(job_id) AS (" + " SELECT job_id FROM jobs WHERE job_id = ? " + " UNION ALL " + " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" + ") SELECT job_id FROM subtree", + (root_id.to_wire(),), + ) + return [JobName.from_wire(str(row["job_id"])) for row in rows] + + def find_prunable(self, tx: Tx, before_ms: int) -> JobName | None: + """Return one terminal job whose ``finished_at_ms`` predates ``before_ms``, or None.""" + placeholders = ",".join("?" for _ in TERMINAL_JOB_STATES) + row = tx.fetchone( + f"SELECT job_id FROM jobs WHERE state IN ({placeholders})" + " AND finished_at_ms IS NOT NULL AND finished_at_ms < ? LIMIT 1", + (*TERMINAL_JOB_STATES, before_ms), + ) + return JobName.from_wire(str(row["job_id"])) if row is not None else None + + # -- Writes -------------------------------------------------------------- + + def update_state_if_not_terminal( + self, + cur: TransactionCursor, + job_id: JobName, + new_state: int, + error: str | None, + finished_at_ms: int | None, + ) -> None: + """Set a new state on a single job, skipping rows already in a terminal state.""" + placeholders = ",".join("?" for _ in TERMINAL_JOB_STATES) + cur.execute( + "UPDATE jobs SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " + f"WHERE job_id = ? AND state NOT IN ({placeholders})", + (new_state, error, finished_at_ms, job_id.to_wire(), *TERMINAL_JOB_STATES), + ) + + def bulk_update_state( + self, + cur: TransactionCursor, + job_ids: Sequence[JobName], + new_state: int, + error: str | None, + finished_at_ms: int | None, + guard_states: Iterable[int], + ) -> None: + """Set state on many jobs; rows in any of ``guard_states`` are skipped.""" + if not job_ids: + return + wire_ids = [jid.to_wire() for jid in job_ids] + guard = tuple(guard_states) + job_placeholders = ",".join("?" for _ in wire_ids) + guard_placeholders = ",".join("?" for _ in guard) + cur.execute( + f"UPDATE jobs SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " + f"WHERE job_id IN ({job_placeholders}) AND state NOT IN ({guard_placeholders})", + (new_state, error, finished_at_ms, *wire_ids, *guard), + ) + + def mark_running_if_pending(self, cur: TransactionCursor, job_id: JobName, now_ms: int) -> None: + """Advance PENDING → RUNNING and set ``started_at_ms`` if not already populated.""" + cur.execute( + "UPDATE jobs SET state = CASE WHEN state = ? THEN ? ELSE state END, " + "started_at_ms = COALESCE(started_at_ms, ?) WHERE job_id = ?", + (job_pb2.JOB_STATE_PENDING, job_pb2.JOB_STATE_RUNNING, now_ms, job_id.to_wire()), + ) + + def apply_recomputed_state( + self, + cur: TransactionCursor, + job_id: JobName, + new_state: int, + now_ms: int, + error: str | None, + ) -> None: + """Write the result of ``_recompute_job_state`` back to the row. + + Sets ``started_at_ms`` (if moving to RUNNING), ``finished_at_ms`` (if + moving to a terminal state), and ``error`` (if the terminal reason + warrants one). The caller has already decided ``new_state`` differs + from the current state. + """ + terminal_placeholders = ",".join("?" for _ in TERMINAL_JOB_STATES) + cur.execute( + "UPDATE jobs SET state = ?, " + "started_at_ms = CASE WHEN ? = ? THEN COALESCE(started_at_ms, ?) ELSE started_at_ms END, " + f"finished_at_ms = CASE WHEN ? IN ({terminal_placeholders}) THEN ? ELSE finished_at_ms END, " + "error = CASE WHEN ? IN (?, ?, ?, ?) THEN ? ELSE error END " + "WHERE job_id = ?", + ( + new_state, + new_state, + job_pb2.JOB_STATE_RUNNING, + now_ms, + new_state, + *TERMINAL_JOB_STATES, + now_ms, + new_state, + job_pb2.JOB_STATE_FAILED, + job_pb2.JOB_STATE_KILLED, + job_pb2.JOB_STATE_UNSCHEDULABLE, + job_pb2.JOB_STATE_WORKER_FAILED, + error, + job_id.to_wire(), + ), + ) + + def insert(self, cur: TransactionCursor, params: JobInsertParams) -> None: + cur.execute( + "INSERT INTO jobs(" + "job_id, user_id, parent_job_id, root_job_id, depth, state, submitted_at_ms, " + "root_submitted_at_ms, started_at_ms, finished_at_ms, scheduling_deadline_epoch_ms, " + "error, exit_code, num_tasks, is_reservation_holder, name, has_reservation" + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + params.job_id.to_wire(), + params.user_id, + params.parent_job_id, + params.root_job_id, + params.depth, + params.state, + params.submitted_at_ms, + params.root_submitted_at_ms, + params.started_at_ms, + params.finished_at_ms, + params.scheduling_deadline_epoch_ms, + params.error, + params.exit_code, + params.num_tasks, + 1 if params.is_reservation_holder else 0, + params.name, + 1 if params.has_reservation else 0, + ), + ) + + def insert_config(self, cur: TransactionCursor, params: JobConfigInsertParams) -> None: + cur.execute( + "INSERT INTO job_config(" + "job_id, name, has_reservation, " + "res_cpu_millicores, res_memory_bytes, res_disk_bytes, res_device_json, " + "constraints_json, has_coscheduling, coscheduling_group_by, " + "scheduling_timeout_ms, max_task_failures, " + "entrypoint_json, environment_json, bundle_id, ports_json, " + "max_retries_failure, max_retries_preemption, timeout_ms, " + "preemption_policy, existing_job_policy, priority_band, " + "task_image, submit_argv_json, reservation_json, fail_if_exists" + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + params.job_id.to_wire(), + params.name, + 1 if params.has_reservation else 0, + params.res_cpu_millicores, + params.res_memory_bytes, + params.res_disk_bytes, + params.res_device_json, + params.constraints_json, + 1 if params.has_coscheduling else 0, + params.coscheduling_group_by, + params.scheduling_timeout_ms, + params.max_task_failures, + params.entrypoint_json, + params.environment_json, + params.bundle_id, + params.ports_json, + params.max_retries_failure, + params.max_retries_preemption, + params.timeout_ms, + params.preemption_policy, + params.existing_job_policy, + params.priority_band, + params.task_image, + params.submit_argv_json, + params.reservation_json, + 1 if params.fail_if_exists else 0, + ), + ) + + def delete(self, cur: TransactionCursor, job_id: JobName) -> None: + """Delete a job row. ON DELETE CASCADE handles tasks, attempts, endpoints.""" + cur.execute("DELETE FROM jobs WHERE job_id = ?", (job_id.to_wire(),)) + + # -- users / user_budgets ------------------------------------------------ + + def ensure_user(self, cur: TransactionCursor, user_id: str, now_ms: int) -> None: + """Idempotently create a ``users`` row at submission time.""" + cur.execute( + "INSERT OR IGNORE INTO users(user_id, created_at_ms) VALUES (?, ?)", + (user_id, now_ms), + ) + + +class TaskStore: + """Tasks and task_resource_history.""" + + def __init__(self, db: ControllerDB) -> None: + self._db = db + + +class TaskAttemptStore: + """Task attempts.""" + + def __init__(self, db: ControllerDB) -> None: + self._db = db + + +class WorkerStore: + """Workers, worker_attributes, worker_task_history, worker_resource_history.""" + + def __init__(self, db: ControllerDB) -> None: + self._db = db + + +class DispatchQueueStore: + """The dispatch_queue table.""" + + def __init__(self, db: ControllerDB) -> None: + self._db = db + + +class ReservationStore: + """Reservation claims and the meta(last_submission_ms) counter.""" + + def __init__(self, db: ControllerDB) -> None: + self._db = db + + +# ============================================================================= +# ControllerStore +# ============================================================================= + + +class ControllerStore: + """Bundle of per-entity stores with direct access to transactions/snapshots.""" + + def __init__(self, db: ControllerDB) -> None: + self._db = db + self.jobs = JobStore(db) + self.tasks = TaskStore(db) + self.attempts = TaskAttemptStore(db) + self.workers = WorkerStore(db) + self.endpoints = EndpointStore(db) + self.dispatch = DispatchQueueStore(db) + self.reservations = ReservationStore(db) + # Caches reload after a checkpoint restore via db.replace_from(). The + # hook fires only in that flow; normal startup loads caches in the + # store constructors above. + db.register_reopen_hook(self.endpoints._load_all) + + def transaction(self): + return self._db.transaction() + + def read_snapshot(self): + return self._db.read_snapshot() diff --git a/lib/iris/src/iris/cluster/controller/transitions.py b/lib/iris/src/iris/cluster/controller/transitions.py index ee64413b20..2f5efbc7df 100644 --- a/lib/iris/src/iris/cluster/controller/transitions.py +++ b/lib/iris/src/iris/cluster/controller/transitions.py @@ -33,9 +33,15 @@ task_row_can_be_scheduled, task_row_is_finished, ) +from iris.cluster.controller.stores import ( + ControllerStore, + EndpointStore, + JobConfigInsertParams, + JobInsertParams, + JobStore, +) from iris.cluster.controller.schema import ( JOB_CONFIG_JOIN, - JOB_DETAIL_PROJECTION, TASK_DETAIL_PROJECTION, WORKER_DETAIL_PROJECTION, EndpointRow, @@ -346,9 +352,9 @@ def _has_reservation_flag(request: controller_pb2.Controller.LaunchJobRequest) - return 1 if request.HasField("reservation") and request.reservation.entries else 0 -def delete_task_endpoints(cur: TransactionCursor, registry, task_id: str) -> None: - """Remove all registered endpoints for a task through the endpoint registry.""" - registry.remove_by_task(cur, JobName.from_wire(task_id)) +def delete_task_endpoints(cur: TransactionCursor, endpoints: EndpointStore, task_id: str) -> None: + """Remove all registered endpoints for a task through the endpoint store.""" + endpoints.remove_by_task(cur, JobName.from_wire(task_id)) def enqueue_run_dispatch( @@ -590,8 +596,8 @@ def _kill_non_terminal_tasks( def _cascade_children( - cur: Any, - registry, + cur: TransactionCursor, + store: ControllerStore, job_id: JobName, now_ms: int, reason: str, @@ -607,58 +613,31 @@ def _cascade_children( tasks_to_kill: set[JobName] = set() task_kill_workers: dict[JobName, WorkerId] = {} - if exclude_reservation_holders: - # Skip reservation holder jobs and anything below them. - descendants = cur.execute( - "WITH RECURSIVE subtree(job_id) AS (" - " SELECT job_id FROM jobs WHERE parent_job_id = ? AND is_reservation_holder = 0 " - " UNION ALL " - " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" - " WHERE j.is_reservation_holder = 0" - ") SELECT job_id FROM subtree", - (job_id.to_wire(),), - ).fetchall() - else: - descendants = cur.execute( - "WITH RECURSIVE subtree(job_id) AS (" - " SELECT job_id FROM jobs WHERE parent_job_id = ? " - " UNION ALL " - " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" - ") SELECT job_id FROM subtree", - (job_id.to_wire(),), - ).fetchall() - for child_row in descendants: - child_job_id = str(child_row["job_id"]) + descendants = store.jobs.list_descendants( + cur, + job_id, + exclude_reservation_holders=exclude_reservation_holders, + ) + for child_job_id in descendants: child_tasks_to_kill, child_task_kill_workers = _kill_non_terminal_tasks( - cur, registry, child_job_id, reason, now_ms + cur, store.endpoints, child_job_id.to_wire(), reason, now_ms ) tasks_to_kill.update(child_tasks_to_kill) task_kill_workers.update(child_task_kill_workers) - terminal_placeholders = ",".join("?" for _ in TERMINAL_JOB_STATES) - cur.execute( - "UPDATE jobs SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " - f"WHERE job_id = ? AND state NOT IN ({terminal_placeholders})", - ( - job_pb2.JOB_STATE_KILLED, - reason, - now_ms, - child_job_id, - *TERMINAL_JOB_STATES, - ), - ) + store.jobs.update_state_if_not_terminal(cur, child_job_id, job_pb2.JOB_STATE_KILLED, reason, now_ms) return tasks_to_kill, task_kill_workers def _cascade_terminal_job( - cur: Any, - registry, + cur: TransactionCursor, + store: ControllerStore, job_id: JobName, now_ms: int, reason: str, ) -> tuple[set[JobName], dict[JobName, WorkerId]]: """Kill remaining tasks and descendant jobs when a job reaches a terminal state.""" - tasks_to_kill, task_kill_workers = _kill_non_terminal_tasks(cur, registry, job_id.to_wire(), reason, now_ms) - child_tasks_to_kill, child_task_kill_workers = _cascade_children(cur, registry, job_id, now_ms, reason) + tasks_to_kill, task_kill_workers = _kill_non_terminal_tasks(cur, store.endpoints, job_id.to_wire(), reason, now_ms) + child_tasks_to_kill, child_task_kill_workers = _cascade_children(cur, store, job_id, now_ms, reason) tasks_to_kill.update(child_tasks_to_kill) task_kill_workers.update(child_task_kill_workers) return tasks_to_kill, task_kill_workers @@ -742,21 +721,18 @@ def _terminate_coscheduled_siblings( return tasks_to_kill, task_kill_workers -def _resolve_preemption_policy(cur: Any, job_id: JobName) -> int: +def _resolve_preemption_policy(jobs: JobStore, cur: TransactionCursor, job_id: JobName) -> int: """Resolve the effective preemption policy for a job. Defaults: single-task jobs → TERMINATE_CHILDREN, multi-task → PRESERVE_CHILDREN. """ - row = cur.execute( - f"SELECT jc.preemption_policy, j.num_tasks FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", - (job_id.to_wire(),), - ).fetchone() - if row is None: + info = jobs.get_preemption_info(cur, job_id) + if info is None: return job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN - policy = int(row["preemption_policy"]) + policy, num_tasks = info if policy != job_pb2.JOB_PREEMPTION_POLICY_UNSPECIFIED: return policy - if int(row["num_tasks"]) <= 1: + if num_tasks <= 1: return job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN return job_pb2.JOB_PREEMPTION_POLICY_PRESERVE_CHILDREN @@ -770,8 +746,8 @@ def _resolve_preemption_policy(cur: Any, job_id: JobName) -> int: def _finalize_terminal_job( - cur: Any, - registry, + cur: TransactionCursor, + store: ControllerStore, job_id: JobName, terminal_state: int, now_ms: int, @@ -786,13 +762,13 @@ def _finalize_terminal_job( Non-succeeded jobs cascade only if the preemption policy is TERMINATE_CHILDREN. """ reason = _TERMINAL_STATE_REASONS.get(terminal_state, "Job finalized") - tasks_to_kill, task_kill_workers = _kill_non_terminal_tasks(cur, registry, job_id.to_wire(), reason, now_ms) + tasks_to_kill, task_kill_workers = _kill_non_terminal_tasks(cur, store.endpoints, job_id.to_wire(), reason, now_ms) should_cascade = True if terminal_state != job_pb2.JOB_STATE_SUCCEEDED: - policy = _resolve_preemption_policy(cur, job_id) + policy = _resolve_preemption_policy(store.jobs, cur, job_id) should_cascade = policy == job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN if should_cascade: - child_tasks_to_kill, child_task_kill_workers = _cascade_children(cur, registry, job_id, now_ms, reason) + child_tasks_to_kill, child_task_kill_workers = _cascade_children(cur, store, job_id, now_ms, reason) tasks_to_kill.update(child_tasks_to_kill) task_kill_workers.update(child_task_kill_workers) return tasks_to_kill, task_kill_workers @@ -969,35 +945,35 @@ class ControllerTransitions: def __init__( self, - db: ControllerDB, + store: ControllerStore, health: WorkerHealthTracker | None = None, ): - self._db = db + self._store = store + # Escape hatch kept only while the phased migration moves SQL out of + # this file. Direct ``self._db`` calls should decrease every phase + # (jobs, tasks, attempts, workers, dispatch) and hit zero at the end; + # new code should go through ``self._store`` instead. + self._db: ControllerDB = store._db self._health = health or WorkerHealthTracker() - def _recompute_job_state(self, cur: Any, job_id: JobName) -> int | None: - row = cur.execute( - f"SELECT j.state, j.started_at_ms, jc.max_task_failures " - f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", - (job_id.to_wire(),), - ).fetchone() - if row is None: + def _recompute_job_state(self, cur: TransactionCursor, job_id: JobName) -> int | None: + basis = self._store.jobs.get_recompute_basis(cur, job_id) + if basis is None: return None - current_state = int(row["state"]) + current_state = basis.state if current_state in TERMINAL_JOB_STATES: return current_state - max_task_failures = int(row["max_task_failures"]) - counts_rows = cur.execute( + counts_rows = cur.fetchall( "SELECT state, COUNT(*) AS c FROM tasks WHERE job_id = ? GROUP BY state", (job_id.to_wire(),), - ).fetchall() + ) counts = {int(r["state"]): int(r["c"]) for r in counts_rows} total = sum(counts.values()) new_state = current_state now_ms = Timestamp.now().epoch_ms() if total > 0 and counts.get(job_pb2.TASK_STATE_SUCCEEDED, 0) == total: new_state = job_pb2.JOB_STATE_SUCCEEDED - elif counts.get(job_pb2.TASK_STATE_FAILED, 0) > max_task_failures: + elif counts.get(job_pb2.TASK_STATE_FAILED, 0) > basis.max_task_failures: new_state = job_pb2.JOB_STATE_FAILED elif counts.get(job_pb2.TASK_STATE_UNSCHEDULABLE, 0) > 0: new_state = job_pb2.JOB_STATE_UNSCHEDULABLE @@ -1015,42 +991,19 @@ def _recompute_job_state(self, cur: Any, job_id: JobName) -> int | None: or counts.get(job_pb2.TASK_STATE_RUNNING, 0) > 0 ): new_state = job_pb2.JOB_STATE_RUNNING - elif row["started_at_ms"] is not None: + elif basis.started_at_ms is not None: # Retries put tasks back into PENDING; keep job running once it has started. new_state = job_pb2.JOB_STATE_RUNNING elif total > 0: new_state = job_pb2.JOB_STATE_PENDING if new_state == current_state: return new_state - terminal_placeholders = ",".join("?" for _ in TERMINAL_JOB_STATES) - error_row = cur.execute( + error_row = cur.fetchone( "SELECT error FROM tasks WHERE job_id = ? AND error IS NOT NULL ORDER BY task_index LIMIT 1", (job_id.to_wire(),), - ).fetchone() - error = str(error_row["error"]) if error_row is not None else None - cur.execute( - "UPDATE jobs SET state = ?, " - "started_at_ms = CASE WHEN ? = ? THEN COALESCE(started_at_ms, ?) ELSE started_at_ms END, " - f"finished_at_ms = CASE WHEN ? IN ({terminal_placeholders}) THEN ? ELSE finished_at_ms END, " - "error = CASE WHEN ? IN (?, ?, ?, ?) THEN ? ELSE error END " - "WHERE job_id = ?", - ( - new_state, - new_state, - job_pb2.JOB_STATE_RUNNING, - now_ms, - new_state, - *TERMINAL_JOB_STATES, - now_ms, - new_state, - job_pb2.JOB_STATE_FAILED, - job_pb2.JOB_STATE_KILLED, - job_pb2.JOB_STATE_UNSCHEDULABLE, - job_pb2.JOB_STATE_WORKER_FAILED, - error, - job_id.to_wire(), - ), ) + error = str(error_row["error"]) if error_row is not None else None + self._store.jobs.apply_recomputed_state(cur, job_id, new_state, now_ms, error) return new_state def replace_reservation_claims(self, claims: dict[WorkerId, ReservationClaim]) -> None: @@ -1088,16 +1041,13 @@ def submit_job( parent_job_id = job_id.parent.to_wire() if job_id.parent is not None else None root_submitted_ms = effective_submission_ms - if parent_job_id is not None: - parent = cur.execute( - "SELECT root_submitted_at_ms FROM jobs WHERE job_id = ?", - (parent_job_id,), - ).fetchone() + if job_id.parent is not None: # `launch_job` is responsible for rejecting submissions with a # missing parent; if we reach here the parent row must exist. - if parent is None: + parent_root = self._store.jobs.get_root_submitted_at_ms(cur, job_id.parent) + if parent_root is None: raise ValueError(f"Cannot submit job {job_id}: parent {parent_job_id} is absent from the database") - root_submitted_ms = int(parent["root_submitted_at_ms"]) + root_submitted_ms = parent_root deadline_epoch_ms: int | None = None if request.HasField("scheduling_timeout") and request.scheduling_timeout.milliseconds > 0: @@ -1107,10 +1057,7 @@ def submit_job( .epoch_ms() ) - cur.execute( - "INSERT OR IGNORE INTO users(user_id, created_at_ms) VALUES (?, ?)", - (job_id.user, effective_submission_ms), - ) + self._store.jobs.ensure_user(cur, job_id.user, effective_submission_ms) # No user_budgets row is created here: absence means "apply # UserBudgetDefaults". Rows exist only for tier seeds from cluster # config (see reconcile_user_budget_tiers) and admin overrides via @@ -1168,67 +1115,57 @@ def submit_job( timeout_ms: int | None = int(request.timeout.milliseconds) if request.timeout.milliseconds > 0 else None job_name_lower = request.name.lower() - cur.execute( - "INSERT INTO jobs(" - "job_id, user_id, parent_job_id, root_job_id, depth, state, submitted_at_ms, " - "root_submitted_at_ms, started_at_ms, finished_at_ms, scheduling_deadline_epoch_ms, " - "error, exit_code, num_tasks, is_reservation_holder, name, has_reservation" - ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, NULL, ?, ?, ?, NULL, ?, 0, ?, ?)", - ( - job_id.to_wire(), - job_id.user, - parent_job_id, - job_id.root_job.to_wire(), - job_id.depth, - state, - effective_submission_ms, - root_submitted_ms, - finished_ms, - deadline_epoch_ms, - validation_error, - replicas, - job_name_lower, - has_reservation, + self._store.jobs.insert( + cur, + JobInsertParams( + job_id=job_id, + user_id=job_id.user, + parent_job_id=parent_job_id, + root_job_id=job_id.root_job.to_wire(), + depth=job_id.depth, + state=state, + submitted_at_ms=effective_submission_ms, + root_submitted_at_ms=root_submitted_ms, + started_at_ms=None, + finished_at_ms=finished_ms, + scheduling_deadline_epoch_ms=deadline_epoch_ms, + error=validation_error, + exit_code=None, + num_tasks=replicas, + is_reservation_holder=False, + name=job_name_lower, + has_reservation=bool(has_reservation), ), ) - cur.execute( - "INSERT INTO job_config(" - "job_id, name, has_reservation, " - "res_cpu_millicores, res_memory_bytes, res_disk_bytes, res_device_json, " - "constraints_json, has_coscheduling, coscheduling_group_by, " - "scheduling_timeout_ms, max_task_failures, " - "entrypoint_json, environment_json, bundle_id, ports_json, " - "max_retries_failure, max_retries_preemption, timeout_ms, " - "preemption_policy, existing_job_policy, priority_band, " - "task_image, submit_argv_json, reservation_json, fail_if_exists" - ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - job_id.to_wire(), - job_name_lower, - has_reservation, - res_cpu, - res_mem, - res_disk, - res_device, - constraints_json, - has_cosched, - cosched_group, - sched_timeout, - max_failures, - entrypoint_json, - environment_json, - request.bundle_id, - ports_json, - int(request.max_retries_failure), - int(request.max_retries_preemption), - timeout_ms, - int(request.preemption_policy), - int(request.existing_job_policy), - int(request.priority_band), - request.task_image, - json.dumps(list(request.submit_argv)), - reservation_json, - 1 if request.fail_if_exists else 0, + self._store.jobs.insert_config( + cur, + JobConfigInsertParams( + job_id=job_id, + name=job_name_lower, + has_reservation=bool(has_reservation), + res_cpu_millicores=res_cpu, + res_memory_bytes=res_mem, + res_disk_bytes=res_disk, + res_device_json=res_device, + constraints_json=constraints_json, + has_coscheduling=bool(has_cosched), + coscheduling_group_by=cosched_group, + scheduling_timeout_ms=sched_timeout, + max_task_failures=max_failures, + entrypoint_json=entrypoint_json, + environment_json=environment_json, + bundle_id=request.bundle_id, + ports_json=ports_json, + max_retries_failure=int(request.max_retries_failure), + max_retries_preemption=int(request.max_retries_preemption), + timeout_ms=timeout_ms, + preemption_policy=int(request.preemption_policy), + existing_job_policy=int(request.existing_job_policy), + priority_band=int(request.priority_band), + task_image=request.task_image, + submit_argv_json=json.dumps(list(request.submit_argv)), + reservation_json=reservation_json, + fail_if_exists=bool(request.fail_if_exists), ), ) @@ -1290,54 +1227,56 @@ def submit_job( holder_res_device = proto_to_json(holder_res.device) if holder_res else None holder_constraints_json = constraints_to_json(holder_request.constraints) holder_name_lower = holder_request.name.lower() - cur.execute( - "INSERT INTO jobs(" - "job_id, user_id, parent_job_id, root_job_id, depth, state, submitted_at_ms, " - "root_submitted_at_ms, started_at_ms, finished_at_ms, scheduling_deadline_epoch_ms, " - "error, exit_code, num_tasks, is_reservation_holder, name, has_reservation" - ") VALUES (" - "?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL, NULL, NULL, NULL, ?, 1, ?, 0" - ")", - ( - holder_id.to_wire(), - holder_id.user, - job_id.to_wire(), - holder_id.root_job.to_wire(), - holder_id.depth, - job_pb2.JOB_STATE_PENDING, - effective_submission_ms, - root_submitted_ms, - len(request.reservation.entries), - holder_name_lower, + self._store.jobs.insert( + cur, + JobInsertParams( + job_id=holder_id, + user_id=holder_id.user, + parent_job_id=job_id.to_wire(), + root_job_id=holder_id.root_job.to_wire(), + depth=holder_id.depth, + state=job_pb2.JOB_STATE_PENDING, + submitted_at_ms=effective_submission_ms, + root_submitted_at_ms=root_submitted_ms, + started_at_ms=None, + finished_at_ms=None, + scheduling_deadline_epoch_ms=None, + error=None, + exit_code=None, + num_tasks=len(request.reservation.entries), + is_reservation_holder=True, + name=holder_name_lower, + has_reservation=False, ), ) holder_entrypoint_json = entrypoint_to_json(holder_request.entrypoint) holder_environment_json = proto_to_json(holder_request.environment) - cur.execute( - "INSERT INTO job_config(" - "job_id, name, has_reservation, " - "res_cpu_millicores, res_memory_bytes, res_disk_bytes, res_device_json, " - "constraints_json, has_coscheduling, coscheduling_group_by, " - "scheduling_timeout_ms, max_task_failures, " - "entrypoint_json, environment_json, bundle_id, ports_json, " - "max_retries_failure, max_retries_preemption, timeout_ms, " - "preemption_policy, existing_job_policy, priority_band, " - "task_image, reservation_json" - ") VALUES (" - "?, ?, 0, ?, ?, ?, ?, ?, 0, '', NULL, 0, " - "?, ?, '', '[]', 0, ?, NULL, 0, 0, 0, '', NULL" - ")", - ( - holder_id.to_wire(), - holder_name_lower, - holder_res_cpu, - holder_res_mem, - holder_res_disk, - holder_res_device, - holder_constraints_json, - holder_entrypoint_json, - holder_environment_json, - DEFAULT_MAX_RETRIES_PREEMPTION, + self._store.jobs.insert_config( + cur, + JobConfigInsertParams( + job_id=holder_id, + name=holder_name_lower, + has_reservation=False, + res_cpu_millicores=holder_res_cpu, + res_memory_bytes=holder_res_mem, + res_disk_bytes=holder_res_disk, + res_device_json=holder_res_device, + constraints_json=holder_constraints_json, + has_coscheduling=False, + coscheduling_group_by="", + scheduling_timeout_ms=None, + max_task_failures=0, + entrypoint_json=holder_entrypoint_json, + environment_json=holder_environment_json, + bundle_id="", + ports_json="[]", + max_retries_failure=0, + max_retries_preemption=DEFAULT_MAX_RETRIES_PREEMPTION, + timeout_ms=None, + preemption_policy=0, + existing_job_policy=0, + priority_band=0, + task_image="", ), ) holder_base = self._db.next_sequence("task_priority_insertion", cur=cur) @@ -1372,17 +1311,10 @@ def submit_job( def cancel_job(self, job_id: JobName, reason: str) -> TxResult: """Cancel a job tree and return tasks that need kill RPCs.""" with self._db.transaction() as cur: - subtree = cur.execute( - "WITH RECURSIVE subtree(job_id) AS (" - " SELECT job_id FROM jobs WHERE job_id = ? " - " UNION ALL " - " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" - ") SELECT job_id FROM subtree", - (job_id.to_wire(),), - ).fetchall() + subtree = self._store.jobs.list_subtree(cur, job_id) if not subtree: return TxResult() - subtree_ids = [str(row["job_id"]) for row in subtree] + subtree_ids = [jid.to_wire() for jid in subtree] placeholders = ",".join("?" for _ in subtree_ids) running_rows = cur.execute( f"SELECT t.task_id, t.current_worker_id AS worker_id, " @@ -1437,19 +1369,15 @@ def cancel_job(self, job_id: JobName, reason: str) -> TxResult: # Deliberately excludes JOB_STATE_WORKER_FAILED from the guard set: # worker-failed jobs should still be cancellable (transitioned to KILLED). cancel_guard_states = TERMINAL_JOB_STATES - {job_pb2.JOB_STATE_WORKER_FAILED} - cancel_guard_placeholders = ",".join("?" for _ in cancel_guard_states) - cur.execute( - f"UPDATE jobs SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " - f"WHERE job_id IN ({placeholders}) AND state NOT IN ({cancel_guard_placeholders})", - ( - job_pb2.JOB_STATE_KILLED, - reason, - now_ms, - *subtree_ids, - *cancel_guard_states, - ), + self._store.jobs.bulk_update_state( + cur, + subtree, + job_pb2.JOB_STATE_KILLED, + reason, + now_ms, + cancel_guard_states, ) - self._db.endpoints.remove_by_job_ids(cur, [JobName.from_wire(jid) for jid in subtree_ids]) + self._store.endpoints.remove_by_job_ids(cur, subtree) log_event("job_cancelled", job_id.to_wire(), reason=reason) return TxResult(tasks_to_kill=tasks_to_kill, task_kill_workers=task_kill_workers) @@ -1618,15 +1546,7 @@ def queue_assignments(self, assignments: list[Assignment], *, direct_dispatch: b continue job_id_wire = task.job_id.to_wire() if job_id_wire not in job_cache: - job_row = cur.execute( - f"SELECT {JOB_DETAIL_PROJECTION.select_clause()} " - f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", - (job_id_wire,), - ).fetchone() - if job_row is None: - rejected.append(assignment) - continue - decoded_job = JOB_DETAIL_PROJECTION.decode_one([job_row]) + decoded_job = self._store.jobs.get_detail(cur, task.job_id) if decoded_job is None: rejected.append(assignment) continue @@ -1692,11 +1612,7 @@ def queue_assignments(self, assignments: list[Assignment], *, direct_dispatch: b jobs_to_update.add(job_id_wire) accepted.append(assignment) for job_id_wire in jobs_to_update: - cur.execute( - "UPDATE jobs SET state = CASE WHEN state = ? THEN ? ELSE state END, " - "started_at_ms = COALESCE(started_at_ms, ?) WHERE job_id = ?", - (job_pb2.JOB_STATE_PENDING, job_pb2.JOB_STATE_RUNNING, now_ms, job_id_wire), - ) + self._store.jobs.mark_running_if_pending(cur, JobName.from_wire(job_id_wire), now_ms) for a in accepted: log_event("assignment_queued", a.task_id.to_wire(), worker=str(a.worker_id)) return AssignmentResult( @@ -1924,8 +1840,7 @@ def _apply_task_transitions( # Fetch and cache job_config row (avoids re-querying per task in same job). job_id_wire = task.job_id.to_wire() if job_id_wire not in job_config_cache: - jc_row = cur.execute("SELECT * FROM job_config WHERE job_id = ?", (job_id_wire,)).fetchone() - job_config_cache[job_id_wire] = dict(jc_row) if jc_row is not None else None + job_config_cache[job_id_wire] = self._store.jobs.get_config(cur, task.job_id) jc = job_config_cache[job_id_wire] if worker_id is not None and task_state not in ACTIVE_TASK_STATES: @@ -1939,7 +1854,7 @@ def _apply_task_transitions( _decommit_worker_resources(cur, str(worker_id), resources) if update.new_state in TERMINAL_TASK_STATES: - delete_task_endpoints(cur, self._db.endpoints, update.task_id.to_wire()) + delete_task_endpoints(cur, self._store.endpoints, update.task_id.to_wire()) # Coscheduled jobs: a terminal host failure should cascade to siblings. if jc is not None and task_state in FAILURE_TASK_STATES: @@ -1952,7 +1867,7 @@ def _apply_task_transitions( jc["res_device_json"], ) cascade_kill, cascade_workers = _terminate_coscheduled_siblings( - cur, self._db.endpoints, siblings, update.task_id, resources, now_ms + cur, self._store.endpoints, siblings, update.task_id, resources, now_ms ) tasks_to_kill.update(cascade_kill) task_kill_workers.update(cascade_workers) @@ -1968,7 +1883,7 @@ def _apply_task_transitions( new_job_state = self._recompute_job_state(cur, job_id) if new_job_state in TERMINAL_JOB_STATES: final_tasks_to_kill, final_task_kill_workers = _finalize_terminal_job( - cur, self._db.endpoints, job_id, new_job_state, now_ms + cur, self._store, job_id, new_job_state, now_ms ) tasks_to_kill.update(final_tasks_to_kill) task_kill_workers.update(final_task_kill_workers) @@ -2138,7 +2053,7 @@ def _remove_failed_worker( holder_preemption_count = 0 if is_reservation_holder else preemption_count _terminate_task( cur, - self._db.endpoints, + self._store.endpoints, tid, int(task_row["current_attempt_id"]), new_task_state, @@ -2152,16 +2067,16 @@ def _remove_failed_worker( new_job_state = self._recompute_job_state(cur, parent_job_id) if new_job_state is not None and new_job_state in TERMINAL_JOB_STATES: cascaded_tasks_to_kill, cascaded_task_kill_workers = _cascade_terminal_job( - cur, self._db.endpoints, parent_job_id, now_ms, f"Worker {worker_id} failed" + cur, self._store, parent_job_id, now_ms, f"Worker {worker_id} failed" ) tasks_to_kill.update(cascaded_tasks_to_kill) task_kill_workers.update(cascaded_task_kill_workers) elif new_task_state == job_pb2.TASK_STATE_PENDING: - policy = _resolve_preemption_policy(cur, parent_job_id) + policy = _resolve_preemption_policy(self._store.jobs, cur, parent_job_id) if policy == job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN: child_tasks_to_kill, child_task_kill_workers = _cascade_children( cur, - self._db.endpoints, + self._store, parent_job_id, now_ms, "Parent task preempted", @@ -2268,7 +2183,7 @@ def mark_task_unschedulable(self, task_id: JobName, reason: str) -> TxResult: now_ms = Timestamp.now().epoch_ms() _terminate_task( cur, - self._db.endpoints, + self._store.endpoints, task_id.to_wire(), None, job_pb2.TASK_STATE_UNSCHEDULABLE, @@ -2327,7 +2242,7 @@ def preempt_task(self, task_id: JobName, reason: str) -> TxResult: _terminate_task( cur, - self._db.endpoints, + self._store.endpoints, task_id.to_wire(), int(row["current_attempt_id"]), new_state, @@ -2343,17 +2258,15 @@ def preempt_task(self, task_id: JobName, reason: str) -> TxResult: job_id = JobName.from_wire(str(row["job_id"])) new_job_state = self._recompute_job_state(cur, job_id) if new_job_state is not None and new_job_state in TERMINAL_JOB_STATES: - cascade_kills, cascade_workers = _finalize_terminal_job( - cur, self._db.endpoints, job_id, new_job_state, now_ms - ) + cascade_kills, cascade_workers = _finalize_terminal_job(cur, self._store, job_id, new_job_state, now_ms) tasks_to_kill.update(cascade_kills) task_kill_workers.update(cascade_workers) elif new_state == job_pb2.TASK_STATE_PENDING: - policy = _resolve_preemption_policy(cur, job_id) + policy = _resolve_preemption_policy(self._store.jobs, cur, job_id) if policy == job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN: child_kills, child_workers = _cascade_children( cur, - self._db.endpoints, + self._store, job_id, now_ms, reason, @@ -2454,7 +2367,7 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes attempt_id = row["current_attempt_id"] _terminate_task( cur, - self._db.endpoints, + self._store.endpoints, task_id_wire, int(attempt_id) if attempt_id is not None else None, job_pb2.TASK_STATE_FAILED, @@ -2480,7 +2393,7 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes # Pick the first direct-timeout task in this job as the "cause" for the error message. cause_tid = next(JobName.from_wire(str(r["task_id"])) for r in rows if str(r["job_id"]) == job_id_wire) cascade_kill, cascade_workers = _terminate_coscheduled_siblings( - cur, self._db.endpoints, siblings, cause_tid, job_resources, now_ms + cur, self._store.endpoints, siblings, cause_tid, job_resources, now_ms ) tasks_to_kill.update(cascade_kill) task_kill_workers.update(cascade_workers) @@ -2490,7 +2403,7 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes new_job_state = self._recompute_job_state(cur, JobName.from_wire(job_wire)) if new_job_state in TERMINAL_JOB_STATES: final_kill, final_workers = _finalize_terminal_job( - cur, self._db.endpoints, JobName.from_wire(job_wire), new_job_state, now_ms + cur, self._store, JobName.from_wire(job_wire), new_job_state, now_ms ) tasks_to_kill.update(final_kill) task_kill_workers.update(final_workers) @@ -2511,10 +2424,9 @@ def remove_finished_job(self, job_id: JobName) -> bool: True if the job was removed, False if it doesn't exist or is not finished """ with self._db.transaction() as cur: - row = cur.execute("SELECT state FROM jobs WHERE job_id = ?", (job_id.to_wire(),)).fetchone() - if row is None: + state = self._store.jobs.get_state(cur, job_id) + if state is None: return False - state = int(row["state"]) if state not in ( job_pb2.JOB_STATE_SUCCEEDED, job_pb2.JOB_STATE_FAILED, @@ -2522,7 +2434,7 @@ def remove_finished_job(self, job_id: JobName) -> bool: job_pb2.JOB_STATE_UNSCHEDULABLE, ): return False - cur.execute("DELETE FROM jobs WHERE job_id = ?", (job_id.to_wire(),)) + self._store.jobs.delete(cur, job_id) log_event("job_removed", job_id.to_wire(), state=state) return True @@ -2706,9 +2618,6 @@ def prune_old_data( job_cutoff_ms = now_ms - job_retention.to_ms() worker_cutoff_ms = now_ms - worker_retention.to_ms() - terminal_states = tuple(TERMINAL_JOB_STATES) - placeholders = ",".join("?" * len(terminal_states)) - def _stopped() -> bool: return stop_event is not None and stop_event.is_set() @@ -2716,20 +2625,15 @@ def _stopped() -> bool: jobs_deleted = 0 while not _stopped(): with self._db.read_snapshot() as snap: - row = snap.fetchone( - f"SELECT job_id FROM jobs WHERE state IN ({placeholders})" - " AND finished_at_ms IS NOT NULL AND finished_at_ms < ? LIMIT 1", - (*terminal_states, job_cutoff_ms), - ) - if row is None: + job_name = self._store.jobs.find_prunable(snap, job_cutoff_ms) + if job_name is None: break - job_id = row["job_id"] with self._db.transaction() as cur: - # Invalidate endpoint cache BEFORE the CASCADE so the registry + # Invalidate endpoint cache BEFORE the CASCADE so the cache # drops rows SQLite is about to delete for us. - self._db.endpoints.remove_by_job_ids(cur, [JobName.from_wire(str(job_id))]) - cur.execute("DELETE FROM jobs WHERE job_id = ?", (job_id,)) - log_event("job_pruned", str(job_id)) + self._store.endpoints.remove_by_job_ids(cur, [job_name]) + self._store.jobs.delete(cur, job_name) + log_event("job_pruned", job_name.to_wire()) jobs_deleted += 1 time.sleep(pause_between_s) @@ -2906,17 +2810,17 @@ def load_workers_from_config(self, configs: list[WorkerConfig]) -> None: # --- Endpoint Management --- def add_endpoint(self, endpoint: EndpointRow) -> bool: - """Add an endpoint row through the endpoint registry. + """Add an endpoint row through the store's endpoint cache. Returns True if the endpoint was inserted, False if the task is already terminal (to prevent orphaned endpoints that would never be cleaned up). """ - with self._db.transaction() as cur: - return self._db.endpoints.add(cur, endpoint) + with self._store.transaction() as cur: + return self._store.endpoints.add(cur, endpoint) def remove_endpoint(self, endpoint_id: str) -> EndpointRow | None: - with self._db.transaction() as cur: - return self._db.endpoints.remove(cur, endpoint_id) + with self._store.transaction() as cur: + return self._store.endpoints.remove(cur, endpoint_id) # --------------------------------------------------------------------- # Test-only SQL mutation helpers @@ -3247,10 +3151,10 @@ def apply_direct_provider_updates(self, updates: list[TaskUpdate]) -> TxResult: update.task_id.to_wire(), ), ) - jc_row = cur.execute("SELECT * FROM job_config WHERE job_id = ?", (task.job_id.to_wire(),)).fetchone() + jc_row = self._store.jobs.get_config(cur, task.job_id) if update.new_state in TERMINAL_TASK_STATES: - delete_task_endpoints(cur, self._db.endpoints, update.task_id.to_wire()) + delete_task_endpoints(cur, self._store.endpoints, update.task_id.to_wire()) # Coscheduled sibling cascade. if jc_row is not None and task_state in FAILURE_TASK_STATES: @@ -3263,7 +3167,7 @@ def apply_direct_provider_updates(self, updates: list[TaskUpdate]) -> TxResult: jc_row["res_device_json"], ) cascade_kill, cascade_workers = _terminate_coscheduled_siblings( - cur, self._db.endpoints, siblings, update.task_id, job_resources, now_ms + cur, self._store.endpoints, siblings, update.task_id, job_resources, now_ms ) tasks_to_kill.update(cascade_kill) task_kill_workers.update(cascade_workers) @@ -3272,7 +3176,7 @@ def apply_direct_provider_updates(self, updates: list[TaskUpdate]) -> TxResult: new_job_state = self._recompute_job_state(cur, task.job_id) if new_job_state in TERMINAL_JOB_STATES: final_tasks_to_kill, final_task_kill_workers = _finalize_terminal_job( - cur, self._db.endpoints, task.job_id, new_job_state, now_ms + cur, self._store, task.job_id, new_job_state, now_ms ) tasks_to_kill.update(final_tasks_to_kill) task_kill_workers.update(final_task_kill_workers) diff --git a/lib/iris/tests/cluster/conftest.py b/lib/iris/tests/cluster/conftest.py index 013013e710..ce7b9e8ad4 100644 --- a/lib/iris/tests/cluster/conftest.py +++ b/lib/iris/tests/cluster/conftest.py @@ -23,6 +23,7 @@ WORKER_DETAIL_PROJECTION, ) from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ( Assignment, ControllerTransitions, @@ -391,7 +392,8 @@ def _drive_gcp(self, task_id: JobName, new_state: int) -> None: def _make_k8s_harness(tmp_path) -> ServiceTestHarness: db = ControllerDB(db_dir=tmp_path / "k8s_db") - state = ControllerTransitions(db=db) + store = ControllerStore(db) + state = ControllerTransitions(store=store) k8s = InMemoryK8sService() k8s.add_node_pool( @@ -413,7 +415,7 @@ def _make_k8s_harness(tmp_path) -> ServiceTestHarness: service = ControllerServiceImpl( state, - db, + store, controller=ctrl, bundle_store=BundleStore(storage_dir=str(tmp_path / "k8s_bundles")), log_service=LogServiceImpl(), @@ -431,14 +433,15 @@ def _make_k8s_harness(tmp_path) -> ServiceTestHarness: def _make_gcp_harness(tmp_path) -> ServiceTestHarness: db = ControllerDB(db_dir=tmp_path / "gcp_db") - state = ControllerTransitions(db=db) + store = ControllerStore(db) + state = ControllerTransitions(store=store) ctrl = _HarnessController() ctrl.has_direct_provider = False service = ControllerServiceImpl( state, - db, + store, controller=ctrl, bundle_store=BundleStore(storage_dir=str(tmp_path / "gcp_bundles")), log_service=LogServiceImpl(), diff --git a/lib/iris/tests/cluster/controller/conftest.py b/lib/iris/tests/cluster/controller/conftest.py index 12ef6cef09..4a9a4a217c 100644 --- a/lib/iris/tests/cluster/controller/conftest.py +++ b/lib/iris/tests/cluster/controller/conftest.py @@ -51,6 +51,7 @@ tasks_with_attempts, ) from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.stores import ControllerStore from iris.log_server.server import LogServiceImpl from iris.cluster.controller.transitions import ( Assignment, @@ -171,7 +172,7 @@ def controller_service(state, log_service, mock_controller, tmp_path) -> Control """ControllerServiceImpl with fresh DB, log service, and mock controller.""" return ControllerServiceImpl( state, - state._db, + state._store, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, @@ -189,7 +190,8 @@ def make_controller_state(**kwargs): tmp = Path(tempfile.mkdtemp(prefix="iris_test_")) try: db = ControllerDB(db_dir=tmp) - yield ControllerTransitions(db=db, **kwargs) + store = ControllerStore(db) + yield ControllerTransitions(store=store, **kwargs) db.close() finally: shutil.rmtree(tmp, ignore_errors=True) diff --git a/lib/iris/tests/cluster/controller/test_api_keys.py b/lib/iris/tests/cluster/controller/test_api_keys.py index 8b9586f2c9..6381c6334c 100644 --- a/lib/iris/tests/cluster/controller/test_api_keys.py +++ b/lib/iris/tests/cluster/controller/test_api_keys.py @@ -23,6 +23,7 @@ from rigging.timing import Timestamp from iris.cluster.controller.db import ControllerDB from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ControllerTransitions from iris.log_server.server import LogServiceImpl from iris.rpc import config_pb2 @@ -39,7 +40,8 @@ def db(tmp_path): def _make_service(db, auth=None): """Create a ControllerServiceImpl with minimal dependencies for API key tests.""" - state = ControllerTransitions(db=db) + store = ControllerStore(db) + state = ControllerTransitions(store=store) controller_mock = Mock() controller_mock.wake = Mock() @@ -51,7 +53,7 @@ def _make_service(db, auth=None): return ControllerServiceImpl( state, - db, + store, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(db.db_path.parent / "bundles")), log_service=LogServiceImpl(), diff --git a/lib/iris/tests/cluster/controller/test_auth.py b/lib/iris/tests/cluster/controller/test_auth.py index 1e40c58d4d..50baf9fbe8 100644 --- a/lib/iris/tests/cluster/controller/test_auth.py +++ b/lib/iris/tests/cluster/controller/test_auth.py @@ -23,6 +23,7 @@ from iris.cluster.controller.dashboard import ControllerDashboard from iris.cluster.controller.db import ControllerDB from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ControllerTransitions from iris.rpc.auth import SESSION_COOKIE, StaticTokenVerifier, hash_token, resolve_auth from rigging.timing import Timestamp @@ -44,7 +45,7 @@ def db(tmp_path): @pytest.fixture def state(db, tmp_path): - s = ControllerTransitions(db=db) + s = ControllerTransitions(store=ControllerStore(db)) yield s @@ -57,7 +58,7 @@ def service(state, tmp_path): controller_mock.has_direct_provider = False return ControllerServiceImpl( state, - state._db, + state._store, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=LogServiceImpl(), diff --git a/lib/iris/tests/cluster/controller/test_dashboard.py b/lib/iris/tests/cluster/controller/test_dashboard.py index 48318235a8..07879eea95 100644 --- a/lib/iris/tests/cluster/controller/test_dashboard.py +++ b/lib/iris/tests/cluster/controller/test_dashboard.py @@ -170,7 +170,7 @@ def service(state, scheduler, tmp_path): log_service = LogServiceImpl() return ControllerServiceImpl( state, - state._db, + state._store, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, @@ -190,7 +190,7 @@ def service_with_autoscaler(state, scheduler, mock_autoscaler, tmp_path): log_service = LogServiceImpl() return ControllerServiceImpl( state, - state._db, + state._store, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, @@ -1042,7 +1042,7 @@ def test_auth_config_kubernetes_provider_kind(state, scheduler, tmp_path): log_service = LogServiceImpl() svc = ControllerServiceImpl( state, - state._db, + state._store, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, @@ -1074,7 +1074,7 @@ def _make_k8s_dashboard_client(state, scheduler, tmp_path): log_service = LogServiceImpl() svc = ControllerServiceImpl( state, - state._db, + state._store, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, diff --git a/lib/iris/tests/cluster/controller/test_endpoint_registry.py b/lib/iris/tests/cluster/controller/test_endpoint_store.py similarity index 75% rename from lib/iris/tests/cluster/controller/test_endpoint_registry.py rename to lib/iris/tests/cluster/controller/test_endpoint_store.py index e3e2b73e0b..c11f05975c 100644 --- a/lib/iris/tests/cluster/controller/test_endpoint_registry.py +++ b/lib/iris/tests/cluster/controller/test_endpoint_store.py @@ -1,7 +1,7 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -"""Tests for EndpointRegistry — the in-memory cache over the ``endpoints`` table.""" +"""Tests for EndpointStore — the in-memory cache over the ``endpoints`` table.""" from __future__ import annotations @@ -10,8 +10,8 @@ import pytest from iris.cluster.controller.db import EndpointQuery -from iris.cluster.controller.endpoint_registry import EndpointRegistry from iris.cluster.controller.schema import ENDPOINT_PROJECTION, EndpointRow +from iris.cluster.controller.stores import EndpointStore from iris.cluster.types import JobName from iris.rpc import job_pb2 from rigging.timing import Timestamp @@ -20,7 +20,7 @@ # --- Parity helper: the legacy SQL builder, preserved solely for parity tests. -# Deleted from production; kept here so a parity test demonstrates the registry +# Deleted from production; kept here so a parity test demonstrates the store # returns an identical row set for representative queries. def _endpoint_query_sql_legacy(query: EndpointQuery) -> tuple[str, list[object]]: from_clause = f"SELECT {ENDPOINT_PROJECTION.select_clause()} FROM endpoints e" @@ -73,9 +73,9 @@ def test_registry_loads_existing_rows_on_startup(state): """On construction, the registry should contain every row in the ``endpoints`` table.""" tasks = submit_job(state, "j", make_job_request("j")) with state._db.transaction() as cur: - assert state._db.endpoints.add(cur, _make_row("e1", "svc", tasks[0].task_id)) + assert state._store.endpoints.add(cur, _make_row("e1", "svc", tasks[0].task_id)) - fresh = EndpointRegistry(state._db) + fresh = EndpointStore(state._db) rows = fresh.query() assert [r.endpoint_id for r in rows] == ["e1"] @@ -85,12 +85,12 @@ def test_add_updates_memory_after_commit(state): t = tasks[0].task_id with state._db.transaction() as cur: - assert state._db.endpoints.add(cur, _make_row("e1", "alpha", t)) + assert state._store.endpoints.add(cur, _make_row("e1", "alpha", t)) # Not yet committed; memory should not reflect the insert. - assert state._db.endpoints.get("e1") is None + assert state._store.endpoints.get("e1") is None - assert state._db.endpoints.get("e1") is not None - assert [r.endpoint_id for r in state._db.endpoints.query()] == ["e1"] + assert state._store.endpoints.get("e1") is not None + assert [r.endpoint_id for r in state._store.endpoints.query()] == ["e1"] def test_rollback_leaves_memory_untouched(state): @@ -102,12 +102,12 @@ class BoomError(RuntimeError): with pytest.raises(BoomError): with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row("e1", "alpha", t)) + state._store.endpoints.add(cur, _make_row("e1", "alpha", t)) raise BoomError # DB rolled back → memory must NOT see the insert. - assert state._db.endpoints.get("e1") is None - assert state._db.endpoints.query() == [] + assert state._store.endpoints.get("e1") is None + assert state._store.endpoints.query() == [] def test_add_rejects_terminal_task(state): @@ -121,22 +121,22 @@ def test_add_rejects_terminal_task(state): ) with state._db.transaction() as cur: - assert state._db.endpoints.add(cur, _make_row("e1", "alpha", task_id)) is False + assert state._store.endpoints.add(cur, _make_row("e1", "alpha", task_id)) is False - assert state._db.endpoints.get("e1") is None + assert state._store.endpoints.get("e1") is None def test_remove_drops_endpoint_by_id(state): tasks = submit_job(state, "j", make_job_request("j")) t = tasks[0].task_id with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row("e1", "alpha", t)) - state._db.endpoints.add(cur, _make_row("e2", "beta", t)) + state._store.endpoints.add(cur, _make_row("e1", "alpha", t)) + state._store.endpoints.add(cur, _make_row("e2", "beta", t)) with state._db.transaction() as cur: - removed = state._db.endpoints.remove(cur, "e1") + removed = state._store.endpoints.remove(cur, "e1") assert removed is not None and removed.endpoint_id == "e1" - assert {r.endpoint_id for r in state._db.endpoints.query()} == {"e2"} + assert {r.endpoint_id for r in state._store.endpoints.query()} == {"e2"} def test_remove_by_task_drops_all_task_endpoints(state): @@ -144,15 +144,15 @@ def test_remove_by_task_drops_all_task_endpoints(state): t1, t2 = tasks[0].task_id, tasks[1].task_id with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row("e1", "alpha", t1)) - state._db.endpoints.add(cur, _make_row("e2", "beta", t1)) - state._db.endpoints.add(cur, _make_row("e3", "gamma", t2)) + state._store.endpoints.add(cur, _make_row("e1", "alpha", t1)) + state._store.endpoints.add(cur, _make_row("e2", "beta", t1)) + state._store.endpoints.add(cur, _make_row("e3", "gamma", t2)) with state._db.transaction() as cur: - removed = state._db.endpoints.remove_by_task(cur, t1) + removed = state._store.endpoints.remove_by_task(cur, t1) assert set(removed) == {"e1", "e2"} - assert {r.endpoint_id for r in state._db.endpoints.query()} == {"e3"} + assert {r.endpoint_id for r in state._store.endpoints.query()} == {"e3"} def test_remove_by_job_ids_drops_subtree(state): @@ -163,14 +163,14 @@ def test_remove_by_job_ids_drops_subtree(state): t2 = tasks_b[0].task_id with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row("e1", "alpha", t1)) - state._db.endpoints.add(cur, _make_row("e2", "beta", t2)) + state._store.endpoints.add(cur, _make_row("e1", "alpha", t1)) + state._store.endpoints.add(cur, _make_row("e2", "beta", t2)) with state._db.transaction() as cur: - removed = state._db.endpoints.remove_by_job_ids(cur, [ja]) + removed = state._store.endpoints.remove_by_job_ids(cur, [ja]) assert removed == ["e1"] - assert [r.endpoint_id for r in state._db.endpoints.query()] == ["e2"] + assert [r.endpoint_id for r in state._store.endpoints.query()] == ["e2"] # --- Query semantics -------------------------------------------------------- @@ -193,51 +193,51 @@ def populated(state): ] with state._db.transaction() as cur: for r in rows: - state._db.endpoints.add(cur, r) + state._store.endpoints.add(cur, r) return state, rows, (t0, t1, t2) def test_query_by_exact_name(populated): state, _, _ = populated - ids = {r.endpoint_id for r in state._db.endpoints.query(EndpointQuery(exact_name="alpha/svc"))} + ids = {r.endpoint_id for r in state._store.endpoints.query(EndpointQuery(exact_name="alpha/svc"))} assert ids == {"e1"} def test_query_by_prefix(populated): state, _, _ = populated - ids = {r.endpoint_id for r in state._db.endpoints.query(EndpointQuery(name_prefix="alpha/"))} + ids = {r.endpoint_id for r in state._store.endpoints.query(EndpointQuery(name_prefix="alpha/"))} assert ids == {"e1", "e2"} def test_query_by_task_ids(populated): state, _, (t0, _, t2) = populated - ids = {r.endpoint_id for r in state._db.endpoints.query(EndpointQuery(task_ids=(t0, t2)))} + ids = {r.endpoint_id for r in state._store.endpoints.query(EndpointQuery(task_ids=(t0, t2)))} assert ids == {"e1", "e2", "e4"} def test_query_by_endpoint_ids(populated): state, _, _ = populated - ids = {r.endpoint_id for r in state._db.endpoints.query(EndpointQuery(endpoint_ids=("e2", "e3")))} + ids = {r.endpoint_id for r in state._store.endpoints.query(EndpointQuery(endpoint_ids=("e2", "e3")))} assert ids == {"e2", "e3"} def test_query_limit(populated): state, _, _ = populated - rows = state._db.endpoints.query(EndpointQuery(limit=2)) + rows = state._store.endpoints.query(EndpointQuery(limit=2)) assert len(rows) == 2 def test_query_empty_matches_all(populated): state, rows, _ = populated - assert {r.endpoint_id for r in state._db.endpoints.query()} == {r.endpoint_id for r in rows} + assert {r.endpoint_id for r in state._store.endpoints.query()} == {r.endpoint_id for r in rows} def test_resolve_returns_address_for_exact_name(populated): state, _, _ = populated - row = state._db.endpoints.resolve("alpha/svc") + row = state._store.endpoints.resolve("alpha/svc") assert row is not None assert row.endpoint_id == "e1" - assert state._db.endpoints.resolve("nope") is None + assert state._store.endpoints.resolve("nope") is None # --- Parity with the legacy SQL builder ------------------------------------- @@ -262,7 +262,7 @@ def test_registry_parity_with_legacy_sql(populated, build_query): sql, params = _endpoint_query_sql_legacy(query) with state._db.read_snapshot() as q: expected_ids = sorted(r.endpoint_id for r in ENDPOINT_PROJECTION.decode(q.fetchall(sql, tuple(params)))) - actual_ids = sorted(r.endpoint_id for r in state._db.endpoints.query(query)) + actual_ids = sorted(r.endpoint_id for r in state._store.endpoints.query(query)) # For LIMIT queries, both sides just need to be a valid subset of matching rows. if query.limit is not None: @@ -290,9 +290,9 @@ def writer(): eid = f"e{i % len(task_ids)}" name = f"svc-{i % len(task_ids)}" with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row(eid, name, t)) + state._store.endpoints.add(cur, _make_row(eid, name, t)) with state._db.transaction() as cur: - state._db.endpoints.remove(cur, eid) + state._store.endpoints.remove(cur, eid) i += 1 except Exception as exc: errors.append(f"writer: {exc!r}") @@ -300,7 +300,7 @@ def writer(): def reader(): try: while not stop.is_set(): - snapshot = state._db.endpoints.query() + snapshot = state._store.endpoints.query() # Verify the snapshot itself is internally consistent: every # endpoint_id in the result set is unique (no duplicates from # a torn index). @@ -311,11 +311,11 @@ def reader(): # present in a subsequent get() — the writer may remove it # between the two calls (TOCTOU). for row in snapshot: - state._db.endpoints.get(row.endpoint_id) + state._store.endpoints.get(row.endpoint_id) for i in range(len(task_ids)): - state._db.endpoints.query(EndpointQuery(name_prefix=f"svc-{i}")) - state._db.endpoints.query(EndpointQuery(exact_name=f"svc-{i}")) - state._db.endpoints.query(EndpointQuery(task_ids=(task_ids[i],))) + state._store.endpoints.query(EndpointQuery(name_prefix=f"svc-{i}")) + state._store.endpoints.query(EndpointQuery(exact_name=f"svc-{i}")) + state._store.endpoints.query(EndpointQuery(task_ids=(task_ids[i],))) except Exception as exc: errors.append(f"reader: {exc!r}") diff --git a/lib/iris/tests/cluster/controller/test_service.py b/lib/iris/tests/cluster/controller/test_service.py index ffafd9fb7f..1995678fac 100644 --- a/lib/iris/tests/cluster/controller/test_service.py +++ b/lib/iris/tests/cluster/controller/test_service.py @@ -679,7 +679,7 @@ def test_terminate_job_rejected_for_non_owner(state, mock_controller, tmp_path): auth_service = ControllerServiceImpl( state, - state._db, + state._store, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles_owner")), log_service=LogServiceImpl(), @@ -710,7 +710,7 @@ def test_launch_child_job_rejected_for_non_owner(state, mock_controller, tmp_pat auth_service = ControllerServiceImpl( state, - state._db, + state._store, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles_child")), log_service=LogServiceImpl(), @@ -1170,7 +1170,7 @@ def test_register_requires_worker_role(state, mock_controller, tmp_path): auth = ControllerAuth(provider="static") service = ControllerServiceImpl( state, - db, + state._store, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=LogServiceImpl(), @@ -1206,7 +1206,7 @@ def test_register_allows_worker_role(state, mock_controller, tmp_path): auth = ControllerAuth(provider="static") service = ControllerServiceImpl( state, - db, + state._store, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=LogServiceImpl(), diff --git a/lib/iris/tests/cluster/controller/test_task_resource_history.py b/lib/iris/tests/cluster/controller/test_task_resource_history.py index bfb3354eac..ce007b22b2 100644 --- a/lib/iris/tests/cluster/controller/test_task_resource_history.py +++ b/lib/iris/tests/cluster/controller/test_task_resource_history.py @@ -5,6 +5,7 @@ import pytest from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ( Assignment, ControllerTransitions, @@ -21,7 +22,7 @@ @pytest.fixture def state(tmp_path): db = ControllerDB(db_dir=tmp_path) - s = ControllerTransitions(db=db) + s = ControllerTransitions(store=ControllerStore(db)) yield s db.close() diff --git a/lib/iris/tests/cluster/controller/test_transitions.py b/lib/iris/tests/cluster/controller/test_transitions.py index d15cc5ceca..d4639c3729 100644 --- a/lib/iris/tests/cluster/controller/test_transitions.py +++ b/lib/iris/tests/cluster/controller/test_transitions.py @@ -30,6 +30,7 @@ EndpointRow, ) from iris.cluster.controller.scheduler import JobRequirements, Scheduler +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ( Assignment, ControllerTransitions, @@ -90,7 +91,7 @@ def _queued_dispatch( def _endpoints(state: ControllerTransitions, query: EndpointQuery = EndpointQuery()) -> list[EndpointRow]: - rows = state._db.endpoints.query(query) + rows = state._store.endpoints.query(query) # Mirror the original helper's ordering (registered_at DESC, endpoint_id ASC). return sorted(rows, key=lambda r: (-r.registered_at.epoch_ms(), r.endpoint_id)) @@ -2746,7 +2747,7 @@ def test_snapshot_round_trip_preserves_reservation_holder(state): checkpoint_path = Path(tmpdir) / "controller.sqlite3" state._db.backup_to(checkpoint_path) restored_db = ControllerDB(db_dir=Path(tmpdir)) - restored_state = ControllerTransitions(db=restored_db) + restored_state = ControllerTransitions(store=ControllerStore(restored_db)) restored_holder = _query_job(restored_state, holder_job_id) assert restored_holder is not None @@ -3437,7 +3438,7 @@ def test_kill_non_terminal_reservation_holder_does_not_decommit_co_tenant(harnes with harness.state._db.transaction() as cur: _kill_non_terminal_tasks( cur, - harness.state._db.endpoints, + harness.state._store.endpoints, holder_job_id.to_wire(), "Job finalized", 0, diff --git a/lib/iris/tests/test_budget.py b/lib/iris/tests/test_budget.py index 2402fea6df..a7aec6c98b 100644 --- a/lib/iris/tests/test_budget.py +++ b/lib/iris/tests/test_budget.py @@ -268,7 +268,7 @@ def service(state, tmp_path) -> ControllerServiceImpl: priority-band authorization triggers (see launch_job band check).""" return ControllerServiceImpl( state, - state._db, + state._store, controller=MockController(), bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=LogServiceImpl(), diff --git a/tests/integration/iris/test_iris_kind.py b/tests/integration/iris/test_iris_kind.py index 8711e1498f..0b81db8176 100644 --- a/tests/integration/iris/test_iris_kind.py +++ b/tests/integration/iris/test_iris_kind.py @@ -31,6 +31,7 @@ from iris.cluster.controller.controller import Controller, ControllerConfig from iris.cluster.controller.db import ControllerDB from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.transitions import ControllerTransitions from iris.log_server.server import LogServiceImpl from iris.cluster.providers.k8s.fake import FakeNodeResources, InMemoryK8sService @@ -139,7 +140,8 @@ def _get_iris_pods(k8s: InMemoryK8sService) -> list[dict]: def _make_coreweave_harness(tmp_path: Path) -> ServiceTestHarness: db = ControllerDB(db_dir=tmp_path / "cw_db") log_service = LogServiceImpl(log_dir=tmp_path / "cw_logs") - state = ControllerTransitions(db=db) + store = ControllerStore(db) + state = ControllerTransitions(store=store) k8s = InMemoryK8sService() k8s.add_node_pool( @@ -181,7 +183,7 @@ def _make_coreweave_harness(tmp_path: Path) -> ServiceTestHarness: service = ControllerServiceImpl( state, - db, + store, controller=ctrl, bundle_store=BundleStore(storage_dir=str(tmp_path / "cw_bundles")), log_service=log_service,