diff --git a/lib/iris/scripts/benchmark_db_queries.py b/lib/iris/scripts/benchmark_db_queries.py index d87e255612..f5f1441726 100644 --- a/lib/iris/scripts/benchmark_db_queries.py +++ b/lib/iris/scripts/benchmark_db_queries.py @@ -42,18 +42,14 @@ _read_reservation_claims, _schedulable_tasks, ) -from iris.cluster.controller.db import ( - ACTIVE_TASK_STATES, - ControllerDB, - EndpointQuery, - healthy_active_workers_with_attributes, - running_tasks_by_worker, - tasks_for_job_with_attempts, -) +from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import ControllerStores, EndpointQuery from iris.cluster.controller.schema import ( + ACTIVE_TASK_STATES, JOB_CONFIG_JOIN, JOB_DETAIL_PROJECTION, ) +from iris.cluster.types import TERMINAL_JOB_STATES from iris.cluster.controller.service import ( USER_JOB_STATES, _descendant_jobs, @@ -80,7 +76,7 @@ ReservationClaim, TaskUpdate, ) -from iris.cluster.types import TERMINAL_JOB_STATES, JobName, WorkerId +from iris.cluster.types import JobName, WorkerId from iris.rpc import job_pb2 from iris.rpc import controller_pb2 from rigging.timing import Timestamp @@ -195,6 +191,7 @@ def bench( def benchmark_scheduling(db: ControllerDB) -> None: """Benchmark scheduling-loop queries.""" + stores = ControllerStores.from_db(db) # Create pending work so scheduling queries have realistic load. # Pick up to 50 running jobs and revert their first few tasks to PENDING. with db.read_snapshot() as snap: @@ -217,18 +214,25 @@ def benchmark_scheduling(db: ControllerDB) -> None: bench("_schedulable_tasks", lambda: _schedulable_tasks(db)) - bench( - "healthy_active_workers_with_attributes", - lambda: healthy_active_workers_with_attributes(db), - ) + def _bench_healthy_active(): + with stores.read() as ctx: + stores.workers.healthy_active_with_attributes(ctx.cur) - workers = healthy_active_workers_with_attributes(db) + bench("healthy_active_workers_with_attributes", _bench_healthy_active) + + with stores.read() as ctx: + workers = stores.workers.healthy_active_with_attributes(ctx.cur) bench("_building_counts", lambda: _building_counts(db, workers)) tasks = _schedulable_tasks(db) job_ids = {t.job_id for t in tasks} + + def _bench_jobs_by_id(): + with stores.read() as ctx: + _jobs_by_id(stores, ctx.cur, job_ids) + if job_ids: - bench("_jobs_by_id", lambda: _jobs_by_id(db, job_ids)) + bench("_jobs_by_id", _bench_jobs_by_id) else: print(" _jobs_by_id (skipped, no pending jobs)") @@ -255,7 +259,8 @@ def benchmark_scheduling(db: ControllerDB) -> None: # --- Write-path benchmarks (use a lightweight clone) --- write_db = clone_db(db) - write_txns = ControllerTransitions(write_db) + write_stores = ControllerStores.from_db(write_db) + write_txns = ControllerTransitions(stores=write_stores) try: # queue_assignments: the main write-lock holder in scheduling. @@ -385,6 +390,7 @@ def _reset_prune(): def benchmark_dashboard(db: ControllerDB) -> None: """Benchmark dashboard/service queries.""" + stores = ControllerStores.from_db(db) def _bench_jobs_in_states(db): placeholders = ",".join("?" for _ in USER_JOB_STATES) @@ -445,10 +451,16 @@ def _bench_jobs_in_states(db): bench("_worker_roster", lambda: _worker_roster(db)) - workers = healthy_active_workers_with_attributes(db) + with stores.read() as ctx: + workers = stores.workers.healthy_active_with_attributes(ctx.cur) worker_ids = {w.worker_id for w in workers} + + def _bench_running_by_worker(): + with stores.read() as ctx: + stores.tasks.running_tasks_by_worker(ctx.cur, worker_ids) + if worker_ids: - bench("running_tasks_by_worker", lambda: running_tasks_by_worker(db, worker_ids)) + bench("running_tasks_by_worker", _bench_running_by_worker) else: print(" running_tasks_by_worker (skipped, no workers)") @@ -479,11 +491,12 @@ def _bench_jobs_in_states(db): if sample_job: bench("_read_job", lambda: _read_job(db, sample_job.job_id)) + def _bench_tasks_for_job(): + with stores.read() as ctx: + stores.tasks.tasks_for_job_with_attempts(ctx.cur, sample_job.job_id) + if sample_job: - bench( - "tasks_for_job_with_attempts", - lambda: tasks_for_job_with_attempts(db, sample_job.job_id), - ) + bench("tasks_for_job_with_attempts", _bench_tasks_for_job) if sample_job: sample_tasks_for_read = _tasks_for_listing(db, job_id=sample_job.job_id) @@ -509,7 +522,9 @@ def _list_jobs_full(db): def benchmark_heartbeat(db: ControllerDB) -> None: """Benchmark heartbeat/provider-sync queries.""" - workers = healthy_active_workers_with_attributes(db) + stores = ControllerStores.from_db(db) + with stores.read() as ctx: + workers = stores.workers.healthy_active_with_attributes(ctx.cur) worker_ids = {w.worker_id for w in workers} if not workers: @@ -543,9 +558,13 @@ def _all_workers_running_tasks(): bench(f"drain_dispatch ({len(workers)} workers)", _all_workers_running_tasks) - bench("running_tasks_by_worker", lambda: running_tasks_by_worker(db, worker_ids)) + def _bench_running_by_worker_heartbeat(): + with stores.read() as ctx: + stores.tasks.running_tasks_by_worker(ctx.cur, worker_ids) + + bench("running_tasks_by_worker", _bench_running_by_worker_heartbeat) - transitions = ControllerTransitions(db) + transitions = ControllerTransitions(stores=stores) bench( f"drain_dispatch_all ({len(workers)} workers)", lambda: transitions.drain_dispatch_all(), @@ -597,7 +616,8 @@ def _all_workers_running_tasks(): ) hb_db = clone_db(db) - hb_transitions = ControllerTransitions(hb_db) + hb_stores = ControllerStores.from_db(hb_db) + hb_transitions = ControllerTransitions(stores=hb_stores) try: bench( @@ -628,6 +648,13 @@ def _all_workers_running_tasks(): shutil.rmtree(hb_db._db_dir, ignore_errors=True) +def _healthy_workers(db: ControllerDB) -> list[Any]: + """Bench-local shim over the post-refactor store API.""" + stores = ControllerStores.from_db(db) + with stores.read() as ctx: + return list(stores.workers.healthy_active_with_attributes(ctx.cur)) + + def _active_task_sample(db: ControllerDB, limit: int) -> list[tuple[JobName, int]]: """Return up to ``limit`` (task_id, current_attempt_id) pairs for non-terminal tasks. @@ -1006,7 +1033,7 @@ def _reset_fail(saved_w=saved_workers, saved_t=saved_tasks): # Contention: run an add_endpoint burst concurrently with an # apply_heartbeats_batch call on two Python threads sharing the clone # DB. SQLite serializes writers, so this measures write-lock wait. - workers = healthy_active_workers_with_attributes(write_db) + workers = _healthy_workers(write_db) if workers and len(sample) >= 200: active_states = tuple(ACTIVE_TASK_STATES) running_by_worker: dict[str, list[tuple[str, int]]] = {} @@ -1164,7 +1191,7 @@ def _burst_100(): # (c) burst 100 under concurrent apply_heartbeats_batch contention. active_states = tuple(ACTIVE_TASK_STATES) - workers = healthy_active_workers_with_attributes(write_db) + workers = _healthy_workers(write_db) running_tasks_per_worker: dict[str, list[tuple[str, int]]] = {} for w in workers: wid = str(w.worker_id) @@ -1231,7 +1258,7 @@ def _build_heartbeat_requests(db: ControllerDB) -> list[HeartbeatApplyRequest]: one HeartbeatApplyRequest per active worker, with one RUNNING resource-usage update per task currently assigned to that worker. """ - workers = healthy_active_workers_with_attributes(db) + workers = _healthy_workers(db) active_states = tuple(ACTIVE_TASK_STATES) snapshot_proto = job_pb2.WorkerResourceSnapshot() usage = job_pb2.ResourceUsage(cpu_millicores=1000, memory_mb=1024) diff --git a/lib/iris/src/iris/cluster/controller/actor_proxy.py b/lib/iris/src/iris/cluster/controller/actor_proxy.py index 4fdae2cfe9..d79414d8f9 100644 --- a/lib/iris/src/iris/cluster/controller/actor_proxy.py +++ b/lib/iris/src/iris/cluster/controller/actor_proxy.py @@ -20,7 +20,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import ControllerStores logger = logging.getLogger(__name__) @@ -48,8 +48,8 @@ class ActorProxy: """Forwards ActorService RPCs to actors resolved from the endpoint registry.""" - def __init__(self, db: ControllerDB): - self._db = db + def __init__(self, stores: ControllerStores): + self._stores = stores self._client = httpx.AsyncClient(timeout=PROXY_TIMEOUT_SECONDS) async def close(self) -> None: @@ -98,7 +98,7 @@ async def handle(self, request: Request) -> Response: def _resolve_endpoint(self, name: str) -> str | None: """Resolve an endpoint name to an address via the in-memory registry.""" - row = self._db.endpoints.resolve(name) + row = self._stores.endpoints.resolve(name) if row is None: return None return row.address diff --git a/lib/iris/src/iris/cluster/controller/budget.py b/lib/iris/src/iris/cluster/controller/budget.py index 214b466da8..092540f7db 100644 --- a/lib/iris/src/iris/cluster/controller/budget.py +++ b/lib/iris/src/iris/cluster/controller/budget.py @@ -7,17 +7,28 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Any, Generic, Protocol, TypeVar +from collections.abc import Callable import json - -from iris.cluster.controller.db import ACTIVE_TASK_STATES, QuerySnapshot +from iris.cluster.controller.schema import ACTIVE_TASK_STATES from iris.cluster.types import JobName from iris.rpc import job_pb2 T = TypeVar("T") +class SnapshotReader(Protocol): + """Interface budget functions need from QuerySnapshot.""" + + def raw( + self, + sql: str, + params: tuple = ..., + decoders: dict[str, Callable] | None = None, + ) -> list[Any]: ... + + def _accel_from_device_json(device_json: str | None) -> int: """Count GPU + TPU accelerators from a device JSON column.""" if not device_json: @@ -62,7 +73,7 @@ def resource_value(cpu_millicores: int, memory_bytes: int, accelerator_count: in return 1000 * accelerator_count + ram_gb + 5 * cpu_cores -def compute_user_spend(snapshot: QuerySnapshot) -> dict[str, int]: +def compute_user_spend(snapshot: SnapshotReader) -> dict[str, int]: """Compute per-user budget spend from active tasks. Joins tasks (in ASSIGNED/BUILDING/RUNNING states) with job_config to get diff --git a/lib/iris/src/iris/cluster/controller/controller.py b/lib/iris/src/iris/cluster/controller/controller.py index f73378718a..fd38b1d129 100644 --- a/lib/iris/src/iris/cluster/controller/controller.py +++ b/lib/iris/src/iris/cluster/controller/controller.py @@ -43,26 +43,23 @@ upload_checkpoint, write_checkpoint, ) -from iris.cluster.controller.db import ( - ControllerDB, - healthy_active_workers_with_attributes, - insert_task_profile, +from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import ( + ControllerStores, + Cursor, + JobDetailFilter, + WorkerFilter, job_scheduling_deadline, - running_tasks_by_worker, task_row_can_be_scheduled, - timed_out_executing_tasks, ) from iris.cluster.controller.schema import ( ATTEMPT_PROJECTION, - JOB_CONFIG_JOIN, - JOB_DETAIL_PROJECTION, - JOB_SCHEDULING_PROJECTION, TASK_DETAIL_PROJECTION, + TASK_DETAIL_SELECT_T, TASK_ROW_PROJECTION, WORKER_DETAIL_PROJECTION, JobDetailRow, JobRow, - JobSchedulingRow, TaskDetailRow, TaskRow, WorkerDetailRow, @@ -220,14 +217,14 @@ class _SchedulingOrder: user_budget_limits: dict[str, int] -def _resource_spec_from_row(job: JobRow | JobSchedulingRow) -> job_pb2.ResourceSpecProto: - """Reconstruct a ResourceSpecProto from native job columns.""" +def _resource_spec_from_row(job: JobRow | JobDetailRow) -> job_pb2.ResourceSpecProto: + """Reconstruct a ResourceSpecProto from a typed job row.""" return resource_spec_from_scalars( - job.res_cpu_millicores, job.res_memory_bytes, job.res_disk_bytes, job.res_device_json + job.resources.cpu_millicores, job.resources.memory_bytes, job.resources.disk_bytes, job.resources.device_json ) -def job_requirements_from_job(job: JobSchedulingRow) -> JobRequirements: +def job_requirements_from_job(job: JobDetailRow) -> JobRequirements: """Convert a job row to scheduler-compatible JobRequirements.""" return JobRequirements( resources=_resource_spec_from_row(job), @@ -238,7 +235,7 @@ def job_requirements_from_job(job: JobSchedulingRow) -> JobRequirements: def compute_demand_entries( - queries: ControllerDB, + stores: ControllerStores, scheduler: Scheduler | None = None, workers: list[WorkerSnapshot] | None = None, reservation_claims: dict[WorkerId, ReservationClaim] | None = None, @@ -265,19 +262,24 @@ def compute_demand_entries( ``max(real_pending, holders)``) should be added here. Args: - queries: Controller DB read surface for pending tasks and jobs. + stores: Controller stores bundle (provides read surface and store access). scheduler: Scheduler for dry-run pass. If None, skips dry-run. workers: Available workers for dry-run. If None, skips dry-run. reservation_claims: Reservation claims to apply taint injection in the dry-run, matching the real scheduling path. If None, no taints applied. """ demand_entries: list[DemandEntry] = [] + queries = stores.db # Collect all schedulable pending tasks, grouped by job. tasks_by_job: dict[JobName, list[TaskRow]] = defaultdict(list) all_schedulable: list[TaskRow] = [] pending = _schedulable_tasks(queries) - job_rows = list(_jobs_by_id(queries, {task.job_id for task in pending}).values()) if pending else [] + if pending: + with stores.read() as _ctx: + job_rows = list(_jobs_by_id(stores, _ctx.cur, {task.job_id for task in pending}).values()) + else: + job_rows = [] jobs_by_id = {job.job_id: job for job in job_rows} for task in pending: if not task_row_can_be_scheduled(task): @@ -403,36 +405,20 @@ def _read_reservation_claims(db: ControllerDB) -> dict[WorkerId, ReservationClai } -def _jobs_by_id(queries: ControllerDB, job_ids: set[JobName]) -> dict[JobName, JobSchedulingRow]: +def _jobs_by_id(stores: ControllerStores, cur: Cursor, job_ids: set[JobName]) -> dict[JobName, JobDetailRow]: if not job_ids: return {} - wires = [job_id.to_wire() for job_id in job_ids] - placeholders = ",".join("?" for _ in wires) - with queries.read_snapshot() as snapshot: - jobs = JOB_SCHEDULING_PROJECTION.decode( - snapshot.fetchall( - f"SELECT {JOB_SCHEDULING_PROJECTION.select_clause()} " - f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id IN ({placeholders})", - tuple(wires), - ), - ) + wires = tuple(job_id.to_wire() for job_id in job_ids) + jobs = stores.jobs.query(cur, JobDetailFilter(job_ids=wires), detail=True) return {job.job_id: job for job in jobs} -def _jobs_with_reservations(queries: ControllerDB, states: tuple[int, ...]) -> list[JobDetailRow]: +def _jobs_with_reservations(stores: ControllerStores, cur: Cursor, states: tuple[int, ...]) -> list[JobDetailRow]: """Fetch only jobs that have reservations, filtering at the SQL level. Uses the has_reservation column on the jobs table to filter without a JOIN. """ - placeholders = ",".join("?" for _ in states) - with queries.read_snapshot() as snapshot: - rows = snapshot._fetchall( - f"SELECT {JOB_DETAIL_PROJECTION.select_clause()} " - f"FROM jobs j {JOB_CONFIG_JOIN} " - f"WHERE j.state IN ({placeholders}) AND j.has_reservation = 1", - list(states), - ) - return JOB_DETAIL_PROJECTION.decode(rows) + return stores.jobs.query(cur, JobDetailFilter(states=frozenset(states), has_reservation=True), detail=True) def _get_running_tasks_with_band_and_value( @@ -579,7 +565,7 @@ def _tasks_by_ids_with_attempts(queries: ControllerDB, task_ids: set[JobName]) - with queries.read_snapshot() as snapshot: tasks = TASK_DETAIL_PROJECTION.decode( snapshot.fetchall( - f"SELECT {TASK_DETAIL_PROJECTION.select_clause()} " + f"SELECT {TASK_DETAIL_SELECT_T} " f"FROM tasks t WHERE t.task_id IN ({placeholders}) ORDER BY t.task_id ASC", tuple(task_wires), ), @@ -754,7 +740,7 @@ def _find_reservation_ancestor(queries: ControllerDB, job_id: JobName) -> JobNam current = job_id.parent with queries.read_snapshot() as q: while current is not None: - row = q.execute_sql( + row = q.execute( "SELECT has_reservation FROM jobs WHERE job_id = ?", (current.to_wire(),), ).fetchone() @@ -767,7 +753,7 @@ def _find_reservation_ancestor(queries: ControllerDB, job_id: JobName) -> JobNam def _reservation_region_constraints( job_id_wire: str, claims: dict[WorkerId, ReservationClaim], - queries: ControllerDB, + stores: ControllerStores, existing_constraints: list[Constraint], ) -> list[Constraint]: """Derive region constraints from claimed reservation workers. @@ -782,11 +768,12 @@ def _reservation_region_constraints( return existing_constraints claimed_worker_ids = {worker_id for worker_id, claim in claims.items() if claim.job_id == job_id_wire} - workers_by_id = { - worker.worker_id: worker - for worker in healthy_active_workers_with_attributes(queries) - if worker.worker_id in claimed_worker_ids - } + with stores.read() as ctx: + workers_by_id = { + worker.worker_id: worker + for worker in stores.workers.healthy_active_with_attributes(ctx.cur) + if worker.worker_id in claimed_worker_ids + } regions: set[str] = set() for worker in workers_by_id.values(): if worker is None: @@ -1033,6 +1020,7 @@ def __init__( self._db = db else: self._db = ControllerDB(db_dir=config.local_state_dir / "db") + self._stores = ControllerStores.from_db(self._db) # ThreadContainer must be initialized before the log service setup # because _start_local_log_server spawns a uvicorn thread. @@ -1069,7 +1057,7 @@ def __init__( logging.getLogger("iris").addHandler(self._log_handler) self._transitions = ControllerTransitions( - db=self._db, + stores=self._stores, heartbeat_failure_threshold=config.heartbeat_failure_threshold, user_budget_defaults=config.user_budget_defaults, ) @@ -1079,7 +1067,7 @@ def __init__( self._service = ControllerServiceImpl( self._transitions, - self._db, + self._stores, controller=self, bundle_store=self._bundle_store, log_service=self._remote_log_service, @@ -1471,11 +1459,12 @@ def _run_profile_loop(self, stop_event: threading.Event) -> None: def _profile_all_running_tasks(self) -> None: """Capture CPU and memory profiles for every running task and store in the DB.""" - workers = healthy_active_workers_with_attributes(self._db) - if not workers: - return - workers_by_id = {w.worker_id: w for w in workers} - tasks_by_worker = running_tasks_by_worker(self._db, set(workers_by_id.keys())) + with self._stores.read() as ctx: + workers = self._stores.workers.healthy_active_with_attributes(ctx.cur) + if not workers: + return + workers_by_id = {w.worker_id: w for w in workers} + tasks_by_worker = self._stores.tasks.running_tasks_by_worker(ctx.cur, set(workers_by_id.keys())) profile_targets: list[tuple[JobName, WorkerRow]] = [] for worker_id, task_ids in tasks_by_worker.items(): @@ -1538,20 +1527,21 @@ def _capture_one_profile( if not resp.profile_data: logger.debug("Empty %s profile for %s", profile_kind, task_id) return - insert_task_profile( - self._db, - task_id=task_id.to_wire(), - profile_data=resp.profile_data, - captured_at=Timestamp.now(), - profile_kind=profile_kind, - ) + with self._stores.transact() as ctx: + self._stores.tasks.insert_task_profile( + ctx.cur, + task_id=task_id.to_wire(), + profile_data=resp.profile_data, + captured_at=Timestamp.now(), + profile_kind=profile_kind, + ) logger.debug("Stored %d byte %s profile for %s", len(resp.profile_data), profile_kind, task_id) except Exception: logger.debug("Profile capture (%s) failed for %s", profile_kind, task_id, exc_info=True) def _is_reservation_satisfied( self, - job: JobSchedulingRow, + job: JobDetailRow, claims: dict[WorkerId, ReservationClaim] | None = None, ) -> bool: """Check if a job's reservation is fully satisfied. @@ -1588,13 +1578,11 @@ def _cleanup_stale_claims(self, claims: dict[WorkerId, ReservationClaim] | None if claims is None: claims = _read_reservation_claims(self._db) persisted = True - with self._db.read_snapshot() as snapshot: - active_worker_ids = { - WorkerId(str(row[0])) - for row in snapshot.fetchall("SELECT w.worker_id FROM workers w WHERE w.active = 1") - } - claimed_job_ids = {JobName.from_wire(claim.job_id) for claim in claims.values()} - claimed_jobs = list(_jobs_by_id(self._db, claimed_job_ids).values()) if claimed_job_ids else [] + with self._stores.read() as ctx: + active_workers = self._stores.workers.query(ctx.cur, WorkerFilter(active=True)) + claimed_job_ids = {JobName.from_wire(claim.job_id) for claim in claims.values()} + claimed_jobs = list(_jobs_by_id(self._stores, ctx.cur, claimed_job_ids).values()) if claimed_job_ids else [] + active_worker_ids = {w.worker_id for w in active_workers} jobs_by_id = {job.job_id.to_wire(): job for job in claimed_jobs} stale: list[WorkerId] = [] for worker_id, claim in claims.items(): @@ -1622,15 +1610,16 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai persisted = True claimed_entries: set[tuple[str, int]] = {(c.job_id, c.entry_idx) for c in claims.values()} claimed_worker_ids: set[WorkerId] = set(claims.keys()) - all_workers = healthy_active_workers_with_attributes(self._db) - changed = False - reservable_states = ( job_pb2.JOB_STATE_PENDING, job_pb2.JOB_STATE_BUILDING, job_pb2.JOB_STATE_RUNNING, ) - reservation_jobs = _jobs_with_reservations(self._db, reservable_states) + with self._stores.read() as ctx: + all_workers = self._stores.workers.healthy_active_with_attributes(ctx.cur) + reservation_jobs = _jobs_with_reservations(self._stores, ctx.cur, reservable_states) + changed = False + for job in reservation_jobs: job_wire = job.job_id.to_wire() for idx, res_entry in enumerate(reservation_entries_from_json(job.reservation_json)): @@ -1749,7 +1738,8 @@ def _read_scheduling_state(self) -> _SchedulingStateRead: timer = Timer() with slow_log(logger, "scheduling state reads", threshold_ms=50): pending_tasks = _schedulable_tasks(self._db) - workers = healthy_active_workers_with_attributes(self._db) + with self._stores.read() as ctx: + workers = self._stores.workers.healthy_active_with_attributes(ctx.cur) return _SchedulingStateRead( pending_tasks=pending_tasks, workers=workers, @@ -1770,7 +1760,8 @@ def _apply_scheduling_gates( tasks_per_job: dict[JobName, int] = defaultdict(int) cap = self._config.max_tasks_per_job_per_cycle filter_counts: dict[str, int] = defaultdict(int) - jobs_by_id = _jobs_by_id(self._db, {task.job_id for task in pending_tasks}) + with self._stores.read() as ctx: + jobs_by_id = _jobs_by_id(self._stores, ctx.cur, {task.job_id for task in pending_tasks}) for task in pending_tasks: if not task_row_can_be_scheduled(task): filter_counts["task_not_schedulable"] += 1 @@ -2025,7 +2016,8 @@ def _enforce_execution_timeouts(self) -> None: if now_ms - self._last_timeout_check_ms < self._TIMEOUT_CHECK_INTERVAL_MS: return self._last_timeout_check_ms = now_ms - timed_out = timed_out_executing_tasks(self._db, now) + with self._stores.read() as ctx: + timed_out = self._stores.tasks.timed_out_executing_tasks(ctx.cur, now) if not timed_out: return for task in timed_out: @@ -2040,7 +2032,8 @@ def _mark_task_unschedulable(self, task: TaskRow) -> None: if self._config.dry_run: logger.info("[DRY-RUN] Would mark task %s as unschedulable", task.task_id) return - job = _jobs_by_id(self._db, {task.job_id}).get(task.job_id) + with self._stores.read() as ctx: + job = _jobs_by_id(self._stores, ctx.cur, {task.job_id}).get(task.job_id) if job and job.scheduling_timeout_ms is not None: timeout = Duration.from_ms(job.scheduling_timeout_ms) else: @@ -2115,7 +2108,8 @@ def _reap_stale_workers(self) -> None: if self._config.dry_run: return threshold_ms = HEARTBEAT_STALENESS_THRESHOLD.to_ms() - workers = healthy_active_workers_with_attributes(self._db) + with self._stores.read() as ctx: + workers = self._stores.workers.healthy_active_with_attributes(ctx.cur) stale = [w for w in workers if w.last_heartbeat.age_ms() > threshold_ms] if not stale: return @@ -2322,7 +2316,8 @@ def _log_sync_health_summary( self._heartbeat_iteration += 1 if _HEALTH_SUMMARY_INTERVAL.should_run(): - workers = healthy_active_workers_with_attributes(self._db) + with self._stores.read() as ctx: + workers = self._stores.workers.healthy_active_with_attributes(ctx.cur) with self._db.read_snapshot() as snap: active = snap.fetchone("SELECT COUNT(*) FROM jobs j WHERE j.state = ?", (job_pb2.JOB_STATE_RUNNING,))[ 0 @@ -2350,9 +2345,10 @@ def _run_autoscaler_once(self) -> None: worker_status_map = self._build_worker_status_map() self._autoscaler.refresh(worker_status_map) - workers = healthy_active_workers_with_attributes(self._db) + with self._stores.read() as ctx: + workers = self._stores.workers.healthy_active_with_attributes(ctx.cur) demand_entries = compute_demand_entries( - self._db, + self._stores, self._scheduler, workers, reservation_claims=_read_reservation_claims(self._db), @@ -2368,7 +2364,8 @@ def _build_worker_status_map(self) -> WorkerStatusMap: decoders={"worker_id": WorkerId}, ) worker_ids = {row.worker_id for row in rows} - running_by_worker = running_tasks_by_worker(self._db, worker_ids) + with self._stores.read() as ctx: + running_by_worker = self._stores.tasks.running_tasks_by_worker(ctx.cur, worker_ids) for wid in worker_ids: result[wid] = WorkerStatus( worker_id=wid, diff --git a/lib/iris/src/iris/cluster/controller/dashboard.py b/lib/iris/src/iris/cluster/controller/dashboard.py index 6f122d94ff..e3da097c0f 100644 --- a/lib/iris/src/iris/cluster/controller/dashboard.py +++ b/lib/iris/src/iris/cluster/controller/dashboard.py @@ -286,7 +286,7 @@ def _create_app(self) -> ASGIApp: rpc_app = WSGIMiddleware(rpc_wsgi_app) - self._actor_proxy = ActorProxy(self._service._db) + self._actor_proxy = ActorProxy(self._service._stores) @requires_auth async def _proxy_actor_rpc(request: Request) -> Response: diff --git a/lib/iris/src/iris/cluster/controller/db.py b/lib/iris/src/iris/cluster/controller/db.py index 57984e5050..96d1f570e4 100644 --- a/lib/iris/src/iris/cluster/controller/db.py +++ b/lib/iris/src/iris/cluster/controller/db.py @@ -8,18 +8,16 @@ import logging import queue import sqlite3 +import time from collections.abc import Callable, Iterable, Iterator, Sequence from contextlib import contextmanager -from dataclasses import dataclass, field, replace as dc_replace from pathlib import Path -from threading import Lock, RLock +from threading import RLock from typing import Any -from iris.cluster.constraints import AttributeValue -from iris.cluster.controller.schema import decode_timestamp_ms, decode_worker_id -from iris.cluster.types import TERMINAL_TASK_STATES, JobName, WorkerId +from iris.cluster.controller.store import UserBudget from iris.rpc import job_pb2 -from rigging.timing import Deadline, Duration, Timestamp +from rigging.timing import Timestamp logger = logging.getLogger(__name__) @@ -62,7 +60,7 @@ def __exit__(self, exc_type: object, exc: object, tb: object) -> None: if self._lock is not None: self._lock.release() - def execute_sql(self, sql: str, params: tuple[object, ...] = ()) -> sqlite3.Cursor: + def execute(self, sql: str, params: tuple[object, ...] = ()) -> sqlite3.Cursor: """Execute raw SQL and return the cursor for result inspection.""" return self._conn.execute(sql, params) @@ -101,60 +99,6 @@ def raw( return rows -# --------------------------------------------------------------------------- -# Shared predicate functions for Task/TaskRow and Worker/WorkerRow. -# Placed above the class definitions so both full and lightweight models -# can delegate to the same logic without duplication. -# --------------------------------------------------------------------------- - - -def task_is_finished( - state: int, failure_count: int, max_retries_failure: int, preemption_count: int, max_retries_preemption: int -) -> bool: - """Whether a task has reached a terminal state with no remaining retries.""" - if state == job_pb2.TASK_STATE_SUCCEEDED: - return True - if state in (job_pb2.TASK_STATE_KILLED, job_pb2.TASK_STATE_UNSCHEDULABLE): - return True - if state == job_pb2.TASK_STATE_FAILED: - return failure_count > max_retries_failure - if state in (job_pb2.TASK_STATE_WORKER_FAILED, job_pb2.TASK_STATE_PREEMPTED): - return preemption_count > max_retries_preemption - return False - - -def task_row_is_finished(task: Any) -> bool: - return task_is_finished( - task.state, task.failure_count, task.max_retries_failure, task.preemption_count, task.max_retries_preemption - ) - - -def task_row_can_be_scheduled(task: Any) -> bool: - if task.state != job_pb2.TASK_STATE_PENDING: - return False - return task.current_attempt_id < 0 or not task_is_finished( - task.state, task.failure_count, task.max_retries_failure, task.preemption_count, task.max_retries_preemption - ) - - -# TERMINAL_TASK_STATES and TERMINAL_JOB_STATES are imported from iris.cluster.types. - -ACTIVE_TASK_STATES: frozenset[int] = frozenset( - { - job_pb2.TASK_STATE_ASSIGNED, - job_pb2.TASK_STATE_BUILDING, - job_pb2.TASK_STATE_RUNNING, - } -) - -# Tasks executing on a worker (subset of ACTIVE that excludes ASSIGNED). -EXECUTING_TASK_STATES: frozenset[int] = frozenset( - { - job_pb2.TASK_STATE_BUILDING, - job_pb2.TASK_STATE_RUNNING, - } -) - # Failure states that trigger coscheduled sibling cascades. FAILURE_TASK_STATES: frozenset[int] = frozenset( { @@ -165,79 +109,12 @@ def task_row_can_be_scheduled(task: Any) -> bool: ) -# job_is_finished is imported from iris.cluster.types (canonical definition). - - -def job_scheduling_deadline(scheduling_deadline_epoch_ms: int | None) -> Deadline | None: - """Compute scheduling deadline from epoch ms.""" - if scheduling_deadline_epoch_ms is None: - return None - return Deadline.after(Timestamp.from_ms(scheduling_deadline_epoch_ms), Duration.from_ms(0)) - - -def attempt_is_terminal(state: int) -> bool: - """Check if an attempt is in a terminal state.""" - return state in TERMINAL_TASK_STATES - - -def attempt_is_worker_failure(state: int) -> bool: - """Check if an attempt is a worker failure or preemption.""" - return state in (job_pb2.TASK_STATE_WORKER_FAILED, job_pb2.TASK_STATE_PREEMPTED) - - -@dataclass(frozen=True) -class UserStats: - user: str - task_state_counts: dict[int, int] = field(default_factory=dict) - job_state_counts: dict[int, int] = field(default_factory=dict) - - -@dataclass(frozen=True) -class TaskJobSummary: - job_id: JobName - task_count: int = 0 - completed_count: int = 0 - failure_count: int = 0 - preemption_count: int = 0 - task_state_counts: dict[int, int] = field(default_factory=dict) - - -@dataclass(frozen=True) -class UserBudget: - user_id: str - budget_limit: int - max_band: int - updated_at: Timestamp - - -@dataclass(frozen=True) -class EndpointQuery: - endpoint_ids: tuple[str, ...] = () - name_prefix: str | None = None - exact_name: str | None = None - task_ids: tuple[JobName, ...] = () - limit: int | None = None - - -def _decode_attribute_rows(rows: Sequence[Any]) -> dict[WorkerId, dict[str, AttributeValue]]: - attrs_by_worker: dict[WorkerId, dict[str, AttributeValue]] = {} - for row in rows: - worker_attrs = attrs_by_worker.setdefault(row.worker_id, {}) - if row.value_type == "int": - worker_attrs[row.key] = AttributeValue(int(row.int_value)) - elif row.value_type == "float": - worker_attrs[row.key] = AttributeValue(float(row.float_value)) - else: - worker_attrs[row.key] = AttributeValue(str(row.str_value or "")) - return attrs_by_worker - - class TransactionCursor: """Wraps a raw sqlite3.Cursor for use within controller transactions. Post-commit hooks registered via :meth:`on_commit` run after the wrapping ``ControllerDB.transaction()`` block commits successfully. They are used - by caches (e.g. ``EndpointRegistry``) to update in-memory state atomically + by caches (e.g. ``EndpointStore``) to update in-memory state atomically with the DB write: rollback suppresses the hook so memory never drifts from disk. """ @@ -316,61 +193,6 @@ def __init__(self, db_dir: Path): self._read_pool: queue.Queue[sqlite3.Connection] = queue.Queue() self._init_read_pool() logger.info("Read pool initialized in %.2fs", time.monotonic() - t0) - # Lazily populated cache of worker attributes, keyed by worker_id. - # Eliminates the per-cycle attribute SQL query from the scheduling hot path. - self._attr_cache: dict[WorkerId, dict[str, AttributeValue]] | None = None - self._attr_cache_lock = Lock() - - # Write-through in-memory cache over the ``endpoints`` table. Imported - # locally to break the ``db -> endpoint_registry -> db`` import cycle; - # this is the single exception to "no local imports" (see AGENTS.md). - from iris.cluster.controller.endpoint_registry import EndpointRegistry - - t0 = time.monotonic() - self._endpoint_registry = EndpointRegistry(self) - logger.info("EndpointRegistry initialized in %.2fs", time.monotonic() - t0) - - @property - def endpoints(self) -> EndpointRegistry: # noqa: F821 - """Process-local cache for the ``endpoints`` table; authoritative for reads.""" - return self._endpoint_registry - - def _populate_attr_cache(self) -> dict[WorkerId, dict[str, AttributeValue]]: - """Load all worker attributes from the DB into the cache. - - Called once on cold start (first access). The caller must NOT hold - _attr_cache_lock when calling this, because the DB read can be slow. - """ - with self.read_snapshot() as q: - rows = q.raw( - "SELECT worker_id, key, value_type, str_value, int_value, float_value FROM worker_attributes", - ) - return _decode_attribute_rows(rows) - - def get_worker_attributes(self) -> dict[WorkerId, dict[str, AttributeValue]]: - """Return cached worker attributes, populating from DB on first call.""" - cache = self._attr_cache - if cache is not None: - return cache - fresh = self._populate_attr_cache() - with self._attr_cache_lock: - if self._attr_cache is None: - self._attr_cache = fresh - return self._attr_cache - - def set_worker_attributes(self, worker_id: WorkerId, attrs: dict[str, AttributeValue]) -> None: - """Update the cached attributes for a single worker after registration.""" - with self._attr_cache_lock: - if self._attr_cache is None: - return - self._attr_cache[worker_id] = attrs - - def remove_worker_from_attr_cache(self, worker_id: WorkerId) -> None: - """Remove a single worker from the attribute cache.""" - with self._attr_cache_lock: - if self._attr_cache is None: - return - self._attr_cache.pop(worker_id, None) def _init_read_pool(self) -> None: """Create (or recreate) the read-only connection pool.""" @@ -454,7 +276,7 @@ def transaction(self): On successful commit, any hooks registered via ``TransactionCursor.on_commit`` fire while the write lock is still held — keeping in-memory caches - (e.g. ``EndpointRegistry``) in sync with the DB without exposing a + (e.g. ``EndpointStore``) in sync with the DB without exposing a torn snapshot to concurrent readers. """ with self._lock: @@ -645,6 +467,20 @@ def next_sequence(self, key: str, *, cur: TransactionCursor) -> int: cur.execute("UPDATE meta SET value = ? WHERE key = ?", (value, key)) return value + def get_counter(self, key: str, cur: TransactionCursor) -> int: + """Read an integer counter from meta. Returns 0 if unset.""" + row = cur.execute("SELECT value FROM meta WHERE key = ?", (key,)).fetchone() + if row is None: + return 0 + return int(row[0]) + + def set_counter(self, key: str, value: int, cur: TransactionCursor) -> None: + """Write an integer counter to meta inside the given transaction.""" + cur.execute( + "INSERT INTO meta(key, value) VALUES (?, ?) " "ON CONFLICT(key) DO UPDATE SET value = excluded.value", + (key, value), + ) + def backup_to(self, destination: Path) -> None: """Create a hot backup to ``destination`` using SQLite backup API. @@ -751,7 +587,6 @@ def replace_from(self, source_dir: str | Path) -> None: self._conn.execute("ATTACH DATABASE ? AS profiles", (str(self._profiles_db_path),)) self._init_read_pool() self.apply_migrations() - self._endpoint_registry._load_all() # SQL-canonical read access is exposed through ``snapshot()`` and typed table # metadata at module scope. Legacy list/get/count helper methods were removed @@ -804,170 +639,23 @@ def get_all_user_budget_limits(self) -> dict[str, int]: return {row.user_id: row.budget_limit for row in rows} -# --------------------------------------------------------------------------- -# Shared read-only query helpers -# -# Pure DB reads that are used by both controller.py and service.py. -# Each takes a ControllerDB and returns domain objects. -# --------------------------------------------------------------------------- - +def batch_delete( + db: ControllerDB, + sql: str, + params: tuple[object, ...], + stopped: Callable[[], bool], + pause_between_s: float, +) -> int: + """Delete rows in batches, sleeping between transactions. -def running_tasks_by_worker(db: ControllerDB, worker_ids: set[WorkerId]) -> dict[WorkerId, set[JobName]]: - """Return the set of currently-running task IDs for each worker. - - Uses the denormalized current_worker_id column instead of joining task_attempts. + Returns the total number of rows deleted. """ - if not worker_ids: - return {} - placeholders = ",".join("?" for _ in worker_ids) - with db.read_snapshot() as q: - rows = q.raw( - f"SELECT t.current_worker_id AS worker_id, t.task_id FROM tasks t " - f"WHERE t.current_worker_id IN ({placeholders}) AND t.state IN (?, ?, ?)", - (*[str(wid) for wid in worker_ids], *ACTIVE_TASK_STATES), - decoders={"worker_id": decode_worker_id, "task_id": JobName.from_wire}, - ) - running: dict[WorkerId, set[JobName]] = {wid: set() for wid in worker_ids} - for row in rows: - running[row.worker_id].add(row.task_id) - return running - - -@dataclass(frozen=True, slots=True) -class TimedOutTask: - """A running task that has exceeded its execution timeout.""" - - task_id: JobName - worker_id: WorkerId | None - - -def timed_out_executing_tasks(db: ControllerDB, now: Timestamp) -> list[TimedOutTask]: - """Find executing tasks whose current attempt has exceeded the job's execution timeout. - - Reads the timeout from job_config.timeout_ms. Uses the current attempt's - started_at_ms so that retried tasks get a fresh timeout budget per attempt. - """ - now_ms = now.epoch_ms() - executing_states = tuple(sorted(EXECUTING_TASK_STATES)) - placeholders = ",".join("?" for _ in executing_states) - with db.read_snapshot() as q: - rows = q.raw( - f"SELECT t.task_id, t.current_worker_id AS worker_id, " - f"ta.started_at_ms AS attempt_started_at_ms, jc.timeout_ms " - f"FROM tasks t " - f"JOIN job_config jc ON jc.job_id = t.job_id " - f"JOIN task_attempts ta ON ta.task_id = t.task_id AND ta.attempt_id = t.current_attempt_id " - f"WHERE t.state IN ({placeholders}) " - f"AND jc.timeout_ms IS NOT NULL AND jc.timeout_ms > 0 " - f"AND ta.started_at_ms IS NOT NULL", - (*executing_states,), - decoders={ - "task_id": JobName.from_wire, - "worker_id": lambda v: WorkerId(v) if v is not None else None, - "attempt_started_at_ms": int, - "timeout_ms": int, - }, - ) - result: list[TimedOutTask] = [] - for row in rows: - if row.attempt_started_at_ms + row.timeout_ms <= now_ms: - result.append(TimedOutTask(task_id=row.task_id, worker_id=row.worker_id)) - return result - - -def tasks_for_job_with_attempts(db: ControllerDB, job_id: JobName) -> list: - """Fetch all tasks for a job with their attempt history.""" - from iris.cluster.controller.schema import ATTEMPT_PROJECTION, TASK_DETAIL_PROJECTION, tasks_with_attempts - - with db.read_snapshot() as q: - tasks = TASK_DETAIL_PROJECTION.decode( - q.fetchall( - "SELECT * FROM tasks WHERE job_id = ? ORDER BY task_index, task_id", - (job_id.to_wire(),), - ), - ) - if not tasks: - return [] - placeholders = ",".join("?" for _ in tasks) - attempts = ATTEMPT_PROJECTION.decode( - q.fetchall( - f"SELECT * FROM task_attempts WHERE task_id IN ({placeholders}) ORDER BY task_id, attempt_id", - tuple(t.task_id.to_wire() for t in tasks), - ), - ) - return tasks_with_attempts(tasks, attempts) - - -def _worker_row_select() -> str: - """Lazily resolve WORKER_ROW_PROJECTION.select_clause() to break the db -> schema cycle.""" - from iris.cluster.controller.schema import WORKER_ROW_PROJECTION - - return WORKER_ROW_PROJECTION.select_clause() - - -def healthy_active_workers_with_attributes(db: ControllerDB) -> list: - """Fetch all healthy, active workers with their attributes populated. - - Returns WorkerRow (scalar-only) so the scheduling loop avoids loading metadata columns. - Uses the in-memory attribute cache to avoid a per-cycle SQL join. - """ - from iris.cluster.controller.schema import WORKER_ROW_PROJECTION - - with db.read_snapshot() as q: - workers = WORKER_ROW_PROJECTION.decode( - q.fetchall(f"SELECT {_worker_row_select()} FROM workers w WHERE w.healthy = 1 AND w.active = 1"), - ) - if not workers: - return [] - attrs_by_worker = db.get_worker_attributes() - return [ - dc_replace( - w, - attributes=attrs_by_worker.get(w.worker_id, {}), - available_cpu_millicores=w.total_cpu_millicores - w.committed_cpu_millicores, - available_memory=w.total_memory_bytes - w.committed_mem, - available_gpus=w.total_gpu_count - w.committed_gpu, - available_tpus=w.total_tpu_count - w.committed_tpu, - ) - for w in workers - ] - - -def insert_task_profile( - db: ControllerDB, task_id: str, profile_data: bytes, captured_at: Timestamp, profile_kind: str = "cpu" -) -> None: - """Insert a captured profile snapshot for a task. - - The DB trigger caps profiles at 10 per (task_id, profile_kind), evicting the oldest automatically. - """ - db.execute( - "INSERT INTO profiles.task_profiles (task_id, profile_data, captured_at_ms, profile_kind) VALUES (?, ?, ?, ?)", - (task_id, profile_data, captured_at.epoch_ms(), profile_kind), - ) - - -def get_task_profiles( - db: ControllerDB, task_id: str, profile_kind: str | None = None -) -> list[tuple[bytes, Timestamp, str]]: - """Return stored profile snapshots for a task, newest first. - - Args: - db: Controller database. - task_id: Task wire string. - profile_kind: If set, filter to this kind (e.g. "cpu", "memory"). Returns all kinds when None. - """ - if profile_kind is not None: - query = ( - "SELECT profile_data, captured_at_ms, profile_kind FROM profiles.task_profiles" - " WHERE task_id = ? AND profile_kind = ? ORDER BY id DESC" - ) - params: tuple[str, ...] = (task_id, profile_kind) - else: - query = ( - "SELECT profile_data, captured_at_ms, profile_kind FROM profiles.task_profiles" - " WHERE task_id = ? ORDER BY id DESC" - ) - params = (task_id,) - with db.read_snapshot() as q: - rows = q.raw(query, params, decoders={"captured_at_ms": decode_timestamp_ms}) - return [(row.profile_data, row.captured_at_ms, row.profile_kind) for row in rows] + total = 0 + while not stopped(): + with db.transaction() as cur: + batch = cur.execute(sql, params).rowcount + if batch == 0: + break + total += batch + time.sleep(pause_between_s) + return total diff --git a/lib/iris/src/iris/cluster/controller/endpoint_registry.py b/lib/iris/src/iris/cluster/controller/endpoint_registry.py deleted file mode 100644 index 095126f235..0000000000 --- a/lib/iris/src/iris/cluster/controller/endpoint_registry.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Process-local in-memory cache for the ``endpoints`` table. - -Profiling showed that ``ListEndpoints`` dominated controller CPU — not because -the SQL was slow per se, but because every call serialized through the -read-connection pool and walked a large WAL to build a snapshot. The endpoints -table is tiny (hundreds of rows) and only changes on explicit register / -unregister, so it is a natural fit for a write-through in-memory cache. - -Design invariants: - -* Reads never touch the DB. All lookups are served from in-memory maps - guarded by an ``RLock`` — readers observe a consistent snapshot of the - indexes, never a torn state mid-update. -* Writes execute the SQL inside the caller's transaction. The in-memory - update is scheduled as a post-commit hook on the cursor so memory only - changes after the DB has committed. If the transaction rolls back, the - hook never fires. -* N is small enough (≈ hundreds) that linear scans for prefix / task / id - lookups are simpler and plenty fast. Extra indexes (by name, by task_id) - speed the two common cases. - -The registry is the sole source of truth for endpoint reads; nothing else in -the controller tree should SELECT from ``endpoints``. -""" - -from __future__ import annotations - -import json -import logging -from collections.abc import Iterable, Sequence -from threading import RLock - -from iris.cluster.controller.db import EndpointQuery, TransactionCursor -from iris.cluster.controller.schema import ENDPOINT_PROJECTION, EndpointRow -from iris.cluster.types import TERMINAL_TASK_STATES, JobName - -logger = logging.getLogger(__name__) - - -class EndpointRegistry: - """In-memory index of endpoint rows, kept in sync with the DB. - - Construct with a ``ControllerDB``; the registry loads all existing rows at - init time. Callers mutate through ``add`` / ``remove*`` methods that take - the open ``TransactionCursor`` so the SQL lands inside the caller's - transaction. Memory is only updated after a successful commit via a - cursor post-commit hook. - """ - - def __init__(self, db): - self._db = db - self._lock = RLock() - self._by_id: dict[str, EndpointRow] = {} - # One name can map to multiple endpoint_ids — the schema does not enforce - # uniqueness on ``name``, and ``INSERT OR REPLACE`` keys off endpoint_id. - self._by_name: dict[str, set[str]] = {} - self._by_task: dict[JobName, set[str]] = {} - self._load_all() - - # -- Loading -------------------------------------------------------------- - - def _load_all(self) -> None: - with self._db.read_snapshot() as q: - rows = ENDPOINT_PROJECTION.decode( - q.fetchall(f"SELECT {ENDPOINT_PROJECTION.select_clause()} FROM endpoints e"), - ) - with self._lock: - self._by_id.clear() - self._by_name.clear() - self._by_task.clear() - for row in rows: - self._index(row) - logger.info("EndpointRegistry loaded %d endpoint(s) from DB", len(rows)) - - def _index(self, row: EndpointRow) -> None: - self._by_id[row.endpoint_id] = row - self._by_name.setdefault(row.name, set()).add(row.endpoint_id) - self._by_task.setdefault(row.task_id, set()).add(row.endpoint_id) - - def _unindex(self, endpoint_id: str) -> EndpointRow | None: - row = self._by_id.pop(endpoint_id, None) - if row is None: - return None - name_ids = self._by_name.get(row.name) - if name_ids is not None: - name_ids.discard(endpoint_id) - if not name_ids: - self._by_name.pop(row.name, None) - task_ids = self._by_task.get(row.task_id) - if task_ids is not None: - task_ids.discard(endpoint_id) - if not task_ids: - self._by_task.pop(row.task_id, None) - return row - - # -- Reads ---------------------------------------------------------------- - - def query(self, query: EndpointQuery = EndpointQuery()) -> list[EndpointRow]: - """Return endpoint rows matching ``query``. - - All filters AND together, matching the semantics of the original SQL - in :func:`iris.cluster.controller.db.endpoint_query_sql`. - """ - with self._lock: - # Narrow the candidate set using the most selective index available. - if query.endpoint_ids: - candidates: Iterable[EndpointRow] = ( - self._by_id[eid] for eid in query.endpoint_ids if eid in self._by_id - ) - elif query.task_ids: - task_set = set(query.task_ids) - candidates = (self._by_id[eid] for task_id in task_set for eid in self._by_task.get(task_id, ())) - elif query.exact_name is not None: - candidates = (self._by_id[eid] for eid in self._by_name.get(query.exact_name, ())) - else: - candidates = self._by_id.values() - - results: list[EndpointRow] = [] - for row in candidates: - if query.name_prefix is not None and not row.name.startswith(query.name_prefix): - continue - if query.exact_name is not None and row.name != query.exact_name: - continue - if query.task_ids and row.task_id not in query.task_ids: - continue - if query.endpoint_ids and row.endpoint_id not in query.endpoint_ids: - continue - results.append(row) - if query.limit is not None and len(results) >= query.limit: - break - return results - - def resolve(self, name: str) -> EndpointRow | None: - """Return any endpoint with exact ``name``, or None. Used by the actor proxy.""" - with self._lock: - ids = self._by_name.get(name) - if not ids: - return None - # Arbitrary but stable pick — the original SQL did not specify ORDER BY. - return self._by_id[next(iter(ids))] - - def get(self, endpoint_id: str) -> EndpointRow | None: - with self._lock: - return self._by_id.get(endpoint_id) - - def all(self) -> list[EndpointRow]: - with self._lock: - return list(self._by_id.values()) - - # -- Writes --------------------------------------------------------------- - - def add(self, cur: TransactionCursor, endpoint: EndpointRow) -> bool: - """Insert ``endpoint`` into the DB and schedule the memory update. - - Returns False (and writes nothing) if the owning task is already - terminal. Otherwise inserts / replaces and schedules a post-commit - hook that updates the in-memory indexes. - """ - task_id = endpoint.task_id - job_id, _ = task_id.require_task() - row = cur.execute("SELECT state FROM tasks WHERE task_id = ?", (task_id.to_wire(),)).fetchone() - if row is not None and int(row["state"]) in TERMINAL_TASK_STATES: - return False - - cur.execute( - "INSERT OR REPLACE INTO endpoints(" - "endpoint_id, name, address, job_id, task_id, metadata_json, registered_at_ms" - ") VALUES (?, ?, ?, ?, ?, ?, ?)", - ( - endpoint.endpoint_id, - endpoint.name, - endpoint.address, - job_id.to_wire(), - task_id.to_wire(), - json.dumps(endpoint.metadata), - endpoint.registered_at.epoch_ms(), - ), - ) - - def apply() -> None: - with self._lock: - # Replace: drop any previous row with this id first so the - # name/task indexes stay consistent on overwrite. - self._unindex(endpoint.endpoint_id) - self._index(endpoint) - - cur.on_commit(apply) - return True - - def remove(self, cur: TransactionCursor, endpoint_id: str) -> EndpointRow | None: - """Remove a single endpoint by id. Returns the removed row snapshot, if any.""" - existing = self.get(endpoint_id) - if existing is None: - return None - cur.execute("DELETE FROM endpoints WHERE endpoint_id = ?", (endpoint_id,)) - - def apply() -> None: - with self._lock: - self._unindex(endpoint_id) - - cur.on_commit(apply) - return existing - - def remove_by_task(self, cur: TransactionCursor, task_id: JobName) -> list[str]: - """Remove all endpoints owned by a task. Returns the removed endpoint_ids.""" - with self._lock: - ids = list(self._by_task.get(task_id, ())) - if not ids: - # Still issue the DELETE to stay consistent with any rows the - # registry might not have observed yet (belt-and-suspenders for - # the unlikely race of an in-flight concurrent writer). This - # costs nothing on the common path. - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) - return [] - cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) - - def apply() -> None: - with self._lock: - for eid in ids: - self._unindex(eid) - - cur.on_commit(apply) - return ids - - def remove_by_job_ids(self, cur: TransactionCursor, job_ids: Sequence[JobName]) -> list[str]: - """Remove all endpoints owned by any of ``job_ids``. Used by cancel_job and prune.""" - if not job_ids: - return [] - wire_ids = [jid.to_wire() for jid in job_ids] - with self._lock: - to_remove: list[str] = [] - for row in self._by_id.values(): - owning_job, _ = row.task_id.require_task() - if owning_job.to_wire() in wire_ids: - to_remove.append(row.endpoint_id) - placeholders = ",".join("?" for _ in wire_ids) - cur.execute( - f"DELETE FROM endpoints WHERE job_id IN ({placeholders})", - tuple(wire_ids), - ) - if not to_remove: - return [] - - def apply() -> None: - with self._lock: - for eid in to_remove: - self._unindex(eid) - - cur.on_commit(apply) - return to_remove diff --git a/lib/iris/src/iris/cluster/controller/query.py b/lib/iris/src/iris/cluster/controller/query.py index 893d437032..8619c72d13 100644 --- a/lib/iris/src/iris/cluster/controller/query.py +++ b/lib/iris/src/iris/cluster/controller/query.py @@ -54,7 +54,7 @@ def execute_raw_query( raise ValueError(f"Forbidden SQL keyword: {keyword}") with db.read_snapshot() as q: - cursor = q.execute_sql(stripped) + cursor = q.execute(stripped) col_descriptions = cursor.description raw_rows = cursor.fetchall() diff --git a/lib/iris/src/iris/cluster/controller/schema.py b/lib/iris/src/iris/cluster/controller/schema.py index 617e609220..ed674548bf 100644 --- a/lib/iris/src/iris/cluster/controller/schema.py +++ b/lib/iris/src/iris/cluster/controller/schema.py @@ -12,17 +12,36 @@ import dataclasses import json import sqlite3 +import typing from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass from threading import Lock from typing import Any, Generic, TypeVar from iris.cluster.types import JobName, WorkerId +from iris.rpc import job_pb2 from rigging.timing import Timestamp T = TypeVar("T") RowDecoder = Callable[[sqlite3.Row], Any] +# Task states that indicate the task is actively running or assigned to a worker. +ACTIVE_TASK_STATES: frozenset[int] = frozenset( + { + job_pb2.TASK_STATE_ASSIGNED, + job_pb2.TASK_STATE_BUILDING, + job_pb2.TASK_STATE_RUNNING, + } +) + +# Tasks executing on a worker (subset of ACTIVE_TASK_STATES that excludes ASSIGNED). +EXECUTING_TASK_STATES: frozenset[int] = frozenset( + { + job_pb2.TASK_STATE_BUILDING, + job_pb2.TASK_STATE_RUNNING, + } +) + # --------------------------------------------------------------------------- # Decoder functions — canonical home for column-level decoders used by @@ -208,6 +227,7 @@ def projection( *column_names: str, extra_fields: tuple[ExtraField, ...] = (), row_cls: type | None = None, + post_decode: Callable[[dict[str, Any]], dict[str, Any]] | None = None, ) -> Projection: """Create a typed projection over a subset of columns. @@ -228,7 +248,7 @@ def projection( f"Unknown column {cn!r} in table {self.name!r}. " f"Available: {sorted(self._column_map.keys())}" ) cols.append(self._column_map[cn]) - return Projection(self, tuple(cols), extra_fields=extra_fields, row_cls=row_cls) + return Projection(self, tuple(cols), extra_fields=extra_fields, row_cls=row_cls, post_decode=post_decode) def select_clause(self, *column_names: str, prefix: bool = True) -> str: """Generate 'alias.col1, alias.col2, ...' for a SELECT. @@ -327,6 +347,7 @@ def __init__( extra_fields: tuple[ExtraField, ...] = (), row_cls: type | None = None, column_aliases: tuple[str, ...] | None = None, + post_decode: Callable[[dict[str, Any]], dict[str, Any]] | None = None, ): self.table = table self.columns = columns @@ -348,8 +369,13 @@ def __init__( self._required_columns: tuple[str, ...] = tuple(c.name for c in columns if c.name not in defaults) + self._post_decode = post_decode + if row_cls is not None: - _validate_row_cls(row_cls, columns, extra_fields) + # Skip column-level validation when post_decode is provided: the transform + # may combine or rename columns, so the 1:1 field-name check does not apply. + if post_decode is None: + _validate_row_cls(row_cls, columns, extra_fields) self._row_cls: type = row_cls else: self._row_cls = _make_row_class(f"{table.name}_projection", columns, extra_fields) @@ -419,11 +445,18 @@ def decode(self, rows: Iterable[sqlite3.Row]) -> list: first_keys = set(first.keys()) all_present = all(col in first_keys for col in columns) + post_decode = self._post_decode + if all_present: # All columns present -- tight loop, no per-row key checks. - result.append(cls(**{name: decoder(first[col]) for name, col, decoder in zipped})) - for row in it: - result.append(cls(**{name: decoder(row[col]) for name, col, decoder in zipped})) + if post_decode is None: + result.append(cls(**{name: decoder(first[col]) for name, col, decoder in zipped})) + for row in it: + result.append(cls(**{name: decoder(row[col]) for name, col, decoder in zipped})) + else: + result.append(cls(**post_decode({name: decoder(first[col]) for name, col, decoder in zipped}))) + for row in it: + result.append(cls(**post_decode({name: decoder(row[col]) for name, col, decoder in zipped}))) else: # Some columns missing -- use default-filling path for every row. result.append(self._decode_row(first)) @@ -452,6 +485,8 @@ def _decode_row(self, row: sqlite3.Row) -> Any: else: field_name, default_val, is_factory = self._defaults[col] values[field_name] = default_val() if is_factory else default_val + if self._post_decode is not None: + values = self._post_decode(values) return self._row_cls(**values) @@ -469,6 +504,197 @@ def adhoc_projection(*fields: tuple[str, type]) -> Projection: return Projection(table, columns) +# --------------------------------------------------------------------------- +# @projection decorator — generates Projection from a dataclass definition +# --------------------------------------------------------------------------- + + +_PCOLUMN_KEY = "_iris_pcolumn_spec" + + +@dataclass(frozen=True) +class _PColumnSpec: + """Per-field configuration attached via ``pcolumn`` field metadata.""" + + column: str | None = None # db column name override + prefix: str | None = None # nested-dataclass flatten prefix + nullable: bool = False # nested nullable: None when first required col is NULL + computed: bool = False # non-DB field (populated post-hoc) + + +def pcolumn( + *, + column: str | None = None, + prefix: str | None = None, + nullable: bool = False, + computed: bool = False, + default: Any = _MISSING, + default_factory: Any = _MISSING, +) -> Any: + """Attach column-mapping metadata to a dataclass field. + + Use inside a class decorated with ``@projection(...)``. The returned object + is a ``dataclasses.field`` with extra metadata the decorator reads at class + creation time. + + Args: + column: Explicit DB column name. Defaults to the field's own name. + Use when the dataclass field name should differ from the column, + e.g. ``user_id: str = pcolumn(column="id")`` reads ``users.id``. + prefix: Flatten a nested dataclass across multiple columns sharing this + prefix. The projection looks up ``{prefix}{subfield}`` for each + field of the nested type. Example: + ``resources: ResourceSpec = pcolumn(prefix="res_")`` with + ResourceSpec fields ``cpu, mem`` maps to columns ``res_cpu``, + ``res_mem`` and re-packs them into a ``ResourceSpec`` on decode. + nullable: For ``prefix=`` fields only. When the first required sub-column + is NULL, yield ``None`` for the whole nested value instead of + building a partial dataclass. Example: a LEFT JOIN where the nested + record may be absent. + computed: Non-DB field, populated post-decode by caller. The projection + SELECTs nothing for it; the field gets its ``default`` / + ``default_factory`` and the caller fills it in later. Use sparingly: + once a projection relies on ``computed`` fields, callers must know to + fill them. Example: ``attempts: tuple[Attempt, ...] = pcolumn( + computed=True, default_factory=tuple)``. + default / default_factory: Passthrough to ``dataclasses.field``. When + ``default=`` is used the field is filled with that value if the + column is absent from the SELECT (useful for partial projections). + """ + spec = _PColumnSpec(column=column, prefix=prefix, nullable=nullable, computed=computed) + metadata = {_PCOLUMN_KEY: spec} + if default_factory is not _MISSING: + return dataclasses.field(default_factory=default_factory, metadata=metadata) + if default is not _MISSING: + return dataclasses.field(default=default, metadata=metadata) + return dataclasses.field(metadata=metadata) + + +def _spec_for_field(f: dataclasses.Field) -> _PColumnSpec: + return f.metadata.get(_PCOLUMN_KEY, _PColumnSpec()) + + +def _lookup_column(name: str, tables: Sequence[Table]) -> tuple[Column, str]: + """Look up a column by name across tables, returning (column, table_alias).""" + for t in tables: + col = t._column_map.get(name) + if col is not None: + return col, t.alias + searched = ", ".join(t.name for t in tables) + raise KeyError(f"Column {name!r} not found in any of: {searched}") + + +def projection(primary: Table, *, extra_tables: Sequence[Table] = ()) -> Callable[[type], type]: + """Class decorator that generates a ``Projection`` from a dataclass definition. + + Each dataclass field maps to a DB column of the same name in the primary + table (or any table listed in ``extra_tables``). Use ``pcolumn`` to override + the column name, attach a prefix for nested-dataclass flattening, or mark a + field as ``computed=True`` (populated post-decode, not read from the DB). + + The generated ``Projection`` is attached as ``cls.PROJECTION``. The class + itself is returned unchanged; all validation happens at decoration time. + """ + tables: tuple[Table, ...] = (primary, *tuple(extra_tables)) + + def decorate(cls: type) -> type: + if not dataclasses.is_dataclass(cls): + raise TypeError(f"@projection requires a dataclass, got {cls!r}") + + # ``from __future__ import annotations`` stores field.type as strings; resolve. + resolved_hints = typing.get_type_hints(cls, globalns=globals(), localns=None) + + columns: list[Column] = [] + aliases: list[str] = [] + extra_fields: list[ExtraField] = [] + # Per-prefix nested-group metadata: (field_name, nested_cls, prefix, nullable, sub_names) + nested_groups: list[tuple[str, type, str, bool, tuple[str, ...]]] = [] + # Fields that are scalar DB columns (not nested, not extra). + # Used below only for validation; no per-field bookkeeping needed. + + for f in dataclasses.fields(cls): + spec = _spec_for_field(f) + if spec.computed: + ef_default = f.default if f.default is not dataclasses.MISSING else _MISSING + ef_factory = ( + f.default_factory if f.default_factory is not dataclasses.MISSING else None # type: ignore[misc] + ) + extra_fields.append( + ExtraField( + name=f.name, + python_type=f.type if isinstance(f.type, type) else object, + default=ef_default, + default_factory=ef_factory, + ) + ) + continue + + if spec.prefix is not None: + nested_cls = resolved_hints.get(f.name, f.type) + # typing.get_type_hints may return typing.Optional[X] etc. For + # a nullable nested group we expect ``Nested | None`` — pick the + # non-None arg. For plain cases it is the dataclass itself. + origin = typing.get_origin(nested_cls) + if origin is not None: + args = [a for a in typing.get_args(nested_cls) if a is not type(None)] + if len(args) == 1: + nested_cls = args[0] + if not (isinstance(nested_cls, type) and dataclasses.is_dataclass(nested_cls)): + raise TypeError( + f"{cls.__name__}.{f.name}: pcolumn(prefix=...) requires a dataclass annotation, " + f"got {nested_cls!r}" + ) + sub_field_names: list[str] = [] + for sub in dataclasses.fields(nested_cls): + db_name = f"{spec.prefix}{sub.name}" + col, alias = _lookup_column(db_name, tables) + columns.append(col) + aliases.append(alias) + sub_field_names.append(sub.name) + nested_groups.append((f.name, nested_cls, spec.prefix, spec.nullable, tuple(sub_field_names))) + continue + + # Scalar DB column. + db_name = spec.column if spec.column is not None else f.name + col, alias = _lookup_column(db_name, tables) + columns.append(col) + aliases.append(alias) + + column_aliases: tuple[str, ...] | None + if len(tables) > 1: + column_aliases = tuple(aliases) + else: + column_aliases = None + + post_decode: Callable[[dict[str, Any]], dict[str, Any]] | None = None + if nested_groups: + + def post_decode(kw: dict[str, Any]) -> dict[str, Any]: + # Columns land in kw under their python field_name (set by + # Projection). For prefixed columns we assume the Column has no + # python_name override, so the key equals the raw db column name. + for field_name, nested_cls, prefix, nullable, sub_names in nested_groups: + sub_values = {sn: kw.pop(f"{prefix}{sn}") for sn in sub_names} + if nullable and sub_values[sub_names[0]] is None: + kw[field_name] = None + else: + kw[field_name] = nested_cls(**sub_values) + return kw + + proj: Projection = Projection( + primary, + tuple(columns), + extra_fields=tuple(extra_fields), + row_cls=cls, + column_aliases=column_aliases, + post_decode=post_decode, + ) + cls.PROJECTION = proj # type: ignore[attr-defined] + return cls + + return decorate + + def generate_full_ddl(tables: Sequence[Table]) -> str: """Concatenate all table DDLs into a single SQL script.""" return "\n\n".join(t.ddl() for t in tables) @@ -573,13 +799,14 @@ def generate_full_ddl(tables: Sequence[Table]) -> str: Column("scheduling_deadline_epoch_ms", "INTEGER", "", python_type=int | None, decoder=_nullable(int)), Column("error", "TEXT", "", python_type=str | None, decoder=_nullable(str)), Column("exit_code", "INTEGER", "", python_type=int | None, decoder=_nullable(int)), - Column("num_tasks", "INTEGER", "NOT NULL", python_type=int, decoder=int), + Column("num_tasks", "INTEGER", "NOT NULL", python_type=int, decoder=int, default=None), Column( "is_reservation_holder", "INTEGER", "NOT NULL CHECK (is_reservation_holder IN (0, 1))", python_type=bool, decoder=_decode_bool_int, + default=None, ), # Kept on jobs (not just job_config) for fast listing/filtering without JOIN. Column("name", "TEXT", "NOT NULL DEFAULT ''", python_type=str, decoder=str, default=""), @@ -623,9 +850,9 @@ def generate_full_ddl(tables: Sequence[Table]) -> str: "has_reservation", "INTEGER", "NOT NULL DEFAULT 0", python_type=bool, decoder=_decode_bool_int, default=False ), # Resource spec (was resources_proto) - Column("res_cpu_millicores", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=0), - Column("res_memory_bytes", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=0), - Column("res_disk_bytes", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=0), + Column("res_cpu_millicores", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=None), + Column("res_memory_bytes", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=None), + Column("res_disk_bytes", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=None), Column("res_device_json", "TEXT", "", python_type=str | None, decoder=_nullable(str), default=None), # Constraints (was constraints_proto) Column("constraints_json", "TEXT", "", python_type=str | None, decoder=_nullable(str), default=None), @@ -636,7 +863,7 @@ def generate_full_ddl(tables: Sequence[Table]) -> str: "NOT NULL DEFAULT 0", python_type=bool, decoder=_decode_bool_int, - default=False, + default=None, ), Column("coscheduling_group_by", "TEXT", "NOT NULL DEFAULT ''", python_type=str, decoder=str, default=""), # Scheduling config @@ -1396,105 +1623,121 @@ def generate_full_ddl(tables: Sequence[Table]) -> str: @dataclass(frozen=True, slots=True) -class JobRow: - """Lightweight job row for listings (jobs JOIN job_config, excludes constraints).""" +class ResourceSpec: + """Normalized resource requirements for a job, derived from job_config columns.""" - job_id: JobName - state: int - submitted_at: Timestamp - root_submitted_at: Timestamp - started_at: Timestamp | None - finished_at: Timestamp | None - scheduling_deadline_epoch_ms: int | None - error: str | None - exit_code: int | None - num_tasks: int - is_reservation_holder: bool - has_reservation: bool - name: str - depth: int - res_cpu_millicores: int - res_memory_bytes: int - res_disk_bytes: int - res_device_json: str | None - has_coscheduling: bool - coscheduling_group_by: str - scheduling_timeout_ms: int | None - max_task_failures: int + cpu_millicores: int + memory_bytes: int + disk_bytes: int + device_json: str | None = None +@projection(JOB_CONFIG) @dataclass(frozen=True, slots=True) -class JobSchedulingRow: - """Full job row for scheduling — adds constraints over JobRow.""" +class JobConfigRow: + """Row from the job_config table, returned by JobStore.get_config.""" job_id: JobName - state: int - submitted_at: Timestamp - root_submitted_at: Timestamp - started_at: Timestamp | None - finished_at: Timestamp | None - scheduling_deadline_epoch_ms: int | None - error: str | None - exit_code: int | None - num_tasks: int - is_reservation_holder: bool - has_reservation: bool name: str - depth: int - res_cpu_millicores: int - res_memory_bytes: int - res_disk_bytes: int - res_device_json: str | None - constraints_json: str | None - has_coscheduling: bool - coscheduling_group_by: str - scheduling_timeout_ms: int | None - max_task_failures: int + has_reservation: bool + resources: ResourceSpec = pcolumn(prefix="res_") + constraints_json: str | None = pcolumn() + has_coscheduling: bool = pcolumn() + coscheduling_group_by: str = pcolumn() + scheduling_timeout_ms: int | None = pcolumn() + max_task_failures: int = pcolumn() + entrypoint_json: str = pcolumn() + environment_json: str = pcolumn() + bundle_id: str = pcolumn() + ports_json: str = pcolumn() + max_retries_failure: int = pcolumn() + max_retries_preemption: int = pcolumn() + timeout_ms: int | None = pcolumn() + preemption_policy: int = pcolumn() + existing_job_policy: int = pcolumn() + priority_band: int = pcolumn() + task_image: str = pcolumn() + submit_argv_json: str = pcolumn() + reservation_json: str | None = pcolumn() + fail_if_exists: bool = pcolumn() @dataclass(frozen=True, slots=True) -class JobDetailRow: - """Full job detail — superset of JobSchedulingRow, adds dispatch config from job_config.""" +class WorkerActiveRow: + """Minimal worker row for heartbeat-failure handling.""" + + consecutive_failures: int + last_heartbeat_ms: int | None + + +@projection(JOBS, extra_tables=(JOB_CONFIG,)) +@dataclass(frozen=True, slots=True) +class JobRow: + """Lightweight job row for listings (jobs JOIN job_config, excludes constraints).""" job_id: JobName state: int - submitted_at: Timestamp - root_submitted_at: Timestamp - started_at: Timestamp | None - finished_at: Timestamp | None - scheduling_deadline_epoch_ms: int | None - error: str | None - exit_code: int | None - num_tasks: int - is_reservation_holder: bool - has_reservation: bool - name: str - depth: int - res_cpu_millicores: int - res_memory_bytes: int - res_disk_bytes: int - res_device_json: str | None - constraints_json: str | None - has_coscheduling: bool - coscheduling_group_by: str - scheduling_timeout_ms: int | None - max_task_failures: int - entrypoint_json: str - environment_json: str - bundle_id: str - ports_json: str - max_retries_failure: int - max_retries_preemption: int - timeout_ms: int | None - preemption_policy: int - existing_job_policy: int - priority_band: int - task_image: str - submit_argv_json: str - reservation_json: str | None - fail_if_exists: bool - + submitted_at: Timestamp = pcolumn(column="submitted_at_ms") + root_submitted_at: Timestamp = pcolumn(column="root_submitted_at_ms") + started_at: Timestamp | None = pcolumn(column="started_at_ms") + finished_at: Timestamp | None = pcolumn(column="finished_at_ms") + scheduling_deadline_epoch_ms: int | None = pcolumn() + error: str | None = pcolumn() + exit_code: int | None = pcolumn() + num_tasks: int = pcolumn() + is_reservation_holder: bool = pcolumn() + has_reservation: bool = pcolumn() + name: str = pcolumn() + depth: int = pcolumn() + resources: ResourceSpec = pcolumn(prefix="res_") + has_coscheduling: bool = pcolumn() + coscheduling_group_by: str = pcolumn() + scheduling_timeout_ms: int | None = pcolumn() + max_task_failures: int = pcolumn() + + +@projection(JOBS, extra_tables=(JOB_CONFIG,)) +@dataclass(frozen=True, slots=True) +class JobDetailRow: + """Full job detail row — includes resources, constraints, and dispatch config.""" + job_id: JobName + state: int + submitted_at: Timestamp = pcolumn(column="submitted_at_ms") + root_submitted_at: Timestamp = pcolumn(column="root_submitted_at_ms") + started_at: Timestamp | None = pcolumn(column="started_at_ms") + finished_at: Timestamp | None = pcolumn(column="finished_at_ms") + scheduling_deadline_epoch_ms: int | None = pcolumn() + error: str | None = pcolumn() + exit_code: int | None = pcolumn() + num_tasks: int = pcolumn() + is_reservation_holder: bool = pcolumn() + has_reservation: bool = pcolumn() + name: str = pcolumn() + depth: int = pcolumn() + resources: ResourceSpec = pcolumn(prefix="res_") + constraints_json: str | None = pcolumn() + has_coscheduling: bool = pcolumn() + coscheduling_group_by: str = pcolumn() + scheduling_timeout_ms: int | None = pcolumn() + max_task_failures: int = pcolumn() + entrypoint_json: str = pcolumn() + environment_json: str = pcolumn() + bundle_id: str = pcolumn() + ports_json: str = pcolumn() + max_retries_failure: int = pcolumn() + max_retries_preemption: int = pcolumn() + timeout_ms: int | None = pcolumn() + preemption_policy: int = pcolumn() + existing_job_policy: int = pcolumn() + priority_band: int = pcolumn() + task_image: str = pcolumn() + submit_argv_json: str = pcolumn() + reservation_json: str | None = pcolumn() + fail_if_exists: bool = pcolumn() + + +@projection(TASKS) @dataclass(frozen=True, slots=True) class TaskRow: """Lightweight task row for scheduling.""" @@ -1507,13 +1750,19 @@ class TaskRow: preemption_count: int max_retries_failure: int max_retries_preemption: int - submitted_at: Timestamp + submitted_at: Timestamp = pcolumn(column="submitted_at_ms") priority_band: int = 2 +@projection(TASKS, extra_tables=(JOBS, JOB_CONFIG)) @dataclass(frozen=True, slots=True) class TaskDetailRow: - """Full task detail — superset of TaskRow, adds diagnostics and attempts.""" + """Full task detail — superset of TaskRow, adds diagnostics and attempts. + + The optional fields below are populated only when a higher projection is requested: + - WITH_JOB: is_reservation_holder, num_tasks + - WITH_JOB_CONFIG: all of the above plus res_* and timeout_ms/has_coscheduling + """ task_id: JobName job_id: JobName @@ -1523,18 +1772,26 @@ class TaskDetailRow: preemption_count: int max_retries_failure: int max_retries_preemption: int - submitted_at: Timestamp - priority_band: int - error: str | None - exit_code: int | None - started_at: Timestamp | None - finished_at: Timestamp | None - current_worker_id: WorkerId | None - current_worker_address: str | None - container_id: str | None = None - attempts: tuple = dataclasses.field(default_factory=tuple) - - + submitted_at: Timestamp = pcolumn(column="submitted_at_ms") + priority_band: int = pcolumn() + error: str | None = pcolumn() + exit_code: int | None = pcolumn() + started_at: Timestamp | None = pcolumn(column="started_at_ms") + finished_at: Timestamp | None = pcolumn(column="finished_at_ms") + current_worker_id: WorkerId | None = pcolumn() + current_worker_address: str | None = pcolumn() + container_id: str | None = pcolumn(default=None) + # Populated by WITH_JOB projection and higher. + is_reservation_holder: bool | None = pcolumn(default=None) + num_tasks: int | None = pcolumn(default=None) + # Populated by WITH_JOB_CONFIG projection. + resources: ResourceSpec | None = pcolumn(prefix="res_", nullable=True, default=None) + has_coscheduling: bool | None = pcolumn(default=None) + timeout_ms: int | None = pcolumn(default=None) + attempts: tuple = pcolumn(computed=True, default_factory=tuple) + + +@projection(WORKERS) @dataclass(frozen=True, slots=True) class WorkerRow: """Worker row for scheduling and health checks.""" @@ -1544,24 +1801,47 @@ class WorkerRow: healthy: bool active: bool consecutive_failures: int - last_heartbeat: Timestamp - committed_cpu_millicores: int - committed_mem: int - committed_gpu: int - committed_tpu: int - total_cpu_millicores: int - total_memory_bytes: int - total_gpu_count: int - total_tpu_count: int - device_type: str - device_variant: str - attributes: dict = dataclasses.field(default_factory=dict) - available_cpu_millicores: int = 0 - available_memory: int = 0 - available_gpus: int = 0 - available_tpus: int = 0 + last_heartbeat: Timestamp = pcolumn(column="last_heartbeat_ms") + committed_cpu_millicores: int = pcolumn() + committed_mem: int = pcolumn(column="committed_mem_bytes") + committed_gpu: int = pcolumn() + committed_tpu: int = pcolumn() + total_cpu_millicores: int = pcolumn() + total_memory_bytes: int = pcolumn() + total_gpu_count: int = pcolumn() + total_tpu_count: int = pcolumn() + device_type: str = pcolumn() + device_variant: str = pcolumn() + attributes: dict = pcolumn(computed=True, default_factory=dict) + available_cpu_millicores: int = pcolumn(computed=True, default=0) + available_memory: int = pcolumn(computed=True, default=0) + available_gpus: int = pcolumn(computed=True, default=0) + available_tpus: int = pcolumn(computed=True, default=0) +@dataclass(frozen=True, slots=True) +class WorkerMetadataRow: + """Worker environment metadata decoded from the flat md_* columns.""" + + hostname: str + ip_address: str + cpu_count: int + memory_bytes: int + disk_bytes: int + tpu_name: str + tpu_worker_hostnames: str + tpu_worker_id: str + tpu_chips_per_host_bounds: str + gpu_count: int + gpu_name: str + gpu_memory_mb: int + gce_instance_name: str + gce_zone: str + git_hash: str + device_json: str + + +@projection(WORKERS) @dataclass(frozen=True, slots=True) class WorkerDetailRow: """Full worker detail — superset of WorkerRow, adds metadata scalar columns.""" @@ -1571,40 +1851,26 @@ class WorkerDetailRow: healthy: bool active: bool consecutive_failures: int - last_heartbeat: Timestamp - committed_cpu_millicores: int - committed_mem: int - committed_gpu: int - committed_tpu: int - total_cpu_millicores: int - total_memory_bytes: int - total_gpu_count: int - total_tpu_count: int - device_type: str - device_variant: str - md_hostname: str - md_ip_address: str - md_cpu_count: int - md_memory_bytes: int - md_disk_bytes: int - md_tpu_name: str - md_tpu_worker_hostnames: str - md_tpu_worker_id: str - md_tpu_chips_per_host_bounds: str - md_gpu_count: int - md_gpu_name: str - md_gpu_memory_mb: int - md_gce_instance_name: str - md_gce_zone: str - md_git_hash: str - md_device_json: str - attributes: dict = dataclasses.field(default_factory=dict) - available_cpu_millicores: int = 0 - available_memory: int = 0 - available_gpus: int = 0 - available_tpus: int = 0 - - + last_heartbeat: Timestamp = pcolumn(column="last_heartbeat_ms") + committed_cpu_millicores: int = pcolumn() + committed_mem: int = pcolumn(column="committed_mem_bytes") + committed_gpu: int = pcolumn() + committed_tpu: int = pcolumn() + total_cpu_millicores: int = pcolumn() + total_memory_bytes: int = pcolumn() + total_gpu_count: int = pcolumn() + total_tpu_count: int = pcolumn() + device_type: str = pcolumn() + device_variant: str = pcolumn() + metadata: WorkerMetadataRow = pcolumn(prefix="md_") + attributes: dict = pcolumn(computed=True, default_factory=dict) + available_cpu_millicores: int = pcolumn(computed=True, default=0) + available_memory: int = pcolumn(computed=True, default=0) + available_gpus: int = pcolumn(computed=True, default=0) + available_tpus: int = pcolumn(computed=True, default=0) + + +@projection(TASK_ATTEMPTS) @dataclass(frozen=True, slots=True) class AttemptRow: """Task attempt row.""" @@ -1613,13 +1879,14 @@ class AttemptRow: attempt_id: int worker_id: WorkerId | None state: int - created_at: Timestamp - started_at: Timestamp | None - finished_at: Timestamp | None - exit_code: int | None - error: str | None + created_at: Timestamp = pcolumn(column="created_at_ms") + started_at: Timestamp | None = pcolumn(column="started_at_ms") + finished_at: Timestamp | None = pcolumn(column="finished_at_ms") + exit_code: int | None = pcolumn() + error: str | None = pcolumn() +@projection(ENDPOINTS) @dataclass(frozen=True, slots=True) class EndpointRow: """Registered service endpoint.""" @@ -1628,20 +1895,22 @@ class EndpointRow: name: str address: str task_id: JobName - metadata: dict - registered_at: Timestamp + metadata: dict = pcolumn(column="metadata_json") + registered_at: Timestamp = pcolumn(column="registered_at_ms") +@projection(TXN_ACTIONS) @dataclass(frozen=True, slots=True) class TransactionActionRow: """Transaction action log entry.""" - timestamp: Timestamp - action: str - entity_id: str - details: dict + timestamp: Timestamp = pcolumn(column="created_at_ms") + action: str = pcolumn() + entity_id: str = pcolumn() + details: dict = pcolumn(column="details_json") +@projection(AUTH_API_KEYS) @dataclass(frozen=True, slots=True) class ApiKeyRow: """API key record.""" @@ -1651,12 +1920,13 @@ class ApiKeyRow: key_prefix: str user_id: str name: str - created_at: Timestamp - last_used_at: Timestamp | None = None - expires_at: Timestamp | None = None - revoked_at: Timestamp | None = None + created_at: Timestamp = pcolumn(column="created_at_ms") + last_used_at: Timestamp | None = pcolumn(column="last_used_at_ms", default=None) + expires_at: Timestamp | None = pcolumn(column="expires_at_ms", default=None) + revoked_at: Timestamp | None = pcolumn(column="revoked_at_ms", default=None) +@projection(USER_BUDGETS) @dataclass(frozen=True, slots=True) class UserBudgetRow: """User budget record.""" @@ -1664,320 +1934,54 @@ class UserBudgetRow: user_id: str budget_limit: int max_band: int - updated_at: Timestamp + updated_at: Timestamp = pcolumn(column="updated_at_ms") # --------------------------------------------------------------------------- # Projections -- typed column subsets that replace hand-maintained column strings # --------------------------------------------------------------------------- - -def _job_columns(*names: str) -> tuple[tuple[Column, ...], tuple[str, ...]]: - """Look up Column objects from JOBS or JOB_CONFIG by name. - - Returns ``(columns, aliases)`` where each alias is the table alias - (``j`` or ``jc``) that the column belongs to. This lets Projection - generate correct ``alias.col`` qualifiers for cross-table queries. - """ - cols: list[Column] = [] - aliases: list[str] = [] - for n in names: - if n in JOBS._column_map: - cols.append(JOBS._column_map[n]) - aliases.append(JOBS.alias) - elif n in JOB_CONFIG._column_map: - cols.append(JOB_CONFIG._column_map[n]) - aliases.append(JOB_CONFIG.alias) - else: - raise KeyError(f"Unknown job column: {n!r}") - return tuple(cols), tuple(aliases) - - # SQL for job queries that need config: JOIN job_config jc ON jc.job_id = j.job_id JOB_CONFIG_JOIN = "JOIN job_config jc ON jc.job_id = j.job_id" -# Lightweight job row for listings (excludes constraints). -_job_row_cols, _job_row_aliases = _job_columns( - "job_id", - "state", - "submitted_at_ms", - "root_submitted_at_ms", - "started_at_ms", - "finished_at_ms", - "scheduling_deadline_epoch_ms", - "error", - "exit_code", - "num_tasks", - "is_reservation_holder", - "has_reservation", - "name", - "depth", - "res_cpu_millicores", - "res_memory_bytes", - "res_disk_bytes", - "res_device_json", - "has_coscheduling", - "coscheduling_group_by", - "scheduling_timeout_ms", - "max_task_failures", -) -JOB_ROW_PROJECTION = Projection( - JOBS, - _job_row_cols, - row_cls=JobRow, - column_aliases=_job_row_aliases, -) - -# Full job row for scheduling (includes constraints). -_job_sched_cols, _job_sched_aliases = _job_columns( - "job_id", - "state", - "submitted_at_ms", - "root_submitted_at_ms", - "started_at_ms", - "finished_at_ms", - "scheduling_deadline_epoch_ms", - "error", - "exit_code", - "num_tasks", - "is_reservation_holder", - "has_reservation", - "name", - "depth", - "res_cpu_millicores", - "res_memory_bytes", - "res_disk_bytes", - "res_device_json", - "constraints_json", - "has_coscheduling", - "coscheduling_group_by", - "scheduling_timeout_ms", - "max_task_failures", -) -JOB_SCHEDULING_PROJECTION = Projection( - JOBS, - _job_sched_cols, - row_cls=JobSchedulingRow, - column_aliases=_job_sched_aliases, -) +JOB_ROW_PROJECTION = JobRow.PROJECTION -# Worker row for scheduling and health checks. -WORKER_ROW_PROJECTION = WORKERS.projection( - "worker_id", - "address", - "healthy", - "active", - "consecutive_failures", - "last_heartbeat_ms", - "committed_cpu_millicores", - "committed_mem_bytes", - "committed_gpu", - "committed_tpu", - "total_cpu_millicores", - "total_memory_bytes", - "total_gpu_count", - "total_tpu_count", - "device_type", - "device_variant", - extra_fields=( - ExtraField("attributes", dict, default_factory=dict), - ExtraField("available_cpu_millicores", int, default=0), - ExtraField("available_memory", int, default=0), - ExtraField("available_gpus", int, default=0), - ExtraField("available_tpus", int, default=0), - ), - row_cls=WorkerRow, -) +WORKER_ROW_PROJECTION = WorkerRow.PROJECTION -# Task row for scheduling. -TASK_ROW_PROJECTION = TASKS.projection( - "task_id", - "job_id", - "state", - "current_attempt_id", - "failure_count", - "preemption_count", - "max_retries_failure", - "max_retries_preemption", - "submitted_at_ms", - "priority_band", - row_cls=TaskRow, -) +TASK_ROW_PROJECTION = TaskRow.PROJECTION # --------------------------------------------------------------------------- # Detail / full-entity projections # --------------------------------------------------------------------------- -# Full job detail — superset of JobSchedulingRow, adds dispatch config. -_job_detail_cols, _job_detail_aliases = _job_columns( - "job_id", - "state", - "submitted_at_ms", - "root_submitted_at_ms", - "started_at_ms", - "finished_at_ms", - "scheduling_deadline_epoch_ms", - "error", - "exit_code", - "num_tasks", - "is_reservation_holder", - "has_reservation", - "name", - "depth", - "res_cpu_millicores", - "res_memory_bytes", - "res_disk_bytes", - "res_device_json", - "constraints_json", - "has_coscheduling", - "coscheduling_group_by", - "scheduling_timeout_ms", - "max_task_failures", - "entrypoint_json", - "environment_json", - "bundle_id", - "ports_json", - "max_retries_failure", - "max_retries_preemption", - "timeout_ms", - "preemption_policy", - "existing_job_policy", - "priority_band", - "task_image", - "submit_argv_json", - "reservation_json", - "fail_if_exists", -) -JOB_DETAIL_PROJECTION = Projection( - JOBS, - _job_detail_cols, - row_cls=JobDetailRow, - column_aliases=_job_detail_aliases, -) +JOB_DETAIL_PROJECTION = JobDetailRow.PROJECTION -# Full task detail — superset of TaskRow, adds diagnostics and attempts. -TASK_DETAIL_PROJECTION = TASKS.projection( - "task_id", - "job_id", - "state", - "current_attempt_id", - "failure_count", - "preemption_count", - "max_retries_failure", - "max_retries_preemption", - "submitted_at_ms", - "priority_band", - "error", - "exit_code", - "started_at_ms", - "finished_at_ms", - "current_worker_id", - "current_worker_address", - "container_id", - extra_fields=(ExtraField("attempts", tuple, default_factory=tuple),), - row_cls=TaskDetailRow, -) +TASK_DETAIL_PROJECTION = TaskDetailRow.PROJECTION -# Full worker detail — superset of WorkerRow, adds metadata scalar columns. -WORKER_DETAIL_PROJECTION = WORKERS.projection( - "worker_id", - "address", - "healthy", - "active", - "consecutive_failures", - "last_heartbeat_ms", - "committed_cpu_millicores", - "committed_mem_bytes", - "committed_gpu", - "committed_tpu", - "total_cpu_millicores", - "total_memory_bytes", - "total_gpu_count", - "total_tpu_count", - "device_type", - "device_variant", - "md_hostname", - "md_ip_address", - "md_cpu_count", - "md_memory_bytes", - "md_disk_bytes", - "md_tpu_name", - "md_tpu_worker_hostnames", - "md_tpu_worker_id", - "md_tpu_chips_per_host_bounds", - "md_gpu_count", - "md_gpu_name", - "md_gpu_memory_mb", - "md_gce_instance_name", - "md_gce_zone", - "md_git_hash", - "md_device_json", - extra_fields=( - ExtraField("attributes", dict, default_factory=dict), - ExtraField("available_cpu_millicores", int, default=0), - ExtraField("available_memory", int, default=0), - ExtraField("available_gpus", int, default=0), - ExtraField("available_tpus", int, default=0), - ), - row_cls=WorkerDetailRow, +# SELECT clause for the TaskDetailRow columns that live on the ``tasks`` table only. +# Used by queries over tasks alone; the projection's own ``select_clause`` spans +# tasks/jobs/job_config and requires JOINs. +TASK_DETAIL_SELECT_T = ( + "t.task_id, t.job_id, t.state, t.current_attempt_id, " + "t.failure_count, t.preemption_count, t.max_retries_failure, t.max_retries_preemption, " + "t.submitted_at_ms, t.priority_band, t.error, t.exit_code, " + "t.started_at_ms, t.finished_at_ms, t.current_worker_id, t.current_worker_address, " + "t.container_id" ) -# Task attempt row. -ATTEMPT_PROJECTION = TASK_ATTEMPTS.projection( - "task_id", - "attempt_id", - "worker_id", - "state", - "created_at_ms", - "started_at_ms", - "finished_at_ms", - "exit_code", - "error", - row_cls=AttemptRow, -) +WORKER_DETAIL_PROJECTION = WorkerDetailRow.PROJECTION -# Endpoint row. -ENDPOINT_PROJECTION = ENDPOINTS.projection( - "endpoint_id", - "name", - "address", - "task_id", - "metadata_json", - "registered_at_ms", - row_cls=EndpointRow, -) +ATTEMPT_PROJECTION = AttemptRow.PROJECTION -# Transaction action row. -TXN_ACTION_PROJECTION = TXN_ACTIONS.projection( - "created_at_ms", - "action", - "entity_id", - "details_json", - row_cls=TransactionActionRow, -) +ENDPOINT_PROJECTION = EndpointRow.PROJECTION -# API key row. -API_KEY_PROJECTION = AUTH_API_KEYS.projection( - "key_id", - "key_hash", - "key_prefix", - "user_id", - "name", - "created_at_ms", - "last_used_at_ms", - "expires_at_ms", - "revoked_at_ms", - row_cls=ApiKeyRow, -) +TXN_ACTION_PROJECTION = TransactionActionRow.PROJECTION -# User budget row. -USER_BUDGET_PROJECTION = USER_BUDGETS.projection( - "user_id", - "budget_limit", - "max_band", - "updated_at_ms", - row_cls=UserBudgetRow, -) +API_KEY_PROJECTION = ApiKeyRow.PROJECTION + +USER_BUDGET_PROJECTION = UserBudgetRow.PROJECTION + +JOB_CONFIG_PROJECTION = JobConfigRow.PROJECTION # --------------------------------------------------------------------------- diff --git a/lib/iris/src/iris/cluster/controller/service.py b/lib/iris/src/iris/cluster/controller/service.py index 25de060e24..245a0a458e 100644 --- a/lib/iris/src/iris/cluster/controller/service.py +++ b/lib/iris/src/iris/cluster/controller/service.py @@ -56,23 +56,24 @@ require_identity, ) from iris.cluster.bundle import BundleStore -from iris.cluster.controller.db import ( - ACTIVE_TASK_STATES, - ControllerDB, +from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import ( + ControllerStores, EndpointQuery, TaskJobSummary, UserStats, attempt_is_worker_failure, - running_tasks_by_worker, task_row_can_be_scheduled, ) from iris.cluster.controller.schema import ( + ACTIVE_TASK_STATES, API_KEY_PROJECTION, ATTEMPT_PROJECTION, JOB_CONFIG_JOIN, JOB_DETAIL_PROJECTION, JOB_ROW_PROJECTION, TASK_DETAIL_PROJECTION, + TASK_DETAIL_SELECT_T, TASK_ROW_PROJECTION, TXN_ACTION_PROJECTION, WORKER_DETAIL_PROJECTION, @@ -288,7 +289,7 @@ def _read_task_with_attempts(db: ControllerDB, task_id: JobName) -> TaskDetailRo with db.read_snapshot() as q: task = TASK_DETAIL_PROJECTION.decode_one( q.fetchall( - f"SELECT {TASK_DETAIL_PROJECTION.select_clause()} FROM tasks t WHERE t.task_id = ?", + f"SELECT {TASK_DETAIL_SELECT_T} FROM tasks t WHERE t.task_id = ?", (task_wire,), ) ) @@ -328,10 +329,10 @@ def _worker_address(db: ControllerDB, worker_id: WorkerId) -> str | None: return str(row[0]) if row else None -def _resource_spec_from_job_row(job: Any) -> job_pb2.ResourceSpecProto: - """Reconstruct a ResourceSpecProto from native job columns.""" +def _resource_spec_from_job_row(job: JobRow | JobDetailRow) -> job_pb2.ResourceSpecProto: + """Reconstruct a ResourceSpecProto from a typed job row.""" return resource_spec_from_scalars( - job.res_cpu_millicores, job.res_memory_bytes, job.res_disk_bytes, job.res_device_json + job.resources.cpu_millicores, job.resources.memory_bytes, job.resources.disk_bytes, job.resources.device_json ) @@ -378,7 +379,9 @@ def _reconstruct_launch_job_request(job: JobDetailRow) -> controller_pb2.Control req.entrypoint.CopyFrom(proto_from_json(job.entrypoint_json, job_pb2.RuntimeEntrypoint)) req.environment.CopyFrom(proto_from_json(job.environment_json, job_pb2.EnvironmentConfig)) req.resources.CopyFrom( - resource_spec_from_scalars(job.res_cpu_millicores, job.res_memory_bytes, job.res_disk_bytes, job.res_device_json) + resource_spec_from_scalars( + job.resources.cpu_millicores, job.resources.memory_bytes, job.resources.disk_bytes, job.resources.device_json + ) ) for c in constraints_from_json(job.constraints_json): @@ -406,25 +409,26 @@ def _reconstruct_launch_job_request(job: JobDetailRow) -> controller_pb2.Control def _worker_metadata_to_proto(worker: WorkerDetailRow) -> job_pb2.WorkerMetadata: """Reconstruct a WorkerMetadata proto from scalar columns.""" + wmd = worker.metadata md = job_pb2.WorkerMetadata( - hostname=worker.md_hostname, - ip_address=worker.md_ip_address, - cpu_count=worker.md_cpu_count, - memory_bytes=worker.md_memory_bytes, - disk_bytes=worker.md_disk_bytes, - tpu_name=worker.md_tpu_name, - tpu_worker_hostnames=worker.md_tpu_worker_hostnames, - tpu_worker_id=worker.md_tpu_worker_id, - tpu_chips_per_host_bounds=worker.md_tpu_chips_per_host_bounds, - gpu_count=worker.md_gpu_count, - gpu_name=worker.md_gpu_name, - gpu_memory_mb=worker.md_gpu_memory_mb, - gce_instance_name=worker.md_gce_instance_name, - gce_zone=worker.md_gce_zone, - git_hash=worker.md_git_hash, + hostname=wmd.hostname, + ip_address=wmd.ip_address, + cpu_count=wmd.cpu_count, + memory_bytes=wmd.memory_bytes, + disk_bytes=wmd.disk_bytes, + tpu_name=wmd.tpu_name, + tpu_worker_hostnames=wmd.tpu_worker_hostnames, + tpu_worker_id=wmd.tpu_worker_id, + tpu_chips_per_host_bounds=wmd.tpu_chips_per_host_bounds, + gpu_count=wmd.gpu_count, + gpu_name=wmd.gpu_name, + gpu_memory_mb=wmd.gpu_memory_mb, + gce_instance_name=wmd.gce_instance_name, + gce_zone=wmd.gce_zone, + git_hash=wmd.git_hash, ) - if worker.md_device_json and worker.md_device_json != "{}": - md.device.CopyFrom(proto_from_json(worker.md_device_json, job_pb2.DeviceConfig)) + if wmd.device_json and wmd.device_json != "{}": + md.device.CopyFrom(proto_from_json(wmd.device_json, job_pb2.DeviceConfig)) # Populate attributes from the worker_attributes table data stored on the row. for key, value in worker.attributes.items(): av = job_pb2.AttributeValue() @@ -506,7 +510,7 @@ def _tasks_for_listing(db: ControllerDB, *, job_id: JobName) -> list[TaskDetailR with db.read_snapshot() as q: tasks = TASK_DETAIL_PROJECTION.decode( q.fetchall( - f"SELECT {TASK_DETAIL_PROJECTION.select_clause()} " + f"SELECT {TASK_DETAIL_SELECT_T} " "FROM tasks t WHERE t.job_id = ? ORDER BY t.job_id ASC, t.task_index ASC", (job_id.to_wire(),), ), @@ -669,13 +673,13 @@ def _query_jobs( select_params.extend([limit, offset]) with db.read_snapshot() as q: - rows = q.execute_sql(select_sql, tuple(select_params)).fetchall() + rows = q.execute(select_sql, tuple(select_params)).fetchall() # Skip the COUNT query when we can infer the total from the result set: # first page + short result means we already have everything. if offset == 0 and limit > 0 and len(rows) < limit: total = len(rows) else: - total = q.execute_sql(count_sql, tuple(params)).fetchone()[0] + total = q.execute(count_sql, tuple(params)).fetchone()[0] return JOB_ROW_PROJECTION.decode(rows), total @@ -852,7 +856,7 @@ def _tasks_for_worker(db: ControllerDB, worker_id: WorkerId, limit: int = 50) -> placeholders = ",".join("?" for _ in task_wires) tasks = TASK_DETAIL_PROJECTION.decode( q.fetchall( - f"SELECT {TASK_DETAIL_PROJECTION.select_clause()} " + f"SELECT {TASK_DETAIL_SELECT_T} " f"FROM tasks t WHERE t.task_id IN ({placeholders}) ORDER BY t.task_id ASC", tuple(task_wires), ), @@ -971,7 +975,7 @@ class ControllerServiceImpl: def __init__( self, transitions: ControllerTransitions, - db: ControllerDB, + stores: ControllerStores, controller: ControllerProtocol, bundle_store: BundleStore, log_service: LogServiceImpl | LogServiceProxy, @@ -979,7 +983,8 @@ def __init__( system_endpoints: dict[str, str] | None = None, ): self._transitions = transitions - self._db = db + self._stores = stores + self._db = stores.db self._controller = controller self._bundle_store = bundle_store self._log_service = log_service @@ -1172,7 +1177,7 @@ def launch_job( self._controller.wake() with self._db.read_snapshot() as q: - num_tasks = q.execute_sql("SELECT COUNT(*) FROM tasks WHERE job_id = ?", (job_id.to_wire(),)).fetchone()[0] + num_tasks = q.execute("SELECT COUNT(*) FROM tasks WHERE job_id = ?", (job_id.to_wire(),)).fetchone()[0] logger.info(f"Job {job_id} submitted with {num_tasks} task(s)") return controller_pb2.Controller.LaunchJobResponse(job_id=job_id.to_wire()) @@ -1586,7 +1591,10 @@ def list_workers( return controller_pb2.Controller.ListWorkersResponse() workers = [] worker_rows = self._worker_roster_cached() - running_by_worker = running_tasks_by_worker(self._db, {worker.worker_id for worker in worker_rows}) + with self._stores.read() as rctx: + running_by_worker = self._stores.tasks.running_tasks_by_worker( + rctx.cur, {worker.worker_id for worker in worker_rows} + ) for worker in worker_rows: workers.append( controller_pb2.Controller.WorkerHealthStatus( @@ -1679,7 +1687,7 @@ def list_endpoints( if prefix.startswith("/system/"): return self._list_system_endpoints(prefix, exact=request.exact) - endpoints = self._db.endpoints.query( + endpoints = self._stores.endpoints.query( EndpointQuery( exact_name=prefix if request.exact else None, name_prefix=None if request.exact else prefix, @@ -1744,7 +1752,11 @@ def get_autoscaler_status( # Fetch running task counts per worker for dashboard display all_worker_ids = {WorkerId(w.worker_id) for w in workers} - running_by_worker = running_tasks_by_worker(self._db, all_worker_ids) if all_worker_ids else {} + if all_worker_ids: + with self._stores.read() as ctx: + running_by_worker = self._stores.tasks.running_tasks_by_worker(ctx.cur, all_worker_ids) + else: + running_by_worker = {} # Enrich VmInfo objects with worker information by matching vm_id to worker_id for group in status.groups: @@ -2469,7 +2481,9 @@ def get_scheduler_state( decoders={"job_id": JobName.from_wire}, ) for row in rows: - job_resources[row.job_id] = _resource_spec_from_job_row(row) + job_resources[row.job_id] = resource_spec_from_scalars( + row.res_cpu_millicores, row.res_memory_bytes, row.res_disk_bytes, row.res_device_json + ) # Group by effective band, interleaving by user within each band BAND_ORDER = [ @@ -2563,7 +2577,9 @@ def get_scheduler_state( ) running_protos: list[controller_pb2.Controller.SchedulerRunningTask] = [] for row in running_rows: - res = _resource_spec_from_job_row(row) + res = resource_spec_from_scalars( + row.res_cpu_millicores, row.res_memory_bytes, row.res_disk_bytes, row.res_device_json + ) eff_band = compute_effective_band(row.priority_band, row.task_id.user, user_spend, budget_limits) accel = get_gpu_count(res.device) + get_tpu_count(res.device) rv = resource_value(res.cpu_millicores, res.memory_bytes, accel) diff --git a/lib/iris/src/iris/cluster/controller/store.py b/lib/iris/src/iris/cluster/controller/store.py new file mode 100644 index 0000000000..b04fa98550 --- /dev/null +++ b/lib/iris/src/iris/cluster/controller/store.py @@ -0,0 +1,2996 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Typed entity operations for task state transitions. + +Sits between transitions.py (state machine logic) and db.py (SQLite logistics). +Methods take and return dataclasses, not raw SQL rows. The store enforces +multi-table invariants (e.g. terminate = attempt terminal + task update + worker +decommit + endpoint delete) in a single method instead of scattering across +inline SQL blocks. + +Stores are process-scoped: a single instance is constructed by +``ControllerDB.__init__`` and reused across transactions. Every method takes +the open ``Cursor`` as its first argument so writes land inside +the caller's transaction. Process-scoped caches (worker attributes, +job_config) survive across transactions and are invalidated via cursor +post-commit hooks. +""" + +from __future__ import annotations + +import json +import logging +import sqlite3 +from collections.abc import Callable, Iterable, Iterator, Sequence +from contextlib import AbstractContextManager, contextmanager +from dataclasses import dataclass, field, replace as dc_replace +from enum import StrEnum +from threading import Lock, RLock +from typing import Any, Literal, Protocol, TypeVar, overload +from iris.cluster.constraints import AttributeValue +from iris.cluster.controller.budget import UserBudgetDefaults +from iris.cluster.controller.codec import resource_spec_from_scalars +from iris.cluster.controller.schema import ( + ACTIVE_TASK_STATES, + ATTEMPT_PROJECTION, + ENDPOINT_PROJECTION, + EXECUTING_TASK_STATES, + JOB_CONFIG_JOIN, + JOB_CONFIG_PROJECTION, + JOB_DETAIL_PROJECTION, + JOB_ROW_PROJECTION, + TASK_DETAIL_PROJECTION, + TASK_DETAIL_SELECT_T, + WORKER_DETAIL_PROJECTION, + WORKER_ROW_PROJECTION, + EndpointRow, + JobConfigRow, + JobDetailRow, + JobRow, + ResourceSpec, + TaskDetailRow, + WorkerActiveRow, + WorkerDetailRow, + WorkerRow, + decode_timestamp_ms, + decode_worker_id, + tasks_with_attempts, +) +from iris.cluster.types import ( + TERMINAL_JOB_STATES, + TERMINAL_TASK_STATES, + JobName, + JobState, + TaskState, + WorkerId, + get_gpu_count, + get_tpu_count, +) +from iris.rpc import job_pb2 +from rigging.timing import Deadline, Duration, Timestamp + +logger = logging.getLogger(__name__) + + +def sql_placeholders(n: int) -> str: + """Return ``?,?,?`` style placeholder string for SQL ``IN`` lists.""" + return ",".join("?" * n) + + +class WhereBuilder: + """Accumulate conditional WHERE clauses + params without repetitive boilerplate. + + Usage: + wb = WhereBuilder() + if flt.worker_id is not None: + wb.eq("t.current_worker_id", str(flt.worker_id)) + wb.in_("t.state", states) # skips if states is empty/None + sql, params = wb.build() # ("WHERE ...", (..., ...)) or ("", ()) + """ + + def __init__(self) -> None: + self._clauses: list[str] = [] + self._params: list[object] = [] + + def eq(self, col: str, value: object) -> None: + self._clauses.append(f"{col} = ?") + self._params.append(value) + + def is_null(self, col: str) -> None: + self._clauses.append(f"{col} IS NULL") + + def in_(self, col: str, values: Iterable[object] | None) -> None: + if not values: + return + values = tuple(values) + self._clauses.append(f"{col} IN ({sql_placeholders(len(values))})") + self._params.extend(values) + + def raw(self, clause: str, *params: object) -> None: + """Escape hatch for non-trivial conditions; keeps param alignment.""" + self._clauses.append(clause) + self._params.extend(params) + + def build(self) -> tuple[str, tuple[object, ...]]: + if not self._clauses: + return "", () + return "WHERE " + " AND ".join(self._clauses), tuple(self._params) + + +# --------------------------------------------------------------------------- +# Protocols — structural interfaces used by stores to decouple from db.py. +# ControllerDB and Cursor satisfy these structurally; no explicit +# declaration is needed on those classes. +# --------------------------------------------------------------------------- + + +class Cursor(Protocol): + """Methods that store operations call on a Cursor. + + Historically includes both read and write shape — tightened in a later pass. + A read-only scope (``QuerySnapshot``) satisfies only the ``execute`` slice; + calling write-only members on one fails at runtime, which is intentional. + """ + + def execute(self, sql: str, params: tuple = ...) -> sqlite3.Cursor: ... + + def executemany(self, sql: str, params: Iterable[tuple]) -> sqlite3.Cursor: ... + + def on_commit(self, hook: Callable[[], None]) -> None: ... + + @property + def rowcount(self) -> int: ... + + @property + def lastrowid(self) -> int | None: ... + + +class WriteCursor(Cursor, Protocol): + """Cursor inside an IMMEDIATE transaction — supports commit hooks.""" + + def on_commit(self, hook: Callable[[], None]) -> None: ... + + @property + def lastrowid(self) -> int | None: ... + + @property + def rowcount(self) -> int: ... + + +class DbBackend(Protocol): + """Methods that stores call on a ControllerDB.""" + + def read_snapshot(self) -> AbstractContextManager[Cursor]: ... + + def transaction(self) -> AbstractContextManager[WriteCursor]: ... + + def execute(self, query: str, params: tuple | list = ...) -> None: ... + + +class _DecodedRow: + """Lightweight attribute-access wrapper over a decoded SQL row.""" + + __slots__ = ("_data",) + + def __init__(self, data: dict[str, Any]) -> None: + object.__setattr__(self, "_data", data) + + def __getattr__(self, name: str) -> Any: + try: + return self._data[name] + except KeyError as exc: + raise AttributeError(name) from exc + + +def decoded_rows( + cur: Cursor, + sql: str, + params: tuple = (), + decoders: dict[str, Callable] | None = None, +) -> list[_DecodedRow]: + """Execute ``sql`` on ``cur`` and return decoded rows with attribute access. + + Replaces the ``QuerySnapshot.raw`` helper so read methods can operate on + any ``Cursor`` — including a write cursor when the caller is already + inside a transaction. + """ + cursor = cur.execute(sql, params) + col_names = [desc[0] for desc in cursor.description] + active_decoders = decoders or {} + out: list[_DecodedRow] = [] + for raw_row in cursor.fetchall(): + data = { + name: active_decoders[name](raw_row[name]) if name in active_decoders else raw_row[name] + for name in col_names + } + out.append(_DecodedRow(data)) + return out + + +# --------------------------------------------------------------------------- +# Domain predicates — logic about task/job/attempt states. +# --------------------------------------------------------------------------- + + +def task_is_finished( + state: int, failure_count: int, max_retries_failure: int, preemption_count: int, max_retries_preemption: int +) -> bool: + """Whether a task has reached a terminal state with no remaining retries.""" + if state == job_pb2.TASK_STATE_SUCCEEDED: + return True + if state in (job_pb2.TASK_STATE_KILLED, job_pb2.TASK_STATE_UNSCHEDULABLE): + return True + if state == job_pb2.TASK_STATE_FAILED: + return failure_count > max_retries_failure + if state in (job_pb2.TASK_STATE_WORKER_FAILED, job_pb2.TASK_STATE_PREEMPTED): + return preemption_count > max_retries_preemption + return False + + +class TaskRowLike(Protocol): + """Structural interface for task rows used by scheduling predicates. + + Satisfied by ``TaskRow`` and ``TaskDetailRow`` — any row carrying state + plus retry counters and the current attempt id can be evaluated for + finish/schedulability without coupling to a concrete row type. + """ + + state: int + failure_count: int + max_retries_failure: int + preemption_count: int + max_retries_preemption: int + current_attempt_id: int + + +def task_row_is_finished(task: TaskRowLike) -> bool: + return task_is_finished( + task.state, task.failure_count, task.max_retries_failure, task.preemption_count, task.max_retries_preemption + ) + + +def task_row_can_be_scheduled(task: TaskRowLike) -> bool: + if task.state != job_pb2.TASK_STATE_PENDING: + return False + return task.current_attempt_id < 0 or not task_is_finished( + task.state, task.failure_count, task.max_retries_failure, task.preemption_count, task.max_retries_preemption + ) + + +def job_scheduling_deadline(scheduling_deadline_epoch_ms: int | None) -> Deadline | None: + """Compute scheduling deadline from epoch ms.""" + if scheduling_deadline_epoch_ms is None: + return None + return Deadline.after(Timestamp.from_ms(scheduling_deadline_epoch_ms), Duration.from_ms(0)) + + +def attempt_is_terminal(state: int) -> bool: + """Check if an attempt is in a terminal state.""" + return state in TERMINAL_TASK_STATES + + +def attempt_is_worker_failure(state: int) -> bool: + """Check if an attempt is a worker failure or preemption.""" + return state in (job_pb2.TASK_STATE_WORKER_FAILED, job_pb2.TASK_STATE_PREEMPTED) + + +# --------------------------------------------------------------------------- +# Domain summary dataclasses +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class UserStats: + user: str + task_state_counts: dict[int, int] = field(default_factory=dict) + job_state_counts: dict[int, int] = field(default_factory=dict) + + +@dataclass(frozen=True) +class TaskJobSummary: + job_id: JobName + task_count: int = 0 + completed_count: int = 0 + failure_count: int = 0 + preemption_count: int = 0 + task_state_counts: dict[int, int] = field(default_factory=dict) + + +@dataclass(frozen=True) +class UserBudget: + user_id: str + budget_limit: int + max_band: int + updated_at: Timestamp + + +def _decode_attribute_rows(rows: Sequence[Any]) -> dict[WorkerId, dict[str, AttributeValue]]: + attrs_by_worker: dict[WorkerId, dict[str, AttributeValue]] = {} + for row in rows: + worker_attrs = attrs_by_worker.setdefault(row.worker_id, {}) + if row.value_type == "int": + worker_attrs[row.key] = AttributeValue(int(row.int_value)) + elif row.value_type == "float": + worker_attrs[row.key] = AttributeValue(float(row.float_value)) + else: + worker_attrs[row.key] = AttributeValue(str(row.str_value or "")) + return attrs_by_worker + + +# --------------------------------------------------------------------------- +# Shared data types (used by both store.py and transitions.py) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class TaskUpdate: + """Single task state update applied in a batch.""" + + task_id: JobName + attempt_id: int + new_state: int + error: str | None = None + exit_code: int | None = None + resource_usage: job_pb2.ResourceUsage | None = None + container_id: str | None = None + + +@dataclass(frozen=True) +class HeartbeatApplyRequest: + """Batch of worker heartbeat updates applied atomically.""" + + worker_id: WorkerId + worker_resource_snapshot: job_pb2.WorkerResourceSnapshot | None + updates: list[TaskUpdate] + + +# --------------------------------------------------------------------------- +# Dataclasses — read-only views and write inputs +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class TimedOutTask: + """A running task that has exceeded its execution timeout.""" + + task_id: JobName + worker_id: WorkerId | None + + +@dataclass(frozen=True, slots=True) +class TaskSnapshot: + """Read-only view of a task row for the transition planner. + + Built from a task row, its current attempt, and the cached job_config. + Provides everything needed to decide the next transition without going + back to the DB. + """ + + task_id: str + job_id: JobName + state: TaskState + attempt_id: int + attempt_state: TaskState + failure_count: int + preemption_count: int + max_retries_failure: int + max_retries_preemption: int + worker_id: str | None + has_coscheduling: bool + resources: job_pb2.ResourceSpecProto | None + + +@dataclass(frozen=True, slots=True) +class SiblingSnapshot: + """Read-only view of a coscheduled sibling for cascade decisions.""" + + task_id: str + attempt_id: int + max_retries_preemption: int + worker_id: str | None + + +@dataclass(frozen=True, slots=True) +class TaskTermination: + """All inputs needed to terminate a task. + + ``finalize`` may be None when no attempt exists; the attempt UPDATE is skipped. + ``finalize.attempt_state`` overrides the state written to the attempt row when + it differs from the task state (e.g. attempt=WORKER_FAILED while task retries + to PENDING). ``error`` is written to the task row; if the same string should + apply to the attempt row, set ``finalize.error`` too. + """ + + task_id: str + state: TaskState + now_ms: int + error: str | None = None + finalize: AttemptFinalizer | None = None + worker_id: str | None = None + resources: job_pb2.ResourceSpecProto | None = None + failure_count: int | None = None + preemption_count: int | None = None + exit_code: int | None = None + + +@dataclass(frozen=True, slots=True) +class TaskRetry: + """All inputs needed to requeue a task to PENDING. + + Terminates the current attempt but resets the task row to PENDING + so the scheduler can create a new attempt. + """ + + task_id: str + finalize: AttemptFinalizer + worker_id: str | None = None + resources: job_pb2.ResourceSpecProto | None = None + failure_count: int = 0 + preemption_count: int = 0 + + +@dataclass(frozen=True, slots=True) +class ActiveStateUpdate: + """Non-terminal task state update (BUILDING, RUNNING).""" + + task_id: str + attempt_id: int + state: TaskState + error: str | None = None + exit_code: int | None = None + started_ms: int | None = None + failure_count: int = 0 + preemption_count: int = 0 + + +@dataclass(frozen=True, slots=True) +class TaskInsert: + """All columns for an INSERT INTO tasks row.""" + + task_id: str + job_id: str + task_index: int + state: int + submitted_at_ms: int + max_retries_failure: int + max_retries_preemption: int + priority_neg_depth: int + priority_root_submitted_ms: int + priority_insertion: int + priority_band: int + + +@dataclass(frozen=True, slots=True) +class WorkerAssignment: + """Assign a task to a worker-backed slot (worker_id and address known).""" + + task_id: str + attempt_id: int + worker_id: str + worker_address: str + now_ms: int + + +@dataclass(frozen=True, slots=True) +class DirectAssignment: + """Assign a task to a direct-provider slot (no backing worker daemon).""" + + task_id: str + attempt_id: int + now_ms: int + + +@dataclass(frozen=True, slots=True) +class AttemptFinalizer: + """Fields needed to write a terminal row to task_attempts. + + Shared by TaskTermination and TaskRetry so both can delegate to a single + ``_finalize_attempt`` helper. + """ + + task_id: str + attempt_id: int + attempt_state: TaskState + now_ms: int + error: str | None = None + exit_code: int | None = None + + @classmethod + def build( + cls, + task_id: str, + attempt_id: int, + state: TaskState | int, + now_ms: int, + error: str | None = None, + ) -> AttemptFinalizer: + return cls( + task_id=task_id, + attempt_id=attempt_id, + attempt_state=state, + now_ms=now_ms, + error=error, + ) + + +@dataclass(frozen=True, slots=True) +class KillResult: + """Tasks that need kill RPCs after a cascade.""" + + tasks_to_kill: frozenset[JobName] + task_kill_workers: dict[JobName, WorkerId] + + +@dataclass(frozen=True, slots=True) +class JobInsert: + """All columns for an INSERT INTO jobs row. + + Covers both the main job insert and the reservation holder insert. + """ + + job_id: str + user_id: str + parent_job_id: str | None + root_job_id: str + depth: int + state: int + submitted_at_ms: int + root_submitted_at_ms: int + finished_at_ms: int | None + scheduling_deadline_epoch_ms: int | None + error: str | None + num_tasks: int + is_reservation_holder: bool + name: str + has_reservation: bool + + +@dataclass(frozen=True, slots=True) +class JobConfigInsert: + """All columns for an INSERT INTO job_config row.""" + + job_id: str + name: str + has_reservation: bool + resources: ResourceSpec + constraints_json: str | None + has_coscheduling: int + coscheduling_group_by: str + scheduling_timeout_ms: int | None + max_task_failures: int + entrypoint_json: str + environment_json: str + bundle_id: str + ports_json: str + max_retries_failure: int + max_retries_preemption: int + timeout_ms: int | None + preemption_policy: int + existing_job_policy: int + priority_band: int + task_image: str + submit_argv_json: str = "[]" + reservation_json: str | None = None + fail_if_exists: int = 0 + + +@dataclass(frozen=True) +class EndpointQuery: + endpoint_ids: tuple[str, ...] = () + name_prefix: str | None = None + exact_name: str | None = None + task_ids: tuple[JobName, ...] = () + limit: int | None = None + + +class TaskProjection(StrEnum): + DETAIL = "detail" + WITH_JOB = "with_job" + WITH_JOB_CONFIG = "with_job_config" + + +@dataclass(frozen=True, slots=True) +class TaskFilter: + """Closed WHERE-clause predicate for the tasks table. + + All set fields AND together; unset fields are not filtered on. Used by + :meth:`TaskStore.query` as a single entry point for simple non-snapshot + reads that differ only in their WHERE clause. + + ``worker_id`` and ``worker_is_null`` are mutually exclusive; setting both + at construction raises ``ValueError``. + """ + + task_ids: tuple[str, ...] | None = None + job_ids: tuple[str, ...] | None = None + worker_id: WorkerId | None = None + worker_is_null: bool = False + states: frozenset[int] | None = None + limit: int | None = None + + def __post_init__(self) -> None: + if self.worker_id is not None and self.worker_is_null: + raise ValueError("TaskFilter: worker_id and worker_is_null are mutually exclusive") + + +@dataclass(frozen=True, slots=True) +class JobDetailFilter: + """Closed WHERE-clause predicate for the jobs table (with config join). + + All set fields AND together; unset fields are not filtered on. Used by + :meth:`JobStore.query` as the single entry point for critical-path reads. + """ + + job_ids: tuple[str, ...] | None = None + states: frozenset[int] | None = None + has_reservation: bool | None = None + limit: int | None = None + + +@dataclass(frozen=True, slots=True) +class WorkerFilter: + """Closed WHERE-clause predicate for the workers table. + + All set fields AND together; unset fields are not filtered on. Used by + :meth:`WorkerStore.query` for scheduling-tick reads that differ only in + their WHERE clause. + """ + + worker_ids: tuple[WorkerId, ...] | None = None + active: bool | None = None + healthy: bool | None = None + + +@dataclass(frozen=True, slots=True) +class WorkerMetadata: + """Worker environment metadata extracted from the registration RPC.""" + + hostname: str + ip_address: str + cpu_count: int + memory_bytes: int + disk_bytes: int + tpu_name: str + tpu_worker_hostnames: str + tpu_worker_id: int + tpu_chips_per_host_bounds: str + gpu_count: int + gpu_name: str + gpu_memory_mb: int + gce_instance_name: str + gce_zone: str + git_hash: str + device_json: str + + +@dataclass(frozen=True, slots=True) +class WorkerUpsert: + """All inputs needed to insert or update a worker row.""" + + worker_id: str + address: str + now_ms: int + total_cpu_millicores: int + total_memory_bytes: int + total_gpu_count: int + total_tpu_count: int + device_type: str + device_variant: str + slice_id: str + scale_group: str + metadata: WorkerMetadata + attributes: list[tuple[str, str, str | None, int | None, float | None]] + + +# SQLite caps host parameters at ~999. Leave headroom for fixed-position +# filter params (states, worker_id, limit) by chunking ID IN-lists well below +# that cap. +_ID_IN_CHUNK = 900 + + +_T = TypeVar("_T") +_R = TypeVar("_R") + + +def chunk_ids(ids: Sequence[_T] | None, size: int = _ID_IN_CHUNK) -> list[tuple[_T, ...]] | None: + """Chunk ids for IN-list queries. Returns None when ids is None (no filter applied).""" + if ids is None: + return None + return [tuple(ids[i : i + size]) for i in range(0, len(ids), size)] + + +def run_chunked( + chunks: list[tuple[_T, ...]] | None, + limit: int | None, + fetch: Callable[[tuple[_T, ...] | None, int | None], list[_R]], +) -> list[_R]: + """Run a chunked IN-list query with optional row limit. + + `chunks is None` means no id filter — fetch is called once with (None, limit). + When chunks are present, fetch is called per chunk with the remaining limit, + stopping early once the limit is reached. + Callers must short-circuit before calling this when chunks == []. + """ + if chunks is None: + return fetch(None, limit) + results: list[_R] = [] + remaining = limit + for chunk in chunks: + if remaining is not None and remaining <= 0: + break + batch = fetch(chunk, remaining) + results.extend(batch) + if remaining is not None: + remaining -= len(batch) + return results + + +# --------------------------------------------------------------------------- +# EndpointStore +# +# Process-local in-memory cache for the ``endpoints`` table. +# +# Profiling showed that ``ListEndpoints`` dominated controller CPU — not because +# the SQL was slow per se, but because every call serialized through the +# read-connection pool and walked a large WAL to build a snapshot. The endpoints +# table is tiny (hundreds of rows) and only changes on explicit register / +# unregister, so it is a natural fit for a write-through in-memory cache. +# +# Design invariants: +# +# * Reads never touch the DB. All lookups are served from in-memory maps +# guarded by an ``RLock`` — readers observe a consistent snapshot of the +# indexes, never a torn state mid-update. +# * Writes execute the SQL inside the caller's transaction. The in-memory +# update is scheduled as a post-commit hook on the cursor so memory only +# changes after the DB has committed. If the transaction rolls back, the +# hook never fires. +# * N is small enough (≈ hundreds) that linear scans for prefix / task / id +# lookups are simpler and plenty fast. Extra indexes (by name, by task_id) +# speed the two common cases. +# +# The store is the sole source of truth for endpoint reads; nothing else in +# the controller tree should SELECT from ``endpoints``. +# --------------------------------------------------------------------------- + + +class EndpointStore: + """In-memory index of endpoint rows, kept in sync with the DB. + + Construct with a ``ControllerDB``; the store loads all existing rows at + init time. Callers mutate through ``add`` / ``remove*`` methods that take + the open ``Cursor`` so the SQL lands inside the caller's + transaction. Memory is only updated after a successful commit via a + cursor post-commit hook. + """ + + def __init__(self) -> None: + self._lock = RLock() + self._by_id: dict[str, EndpointRow] = {} + # One name can map to multiple endpoint_ids — the schema does not enforce + # uniqueness on ``name``, and ``INSERT OR REPLACE`` keys off endpoint_id. + self._by_name: dict[str, set[str]] = {} + self._by_task: dict[JobName, set[str]] = {} + + # -- Loading -------------------------------------------------------------- + + def _load_all(self, cur: Cursor) -> None: + rows = ENDPOINT_PROJECTION.decode( + cur.execute(f"SELECT {ENDPOINT_PROJECTION.select_clause()} FROM endpoints e").fetchall(), + ) + with self._lock: + self._by_id.clear() + self._by_name.clear() + self._by_task.clear() + for row in rows: + self._index(row) + logger.info("EndpointStore loaded %d endpoint(s) from DB", len(rows)) + + def _commit_index_update( + self, + cur: Cursor, + *, + add: EndpointRow | None = None, + remove_ids: Iterable[str] = (), + ) -> None: + """Register index mutations to fire when ``cur``'s transaction commits.""" + # Capture mutable locals into the closure now, before the transaction commits. + remove_list = list(remove_ids) + + def apply() -> None: + with self._lock: + for eid in remove_list: + self._unindex(eid) + if add is not None: + self._unindex(add.endpoint_id) + self._index(add) + + cur.on_commit(apply) + + def _index(self, row: EndpointRow) -> None: + self._by_id[row.endpoint_id] = row + self._by_name.setdefault(row.name, set()).add(row.endpoint_id) + self._by_task.setdefault(row.task_id, set()).add(row.endpoint_id) + + def _unindex(self, endpoint_id: str) -> EndpointRow | None: + row = self._by_id.pop(endpoint_id, None) + if row is None: + return None + name_ids = self._by_name.get(row.name) + if name_ids is not None: + name_ids.discard(endpoint_id) + if not name_ids: + self._by_name.pop(row.name, None) + task_ids = self._by_task.get(row.task_id) + if task_ids is not None: + task_ids.discard(endpoint_id) + if not task_ids: + self._by_task.pop(row.task_id, None) + return row + + # -- Reads ---------------------------------------------------------------- + + def query(self, query: EndpointQuery = EndpointQuery()) -> list[EndpointRow]: + """Return endpoint rows matching ``query``. + + All filters AND together, matching the semantics of the original SQL + in :func:`iris.cluster.controller.db.endpoint_query_sql`. + """ + with self._lock: + # Narrow the candidate set using the most selective index available. + if query.endpoint_ids: + candidates: Iterable[EndpointRow] = ( + self._by_id[eid] for eid in query.endpoint_ids if eid in self._by_id + ) + elif query.task_ids: + task_set = set(query.task_ids) + candidates = (self._by_id[eid] for task_id in task_set for eid in self._by_task.get(task_id, ())) + elif query.exact_name is not None: + candidates = (self._by_id[eid] for eid in self._by_name.get(query.exact_name, ())) + else: + candidates = self._by_id.values() + + results: list[EndpointRow] = [] + for row in candidates: + if query.name_prefix is not None and not row.name.startswith(query.name_prefix): + continue + if query.exact_name is not None and row.name != query.exact_name: + continue + if query.task_ids and row.task_id not in query.task_ids: + continue + if query.endpoint_ids and row.endpoint_id not in query.endpoint_ids: + continue + results.append(row) + if query.limit is not None and len(results) >= query.limit: + break + return results + + def resolve(self, name: str) -> EndpointRow | None: + """Return any endpoint with exact ``name``, or None. Used by the actor proxy.""" + with self._lock: + ids = self._by_name.get(name) + if not ids: + return None + # Arbitrary but stable pick — the original SQL did not specify ORDER BY. + return self._by_id[next(iter(ids))] + + def get(self, endpoint_id: str) -> EndpointRow | None: + with self._lock: + return self._by_id.get(endpoint_id) + + def all(self) -> list[EndpointRow]: + with self._lock: + return list(self._by_id.values()) + + # -- Writes --------------------------------------------------------------- + + def add(self, cur: Cursor, endpoint: EndpointRow) -> bool: + """Insert ``endpoint`` into the DB and schedule the memory update. + + Returns False (and writes nothing) if the owning task is already + terminal. Otherwise inserts / replaces and schedules a post-commit + hook that updates the in-memory indexes. + """ + task_id = endpoint.task_id + job_id, _ = task_id.require_task() + row = cur.execute("SELECT state FROM tasks WHERE task_id = ?", (task_id.to_wire(),)).fetchone() + if row is not None and int(row["state"]) in TERMINAL_TASK_STATES: + return False + + cur.execute( + "INSERT OR REPLACE INTO endpoints(" + "endpoint_id, name, address, job_id, task_id, metadata_json, registered_at_ms" + ") VALUES (?, ?, ?, ?, ?, ?, ?)", + ( + endpoint.endpoint_id, + endpoint.name, + endpoint.address, + job_id.to_wire(), + task_id.to_wire(), + json.dumps(endpoint.metadata), + endpoint.registered_at.epoch_ms(), + ), + ) + + self._commit_index_update(cur, add=endpoint) + return True + + def remove(self, cur: Cursor, endpoint_id: str) -> EndpointRow | None: + """Remove a single endpoint by id. Returns the removed row snapshot, if any.""" + existing = self.get(endpoint_id) + if existing is None: + return None + cur.execute("DELETE FROM endpoints WHERE endpoint_id = ?", (endpoint_id,)) + + self._commit_index_update(cur, remove_ids=[endpoint_id]) + return existing + + def remove_by_task(self, cur: Cursor, task_id: JobName) -> list[str]: + """Remove all endpoints owned by a task. Returns the removed endpoint_ids.""" + with self._lock: + ids = list(self._by_task.get(task_id, ())) + if not ids: + # Still issue the DELETE to stay consistent with any rows the + # store might not have observed yet (belt-and-suspenders for + # the unlikely race of an in-flight concurrent writer). This + # costs nothing on the common path. + cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) + return [] + cur.execute("DELETE FROM endpoints WHERE task_id = ?", (task_id.to_wire(),)) + + self._commit_index_update(cur, remove_ids=ids) + return ids + + def remove_by_job_ids(self, cur: Cursor, job_ids: Sequence[JobName]) -> list[str]: + """Remove all endpoints owned by any of ``job_ids``. Used by cancel_job and prune.""" + if not job_ids: + return [] + wire_ids = [jid.to_wire() for jid in job_ids] + with self._lock: + to_remove: list[str] = [] + for row in self._by_id.values(): + owning_job, _ = row.task_id.require_task() + if owning_job.to_wire() in wire_ids: + to_remove.append(row.endpoint_id) + placeholders = sql_placeholders(len(wire_ids)) + cur.execute( + f"DELETE FROM endpoints WHERE job_id IN ({placeholders})", + tuple(wire_ids), + ) + if not to_remove: + return [] + + self._commit_index_update(cur, remove_ids=to_remove) + return to_remove + + +# --------------------------------------------------------------------------- +# TaskStore +# --------------------------------------------------------------------------- + + +class TaskStore: + """Typed read/write operations for task entities. + + Process-scoped: a single instance lives on the ``ControllerDB``. Every + method takes the open ``Cursor`` as its first argument so + writes land inside the caller's transaction. + """ + + def __init__(self, endpoints: EndpointStore, jobs: JobStore) -> None: + self._endpoints = endpoints + self._jobs = jobs + + # ── Reads ──────────────────────────────────────────────────────── + + def get_task(self, cur: Cursor, task_id: JobName) -> TaskSnapshot | None: + """Load a task + its current attempt + job_config into a snapshot. + + Returns None if the task doesn't exist. Reads the job_config through + ``JobStore.get_config`` which caches lookups process-wide. + """ + task_row = cur.execute("SELECT * FROM tasks WHERE task_id = ?", (task_id.to_wire(),)).fetchone() + if task_row is None: + return None + + task = TASK_DETAIL_PROJECTION.decode_one([task_row]) + attempt_id = int(task_row["current_attempt_id"]) + + # Load the current attempt to get its state and worker_id. + attempt_row = cur.execute( + "SELECT state, worker_id FROM task_attempts WHERE task_id = ? AND attempt_id = ?", + (task_id.to_wire(), attempt_id), + ).fetchone() + if attempt_row is None: + attempt_state = int(task_row["state"]) + worker_id = task_row["current_worker_id"] + worker_id = str(worker_id) if worker_id is not None else None + else: + attempt_state = int(attempt_row["state"]) + worker_id = str(attempt_row["worker_id"]) if attempt_row["worker_id"] is not None else None + + # Fetch job_config via the process-scoped JobStore cache. + jc = self._jobs.get_config(cur, task.job_id.to_wire()) + + has_coscheduling = False + resources: job_pb2.ResourceSpecProto | None = None + if jc is not None: + has_coscheduling = jc.has_coscheduling + resources = resource_spec_from_scalars( + jc.resources.cpu_millicores, + jc.resources.memory_bytes, + jc.resources.disk_bytes, + jc.resources.device_json, + ) + + return TaskSnapshot( + task_id=task_id.to_wire(), + job_id=task.job_id, + state=int(task_row["state"]), + attempt_id=attempt_id, + attempt_state=attempt_state, + failure_count=int(task_row["failure_count"]), + preemption_count=int(task_row["preemption_count"]), + max_retries_failure=int(task_row["max_retries_failure"]), + max_retries_preemption=int(task_row["max_retries_preemption"]), + worker_id=worker_id, + has_coscheduling=has_coscheduling, + resources=resources, + ) + + def get_attempt_state(self, cur: Cursor, task_id: str, attempt_id: int) -> int | None: + """Load the state of a specific attempt. Used for stale-attempt checks.""" + row = cur.execute( + "SELECT state FROM task_attempts WHERE task_id = ? AND attempt_id = ?", + (task_id, attempt_id), + ).fetchone() + if row is None: + return None + return int(row["state"]) + + def find_coscheduled_siblings( + self, + cur: Cursor, + job_id: JobName, + exclude_task_id: JobName, + has_coscheduling: bool, + ) -> list[SiblingSnapshot]: + """Find active siblings in a coscheduled job. + + Returns an empty list when the job has no coscheduling config. + Active means ASSIGNED, BUILDING, or RUNNING. + """ + if not has_coscheduling: + return [] + rows = cur.execute( + "SELECT t.task_id, t.current_attempt_id, t.max_retries_preemption, " + "t.current_worker_id AS worker_id " + "FROM tasks t " + "WHERE t.job_id = ? AND t.task_id != ? AND t.state IN (?, ?, ?)", + ( + job_id.to_wire(), + exclude_task_id.to_wire(), + job_pb2.TASK_STATE_ASSIGNED, + job_pb2.TASK_STATE_BUILDING, + job_pb2.TASK_STATE_RUNNING, + ), + ).fetchall() + return [ + SiblingSnapshot( + task_id=str(r["task_id"]), + attempt_id=int(r["current_attempt_id"]), + max_retries_preemption=int(r["max_retries_preemption"]), + worker_id=str(r["worker_id"]) if r["worker_id"] is not None else None, + ) + for r in rows + ] + + # ── Read-pool helpers ───────────────────────────────────────────── + # These operate on the ControllerDB read pool, not a write transaction. + + def running_tasks_by_worker(self, cur: Cursor, worker_ids: set[WorkerId]) -> dict[WorkerId, set[JobName]]: + """Return the set of currently-running task IDs for each worker. + + Uses the denormalized current_worker_id column instead of joining task_attempts. + """ + if not worker_ids: + return {} + placeholders = sql_placeholders(len(worker_ids)) + rows = decoded_rows( + cur, + f"SELECT t.current_worker_id AS worker_id, t.task_id FROM tasks t " + f"WHERE t.current_worker_id IN ({placeholders}) AND t.state IN (?, ?, ?)", + (*[str(wid) for wid in worker_ids], *ACTIVE_TASK_STATES), + decoders={"worker_id": decode_worker_id, "task_id": JobName.from_wire}, + ) + running: dict[WorkerId, set[JobName]] = {wid: set() for wid in worker_ids} + for row in rows: + running[row.worker_id].add(row.task_id) + return running + + def timed_out_executing_tasks(self, cur: Cursor, now: Timestamp) -> list[TimedOutTask]: + """Find executing tasks whose current attempt has exceeded the job's execution timeout. + + Reads the timeout from job_config.timeout_ms. Uses the current attempt's + started_at_ms so that retried tasks get a fresh timeout budget per attempt. + """ + now_ms = now.epoch_ms() + executing_states = tuple(sorted(EXECUTING_TASK_STATES)) + placeholders = sql_placeholders(len(executing_states)) + rows = decoded_rows( + cur, + f"SELECT t.task_id, t.current_worker_id AS worker_id, " + f"ta.started_at_ms AS attempt_started_at_ms, jc.timeout_ms " + f"FROM tasks t " + f"JOIN job_config jc ON jc.job_id = t.job_id " + f"JOIN task_attempts ta ON ta.task_id = t.task_id AND ta.attempt_id = t.current_attempt_id " + f"WHERE t.state IN ({placeholders}) " + f"AND jc.timeout_ms IS NOT NULL AND jc.timeout_ms > 0 " + f"AND ta.started_at_ms IS NOT NULL", + (*executing_states,), + decoders={ + "task_id": JobName.from_wire, + "worker_id": lambda v: WorkerId(v) if v is not None else None, + "attempt_started_at_ms": int, + "timeout_ms": int, + }, + ) + result: list[TimedOutTask] = [] + for row in rows: + if row.attempt_started_at_ms + row.timeout_ms <= now_ms: + result.append(TimedOutTask(task_id=row.task_id, worker_id=row.worker_id)) + return result + + def tasks_for_job_with_attempts(self, cur: Cursor, job_id: JobName) -> list: + """Fetch all tasks for a job with their attempt history.""" + tasks = TASK_DETAIL_PROJECTION.decode( + cur.execute( + "SELECT * FROM tasks WHERE job_id = ? ORDER BY task_index, task_id", + (job_id.to_wire(),), + ).fetchall(), + ) + if not tasks: + return [] + placeholders = sql_placeholders(len(tasks)) + attempts = ATTEMPT_PROJECTION.decode( + cur.execute( + f"SELECT * FROM task_attempts WHERE task_id IN ({placeholders}) ORDER BY task_id, attempt_id", + tuple(t.task_id.to_wire() for t in tasks), + ).fetchall(), + ) + return tasks_with_attempts(tasks, attempts) + + def insert_task_profile( + self, cur: Cursor, task_id: str, profile_data: bytes, captured_at: Timestamp, profile_kind: str = "cpu" + ) -> None: + """Insert a captured profile snapshot for a task. + + The DB trigger caps profiles at 10 per (task_id, profile_kind), evicting the oldest automatically. + """ + cur.execute( + "INSERT INTO profiles.task_profiles " + "(task_id, profile_data, captured_at_ms, profile_kind) VALUES (?, ?, ?, ?)", + (task_id, profile_data, captured_at.epoch_ms(), profile_kind), + ) + + def get_task_profiles( + self, cur: Cursor, task_id: str, profile_kind: str | None = None + ) -> list[tuple[bytes, Timestamp, str]]: + """Return stored profile snapshots for a task, newest first. + + Args: + task_id: Task wire string. + profile_kind: If set, filter to this kind (e.g. "cpu", "memory"). Returns all kinds when None. + """ + if profile_kind is not None: + query = ( + "SELECT profile_data, captured_at_ms, profile_kind FROM profiles.task_profiles" + " WHERE task_id = ? AND profile_kind = ? ORDER BY id DESC" + ) + params: tuple[str, ...] = (task_id, profile_kind) + else: + query = ( + "SELECT profile_data, captured_at_ms, profile_kind FROM profiles.task_profiles" + " WHERE task_id = ? ORDER BY id DESC" + ) + params = (task_id,) + rows = decoded_rows(cur, query, params, decoders={"captured_at_ms": decode_timestamp_ms}) + return [(row.profile_data, row.captured_at_ms, row.profile_kind) for row in rows] + + # ── Writes ─────────────────────────────────────────────────────── + + def terminate(self, cur: Cursor, t: TaskTermination) -> None: + """Move a task (and its current attempt) to terminal state consistently. + + Enforces the multi-table invariant: attempt marked terminal, task + state/error/finished_at updated, worker columns cleared, endpoints + deleted, worker resources released. + """ + finished_at_ms = None if t.state in ACTIVE_TASK_STATES or t.state == job_pb2.TASK_STATE_PENDING else t.now_ms + + if t.finalize is not None and t.finalize.attempt_id >= 0: + self._finalize_attempt(cur, t.finalize) + + # Build the UPDATE tasks statement dynamically based on optional counters. + if finished_at_ms is not None: + set_clauses = ["state = ?", "error = ?", "finished_at_ms = COALESCE(finished_at_ms, ?)"] + else: + set_clauses = ["state = ?", "error = ?", "finished_at_ms = ?"] + exit_code = t.finalize.exit_code if t.finalize is not None else None + params: list[object] = [int(t.state), t.error, finished_at_ms] + + if t.failure_count is not None: + set_clauses.append("failure_count = ?") + params.append(t.failure_count) + if t.preemption_count is not None: + set_clauses.append("preemption_count = ?") + params.append(t.preemption_count) + if exit_code is not None: + set_clauses.append("exit_code = COALESCE(?, exit_code)") + params.append(exit_code) + + # Clear worker columns when leaving active state. + if t.state not in ACTIVE_TASK_STATES: + set_clauses.append("current_worker_id = NULL") + set_clauses.append("current_worker_address = NULL") + + params.append(t.task_id) + cur.execute( + f"UPDATE tasks SET {', '.join(set_clauses)} WHERE task_id = ?", + tuple(params), + ) + + self._remove_task_endpoints(cur, t.task_id) + + if t.worker_id is not None and t.resources is not None: + self._decommit_worker_resources(cur, t.worker_id, t.resources) + + def requeue(self, cur: Cursor, r: TaskRetry) -> None: + """Terminate the current attempt but reset the task to PENDING. + + The attempt is marked with the given terminal state, but the task + row reverts to PENDING so the scheduler can create a fresh attempt. + """ + self._finalize_attempt(cur, r.finalize) + + # Reset task to PENDING, clear worker columns, update counters, + # and clear finished_at_ms. + cur.execute( + "UPDATE tasks SET state = ?, error = NULL, finished_at_ms = NULL, " + "failure_count = ?, preemption_count = ?, " + "current_worker_id = NULL, current_worker_address = NULL " + "WHERE task_id = ?", + ( + int(job_pb2.TASK_STATE_PENDING), + r.failure_count, + r.preemption_count, + r.task_id, + ), + ) + + self._remove_task_endpoints(cur, r.task_id) + + if r.worker_id is not None and r.resources is not None: + self._decommit_worker_resources(cur, r.worker_id, r.resources) + + def update_active(self, cur: Cursor, u: ActiveStateUpdate) -> None: + """Non-terminal state update (BUILDING, RUNNING). + + Updates both the attempt and task rows. Does not clear worker + columns since the task remains active. + """ + cur.execute( + "UPDATE task_attempts SET state = ?, started_at_ms = COALESCE(started_at_ms, ?), " + "exit_code = COALESCE(?, exit_code), error = COALESCE(?, error) " + "WHERE task_id = ? AND attempt_id = ?", + ( + int(u.state), + u.started_ms, + u.exit_code, + u.error, + u.task_id, + u.attempt_id, + ), + ) + + cur.execute( + "UPDATE tasks SET state = ?, error = COALESCE(?, error), " + "exit_code = COALESCE(?, exit_code), " + "started_at_ms = COALESCE(started_at_ms, ?), finished_at_ms = ?, " + "failure_count = ?, preemption_count = ? " + "WHERE task_id = ?", + ( + int(u.state), + u.error, + u.exit_code, + u.started_ms, + None, # finished_at_ms — active tasks are not finished + u.failure_count, + u.preemption_count, + u.task_id, + ), + ) + + def assign_to_worker(self, cur: Cursor, a: WorkerAssignment) -> None: + """Create an attempt bound to a worker and mark the task ASSIGNED.""" + cur.execute( + "INSERT INTO task_attempts(task_id, attempt_id, worker_id, state, created_at_ms) VALUES (?, ?, ?, ?, ?)", + (a.task_id, a.attempt_id, a.worker_id, int(job_pb2.TASK_STATE_ASSIGNED), a.now_ms), + ) + cur.execute( + "UPDATE tasks SET state = ?, current_attempt_id = ?, " + "current_worker_id = ?, current_worker_address = ?, " + "started_at_ms = COALESCE(started_at_ms, ?) WHERE task_id = ?", + ( + int(job_pb2.TASK_STATE_ASSIGNED), + a.attempt_id, + a.worker_id, + a.worker_address, + a.now_ms, + a.task_id, + ), + ) + + def assign_direct(self, cur: Cursor, a: DirectAssignment) -> None: + """Create an attempt with no backing worker and mark the task ASSIGNED.""" + cur.execute( + "INSERT INTO task_attempts(task_id, attempt_id, worker_id, state, created_at_ms) VALUES (?, ?, ?, ?, ?)", + (a.task_id, a.attempt_id, None, int(job_pb2.TASK_STATE_ASSIGNED), a.now_ms), + ) + cur.execute( + "UPDATE tasks SET state = ?, current_attempt_id = ?, " + "started_at_ms = COALESCE(started_at_ms, ?) WHERE task_id = ?", + (int(job_pb2.TASK_STATE_ASSIGNED), a.attempt_id, a.now_ms, a.task_id), + ) + + def terminate_coscheduled_siblings( + self, + cur: Cursor, + siblings: list[SiblingSnapshot], + cause_task_id: JobName, + resources: job_pb2.ResourceSpecProto, + now_ms: int, + ) -> KillResult: + """Terminate coscheduled siblings and decommit their resources. + + Each sibling is marked WORKER_FAILED with exhausted preemption count + so it will not be retried. Returns tasks that need kill RPCs. + """ + tasks_to_kill: set[JobName] = set() + task_kill_workers: dict[JobName, WorkerId] = {} + error = f"Coscheduled sibling {cause_task_id.to_wire()} failed" + + for sib in siblings: + self.terminate( + cur, + TaskTermination( + task_id=sib.task_id, + state=job_pb2.TASK_STATE_WORKER_FAILED, + now_ms=now_ms, + error=error, + finalize=AttemptFinalizer.build( + sib.task_id, sib.attempt_id, job_pb2.TASK_STATE_WORKER_FAILED, now_ms, error=error + ), + worker_id=sib.worker_id, + resources=resources if sib.worker_id is not None else None, + preemption_count=sib.max_retries_preemption + 1, + ), + ) + if sib.worker_id is not None: + task_kill_workers[JobName.from_wire(sib.task_id)] = WorkerId(sib.worker_id) + tasks_to_kill.add(JobName.from_wire(sib.task_id)) + + return KillResult( + tasks_to_kill=frozenset(tasks_to_kill), + task_kill_workers=task_kill_workers, + ) + + def insert_resource_usage( + self, + cur: Cursor, + task_id: str, + attempt_id: int, + usage: job_pb2.ResourceUsage, + now_ms: int, + ) -> None: + """Write a single task_resource_history row.""" + cur.execute( + "INSERT INTO task_resource_history" + "(task_id, attempt_id, cpu_millicores, memory_mb, disk_mb, memory_peak_mb, timestamp_ms) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + ( + task_id, + attempt_id, + usage.cpu_millicores, + usage.memory_mb, + usage.disk_mb, + usage.memory_peak_mb, + now_ms, + ), + ) + + def insert_resource_usage_batch(self, cur: Cursor, params: list[tuple]) -> None: + """Batch insert task_resource_history via executemany for steady-state updates.""" + cur.executemany( + "INSERT INTO task_resource_history" + "(task_id, attempt_id, cpu_millicores, memory_mb, disk_mb, memory_peak_mb, timestamp_ms) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + params, + ) + + # ── Internal helpers ───────────────────────────────────────────── + + def _finalize_attempt(self, cur: Cursor, fin: AttemptFinalizer) -> None: + """Write terminal state to a task_attempts row.""" + cur.execute( + "UPDATE task_attempts SET state = ?, " + "finished_at_ms = COALESCE(finished_at_ms, ?), error = ?, " + "exit_code = COALESCE(?, exit_code) " + "WHERE task_id = ? AND attempt_id = ?", + (int(fin.attempt_state), fin.now_ms, fin.error, fin.exit_code, fin.task_id, fin.attempt_id), + ) + + def _remove_task_endpoints(self, cur: Cursor, task_id: str) -> None: + """Remove all registered endpoints for a task.""" + self._endpoints.remove_by_task(cur, JobName.from_wire(task_id)) + + def _decommit_worker_resources(self, cur: Cursor, worker_id: str, resources: job_pb2.ResourceSpecProto) -> None: + """Subtract a task's resource reservation from a worker, flooring at zero.""" + cur.execute( + "UPDATE workers SET committed_cpu_millicores = MAX(0, committed_cpu_millicores - ?), " + "committed_mem_bytes = MAX(0, committed_mem_bytes - ?), " + "committed_gpu = MAX(0, committed_gpu - ?), committed_tpu = MAX(0, committed_tpu - ?) " + "WHERE worker_id = ?", + ( + int(resources.cpu_millicores), + int(resources.memory_bytes), + int(get_gpu_count(resources.device)), + int(get_tpu_count(resources.device)), + worker_id, + ), + ) + + # ── Extended reads (migrated from transitions.py inline SQL) ───── + + def get_for_assignment(self, cur: Cursor, task_id: str) -> sqlite3.Row | None: + """Fetch task detail row for assignment validation.""" + return cur.execute( + f"SELECT {TASK_DETAIL_SELECT_T} FROM tasks t WHERE t.task_id = ?", + (task_id,), + ).fetchone() + + def query( + self, cur: Cursor, flt: TaskFilter, *, projection: TaskProjection = TaskProjection.DETAIL + ) -> list[TaskDetailRow]: + """Return rows matching ``flt`` under the requested projection. + + Builds SQL from ``flt`` by AND-ing every set field. Large + ``task_ids`` / ``job_ids`` lists are chunked under SQLite's + host-parameter cap (~999); chunk results are concatenated preserving + ORDER BY task_id ASC within each chunk. + + ``projection=DETAIL`` (default) returns plain task columns. + ``WITH_JOB`` joins ``jobs`` and populates ``is_reservation_holder``/``num_tasks``. + ``WITH_JOB_CONFIG`` additionally joins ``job_config`` and populates resource/timeout fields. + """ + return self._run_query(cur, flt, projection=projection) + + def _run_query(self, cur: Cursor, flt: TaskFilter, *, projection: TaskProjection) -> list: + # Short-circuit empty IN-lists: SQLite forbids the empty literal. + if flt.task_ids is not None and not flt.task_ids: + return [] + if flt.job_ids is not None and not flt.job_ids: + return [] + + task_chunks = chunk_ids(flt.task_ids) + job_chunks = chunk_ids(flt.job_ids) + if task_chunks is None and job_chunks is None: + return self._query_chunk( + cur, flt, chunk_task_ids=None, chunk_job_ids=None, projection=projection, limit=flt.limit + ) + + # Pair each chunk with None for the other field; task_ids takes precedence + # when both are set (callers pre-validate mutual exclusion). + id_chunks: list[tuple[tuple[str, ...] | None, tuple[str, ...] | None]] + if task_chunks is not None: + id_chunks = [(c, None) for c in task_chunks] + else: + assert job_chunks is not None + id_chunks = [(None, c) for c in job_chunks] + + results: list = [] + remaining_limit = flt.limit + for chunk_task_ids, chunk_job_ids in id_chunks: + if remaining_limit is not None and remaining_limit <= 0: + break + chunk_rows = self._query_chunk( + cur, + flt, + chunk_task_ids=chunk_task_ids, + chunk_job_ids=chunk_job_ids, + projection=projection, + limit=remaining_limit, + ) + results.extend(chunk_rows) + if remaining_limit is not None: + remaining_limit -= len(chunk_rows) + return results + + def _query_chunk( + self, + cur: Cursor, + flt: TaskFilter, + *, + chunk_task_ids: tuple[str, ...] | None, + chunk_job_ids: tuple[str, ...] | None, + projection: TaskProjection, + limit: int | None, + ) -> list: + if projection == TaskProjection.DETAIL: + sql_parts = [f"SELECT {TASK_DETAIL_SELECT_T} FROM tasks t"] + elif projection == TaskProjection.WITH_JOB: + sql_parts = [ + "SELECT t.task_id, t.job_id, t.state, t.current_attempt_id, " + "t.failure_count, t.preemption_count, t.max_retries_failure, t.max_retries_preemption, " + "t.submitted_at_ms, t.priority_band, t.error, t.exit_code, " + "t.started_at_ms, t.finished_at_ms, t.current_worker_id, t.current_worker_address, " + "t.container_id, j.is_reservation_holder, j.num_tasks " + "FROM tasks t" + ] + sql_parts.append("JOIN jobs j ON j.job_id = t.job_id") + else: # with_job_config + sql_parts = [ + "SELECT t.task_id, t.job_id, t.state, t.current_attempt_id, " + "t.failure_count, t.preemption_count, t.max_retries_failure, t.max_retries_preemption, " + "t.submitted_at_ms, t.priority_band, t.error, t.exit_code, " + "t.started_at_ms, t.finished_at_ms, t.current_worker_id, t.current_worker_address, " + "t.container_id, j.is_reservation_holder, j.num_tasks, " + "jc.res_cpu_millicores, jc.res_memory_bytes, jc.res_disk_bytes, jc.res_device_json, " + "jc.has_coscheduling, jc.timeout_ms " + "FROM tasks t" + ] + sql_parts.append("JOIN jobs j ON j.job_id = t.job_id") + sql_parts.append(JOB_CONFIG_JOIN) + + wb = WhereBuilder() + wb.in_("t.task_id", chunk_task_ids) + wb.in_("t.job_id", chunk_job_ids) + if flt.worker_id is not None: + wb.eq("t.current_worker_id", str(flt.worker_id)) + if flt.worker_is_null: + wb.is_null("t.current_worker_id") + if flt.states is not None: + wb.in_("t.state", tuple(sorted(flt.states))) + + where_sql, where_params = wb.build() + params: list[object] = list(where_params) + if where_sql: + sql_parts.append(where_sql) + sql_parts.append("ORDER BY t.task_id ASC") + if limit is not None: + sql_parts.append("LIMIT ?") + params.append(limit) + + rows = cur.execute(" ".join(sql_parts), tuple(params)).fetchall() + return TASK_DETAIL_PROJECTION.decode(rows) + + def update_container_id(self, cur: Cursor, task_id: str, container_id: str) -> None: + """Set container_id on a task row.""" + cur.execute( + "UPDATE tasks SET container_id = ? WHERE task_id = ?", + (container_id, task_id), + ) + + def get_job_id(self, cur: Cursor, task_id: str) -> str | None: + """Read job_id for a task. Returns None if the task does not exist.""" + row = cur.execute("SELECT job_id FROM tasks WHERE task_id = ?", (task_id,)).fetchone() + if row is None: + return None + return str(row["job_id"]) + + def get_attempt_worker(self, cur: Cursor, task_id: str, attempt_id: int) -> str | None: + """Worker_id from a specific attempt. Returns None if missing.""" + row = cur.execute( + "SELECT worker_id FROM task_attempts WHERE task_id = ? AND attempt_id = ?", + (task_id, attempt_id), + ).fetchone() + if row is None or row["worker_id"] is None: + return None + return str(row["worker_id"]) + + def insert_task(self, cur: Cursor, req: TaskInsert) -> None: + """Insert a single task row with all priority columns.""" + cur.execute( + "INSERT INTO tasks(" + "task_id, job_id, task_index, state, error, exit_code, submitted_at_ms, started_at_ms, " + "finished_at_ms, max_retries_failure, max_retries_preemption, failure_count, preemption_count, " + "current_attempt_id, priority_neg_depth, priority_root_submitted_ms, " + "priority_insertion, priority_band" + ") VALUES (?, ?, ?, ?, NULL, NULL, ?, NULL, NULL, ?, ?, 0, 0, -1, ?, ?, ?, ?)", + ( + req.task_id, + req.job_id, + req.task_index, + req.state, + req.submitted_at_ms, + req.max_retries_failure, + req.max_retries_preemption, + req.priority_neg_depth, + req.priority_root_submitted_ms, + req.priority_insertion, + req.priority_band, + ), + ) + + def delete_attempt(self, cur: Cursor, task_id: str, attempt_id: int) -> None: + """Delete a task attempt row (used for reservation holder reset).""" + cur.execute( + "DELETE FROM task_attempts WHERE task_id = ? AND attempt_id = ?", + (task_id, attempt_id), + ) + + def reset_reservation_holder(self, cur: Cursor, task_id: str, state: int) -> None: + """Reset a reservation holder task to pristine PENDING state.""" + cur.execute( + "UPDATE tasks SET state = ?, current_attempt_id = -1, started_at_ms = NULL, " + "finished_at_ms = NULL, error = NULL, preemption_count = 0, " + "current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", + (state, task_id), + ) + + def bulk_cancel(self, cur: Cursor, job_ids: list[str], reason: str, now_ms: int) -> None: + """Bulk UPDATE tasks to KILLED across multiple job IDs. + + Skips tasks already in terminal states. Clears worker columns. + """ + placeholders = sql_placeholders(len(job_ids)) + task_terminal_placeholders = sql_placeholders(len(TERMINAL_TASK_STATES)) + cur.execute( + f"UPDATE tasks SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?), " + f"current_worker_id = NULL, current_worker_address = NULL " + f"WHERE job_id IN ({placeholders}) AND state NOT IN ({task_terminal_placeholders})", + ( + job_pb2.TASK_STATE_KILLED, + reason, + now_ms, + *job_ids, + *TERMINAL_TASK_STATES, + ), + ) + + def get_pending_for_direct_provider(self, cur: Cursor, limit: int) -> list: + """Pending tasks for direct provider promotion (non-reservation-holder only).""" + return cur.execute( + "SELECT t.task_id, t.job_id, t.current_attempt_id, j.num_tasks, j.is_reservation_holder, " + "jc.res_cpu_millicores, jc.res_memory_bytes, jc.res_disk_bytes, jc.res_device_json, " + "jc.entrypoint_json, jc.environment_json, jc.bundle_id, jc.ports_json, " + "jc.constraints_json, jc.task_image, jc.timeout_ms " + f"FROM tasks t JOIN jobs j ON j.job_id = t.job_id {JOB_CONFIG_JOIN} " + "WHERE t.state = ? AND j.is_reservation_holder = 0 " + "LIMIT ?", + (job_pb2.TASK_STATE_PENDING, limit), + ).fetchall() + + def prune_task_resource_history(self, cur: Cursor, retention: int) -> int: + """Logarithmic downsampling: when a (task, attempt) exceeds 2*retention rows, + thin the older half by deleting every other row. + + Over repeated compaction cycles older data becomes exponentially sparser, + preserving long-term trends while bounding total row count. + """ + threshold = retention * 2 + overflows = cur.execute( + "SELECT task_id, attempt_id, COUNT(*) as cnt " + "FROM task_resource_history " + "GROUP BY task_id, attempt_id HAVING cnt > ?", + (threshold,), + ).fetchall() + ids_to_delete: list[int] = [] + for row in overflows: + tid, aid = row["task_id"], row["attempt_id"] + all_ids = [ + r["id"] + for r in cur.execute( + "SELECT id FROM task_resource_history WHERE task_id = ? AND attempt_id = ? ORDER BY id ASC", + (tid, aid), + ).fetchall() + ] + older = all_ids[: len(all_ids) - retention] + ids_to_delete.extend(older[1::2]) + + total_deleted = 0 + for chunk_start in range(0, len(ids_to_delete), 900): + chunk = ids_to_delete[chunk_start : chunk_start + 900] + ph = sql_placeholders(len(chunk)) + cur.execute(f"DELETE FROM task_resource_history WHERE id IN ({ph})", tuple(chunk)) + total_deleted += cur.rowcount + if total_deleted > 0: + logger.info("Pruned %d task_resource_history rows (log downsampling)", total_deleted) + return total_deleted + + +# --------------------------------------------------------------------------- +# Pure job-state derivation +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class JobContext: + """Read-only snapshot of a job row for state derivation.""" + + state: int + started_at_ms: int | None + max_task_failures: int + + +def derive_job_state( + current: int, + counts: dict[int, int], + max_task_failures: int, + has_started: bool, +) -> int: + """Pure function: derive job state from task state counts. + + Priority ordering: + SUCCEEDED (all tasks) > FAILED (exceeded budget) > UNSCHEDULABLE > KILLED > + WORKER_FAILED/PREEMPTED (all terminal) > RUNNING > PENDING. + + Returns the current state unchanged when no transition is warranted. + """ + total = sum(counts.values()) + + if total > 0 and counts.get(job_pb2.TASK_STATE_SUCCEEDED, 0) == total: + return job_pb2.JOB_STATE_SUCCEEDED + if counts.get(job_pb2.TASK_STATE_FAILED, 0) > max_task_failures: + return job_pb2.JOB_STATE_FAILED + if counts.get(job_pb2.TASK_STATE_UNSCHEDULABLE, 0) > 0: + return job_pb2.JOB_STATE_UNSCHEDULABLE + if counts.get(job_pb2.TASK_STATE_KILLED, 0) > 0: + return job_pb2.JOB_STATE_KILLED + if ( + total > 0 + and (counts.get(job_pb2.TASK_STATE_WORKER_FAILED, 0) + counts.get(job_pb2.TASK_STATE_PREEMPTED, 0)) > 0 + and all(s in TERMINAL_TASK_STATES for s in counts) + ): + return job_pb2.JOB_STATE_WORKER_FAILED + if ( + counts.get(job_pb2.TASK_STATE_ASSIGNED, 0) > 0 + or counts.get(job_pb2.TASK_STATE_BUILDING, 0) > 0 + or counts.get(job_pb2.TASK_STATE_RUNNING, 0) > 0 + ): + return job_pb2.JOB_STATE_RUNNING + if has_started: + # Retries put tasks back into PENDING; keep job running once it has started. + return job_pb2.JOB_STATE_RUNNING + if total > 0: + return job_pb2.JOB_STATE_PENDING + + return current + + +# --------------------------------------------------------------------------- +# JobStore +# --------------------------------------------------------------------------- + + +class JobStore: + """Typed read/write operations for job entities. + + Process-scoped: a single instance lives on the ``ControllerDB``. Every + method takes the open ``Cursor`` as its first argument. + + Owns a process-scoped cache of ``job_config`` rows. The cache is + populated on ``insert_job_config`` and invalidated on ``delete_job`` + via cursor post-commit hooks — memory only diverges from disk on + successful commit. The cache lets hot scheduling paths read + resource/coscheduling config without re-hitting SQLite. + """ + + def __init__(self, endpoints: EndpointStore) -> None: + self._endpoints = endpoints + self._job_config_cache: dict[str, JobConfigRow] = {} + self._job_config_lock = Lock() + + # ── Reads ──────────────────────────────────────────────────────── + + def get_config(self, cur: Cursor, job_id_wire: str) -> JobConfigRow | None: + """Fetch a job_config row by job_id, caching process-wide. + + Cache is populated here on hit, on ``insert_job_config``, and invalidated + on ``delete_job`` (FK cascade deletes the config row). Misses are not + cached so a later insert is observed immediately. + """ + with self._job_config_lock: + cached = self._job_config_cache.get(job_id_wire) + if cached is not None: + return cached + row = cur.execute( + f"SELECT {JOB_CONFIG_PROJECTION.select_clause(prefix=False)} FROM job_config WHERE job_id = ?", + (job_id_wire,), + ).fetchone() + if row is None: + return None + value = JOB_CONFIG_PROJECTION.decode_one([row]) + assert value is not None + with self._job_config_lock: + self._job_config_cache.setdefault(job_id_wire, value) + return self._job_config_cache[job_id_wire] + + def get_state(self, cur: Cursor, job_id: JobName) -> int | None: + """Read current job state.""" + row = cur.execute( + "SELECT state FROM jobs WHERE job_id = ?", + (job_id.to_wire(),), + ).fetchone() + if row is None: + return None + return int(row["state"]) + + def get_preemption_policy(self, cur: Cursor, job_id: JobName) -> int: + """Resolve the effective preemption policy for a job. + + Defaults: single-task jobs use TERMINATE_CHILDREN, multi-task use + PRESERVE_CHILDREN. + """ + row = cur.execute( + f"SELECT jc.preemption_policy, j.num_tasks FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", + (job_id.to_wire(),), + ).fetchone() + if row is None: + return job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN + policy = int(row["preemption_policy"]) + if policy != job_pb2.JOB_PREEMPTION_POLICY_UNSPECIFIED: + return policy + if int(row["num_tasks"]) <= 1: + return job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN + return job_pb2.JOB_PREEMPTION_POLICY_PRESERVE_CHILDREN + + # ── Writes ─────────────────────────────────────────────────────── + + def update_state( + self, + cur: Cursor, + job_id: JobName, + state: JobState, + now_ms: int, + error: str | None = None, + ) -> None: + """Direct job state update with COALESCE patterns for timestamps and error.""" + terminal_placeholders = sql_placeholders(len(TERMINAL_JOB_STATES)) + cur.execute( + "UPDATE jobs SET state = ?, " + "started_at_ms = CASE WHEN ? = ? THEN COALESCE(started_at_ms, ?) ELSE started_at_ms END, " + f"finished_at_ms = CASE WHEN ? IN ({terminal_placeholders}) THEN ? ELSE finished_at_ms END, " + "error = CASE WHEN ? IN (?, ?, ?, ?) THEN ? ELSE error END " + "WHERE job_id = ?", + ( + state, + state, + job_pb2.JOB_STATE_RUNNING, + now_ms, + state, + *TERMINAL_JOB_STATES, + now_ms, + state, + job_pb2.JOB_STATE_FAILED, + job_pb2.JOB_STATE_KILLED, + job_pb2.JOB_STATE_UNSCHEDULABLE, + job_pb2.JOB_STATE_WORKER_FAILED, + error, + job_id.to_wire(), + ), + ) + + def get_job_context(self, cur: Cursor, job_id: JobName) -> JobContext | None: + """Read-only: fetch current state, started_at_ms, and max_task_failures. + + Returns None if the job doesn't exist. + """ + row = cur.execute( + f"SELECT j.state, j.started_at_ms, jc.max_task_failures " + f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", + (job_id.to_wire(),), + ).fetchone() + if row is None: + return None + return JobContext( + state=int(row["state"]), + started_at_ms=int(row["started_at_ms"]) if row["started_at_ms"] is not None else None, + max_task_failures=int(row["max_task_failures"]), + ) + + def get_task_state_counts(self, cur: Cursor, job_id: JobName) -> dict[int, int]: + """Read-only: GROUP BY state count query for a job's tasks.""" + rows = cur.execute( + "SELECT state, COUNT(*) AS c FROM tasks WHERE job_id = ? GROUP BY state", + (job_id.to_wire(),), + ).fetchall() + return {int(r["state"]): int(r["c"]) for r in rows} + + def get_first_task_error(self, cur: Cursor, job_id: JobName) -> str | None: + """Read-only: fetch the error from the first failing task by task_index.""" + row = cur.execute( + "SELECT error FROM tasks WHERE job_id = ? AND error IS NOT NULL ORDER BY task_index LIMIT 1", + (job_id.to_wire(),), + ).fetchone() + if row is None: + return None + return str(row["error"]) + + def recompute_state(self, cur: Cursor, job_id: JobName) -> int | None: + """Derive job state from task state counts and update the row. + + Uses the pure derive_job_state function for the decision logic, + then writes the result if the state changed. + """ + ctx = self.get_job_context(cur, job_id) + if ctx is None: + return None + if ctx.state in TERMINAL_JOB_STATES: + return ctx.state + + counts = self.get_task_state_counts(cur, job_id) + new_state = derive_job_state( + current=ctx.state, + counts=counts, + max_task_failures=ctx.max_task_failures, + has_started=ctx.started_at_ms is not None, + ) + + if new_state == ctx.state: + return new_state + + error = self.get_first_task_error(cur, job_id) + now_ms = Timestamp.now().epoch_ms() + self.update_state(cur, job_id, new_state, now_ms, error) + return new_state + + def kill_non_terminal_tasks( + self, + cur: Cursor, + tasks: TaskStore, + job_id: str, + reason: str, + now_ms: int, + ) -> KillResult: + """Kill all non-terminal tasks for a job, decommit resources, and delete endpoints.""" + terminal_states = tuple(sorted(TERMINAL_TASK_STATES)) + placeholders = sql_placeholders(len(terminal_states)) + rows = cur.execute( + "SELECT t.task_id, t.current_attempt_id, t.current_worker_id, " + "jc.res_cpu_millicores, jc.res_memory_bytes, jc.res_disk_bytes, jc.res_device_json " + "FROM tasks t " + "JOIN jobs j ON j.job_id = t.job_id " + f"{JOB_CONFIG_JOIN} " + f"WHERE t.job_id = ? AND t.state NOT IN ({placeholders})", + (job_id, *terminal_states), + ).fetchall() + + tasks_to_kill: set[JobName] = set() + task_kill_workers: dict[JobName, WorkerId] = {} + + for row in rows: + task_id = str(row["task_id"]) + worker_id = row["current_worker_id"] + task_name = JobName.from_wire(task_id) + resources = None + if worker_id is not None: + resources = resource_spec_from_scalars( + int(row["res_cpu_millicores"]), + int(row["res_memory_bytes"]), + int(row["res_disk_bytes"]), + row["res_device_json"], + ) + task_kill_workers[task_name] = WorkerId(str(worker_id)) + attempt_id = int(row["current_attempt_id"]) + tasks.terminate( + cur, + TaskTermination( + task_id=task_id, + state=job_pb2.TASK_STATE_KILLED, + now_ms=now_ms, + error=reason, + finalize=( + AttemptFinalizer.build(task_id, attempt_id, job_pb2.TASK_STATE_KILLED, now_ms, error=reason) + if attempt_id >= 0 + else None + ), + worker_id=str(worker_id) if worker_id is not None else None, + resources=resources, + ), + ) + tasks_to_kill.add(task_name) + + return KillResult(tasks_to_kill=frozenset(tasks_to_kill), task_kill_workers=task_kill_workers) + + def cascade_children( + self, + cur: Cursor, + tasks: TaskStore, + job_id: JobName, + reason: str, + now_ms: int, + *, + exclude_reservation_holders: bool = False, + ) -> KillResult: + """Kill descendant jobs (not the job itself) when a parent reaches terminal state. + + When exclude_reservation_holders is True, reservation holder jobs and their + descendants are left alive. Used during preemption retry so the parent's + reservation survives for re-scheduling. + """ + tasks_to_kill: set[JobName] = set() + task_kill_workers: dict[JobName, WorkerId] = {} + + if exclude_reservation_holders: + descendants = cur.execute( + "WITH RECURSIVE subtree(job_id) AS (" + " SELECT job_id FROM jobs WHERE parent_job_id = ? AND is_reservation_holder = 0 " + " UNION ALL " + " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" + " WHERE j.is_reservation_holder = 0" + ") SELECT job_id FROM subtree", + (job_id.to_wire(),), + ).fetchall() + else: + descendants = cur.execute( + "WITH RECURSIVE subtree(job_id) AS (" + " SELECT job_id FROM jobs WHERE parent_job_id = ? " + " UNION ALL " + " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" + ") SELECT job_id FROM subtree", + (job_id.to_wire(),), + ).fetchall() + + for child_row in descendants: + child_job_id = str(child_row["job_id"]) + child_result = self.kill_non_terminal_tasks(cur, tasks, child_job_id, reason, now_ms) + tasks_to_kill.update(child_result.tasks_to_kill) + task_kill_workers.update(child_result.task_kill_workers) + + terminal_placeholders = sql_placeholders(len(TERMINAL_JOB_STATES)) + cur.execute( + "UPDATE jobs SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " + f"WHERE job_id = ? AND state NOT IN ({terminal_placeholders})", + ( + job_pb2.JOB_STATE_KILLED, + reason, + now_ms, + child_job_id, + *TERMINAL_JOB_STATES, + ), + ) + + return KillResult(tasks_to_kill=frozenset(tasks_to_kill), task_kill_workers=task_kill_workers) + + # ── Job submission and lifecycle ──────────────────────────────── + + def insert_job(self, cur: Cursor, job: JobInsert) -> None: + """Insert a row into the jobs table.""" + cur.execute( + "INSERT INTO jobs(" + "job_id, user_id, parent_job_id, root_job_id, depth, state, submitted_at_ms, " + "root_submitted_at_ms, started_at_ms, finished_at_ms, scheduling_deadline_epoch_ms, " + "error, exit_code, num_tasks, is_reservation_holder, name, has_reservation" + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, NULL, ?, ?, ?, NULL, ?, ?, ?, ?)", + ( + job.job_id, + job.user_id, + job.parent_job_id, + job.root_job_id, + job.depth, + job.state, + job.submitted_at_ms, + job.root_submitted_at_ms, + job.finished_at_ms, + job.scheduling_deadline_epoch_ms, + job.error, + job.num_tasks, + 1 if job.is_reservation_holder else 0, + job.name, + 1 if job.has_reservation else 0, + ), + ) + + def insert_job_config(self, cur: Cursor, cfg: JobConfigInsert) -> None: + """Insert a row into the job_config table and cache it on commit.""" + cur.execute( + "INSERT INTO job_config(" + "job_id, name, has_reservation, " + "res_cpu_millicores, res_memory_bytes, res_disk_bytes, res_device_json, " + "constraints_json, has_coscheduling, coscheduling_group_by, " + "scheduling_timeout_ms, max_task_failures, " + "entrypoint_json, environment_json, bundle_id, ports_json, " + "max_retries_failure, max_retries_preemption, timeout_ms, " + "preemption_policy, existing_job_policy, priority_band, " + "task_image, submit_argv_json, reservation_json, fail_if_exists" + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + cfg.job_id, + cfg.name, + 1 if cfg.has_reservation else 0, + cfg.resources.cpu_millicores, + cfg.resources.memory_bytes, + cfg.resources.disk_bytes, + cfg.resources.device_json, + cfg.constraints_json, + cfg.has_coscheduling, + cfg.coscheduling_group_by, + cfg.scheduling_timeout_ms, + cfg.max_task_failures, + cfg.entrypoint_json, + cfg.environment_json, + cfg.bundle_id, + cfg.ports_json, + cfg.max_retries_failure, + cfg.max_retries_preemption, + cfg.timeout_ms, + cfg.preemption_policy, + cfg.existing_job_policy, + cfg.priority_band, + cfg.task_image, + cfg.submit_argv_json, + cfg.reservation_json, + cfg.fail_if_exists, + ), + ) + # Populate the process-scoped cache on commit, constructing the typed + # row directly from the insert payload so we do not need a re-read. + cached = JobConfigRow( + job_id=JobName.from_wire(cfg.job_id), + name=cfg.name, + has_reservation=bool(cfg.has_reservation), + resources=cfg.resources, + constraints_json=cfg.constraints_json, + has_coscheduling=bool(cfg.has_coscheduling), + coscheduling_group_by=cfg.coscheduling_group_by, + scheduling_timeout_ms=cfg.scheduling_timeout_ms, + max_task_failures=cfg.max_task_failures, + entrypoint_json=cfg.entrypoint_json, + environment_json=cfg.environment_json, + bundle_id=cfg.bundle_id, + ports_json=cfg.ports_json, + max_retries_failure=cfg.max_retries_failure, + max_retries_preemption=cfg.max_retries_preemption, + timeout_ms=cfg.timeout_ms, + preemption_policy=cfg.preemption_policy, + existing_job_policy=cfg.existing_job_policy, + priority_band=cfg.priority_band, + task_image=cfg.task_image, + submit_argv_json=cfg.submit_argv_json, + reservation_json=cfg.reservation_json, + fail_if_exists=bool(cfg.fail_if_exists), + ) + + def apply() -> None: + with self._job_config_lock: + self._job_config_cache[cfg.job_id] = cached + + cur.on_commit(apply) + + def insert_workdir_files(self, cur: Cursor, job_id: str, files: list[tuple[str, bytes]]) -> None: + """Insert workdir file entries for a job.""" + for filename, data in files: + cur.execute( + "INSERT INTO job_workdir_files(job_id, filename, data) VALUES (?, ?, ?)", + (job_id, filename, data), + ) + + def get_workdir_files(self, cur: Cursor, job_id: str) -> dict[str, bytes]: + """Fetch workdir files for a job, keyed by filename.""" + rows = cur.execute( + "SELECT filename, data FROM job_workdir_files WHERE job_id = ?", + (job_id,), + ).fetchall() + return {str(row["filename"]): bytes(row["data"]) for row in rows} + + def exists(self, cur: Cursor, job_id: str) -> bool: + """Check whether a job row exists.""" + row = cur.execute("SELECT 1 FROM jobs WHERE job_id = ?", (job_id,)).fetchone() + return row is not None + + def get_root_submitted_ms(self, cur: Cursor, parent_job_id: str) -> int | None: + """Read root_submitted_at_ms for a parent job. Returns None if not found.""" + row = cur.execute( + "SELECT root_submitted_at_ms FROM jobs WHERE job_id = ?", + (parent_job_id,), + ).fetchone() + if row is None: + return None + return int(row["root_submitted_at_ms"]) + + def get_parent_band(self, cur: Cursor, parent_job_id: str) -> int | None: + """Read priority_band from the parent's first task. Returns None if not found.""" + row = cur.execute( + "SELECT priority_band FROM tasks WHERE job_id = ? LIMIT 1", + (parent_job_id,), + ).fetchone() + if row is None: + return None + return int(row["priority_band"]) + + def get_subtree_ids(self, cur: Cursor, job_id: str) -> list[str]: + """Recursive CTE returning all job IDs in the subtree rooted at job_id (inclusive).""" + rows = cur.execute( + "WITH RECURSIVE subtree(job_id) AS (" + " SELECT job_id FROM jobs WHERE job_id = ? " + " UNION ALL " + " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" + ") SELECT job_id FROM subtree", + (job_id,), + ).fetchall() + return [str(row["job_id"]) for row in rows] + + def bulk_cancel(self, cur: Cursor, job_ids: list[str], reason: str, now_ms: int) -> None: + """Bulk UPDATE jobs to KILLED, skipping already-terminal jobs. + + Deliberately excludes JOB_STATE_WORKER_FAILED from the guard set so + worker-failed jobs can still be cancelled. + """ + if not job_ids: + return + placeholders = sql_placeholders(len(job_ids)) + cancel_guard_states = TERMINAL_JOB_STATES - {job_pb2.JOB_STATE_WORKER_FAILED} + guard_placeholders = sql_placeholders(len(cancel_guard_states)) + cur.execute( + f"UPDATE jobs SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " + f"WHERE job_id IN ({placeholders}) AND state NOT IN ({guard_placeholders})", + ( + job_pb2.JOB_STATE_KILLED, + reason, + now_ms, + *job_ids, + *cancel_guard_states, + ), + ) + + def start_if_pending(self, cur: Cursor, job_id: str, now_ms: int) -> None: + """Transition a job from PENDING to RUNNING. No-op if already started.""" + cur.execute( + "UPDATE jobs SET state = CASE WHEN state = ? THEN ? ELSE state END, " + "started_at_ms = COALESCE(started_at_ms, ?) WHERE job_id = ?", + (job_pb2.JOB_STATE_PENDING, job_pb2.JOB_STATE_RUNNING, now_ms, job_id), + ) + + def get_job_detail(self, cur: Cursor, job_id: str) -> JobDetailRow | None: + """Fetch full job detail with config join for scheduling/dispatch.""" + row = cur.execute( + f"SELECT {JOB_DETAIL_PROJECTION.select_clause()} " f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", + (job_id,), + ).fetchone() + if row is None: + return None + return JOB_DETAIL_PROJECTION.decode_one([row]) + + @overload + def query(self, cur: Cursor, flt: JobDetailFilter, *, detail: Literal[False] = ...) -> list[JobRow]: ... + + @overload + def query(self, cur: Cursor, flt: JobDetailFilter, *, detail: Literal[True]) -> list[JobDetailRow]: ... + + def query( + self, + cur: Cursor, + flt: JobDetailFilter, + *, + detail: bool = False, + ) -> list[JobRow] | list[JobDetailRow]: + """Query jobs matching ``flt``. + + Executes inside the caller's cursor (read snapshot or write + transaction). ``detail=True`` selects the full :class:`JobDetailRow` + projection (with config join); ``False`` selects the lightweight + :class:`JobRow` projection. + + Large ``job_ids`` lists are chunked to respect SQLite's host-parameter cap. + """ + if flt.job_ids is not None and not flt.job_ids: + return [] + projection = JOB_DETAIL_PROJECTION if detail else JOB_ROW_PROJECTION + + def fetch(sql: str, params: tuple[object, ...]) -> list[tuple]: + return cur.execute(sql, params).fetchall() + + chunks = chunk_ids(flt.job_ids) + return run_chunked( # type: ignore[return-value] + chunks, + flt.limit, + lambda chunk, limit: self._query_job_chunk(fetch, flt, projection, chunk_job_ids=chunk, limit=limit), + ) + + def _where_for(self, flt: JobDetailFilter, chunk_job_ids: tuple[str, ...] | None) -> tuple[str, list[object]]: + """Build the WHERE fragment and params for a JobDetailFilter. + + Returns ``("", [])`` when no predicates are set (no WHERE clause). + """ + wb = WhereBuilder() + wb.in_("j.job_id", chunk_job_ids) + if flt.states is not None: + wb.in_("j.state", tuple(sorted(flt.states))) + if flt.has_reservation is not None: + wb.eq("j.has_reservation", 1 if flt.has_reservation else 0) + where_sql, where_params = wb.build() + return where_sql, list(where_params) + + def _query_job_chunk( + self, + fetch: Callable[[str, tuple[object, ...]], list[tuple]], + flt: JobDetailFilter, + projection: Any, + *, + chunk_job_ids: tuple[str, ...] | None, + limit: int | None = None, + ) -> list: + effective_limit = limit if limit is not None else flt.limit + sql_parts = [f"SELECT {projection.select_clause()} FROM jobs j {JOB_CONFIG_JOIN}"] + where_clause, params = self._where_for(flt, chunk_job_ids) + if where_clause: + sql_parts.append(where_clause) + if effective_limit is not None: + sql_parts.append("LIMIT ?") + params.append(effective_limit) + rows = fetch(" ".join(sql_parts), tuple(params)) + return projection.decode(rows) + + def delete_job(self, cur: Cursor, job_id: str) -> None: + """DELETE FROM jobs WHERE job_id = ?. Cascades to tasks, attempts, endpoints, job_config. + + The DELETE cascades via FK to job_config; the in-memory cache entry + is popped on commit. + """ + cur.execute("DELETE FROM jobs WHERE job_id = ?", (job_id,)) + + def apply() -> None: + with self._job_config_lock: + self._job_config_cache.pop(job_id, None) + + cur.on_commit(apply) + + def get_reservation_holder_ids(self, cur: Cursor, job_ids: set[str]) -> set[str]: + """Filter a set of job IDs to those that are reservation holders.""" + if not job_ids: + return set() + placeholders = sql_placeholders(len(job_ids)) + rows = cur.execute( + f"SELECT job_id FROM jobs WHERE job_id IN ({placeholders}) AND is_reservation_holder = 1", + tuple(job_ids), + ).fetchall() + return {str(r["job_id"]) for r in rows} + + def get_finished_jobs_before(self, cur: Cursor, cutoff_ms: int) -> list[str]: + """Return job_ids of terminal jobs finished before the cutoff, one at a time.""" + terminal_states = tuple(TERMINAL_JOB_STATES) + placeholders = sql_placeholders(len(terminal_states)) + row = cur.execute( + f"SELECT job_id FROM jobs WHERE state IN ({placeholders})" + " AND finished_at_ms IS NOT NULL AND finished_at_ms < ? LIMIT 1", + (*terminal_states, cutoff_ms), + ).fetchone() + if row is None: + return [] + return [str(row["job_id"])] + + # --------------------------------------------------------------------------- + + +# WorkerStore +# --------------------------------------------------------------------------- + + +class WorkerStore: + """Typed read/write operations for worker entities. + + Process-scoped: a single instance lives on the ``ControllerDB``. Every + write takes the open ``Cursor`` as its first argument. Owns + the lazy worker-attributes cache used by the scheduling hot path. + """ + + def __init__(self, endpoints: EndpointStore, dispatch: DispatchStore): + self._endpoints = endpoints + self._dispatch = dispatch + + # ── Reads ──────────────────────────────────────────────────────── + + def healthy_active_with_attributes(self, cur: Cursor) -> list[WorkerRow]: + """Fetch all healthy, active workers with their attributes and available resources. + + Both the worker rows and their attributes are read through the caller's + cursor so the result is coherent with any outer snapshot/transaction. + """ + workers = WORKER_ROW_PROJECTION.decode( + cur.execute( + f"SELECT {WORKER_ROW_PROJECTION.select_clause()} " "FROM workers w WHERE w.healthy = 1 AND w.active = 1" + ).fetchall(), + ) + if not workers: + return [] + worker_ids = tuple(str(w.worker_id) for w in workers) + placeholders = sql_placeholders(len(worker_ids)) + attr_rows = decoded_rows( + cur, + f"SELECT worker_id, key, value_type, str_value, int_value, float_value " + f"FROM worker_attributes WHERE worker_id IN ({placeholders})", + worker_ids, + ) + attrs_by_worker = _decode_attribute_rows(attr_rows) + return [ + dc_replace( + w, + attributes=attrs_by_worker.get(w.worker_id, {}), + available_cpu_millicores=w.total_cpu_millicores - w.committed_cpu_millicores, + available_memory=w.total_memory_bytes - w.committed_mem, + available_gpus=w.total_gpu_count - w.committed_gpu, + available_tpus=w.total_tpu_count - w.committed_tpu, + ) + for w in workers + ] + + def query(self, cur: Cursor, flt: WorkerFilter) -> list[WorkerRow]: + """Return :class:`WorkerRow` instances matching ``flt``. + + Executes inside the caller's cursor. Large ``worker_ids`` lists are + chunked under SQLite's host-parameter cap. Worker attributes are NOT + loaded — use :meth:`healthy_active_with_attributes` when attributes + are needed. + """ + if flt.worker_ids is not None and not flt.worker_ids: + return [] + chunks = chunk_ids(flt.worker_ids) + return run_chunked( + chunks, + limit=None, + fetch=lambda chunk, _limit: self._query_worker_chunk(cur, flt, chunk_worker_ids=chunk), + ) + + def _query_worker_chunk( + self, + cur: Cursor, + flt: WorkerFilter, + *, + chunk_worker_ids: tuple[WorkerId, ...] | None, + ) -> list[WorkerRow]: + sql_parts = [f"SELECT {WORKER_ROW_PROJECTION.select_clause()} FROM workers w"] + wb = WhereBuilder() + if chunk_worker_ids is not None: + wb.in_("w.worker_id", tuple(str(wid) for wid in chunk_worker_ids)) + if flt.active is not None: + wb.eq("w.active", 1 if flt.active else 0) + if flt.healthy is not None: + wb.eq("w.healthy", 1 if flt.healthy else 0) + where_sql, params = wb.build() + if where_sql: + sql_parts.append(where_sql) + rows = cur.execute(" ".join(sql_parts), params).fetchall() + return WORKER_ROW_PROJECTION.decode(rows) + + # ── Writes ─────────────────────────────────────────────────────── + + def update_health_batch(self, cur: Cursor, requests: list[HeartbeatApplyRequest], now_ms: int) -> set[str]: + """Batch-update worker health, resource snapshots, and history. + + Returns the set of worker IDs that actually exist in the DB so callers + can skip updates from stale/removed workers. + """ + worker_ids = [str(req.worker_id) for req in requests] + if not worker_ids: + return set() + + placeholders = sql_placeholders(len(worker_ids)) + rows = cur.execute( + f"SELECT worker_id FROM workers WHERE worker_id IN ({placeholders})", + tuple(worker_ids), + ).fetchall() + existing = {str(r["worker_id"]) for r in rows} + + health_params_no_snap: list[tuple] = [] + health_params_with_snap: list[tuple] = [] + history_params: list[tuple] = [] + for req in requests: + wid = str(req.worker_id) + if wid not in existing: + continue + snap = req.worker_resource_snapshot + if snap is not None: + snap_fields = ( + snap.host_cpu_percent, + snap.memory_used_bytes, + snap.memory_total_bytes, + snap.disk_used_bytes, + snap.disk_total_bytes, + snap.running_task_count, + snap.total_process_count, + snap.net_recv_bps, + snap.net_sent_bps, + ) + health_params_with_snap.append((now_ms, *snap_fields, wid)) + history_params.append((wid, *snap_fields, now_ms)) + else: + health_params_no_snap.append((now_ms, wid)) + + if health_params_no_snap: + cur.executemany( + "UPDATE workers SET healthy = 1, active = 1, consecutive_failures = 0, " + "last_heartbeat_ms = ? WHERE worker_id = ?", + health_params_no_snap, + ) + if health_params_with_snap: + cur.executemany( + "UPDATE workers SET healthy = 1, active = 1, consecutive_failures = 0, " + "last_heartbeat_ms = ?, " + "snapshot_host_cpu_percent = ?, snapshot_memory_used_bytes = ?, " + "snapshot_memory_total_bytes = ?, snapshot_disk_used_bytes = ?, " + "snapshot_disk_total_bytes = ?, snapshot_running_task_count = ?, " + "snapshot_total_process_count = ?, snapshot_net_recv_bps = ?, " + "snapshot_net_sent_bps = ? WHERE worker_id = ?", + health_params_with_snap, + ) + if history_params: + cur.executemany( + "INSERT INTO worker_resource_history(" + "worker_id, snapshot_host_cpu_percent, snapshot_memory_used_bytes, " + "snapshot_memory_total_bytes, snapshot_disk_used_bytes, snapshot_disk_total_bytes, " + "snapshot_running_task_count, snapshot_total_process_count, " + "snapshot_net_recv_bps, snapshot_net_sent_bps, timestamp_ms" + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + history_params, + ) + return existing + + def record_heartbeat_failure(self, cur: Cursor, worker_id: WorkerId, failures: int, threshold: int) -> None: + """Increment consecutive_failures and mark unhealthy if threshold reached. + + The caller is responsible for reading the current failure count and + computing `failures` (old count + 1) before calling this method. + """ + cur.execute( + "UPDATE workers SET consecutive_failures = ?, " + "healthy = CASE WHEN ? >= ? THEN 0 ELSE healthy END " + "WHERE worker_id = ?", + (failures, failures, threshold, str(worker_id)), + ) + + def record_worker_task_history(self, cur: Cursor, worker_id: str, task_id: str, now_ms: int) -> None: + """Insert a worker_task_history row recording an assignment.""" + cur.execute( + "INSERT INTO worker_task_history(worker_id, task_id, assigned_at_ms) VALUES (?, ?, ?)", + (worker_id, task_id, now_ms), + ) + + def remove(self, cur: Cursor, worker_id: str) -> None: + """Remove a worker and sever all its foreign-key references. + + Nullifies worker_id in task_attempts and tasks, removes dispatch_queue + entries, and deletes the worker row. + """ + cur.execute( + "UPDATE task_attempts SET worker_id = NULL WHERE worker_id = ?", + (worker_id,), + ) + cur.execute( + "UPDATE tasks SET current_worker_id = NULL WHERE current_worker_id = ?", + (worker_id,), + ) + self._dispatch.delete_for_worker(cur, worker_id) + cur.execute("DELETE FROM workers WHERE worker_id = ?", (worker_id,)) + + def decommit_resources(self, cur: Cursor, worker_id: str, resources: job_pb2.ResourceSpecProto) -> None: + """Subtract a task's resource reservation from a worker, flooring at zero.""" + cur.execute( + "UPDATE workers SET committed_cpu_millicores = MAX(0, committed_cpu_millicores - ?), " + "committed_mem_bytes = MAX(0, committed_mem_bytes - ?), " + "committed_gpu = MAX(0, committed_gpu - ?), " + "committed_tpu = MAX(0, committed_tpu - ?) " + "WHERE worker_id = ?", + ( + int(resources.cpu_millicores), + int(resources.memory_bytes), + int(get_gpu_count(resources.device)), + int(get_tpu_count(resources.device)), + worker_id, + ), + ) + + def commit_resources(self, cur: Cursor, worker_id: str, resources: job_pb2.ResourceSpecProto) -> None: + """Add a task's resource reservation to a worker's committed totals.""" + cur.execute( + "UPDATE workers SET committed_cpu_millicores = committed_cpu_millicores + ?, " + "committed_mem_bytes = committed_mem_bytes + ?, " + "committed_gpu = committed_gpu + ?, " + "committed_tpu = committed_tpu + ? " + "WHERE worker_id = ?", + ( + int(resources.cpu_millicores), + int(resources.memory_bytes), + int(get_gpu_count(resources.device)), + int(get_tpu_count(resources.device)), + worker_id, + ), + ) + + def upsert(self, cur: Cursor, req: WorkerUpsert) -> None: + """Insert or update a worker row and replace its attributes. + + Performs the INSERT...ON CONFLICT UPDATE for the workers table, + then deletes and re-inserts all worker_attributes rows. + """ + md = req.metadata + cur.execute( + "INSERT INTO workers(" + "worker_id, address, healthy, active, consecutive_failures, last_heartbeat_ms, " + "committed_cpu_millicores, committed_mem_bytes, committed_gpu, committed_tpu, " + "total_cpu_millicores, total_memory_bytes, total_gpu_count, total_tpu_count, " + "device_type, device_variant, slice_id, scale_group, " + "md_hostname, md_ip_address, md_cpu_count, md_memory_bytes, md_disk_bytes, " + "md_tpu_name, md_tpu_worker_hostnames, md_tpu_worker_id, md_tpu_chips_per_host_bounds, " + "md_gpu_count, md_gpu_name, md_gpu_memory_mb, " + "md_gce_instance_name, md_gce_zone, md_git_hash, md_device_json" + ") VALUES (?, ?, 1, 1, 0, ?, 0, 0, 0, 0, ?, ?, ?, ?, ?, ?, ?, ?, " + "?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) " + "ON CONFLICT(worker_id) DO UPDATE SET " + "address=excluded.address, healthy=1, active=1, " + "consecutive_failures=0, last_heartbeat_ms=excluded.last_heartbeat_ms, " + "total_cpu_millicores=excluded.total_cpu_millicores, total_memory_bytes=excluded.total_memory_bytes, " + "total_gpu_count=excluded.total_gpu_count, total_tpu_count=excluded.total_tpu_count, " + "device_type=excluded.device_type, device_variant=excluded.device_variant, " + "slice_id=excluded.slice_id, scale_group=excluded.scale_group, " + "md_hostname=excluded.md_hostname, md_ip_address=excluded.md_ip_address, " + "md_cpu_count=excluded.md_cpu_count, md_memory_bytes=excluded.md_memory_bytes, " + "md_disk_bytes=excluded.md_disk_bytes, md_tpu_name=excluded.md_tpu_name, " + "md_tpu_worker_hostnames=excluded.md_tpu_worker_hostnames, " + "md_tpu_worker_id=excluded.md_tpu_worker_id, " + "md_tpu_chips_per_host_bounds=excluded.md_tpu_chips_per_host_bounds, " + "md_gpu_count=excluded.md_gpu_count, md_gpu_name=excluded.md_gpu_name, " + "md_gpu_memory_mb=excluded.md_gpu_memory_mb, " + "md_gce_instance_name=excluded.md_gce_instance_name, md_gce_zone=excluded.md_gce_zone, " + "md_git_hash=excluded.md_git_hash, md_device_json=excluded.md_device_json", + ( + req.worker_id, + req.address, + req.now_ms, + req.total_cpu_millicores, + req.total_memory_bytes, + req.total_gpu_count, + req.total_tpu_count, + req.device_type, + req.device_variant, + req.slice_id, + req.scale_group, + md.hostname, + md.ip_address, + md.cpu_count, + md.memory_bytes, + md.disk_bytes, + md.tpu_name, + md.tpu_worker_hostnames, + md.tpu_worker_id, + md.tpu_chips_per_host_bounds, + md.gpu_count, + md.gpu_name, + md.gpu_memory_mb, + md.gce_instance_name, + md.gce_zone, + md.git_hash, + md.device_json, + ), + ) + cur.execute("DELETE FROM worker_attributes WHERE worker_id = ?", (req.worker_id,)) + for key, value_type, str_value, int_value, float_value in req.attributes: + cur.execute( + "INSERT INTO worker_attributes(worker_id, key, value_type, str_value, int_value, float_value) " + "VALUES (?, ?, ?, ?, ?, ?)", + (req.worker_id, key, value_type, str_value, int_value, float_value), + ) + + def get_active_row(self, cur: Cursor, worker_id: str) -> WorkerActiveRow | None: + """Fetch consecutive_failures and last_heartbeat_ms for an active worker. + + Returns None if the worker doesn't exist or is inactive. + """ + row = cur.execute( + "SELECT consecutive_failures, last_heartbeat_ms FROM workers WHERE worker_id = ? AND active = 1", + (worker_id,), + ).fetchone() + if row is None: + return None + return WorkerActiveRow( + consecutive_failures=int(row["consecutive_failures"]), + last_heartbeat_ms=int(row["last_heartbeat_ms"]) if row["last_heartbeat_ms"] is not None else None, + ) + + def get_row(self, cur: Cursor, worker_id: str) -> WorkerDetailRow | None: + """Fetch the full worker row. Returns None if not found.""" + row = cur.execute( + f"SELECT {WORKER_DETAIL_PROJECTION.select_clause(prefix=False)} FROM workers WHERE worker_id = ?", + (worker_id,), + ).fetchone() + if row is None: + return None + return WORKER_DETAIL_PROJECTION.decode_one([row]) + + def get_healthy_active(self, cur: Cursor, worker_id: str) -> dict | None: + """Fetch worker_id and address for a healthy active worker. + + Returns None if the worker is missing, inactive, or unhealthy. + """ + row = cur.execute( + "SELECT worker_id, address FROM workers WHERE worker_id = ? AND active = 1 AND healthy = 1", + (worker_id,), + ).fetchone() + if row is None: + return None + return row + + def prune_task_history(self, cur: Cursor, retention: int) -> int: + """Trim worker_task_history to *retention* rows per worker.""" + return self._prune_per_worker_history( + cur, "worker_task_history", retention, order_by="assigned_at_ms DESC, id DESC" + ) + + def prune_resource_history(self, cur: Cursor, retention: int) -> int: + """Trim worker_resource_history to *retention* rows per worker.""" + return self._prune_per_worker_history(cur, "worker_resource_history", retention) + + def _prune_per_worker_history(self, cur: Cursor, table: str, retention: int, order_by: str = "id DESC") -> int: + """Trim a per-worker history table to *retention* rows per worker.""" + rows = cur.execute( + f"SELECT worker_id, COUNT(*) as cnt FROM {table} GROUP BY worker_id HAVING cnt > ?", + (retention,), + ).fetchall() + total_deleted = 0 + for row in rows: + wid = row["worker_id"] + cur.execute( + f"DELETE FROM {table} " + "WHERE worker_id = ? " + f"AND id NOT IN (" + f" SELECT id FROM {table} " + " WHERE worker_id = ? " + f" ORDER BY {order_by} LIMIT ?" + ")", + (wid, wid, retention), + ) + total_deleted += cur.rowcount + if total_deleted > 0: + logger.info("Pruned %d %s rows", total_deleted, table) + return total_deleted + + def get_inactive_worker_before(self, cur: Cursor, cutoff_ms: int) -> str | None: + """Return a single inactive/unhealthy worker_id with heartbeat before the cutoff.""" + row = cur.execute( + "SELECT worker_id FROM workers WHERE (active = 0 OR healthy = 0) AND last_heartbeat_ms < ? LIMIT 1", + (cutoff_ms,), + ).fetchone() + if row is None: + return None + return str(row["worker_id"]) + + +# --------------------------------------------------------------------------- +# DispatchStore +# --------------------------------------------------------------------------- + + +class DispatchStore: + """Typed operations for the dispatch_queue table. + + Process-scoped: every method takes the open ``Cursor``. + Encapsulates enqueue, drain, and delete so callers don't scatter raw + dispatch_queue SQL. + """ + + def enqueue_run(self, cur: Cursor, worker_id: str, payload: bytes, now_ms: int) -> None: + """Queue a 'run' dispatch entry for delivery on the next heartbeat.""" + cur.execute( + "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " + "VALUES (?, 'run', ?, NULL, ?)", + (worker_id, payload, now_ms), + ) + + def enqueue_kill(self, cur: Cursor, worker_id: str | None, task_id: str, now_ms: int) -> None: + """Queue a 'kill' dispatch entry for delivery on the next heartbeat.""" + cur.execute( + "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " + "VALUES (?, 'kill', NULL, ?, ?)", + (worker_id, task_id, now_ms), + ) + + def drain_for_worker(self, cur: Cursor, worker_id: str) -> list[tuple[str, bytes | None, str | None]]: + """SELECT and DELETE dispatch rows for one worker. + + Returns list of (kind, payload_proto, task_id) tuples ordered by id ASC. + """ + rows = cur.execute( + "SELECT kind, payload_proto, task_id FROM dispatch_queue WHERE worker_id = ? ORDER BY id ASC", + (worker_id,), + ).fetchall() + if rows: + cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (worker_id,)) + return [(str(r["kind"]), r["payload_proto"], r["task_id"]) for r in rows] + + def drain_for_workers( + self, cur: Cursor, worker_ids: list[str] + ) -> dict[str, list[tuple[str, bytes | None, str | None]]]: + """Batch drain dispatch rows for multiple workers. + + Returns a dict mapping worker_id to list of (kind, payload_proto, task_id) tuples. + """ + if not worker_ids: + return {} + placeholders = sql_placeholders(len(worker_ids)) + rows = cur.execute( + f"SELECT worker_id, kind, payload_proto, task_id FROM dispatch_queue " + f"WHERE worker_id IN ({placeholders}) ORDER BY id ASC", + tuple(worker_ids), + ).fetchall() + if rows: + cur.execute( + f"DELETE FROM dispatch_queue WHERE worker_id IN ({placeholders})", + tuple(worker_ids), + ) + result: dict[str, list[tuple[str, bytes | None, str | None]]] = {} + for r in rows: + wid = str(r["worker_id"]) + if wid not in result: + result[wid] = [] + result[wid].append((str(r["kind"]), r["payload_proto"], r["task_id"])) + return result + + def drain_direct_kills(self, cur: Cursor) -> list[str]: + """Drain NULL-worker kill entries. Returns list of task_ids.""" + rows = cur.execute( + "SELECT task_id FROM dispatch_queue WHERE worker_id IS NULL AND kind = 'kill'", + ).fetchall() + task_ids = [str(r["task_id"]) for r in rows if r["task_id"] is not None] + if rows: + cur.execute("DELETE FROM dispatch_queue WHERE worker_id IS NULL AND kind = 'kill'") + return task_ids + + def delete_for_worker(self, cur: Cursor, worker_id: str) -> None: + """Delete all dispatch entries for a worker.""" + cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (worker_id,)) + + def replace_claims(self, cur: Cursor, claims: dict[WorkerId, tuple[str, int]]) -> None: + """Replace all reservation claims atomically. + + Args: + claims: Mapping of worker_id -> (job_id, entry_idx). + """ + cur.execute("DELETE FROM reservation_claims") + cur.executemany( + "INSERT INTO reservation_claims(worker_id, job_id, entry_idx) VALUES (?, ?, ?)", + [(str(worker_id), job_id, entry_idx) for worker_id, (job_id, entry_idx) in claims.items()], + ) + + +# --------------------------------------------------------------------------- +# UserStore +# --------------------------------------------------------------------------- + + +class UserStore: + """User and budget table operations.""" + + def ensure_user_and_budget( + self, + cur: Cursor, + user: str, + now_ms: int, + budget_defaults: UserBudgetDefaults, + ) -> None: + """Create user and default budget row if they don't already exist.""" + cur.execute( + "INSERT OR IGNORE INTO users(user_id, created_at_ms) VALUES (?, ?)", + (user, now_ms), + ) + cur.execute( + "INSERT OR IGNORE INTO user_budgets(user_id, budget_limit, max_band, updated_at_ms) " "VALUES (?, ?, ?, ?)", + (user, budget_defaults.budget_limit, budget_defaults.max_band, now_ms), + ) + + +# --------------------------------------------------------------------------- +# ControllerStore — transaction-scoped handle bundling cursor + store access. +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class ControllerStore: + """Transaction-scoped handle bundling the cursor with typed store access. + + Stores are process-scoped; ControllerStore is the thin per-transaction view + binding them to the active cursor. State-machine code takes ControllerStore + as its single DB argument. + """ + + cur: Cursor + tasks: TaskStore + jobs: JobStore + workers: WorkerStore + endpoints: EndpointStore + dispatch: DispatchStore + users: UserStore + + +# --------------------------------------------------------------------------- +# ControllerStores — process-scoped bundle owning the stores (and the db). +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class ControllerStores: + """Process-scoped bundle of stores, constructed on top of a ``ControllerDB``. + + This is the single ownership layer for domain access. ``ControllerDB`` knows + nothing about stores; callers construct a ``ControllerDB`` (pure infra) and + then wrap it with ``ControllerStores.from_db(db)``. + """ + + db: DbBackend + endpoints: EndpointStore + dispatch: DispatchStore + users: UserStore + workers: WorkerStore + jobs: JobStore + tasks: TaskStore + + @classmethod + def from_db(cls, db: DbBackend) -> ControllerStores: + # Construction order encodes the store dependency graph: + # endpoints, dispatch, users: no dependencies (leaf caches/stores) + # workers depends on endpoints (for _remove_task_endpoints on dead + # workers) and dispatch (to drain pending dispatches on remove) + # jobs depends on endpoints (to cascade endpoint cleanup on job kill) + # tasks depends on endpoints (same reason) and jobs (to look up the + # job row while transitioning a task) + # Adding a new store? Place it by its inward-edge count — leaves first. + endpoints = EndpointStore() + with db.read_snapshot() as snap: + endpoints._load_all(snap) + dispatch = DispatchStore() + users = UserStore() + workers = WorkerStore(endpoints, dispatch) + jobs = JobStore(endpoints) + tasks = TaskStore(endpoints, jobs) + return cls( + db=db, + endpoints=endpoints, + dispatch=dispatch, + users=users, + workers=workers, + jobs=jobs, + tasks=tasks, + ) + + @contextmanager + def transact(self) -> Iterator[ControllerStore]: + """Open an IMMEDIATE transaction and yield a per-txn ``ControllerStore``.""" + with self.db.transaction() as cur: + yield ControllerStore( + cur=cur, + tasks=self.tasks, + jobs=self.jobs, + workers=self.workers, + endpoints=self.endpoints, + dispatch=self.dispatch, + users=self.users, + ) + + @contextmanager + def read(self) -> Iterator[ControllerStore]: + """Open a read-only snapshot and yield a ControllerStore bound to it. + + Store methods that only read can be called against ``ctx.cur`` inside + this scope. Write methods will fail at runtime (QuerySnapshot lacks + on_commit/rowcount) — that's intentional. + """ + with self.db.read_snapshot() as snap: + yield ControllerStore( + cur=snap, + tasks=self.tasks, + jobs=self.jobs, + workers=self.workers, + endpoints=self.endpoints, + dispatch=self.dispatch, + users=self.users, + ) diff --git a/lib/iris/src/iris/cluster/controller/transitions.py b/lib/iris/src/iris/cluster/controller/transitions.py index c26df8652d..a0adbb8807 100644 --- a/lib/iris/src/iris/cluster/controller/transitions.py +++ b/lib/iris/src/iris/cluster/controller/transitions.py @@ -13,7 +13,6 @@ import json import logging from dataclasses import dataclass, field -from collections.abc import Callable, Iterable from typing import Any, NamedTuple from iris.cluster.constraints import AttributeValue, Constraint, constraints_from_resources, merge_constraints @@ -28,23 +27,44 @@ resource_spec_from_scalars, ) from iris.cluster.controller.db import ( - ACTIVE_TASK_STATES, - EXECUTING_TASK_STATES, FAILURE_TASK_STATES, - ControllerDB, - TransactionCursor, - task_row_can_be_scheduled, - task_row_is_finished, + batch_delete, ) from iris.cluster.controller.schema import ( - JOB_CONFIG_JOIN, - JOB_DETAIL_PROJECTION, + ACTIVE_TASK_STATES, + EXECUTING_TASK_STATES, TASK_DETAIL_PROJECTION, WORKER_DETAIL_PROJECTION, EndpointRow, JobDetailRow, + ResourceSpec, + TaskDetailRow, WorkerDetailRow, ) +from iris.cluster.controller.store import ( + ActiveStateUpdate, + AttemptFinalizer, + ControllerStore, + ControllerStores, + DirectAssignment, + HeartbeatApplyRequest, + JobConfigInsert, + JobInsert, + KillResult, + SiblingSnapshot, + TaskFilter, + TaskInsert, + TaskProjection, + TaskRetry, + TaskTermination, + TaskUpdate, + WorkerAssignment, + WorkerMetadata, + WorkerUpsert, + sql_placeholders, + task_row_can_be_scheduled, + task_row_is_finished, +) from iris.cluster.types import ( TERMINAL_JOB_STATES, TERMINAL_TASK_STATES, @@ -174,28 +194,6 @@ class WorkerConfig: metadata: job_pb2.WorkerMetadata -@dataclass(frozen=True) -class TaskUpdate: - """Single task state update applied in a batch.""" - - task_id: JobName - attempt_id: int - new_state: int - error: str | None = None - exit_code: int | None = None - resource_usage: job_pb2.ResourceUsage | None = None - container_id: str | None = None - - -@dataclass(frozen=True) -class HeartbeatApplyRequest: - """Batch of worker heartbeat updates applied atomically.""" - - worker_id: WorkerId - worker_resource_snapshot: job_pb2.WorkerResourceSnapshot | None - updates: list[TaskUpdate] - - @dataclass(frozen=True) class Assignment: """Scheduler assignment decision.""" @@ -322,411 +320,6 @@ def _has_reservation_flag(request: controller_pb2.Controller.LaunchJobRequest) - return 1 if request.HasField("reservation") and request.reservation.entries else 0 -def delete_task_endpoints(cur: TransactionCursor, registry, task_id: str) -> None: - """Remove all registered endpoints for a task through the endpoint registry.""" - registry.remove_by_task(cur, JobName.from_wire(task_id)) - - -def enqueue_run_dispatch( - cur: TransactionCursor, - worker_id: str, - payload_proto: bytes, - now_ms: int, -) -> None: - """Queue a 'run' dispatch entry for delivery on the next heartbeat.""" - cur.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'run', ?, NULL, ?)", - (worker_id, payload_proto, now_ms), - ) - - -def enqueue_kill_dispatch( - cur: TransactionCursor, - worker_id: str | None, - task_id: str, - now_ms: int, -) -> None: - """Queue a 'kill' dispatch entry for delivery on the next heartbeat.""" - cur.execute( - "INSERT INTO dispatch_queue(worker_id, kind, payload_proto, task_id, created_at_ms) " - "VALUES (?, 'kill', NULL, ?, ?)", - (worker_id, task_id, now_ms), - ) - - -def insert_task_attempt( - cur: TransactionCursor, - task_id: str, - attempt_id: int, - worker_id: str | None, - state: int, - now_ms: int, -) -> None: - """Record a new task attempt row.""" - cur.execute( - "INSERT INTO task_attempts(task_id, attempt_id, worker_id, state, created_at_ms) " "VALUES (?, ?, ?, ?, ?)", - (task_id, attempt_id, worker_id, state, now_ms), - ) - - -def _decommit_worker_resources( - cur: TransactionCursor, - worker_id: str, - resources: "job_pb2.ResourceSpecProto", -) -> None: - """Subtract a task's resource reservation from a worker, flooring at zero.""" - cur.execute( - "UPDATE workers SET committed_cpu_millicores = MAX(0, committed_cpu_millicores - ?), " - "committed_mem_bytes = MAX(0, committed_mem_bytes - ?), " - "committed_gpu = MAX(0, committed_gpu - ?), committed_tpu = MAX(0, committed_tpu - ?) " - "WHERE worker_id = ?", - ( - int(resources.cpu_millicores), - int(resources.memory_bytes), - int(get_gpu_count(resources.device)), - int(get_tpu_count(resources.device)), - worker_id, - ), - ) - - -def _remove_worker(cur: TransactionCursor, worker_id: str) -> None: - """Remove a worker and sever all its foreign-key references. - - Must be called inside an existing transaction. The four statements - enforce the multi-table invariant: no dangling worker_id references - remain in task_attempts, tasks, or dispatch_queue after the worker - row is deleted. - """ - cur.execute("UPDATE task_attempts SET worker_id = NULL WHERE worker_id = ?", (worker_id,)) - cur.execute("UPDATE tasks SET current_worker_id = NULL WHERE current_worker_id = ?", (worker_id,)) - cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (worker_id,)) - cur.execute("DELETE FROM workers WHERE worker_id = ?", (worker_id,)) - - -def _assign_task( - cur: TransactionCursor, - task_id: str, - worker_id: str | None, - worker_address: str | None, - attempt_id: int, - now_ms: int, -) -> None: - """Create an attempt and mark a task as ASSIGNED in one consistent step. - - worker_id may be None for direct-provider tasks that have no backing - worker daemon. - """ - insert_task_attempt(cur, task_id, attempt_id, worker_id, job_pb2.TASK_STATE_ASSIGNED, now_ms) - if worker_id is not None: - cur.execute( - "UPDATE tasks SET state = ?, current_attempt_id = ?, " - "current_worker_id = ?, current_worker_address = ?, " - "started_at_ms = COALESCE(started_at_ms, ?) WHERE task_id = ?", - (job_pb2.TASK_STATE_ASSIGNED, attempt_id, worker_id, worker_address, now_ms, task_id), - ) - else: - cur.execute( - "UPDATE tasks SET state = ?, current_attempt_id = ?, " - "started_at_ms = COALESCE(started_at_ms, ?) WHERE task_id = ?", - (job_pb2.TASK_STATE_ASSIGNED, attempt_id, now_ms, task_id), - ) - - -def _terminate_task( - cur: TransactionCursor, - registry, - task_id: str, - attempt_id: int | None, - state: int, - error: str | None, - now_ms: int, - *, - attempt_state: int | None = None, - worker_id: str | None = None, - resources: "job_pb2.ResourceSpecProto | None" = None, - failure_count: int | None = None, - preemption_count: int | None = None, -) -> None: - """Move a task (and its current attempt) out of active state consistently. - - Enforces the multi-table invariant: attempt is marked terminal, - task state/error/finished_at are updated, endpoints are deleted, - and worker resources are released. - - ``attempt_state`` overrides the state written to the attempt row when it - differs from the task state (e.g. attempt=WORKER_FAILED while task retries - to PENDING). Defaults to ``state`` when not provided. - - attempt_id < 0 means no attempt exists; the attempt UPDATE is skipped. - """ - finished_at_ms = None if state in ACTIVE_TASK_STATES or state == job_pb2.TASK_STATE_PENDING else now_ms - effective_attempt_state = attempt_state if attempt_state is not None else state - - if attempt_id is not None and attempt_id >= 0: - cur.execute( - "UPDATE task_attempts SET state = ?, " - "finished_at_ms = COALESCE(finished_at_ms, ?), error = ? " - "WHERE task_id = ? AND attempt_id = ?", - (effective_attempt_state, now_ms, error, task_id, attempt_id), - ) - - # Build the UPDATE tasks statement dynamically based on optional counters. - # Use COALESCE for finished_at_ms when non-NULL to preserve any existing - # timestamp (defensive against double-termination). When NULL (retrying to - # PENDING), assign directly so the column is cleared. - if finished_at_ms is not None: - set_clauses = ["state = ?", "error = ?", "finished_at_ms = COALESCE(finished_at_ms, ?)"] - else: - set_clauses = ["state = ?", "error = ?", "finished_at_ms = ?"] - params: list[object] = [state, error, finished_at_ms] - - if failure_count is not None: - set_clauses.append("failure_count = ?") - params.append(failure_count) - if preemption_count is not None: - set_clauses.append("preemption_count = ?") - params.append(preemption_count) - - # Always clear worker columns when leaving active state. - if state not in ACTIVE_TASK_STATES: - set_clauses.append("current_worker_id = NULL") - set_clauses.append("current_worker_address = NULL") - - params.append(task_id) - cur.execute( - f"UPDATE tasks SET {', '.join(set_clauses)} WHERE task_id = ?", - tuple(params), - ) - - delete_task_endpoints(cur, registry, task_id) - - if worker_id is not None and resources is not None: - _decommit_worker_resources(cur, worker_id, resources) - - -def _kill_non_terminal_tasks( - cur: Any, - registry, - job_id_wire: str, - reason: str, - now_ms: int, -) -> tuple[set[JobName], dict[JobName, WorkerId]]: - """Kill all non-terminal tasks for a single job, decommit resources, and delete endpoints.""" - terminal_states = tuple(sorted(TERMINAL_TASK_STATES)) - placeholders = ",".join("?" * len(terminal_states)) - rows = cur.execute( - "SELECT t.task_id, t.current_attempt_id, t.current_worker_id, " - "jc.res_cpu_millicores, jc.res_memory_bytes, jc.res_disk_bytes, jc.res_device_json " - "FROM tasks t " - "JOIN jobs j ON j.job_id = t.job_id " - f"{JOB_CONFIG_JOIN} " - f"WHERE t.job_id = ? AND t.state NOT IN ({placeholders})", - (job_id_wire, *terminal_states), - ).fetchall() - tasks_to_kill: set[JobName] = set() - task_kill_workers: dict[JobName, WorkerId] = {} - for row in rows: - task_id = str(row["task_id"]) - worker_id = row["current_worker_id"] - task_name = JobName.from_wire(task_id) - resources = None - if worker_id is not None: - resources = resource_spec_from_scalars( - int(row["res_cpu_millicores"]), - int(row["res_memory_bytes"]), - int(row["res_disk_bytes"]), - row["res_device_json"], - ) - task_kill_workers[task_name] = WorkerId(str(worker_id)) - _terminate_task( - cur, - registry, - task_id, - int(row["current_attempt_id"]), - job_pb2.TASK_STATE_KILLED, - reason, - now_ms, - worker_id=str(worker_id) if worker_id is not None else None, - resources=resources, - ) - tasks_to_kill.add(task_name) - return tasks_to_kill, task_kill_workers - - -def _cascade_children( - cur: Any, - registry, - job_id: JobName, - now_ms: int, - reason: str, - exclude_reservation_holders: bool = False, -) -> tuple[set[JobName], dict[JobName, WorkerId]]: - """Kill descendant jobs (not the job itself) when a parent reaches terminal state or is preempted. - - When exclude_reservation_holders is True, reservation holder jobs and their - descendants are left alive. This is used during preemption retry: the parent - goes back to PENDING and needs its reservation to survive so the scheduler - can re-satisfy it. - """ - tasks_to_kill: set[JobName] = set() - task_kill_workers: dict[JobName, WorkerId] = {} - - if exclude_reservation_holders: - # Skip reservation holder jobs and anything below them. - descendants = cur.execute( - "WITH RECURSIVE subtree(job_id) AS (" - " SELECT job_id FROM jobs WHERE parent_job_id = ? AND is_reservation_holder = 0 " - " UNION ALL " - " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" - " WHERE j.is_reservation_holder = 0" - ") SELECT job_id FROM subtree", - (job_id.to_wire(),), - ).fetchall() - else: - descendants = cur.execute( - "WITH RECURSIVE subtree(job_id) AS (" - " SELECT job_id FROM jobs WHERE parent_job_id = ? " - " UNION ALL " - " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" - ") SELECT job_id FROM subtree", - (job_id.to_wire(),), - ).fetchall() - for child_row in descendants: - child_job_id = str(child_row["job_id"]) - child_tasks_to_kill, child_task_kill_workers = _kill_non_terminal_tasks( - cur, registry, child_job_id, reason, now_ms - ) - tasks_to_kill.update(child_tasks_to_kill) - task_kill_workers.update(child_task_kill_workers) - terminal_placeholders = ",".join("?" for _ in TERMINAL_JOB_STATES) - cur.execute( - "UPDATE jobs SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " - f"WHERE job_id = ? AND state NOT IN ({terminal_placeholders})", - ( - job_pb2.JOB_STATE_KILLED, - reason, - now_ms, - child_job_id, - *TERMINAL_JOB_STATES, - ), - ) - return tasks_to_kill, task_kill_workers - - -def _cascade_terminal_job( - cur: Any, - registry, - job_id: JobName, - now_ms: int, - reason: str, -) -> tuple[set[JobName], dict[JobName, WorkerId]]: - """Kill remaining tasks and descendant jobs when a job reaches a terminal state.""" - tasks_to_kill, task_kill_workers = _kill_non_terminal_tasks(cur, registry, job_id.to_wire(), reason, now_ms) - child_tasks_to_kill, child_task_kill_workers = _cascade_children(cur, registry, job_id, now_ms, reason) - tasks_to_kill.update(child_tasks_to_kill) - task_kill_workers.update(child_task_kill_workers) - return tasks_to_kill, task_kill_workers - - -@dataclass(frozen=True, slots=True) -class _CoscheduledSibling: - task_id: str # wire format - attempt_id: int - max_retries_preemption: int - worker_id: str | None - - -def _find_coscheduled_siblings( - cur: Any, - job_id: JobName, - exclude_task_id: JobName, - has_coscheduling: bool, -) -> list[_CoscheduledSibling]: - """Find active siblings in a coscheduled job (read-only).""" - if not has_coscheduling: - return [] - rows = cur.execute( - "SELECT t.task_id, t.current_attempt_id, t.max_retries_preemption, " - "t.current_worker_id AS worker_id " - "FROM tasks t " - "WHERE t.job_id = ? AND t.task_id != ? AND t.state IN (?, ?, ?)", - ( - job_id.to_wire(), - exclude_task_id.to_wire(), - job_pb2.TASK_STATE_ASSIGNED, - job_pb2.TASK_STATE_BUILDING, - job_pb2.TASK_STATE_RUNNING, - ), - ).fetchall() - return [ - _CoscheduledSibling( - task_id=str(r["task_id"]), - attempt_id=int(r["current_attempt_id"]), - max_retries_preemption=int(r["max_retries_preemption"]), - worker_id=str(r["worker_id"]) if r["worker_id"] is not None else None, - ) - for r in rows - ] - - -def _terminate_coscheduled_siblings( - cur: Any, - registry, - siblings: Iterable[_CoscheduledSibling], - failed_task_id: JobName, - resources: "job_pb2.ResourceSpecProto", - now_ms: int, -) -> tuple[set[JobName], dict[JobName, WorkerId]]: - """Terminate coscheduled siblings and decommit their resources. - - Each sibling is marked WORKER_FAILED with exhausted preemption count so it - will not be retried. - """ - tasks_to_kill: set[JobName] = set() - task_kill_workers: dict[JobName, WorkerId] = {} - error = f"Coscheduled sibling {failed_task_id.to_wire()} failed" - - for sib in siblings: - _terminate_task( - cur, - registry, - sib.task_id, - sib.attempt_id, - job_pb2.TASK_STATE_WORKER_FAILED, - error, - now_ms, - worker_id=sib.worker_id, - resources=resources if sib.worker_id is not None else None, - preemption_count=sib.max_retries_preemption + 1, - ) - if sib.worker_id is not None: - task_kill_workers[JobName.from_wire(sib.task_id)] = WorkerId(sib.worker_id) - tasks_to_kill.add(JobName.from_wire(sib.task_id)) - - return tasks_to_kill, task_kill_workers - - -def _resolve_preemption_policy(cur: Any, job_id: JobName) -> int: - """Resolve the effective preemption policy for a job. - - Defaults: single-task jobs → TERMINATE_CHILDREN, multi-task → PRESERVE_CHILDREN. - """ - row = cur.execute( - f"SELECT jc.preemption_policy, j.num_tasks FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", - (job_id.to_wire(),), - ).fetchone() - if row is None: - return job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN - policy = int(row["preemption_policy"]) - if policy != job_pb2.JOB_PREEMPTION_POLICY_UNSPECIFIED: - return policy - if int(row["num_tasks"]) <= 1: - return job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN - return job_pb2.JOB_PREEMPTION_POLICY_PRESERVE_CHILDREN - - _TERMINAL_STATE_REASONS: dict[int, str] = { job_pb2.JOB_STATE_FAILED: "Job exceeded max_task_failures", job_pb2.JOB_STATE_KILLED: "Job was terminated.", @@ -736,15 +329,14 @@ def _resolve_preemption_policy(cur: Any, job_id: JobName) -> int: def _finalize_terminal_job( - cur: Any, - registry, + ctx: ControllerStore, job_id: JobName, terminal_state: int, now_ms: int, -) -> tuple[set[JobName], dict[JobName, WorkerId]]: +) -> KillResult: """Kill remaining tasks and optionally cascade to children when a job goes terminal. - Called after _recompute_job_state determines a job has reached a terminal + Called after recompute_state determines a job has reached a terminal state. Kills the job's own non-terminal tasks and, depending on preemption policy, cascades to descendant jobs. @@ -752,16 +344,18 @@ def _finalize_terminal_job( Non-succeeded jobs cascade only if the preemption policy is TERMINATE_CHILDREN. """ reason = _TERMINAL_STATE_REASONS.get(terminal_state, "Job finalized") - tasks_to_kill, task_kill_workers = _kill_non_terminal_tasks(cur, registry, job_id.to_wire(), reason, now_ms) + result = ctx.jobs.kill_non_terminal_tasks(ctx.cur, ctx.tasks, job_id.to_wire(), reason, now_ms) + tasks_to_kill = set(result.tasks_to_kill) + task_kill_workers = dict(result.task_kill_workers) should_cascade = True if terminal_state != job_pb2.JOB_STATE_SUCCEEDED: - policy = _resolve_preemption_policy(cur, job_id) + policy = ctx.jobs.get_preemption_policy(ctx.cur, job_id) should_cascade = policy == job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN if should_cascade: - child_tasks_to_kill, child_task_kill_workers = _cascade_children(cur, registry, job_id, now_ms, reason) - tasks_to_kill.update(child_tasks_to_kill) - task_kill_workers.update(child_task_kill_workers) - return tasks_to_kill, task_kill_workers + child_result = ctx.jobs.cascade_children(ctx.cur, ctx.tasks, job_id, reason, now_ms) + tasks_to_kill.update(child_result.tasks_to_kill) + task_kill_workers.update(child_result.task_kill_workers) + return KillResult(tasks_to_kill=frozenset(tasks_to_kill), task_kill_workers=task_kill_workers) def _resolve_task_failure_state( @@ -786,102 +380,6 @@ def _resolve_task_failure_state( return terminal_state, preemption_count -# ============================================================================= -# Batch helpers for apply_heartbeats_batch -# ============================================================================= - - -def _batch_worker_health( - cur: TransactionCursor, - requests: list["HeartbeatApplyRequest"], - now_ms: int, -) -> set[str]: - """Batch-update worker health, resource snapshots, and history. - - Returns the set of worker IDs that actually exist in the DB so callers - can skip updates from stale/removed workers. - """ - worker_ids = [str(req.worker_id) for req in requests] - if not worker_ids: - return set() - - placeholders = ",".join("?" * len(worker_ids)) - rows = cur.execute( - f"SELECT worker_id FROM workers WHERE worker_id IN ({placeholders})", - tuple(worker_ids), - ).fetchall() - existing = {str(r["worker_id"]) for r in rows} - - health_params_no_snap = [] - health_params_with_snap = [] - history_params = [] - for req in requests: - wid = str(req.worker_id) - if wid not in existing: - continue - snap = req.worker_resource_snapshot - if snap is not None: - snap_fields = ( - snap.host_cpu_percent, - snap.memory_used_bytes, - snap.memory_total_bytes, - snap.disk_used_bytes, - snap.disk_total_bytes, - snap.running_task_count, - snap.total_process_count, - snap.net_recv_bps, - snap.net_sent_bps, - ) - health_params_with_snap.append((now_ms, *snap_fields, wid)) - history_params.append((wid, *snap_fields, now_ms)) - else: - health_params_no_snap.append((now_ms, wid)) - - if health_params_no_snap: - cur.executemany( - "UPDATE workers SET healthy = 1, active = 1, consecutive_failures = 0, " - "last_heartbeat_ms = ? WHERE worker_id = ?", - health_params_no_snap, - ) - if health_params_with_snap: - cur.executemany( - "UPDATE workers SET healthy = 1, active = 1, consecutive_failures = 0, " - "last_heartbeat_ms = ?, " - "snapshot_host_cpu_percent = ?, snapshot_memory_used_bytes = ?, " - "snapshot_memory_total_bytes = ?, snapshot_disk_used_bytes = ?, " - "snapshot_disk_total_bytes = ?, snapshot_running_task_count = ?, " - "snapshot_total_process_count = ?, snapshot_net_recv_bps = ?, " - "snapshot_net_sent_bps = ? WHERE worker_id = ?", - health_params_with_snap, - ) - if history_params: - cur.executemany( - "INSERT INTO worker_resource_history(" - "worker_id, snapshot_host_cpu_percent, snapshot_memory_used_bytes, " - "snapshot_memory_total_bytes, snapshot_disk_used_bytes, snapshot_disk_total_bytes, " - "snapshot_running_task_count, snapshot_total_process_count, " - "snapshot_net_recv_bps, snapshot_net_sent_bps, timestamp_ms" - ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - history_params, - ) - return existing - - -def _bulk_fetch_tasks(cur: TransactionCursor, task_ids: list[str]) -> dict[str, Any]: - """Fetch task rows for all given IDs in chunked IN queries.""" - result: dict[str, Any] = {} - for chunk_start in range(0, len(task_ids), 900): - chunk = task_ids[chunk_start : chunk_start + 900] - ph = ",".join("?" * len(chunk)) - rows = cur.execute( - f"SELECT * FROM tasks WHERE task_id IN ({ph})", - tuple(chunk), - ).fetchall() - for r in rows: - result[str(r["task_id"])] = r - return result - - # ============================================================================= # Controller Transitions # ============================================================================= @@ -900,11 +398,12 @@ class ControllerTransitions: def __init__( self, - db: ControllerDB, + stores: ControllerStores, heartbeat_failure_threshold: int = HEARTBEAT_FAILURE_THRESHOLD, user_budget_defaults: UserBudgetDefaults | None = None, ): - self._db = db + self._stores = stores + self._db = stores.db # infra calls (read_snapshot, wal_checkpoint) self._heartbeat_failure_threshold = heartbeat_failure_threshold self._user_budget_defaults = user_budget_defaults or UserBudgetDefaults() @@ -928,93 +427,10 @@ def _record_transaction( (txn_id, action, entity_id, json.dumps(details), created_ms), ) - def _recompute_job_state(self, cur: Any, job_id: JobName) -> int | None: - row = cur.execute( - f"SELECT j.state, j.started_at_ms, jc.max_task_failures " - f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", - (job_id.to_wire(),), - ).fetchone() - if row is None: - return None - current_state = int(row["state"]) - if current_state in TERMINAL_JOB_STATES: - return current_state - max_task_failures = int(row["max_task_failures"]) - counts_rows = cur.execute( - "SELECT state, COUNT(*) AS c FROM tasks WHERE job_id = ? GROUP BY state", - (job_id.to_wire(),), - ).fetchall() - counts = {int(r["state"]): int(r["c"]) for r in counts_rows} - total = sum(counts.values()) - new_state = current_state - now_ms = Timestamp.now().epoch_ms() - if total > 0 and counts.get(job_pb2.TASK_STATE_SUCCEEDED, 0) == total: - new_state = job_pb2.JOB_STATE_SUCCEEDED - elif counts.get(job_pb2.TASK_STATE_FAILED, 0) > max_task_failures: - new_state = job_pb2.JOB_STATE_FAILED - elif counts.get(job_pb2.TASK_STATE_UNSCHEDULABLE, 0) > 0: - new_state = job_pb2.JOB_STATE_UNSCHEDULABLE - elif counts.get(job_pb2.TASK_STATE_KILLED, 0) > 0: - new_state = job_pb2.JOB_STATE_KILLED - elif ( - total > 0 - and (counts.get(job_pb2.TASK_STATE_WORKER_FAILED, 0) + counts.get(job_pb2.TASK_STATE_PREEMPTED, 0)) > 0 - and all(s in TERMINAL_TASK_STATES for s in counts) - ): - new_state = job_pb2.JOB_STATE_WORKER_FAILED - elif ( - counts.get(job_pb2.TASK_STATE_ASSIGNED, 0) > 0 - or counts.get(job_pb2.TASK_STATE_BUILDING, 0) > 0 - or counts.get(job_pb2.TASK_STATE_RUNNING, 0) > 0 - ): - new_state = job_pb2.JOB_STATE_RUNNING - elif row["started_at_ms"] is not None: - # Retries put tasks back into PENDING; keep job running once it has started. - new_state = job_pb2.JOB_STATE_RUNNING - elif total > 0: - new_state = job_pb2.JOB_STATE_PENDING - if new_state == current_state: - return new_state - terminal_placeholders = ",".join("?" for _ in TERMINAL_JOB_STATES) - error_row = cur.execute( - "SELECT error FROM tasks WHERE job_id = ? AND error IS NOT NULL ORDER BY task_index LIMIT 1", - (job_id.to_wire(),), - ).fetchone() - error = str(error_row["error"]) if error_row is not None else None - cur.execute( - "UPDATE jobs SET state = ?, " - "started_at_ms = CASE WHEN ? = ? THEN COALESCE(started_at_ms, ?) ELSE started_at_ms END, " - f"finished_at_ms = CASE WHEN ? IN ({terminal_placeholders}) THEN ? ELSE finished_at_ms END, " - "error = CASE WHEN ? IN (?, ?, ?, ?) THEN ? ELSE error END " - "WHERE job_id = ?", - ( - new_state, - new_state, - job_pb2.JOB_STATE_RUNNING, - now_ms, - new_state, - *TERMINAL_JOB_STATES, - now_ms, - new_state, - job_pb2.JOB_STATE_FAILED, - job_pb2.JOB_STATE_KILLED, - job_pb2.JOB_STATE_UNSCHEDULABLE, - job_pb2.JOB_STATE_WORKER_FAILED, - error, - job_id.to_wire(), - ), - ) - return new_state - def replace_reservation_claims(self, claims: dict[WorkerId, ReservationClaim]) -> None: """Replace all reservation claims atomically.""" - with self._db.transaction() as cur: - cur.execute("DELETE FROM reservation_claims") - for worker_id, claim in claims.items(): - cur.execute( - "INSERT INTO reservation_claims(worker_id, job_id, entry_idx) VALUES (?, ?, ?)", - (str(worker_id), claim.job_id, claim.entry_idx), - ) + with self._stores.transact() as ctx: + ctx.dispatch.replace_claims(ctx.cur, {wid: (claim.job_id, claim.entry_idx) for wid, claim in claims.items()}) # ========================================================================= # Command API @@ -1031,28 +447,20 @@ def submit_job( actions: list[tuple[str, str, dict[str, object]]] = [] created_task_ids: list[JobName] = [] - with self._db.transaction() as cur: - row = cur.execute("SELECT value FROM meta WHERE key = 'last_submission_ms'").fetchone() - last_submission_ms = int(row["value"]) if row is not None else 0 + with self._stores.transact() as ctx: + last_submission_ms = self._db.get_counter("last_submission_ms", ctx.cur) effective_submission_ms = max(submitted_ms, last_submission_ms + 1) - if row is None: - cur.execute("INSERT INTO meta(key, value) VALUES ('last_submission_ms', ?)", (effective_submission_ms,)) - else: - cur.execute("UPDATE meta SET value = ? WHERE key = 'last_submission_ms'", (effective_submission_ms,)) + self._db.set_counter("last_submission_ms", effective_submission_ms, ctx.cur) parent_job_id = job_id.parent.to_wire() if job_id.parent is not None else None if parent_job_id is not None: - parent_exists = cur.execute("SELECT 1 FROM jobs WHERE job_id = ?", (parent_job_id,)).fetchone() - if parent_exists is None: + if not ctx.jobs.exists(ctx.cur, parent_job_id): parent_job_id = None root_submitted_ms = effective_submission_ms if parent_job_id is not None: - parent = cur.execute( - "SELECT root_submitted_at_ms FROM jobs WHERE job_id = ?", - (parent_job_id,), - ).fetchone() - if parent is not None: - root_submitted_ms = int(parent["root_submitted_at_ms"]) + parent_root = ctx.jobs.get_root_submitted_ms(ctx.cur, parent_job_id) + if parent_root is not None: + root_submitted_ms = parent_root deadline_epoch_ms: int | None = None if request.HasField("scheduling_timeout") and request.scheduling_timeout.milliseconds > 0: @@ -1062,36 +470,15 @@ def submit_job( .epoch_ms() ) - cur.execute( - "INSERT OR IGNORE INTO users(user_id, created_at_ms) VALUES (?, ?)", - (job_id.user, effective_submission_ms), - ) - # Create default user budget row alongside user creation. - budget_defaults = self._user_budget_defaults - cur.execute( - "INSERT OR IGNORE INTO user_budgets(user_id, budget_limit, max_band, updated_at_ms) " - "VALUES (?, ?, ?, ?)", - ( - job_id.user, - budget_defaults.budget_limit, - budget_defaults.max_band, - effective_submission_ms, - ), - ) + ctx.users.ensure_user_and_budget(ctx.cur, job_id.user, effective_submission_ms, self._user_budget_defaults) # Resolve priority band: use explicit request value, inherit from parent, or default to INTERACTIVE. requested_band = int(request.priority_band) if requested_band != job_pb2.PRIORITY_BAND_UNSPECIFIED: band_sort_key = requested_band elif parent_job_id is not None: - parent_band_row = cur.execute( - "SELECT priority_band FROM tasks WHERE job_id = ? LIMIT 1", - (parent_job_id,), - ).fetchone() - if parent_band_row is not None: - band_sort_key = parent_band_row["priority_band"] - else: - band_sort_key = job_pb2.PRIORITY_BAND_INTERACTIVE + parent_band = ctx.jobs.get_parent_band(ctx.cur, parent_job_id) + band_sort_key = parent_band if parent_band is not None else job_pb2.PRIORITY_BAND_INTERACTIVE else: band_sort_key = job_pb2.PRIORITY_BAND_INTERACTIVE @@ -1131,102 +518,86 @@ def submit_job( timeout_ms: int | None = int(request.timeout.milliseconds) if request.timeout.milliseconds > 0 else None job_name_lower = request.name.lower() - cur.execute( - "INSERT INTO jobs(" - "job_id, user_id, parent_job_id, root_job_id, depth, state, submitted_at_ms, " - "root_submitted_at_ms, started_at_ms, finished_at_ms, scheduling_deadline_epoch_ms, " - "error, exit_code, num_tasks, is_reservation_holder, name, has_reservation" - ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, NULL, ?, ?, ?, NULL, ?, 0, ?, ?)", - ( - job_id.to_wire(), - job_id.user, - parent_job_id, - job_id.root_job.to_wire(), - job_id.depth, - state, - effective_submission_ms, - root_submitted_ms, - finished_ms, - deadline_epoch_ms, - validation_error, - replicas, - job_name_lower, - has_reservation, + ctx.jobs.insert_job( + ctx.cur, + JobInsert( + job_id=job_id.to_wire(), + user_id=job_id.user, + parent_job_id=parent_job_id, + root_job_id=job_id.root_job.to_wire(), + depth=job_id.depth, + state=state, + submitted_at_ms=effective_submission_ms, + root_submitted_at_ms=root_submitted_ms, + finished_at_ms=finished_ms, + scheduling_deadline_epoch_ms=deadline_epoch_ms, + error=validation_error, + num_tasks=replicas, + is_reservation_holder=False, + name=job_name_lower, + has_reservation=has_reservation, ), ) - cur.execute( - "INSERT INTO job_config(" - "job_id, name, has_reservation, " - "res_cpu_millicores, res_memory_bytes, res_disk_bytes, res_device_json, " - "constraints_json, has_coscheduling, coscheduling_group_by, " - "scheduling_timeout_ms, max_task_failures, " - "entrypoint_json, environment_json, bundle_id, ports_json, " - "max_retries_failure, max_retries_preemption, timeout_ms, " - "preemption_policy, existing_job_policy, priority_band, " - "task_image, submit_argv_json, reservation_json, fail_if_exists" - ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - ( - job_id.to_wire(), - job_name_lower, - has_reservation, - res_cpu, - res_mem, - res_disk, - res_device, - constraints_json, - has_cosched, - cosched_group, - sched_timeout, - max_failures, - entrypoint_json, - environment_json, - request.bundle_id, - ports_json, - int(request.max_retries_failure), - int(request.max_retries_preemption), - timeout_ms, - int(request.preemption_policy), - int(request.existing_job_policy), - int(request.priority_band), - request.task_image, - json.dumps(list(request.submit_argv)), - reservation_json, - 1 if request.fail_if_exists else 0, + ctx.jobs.insert_job_config( + ctx.cur, + JobConfigInsert( + job_id=job_id.to_wire(), + name=job_name_lower, + has_reservation=has_reservation, + resources=ResourceSpec( + cpu_millicores=res_cpu, + memory_bytes=res_mem, + disk_bytes=res_disk, + device_json=res_device, + ), + constraints_json=constraints_json, + has_coscheduling=has_cosched, + coscheduling_group_by=cosched_group, + scheduling_timeout_ms=sched_timeout, + max_task_failures=max_failures, + entrypoint_json=entrypoint_json, + environment_json=environment_json, + bundle_id=request.bundle_id, + ports_json=ports_json, + max_retries_failure=int(request.max_retries_failure), + max_retries_preemption=int(request.max_retries_preemption), + timeout_ms=timeout_ms, + preemption_policy=int(request.preemption_policy), + existing_job_policy=int(request.existing_job_policy), + priority_band=int(request.priority_band), + task_image=request.task_image, + submit_argv_json=json.dumps(list(request.submit_argv)), + reservation_json=reservation_json, + fail_if_exists=1 if request.fail_if_exists else 0, ), ) - # Store workdir files in separate table. if request.entrypoint.workdir_files: - for filename, data in request.entrypoint.workdir_files.items(): - cur.execute( - "INSERT INTO job_workdir_files(job_id, filename, data) VALUES (?, ?, ?)", - (job_id.to_wire(), filename, data), - ) + ctx.jobs.insert_workdir_files( + ctx.cur, + job_id.to_wire(), + list(request.entrypoint.workdir_files.items()), + ) if validation_error is None: - insertion_base = self._db.next_sequence("task_priority_insertion", cur=cur) + insertion_base = self._db.next_sequence("task_priority_insertion", cur=ctx.cur) for idx in range(replicas): task_id = job_id.task(idx).to_wire() created_task_ids.append(JobName.from_wire(task_id)) - cur.execute( - "INSERT INTO tasks(" - "task_id, job_id, task_index, state, error, exit_code, submitted_at_ms, started_at_ms, " - "finished_at_ms, max_retries_failure, max_retries_preemption, failure_count, preemption_count, " - "current_attempt_id, priority_neg_depth, priority_root_submitted_ms, " - "priority_insertion, priority_band" - ") VALUES (?, ?, ?, ?, NULL, NULL, ?, NULL, NULL, ?, ?, 0, 0, -1, ?, ?, ?, ?)", - ( - task_id, - job_id.to_wire(), - idx, - job_pb2.TASK_STATE_PENDING, - effective_submission_ms, - int(request.max_retries_failure), - int(request.max_retries_preemption), - -job_id.depth, - root_submitted_ms, - insertion_base + idx, - band_sort_key, + ctx.tasks.insert_task( + ctx.cur, + TaskInsert( + task_id=task_id, + job_id=job_id.to_wire(), + task_index=idx, + state=job_pb2.TASK_STATE_PENDING, + submitted_at_ms=effective_submission_ms, + max_retries_failure=int(request.max_retries_failure), + max_retries_preemption=int(request.max_retries_preemption), + priority_neg_depth=-job_id.depth, + priority_root_submitted_ms=root_submitted_ms, + priority_insertion=insertion_base + idx, + priority_band=band_sort_key, ), ) if request.HasField("reservation") and request.reservation.entries: @@ -1253,122 +624,98 @@ def submit_job( holder_res_device = proto_to_json(holder_res.device) if holder_res else None holder_constraints_json = constraints_to_json(holder_request.constraints) holder_name_lower = holder_request.name.lower() - cur.execute( - "INSERT INTO jobs(" - "job_id, user_id, parent_job_id, root_job_id, depth, state, submitted_at_ms, " - "root_submitted_at_ms, started_at_ms, finished_at_ms, scheduling_deadline_epoch_ms, " - "error, exit_code, num_tasks, is_reservation_holder, name, has_reservation" - ") VALUES (" - "?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL, NULL, NULL, NULL, ?, 1, ?, 0" - ")", - ( - holder_id.to_wire(), - holder_id.user, - job_id.to_wire(), - holder_id.root_job.to_wire(), - holder_id.depth, - job_pb2.JOB_STATE_PENDING, - effective_submission_ms, - root_submitted_ms, - len(request.reservation.entries), - holder_name_lower, + ctx.jobs.insert_job( + ctx.cur, + JobInsert( + job_id=holder_id.to_wire(), + user_id=holder_id.user, + parent_job_id=job_id.to_wire(), + root_job_id=holder_id.root_job.to_wire(), + depth=holder_id.depth, + state=job_pb2.JOB_STATE_PENDING, + submitted_at_ms=effective_submission_ms, + root_submitted_at_ms=root_submitted_ms, + finished_at_ms=None, + scheduling_deadline_epoch_ms=None, + error=None, + num_tasks=len(request.reservation.entries), + is_reservation_holder=True, + name=holder_name_lower, + has_reservation=False, ), ) holder_entrypoint_json = entrypoint_to_json(holder_request.entrypoint) holder_environment_json = proto_to_json(holder_request.environment) - cur.execute( - "INSERT INTO job_config(" - "job_id, name, has_reservation, " - "res_cpu_millicores, res_memory_bytes, res_disk_bytes, res_device_json, " - "constraints_json, has_coscheduling, coscheduling_group_by, " - "scheduling_timeout_ms, max_task_failures, " - "entrypoint_json, environment_json, bundle_id, ports_json, " - "max_retries_failure, max_retries_preemption, timeout_ms, " - "preemption_policy, existing_job_policy, priority_band, " - "task_image, reservation_json" - ") VALUES (" - "?, ?, 0, ?, ?, ?, ?, ?, 0, '', NULL, 0, " - "?, ?, '', '[]', 0, ?, NULL, 0, 0, 0, '', NULL" - ")", - ( - holder_id.to_wire(), - holder_name_lower, - holder_res_cpu, - holder_res_mem, - holder_res_disk, - holder_res_device, - holder_constraints_json, - holder_entrypoint_json, - holder_environment_json, - DEFAULT_MAX_RETRIES_PREEMPTION, + ctx.jobs.insert_job_config( + ctx.cur, + JobConfigInsert( + job_id=holder_id.to_wire(), + name=holder_name_lower, + has_reservation=False, + resources=ResourceSpec( + cpu_millicores=holder_res_cpu, + memory_bytes=holder_res_mem, + disk_bytes=holder_res_disk, + device_json=holder_res_device, + ), + constraints_json=holder_constraints_json, + has_coscheduling=0, + coscheduling_group_by="", + scheduling_timeout_ms=None, + max_task_failures=0, + entrypoint_json=holder_entrypoint_json, + environment_json=holder_environment_json, + bundle_id="", + ports_json="[]", + max_retries_failure=0, + max_retries_preemption=DEFAULT_MAX_RETRIES_PREEMPTION, + timeout_ms=None, + preemption_policy=0, + existing_job_policy=0, + priority_band=0, + task_image="", ), ) - holder_base = self._db.next_sequence("task_priority_insertion", cur=cur) + holder_base = self._db.next_sequence("task_priority_insertion", cur=ctx.cur) for idx in range(len(request.reservation.entries)): created_task_ids.append(holder_id.task(idx)) - cur.execute( - "INSERT INTO tasks(" - "task_id, job_id, task_index, state, error, exit_code, submitted_at_ms, started_at_ms, " - "finished_at_ms, max_retries_failure, max_retries_preemption, " - "failure_count, preemption_count, " - "current_attempt_id, priority_neg_depth, priority_root_submitted_ms, " - "priority_insertion, priority_band" - ") VALUES (?, ?, ?, ?, NULL, NULL, ?, NULL, NULL, ?, ?, 0, 0, -1, ?, ?, ?, ?)", - ( - holder_id.task(idx).to_wire(), - holder_id.to_wire(), - idx, - job_pb2.TASK_STATE_PENDING, - effective_submission_ms, - 0, - DEFAULT_MAX_RETRIES_PREEMPTION, - -holder_id.depth, - root_submitted_ms, - holder_base + idx, - band_sort_key, + ctx.tasks.insert_task( + ctx.cur, + TaskInsert( + task_id=holder_id.task(idx).to_wire(), + job_id=holder_id.to_wire(), + task_index=idx, + state=job_pb2.TASK_STATE_PENDING, + submitted_at_ms=effective_submission_ms, + max_retries_failure=0, + max_retries_preemption=DEFAULT_MAX_RETRIES_PREEMPTION, + priority_neg_depth=-holder_id.depth, + priority_root_submitted_ms=root_submitted_ms, + priority_insertion=holder_base + idx, + priority_band=band_sort_key, ), ) actions.append(("job_submitted", job_id.to_wire(), {"num_tasks": replicas, "error": validation_error})) - self._record_transaction(cur, "submit_job", actions) + self._record_transaction(ctx.cur, "submit_job", actions) return SubmitJobResult(job_id=job_id, task_ids=created_task_ids) def cancel_job(self, job_id: JobName, reason: str) -> TxResult: """Cancel a job tree and return tasks that need kill RPCs.""" - with self._db.transaction() as cur: - subtree = cur.execute( - "WITH RECURSIVE subtree(job_id) AS (" - " SELECT job_id FROM jobs WHERE job_id = ? " - " UNION ALL " - " SELECT j.job_id FROM jobs j JOIN subtree s ON j.parent_job_id = s.job_id" - ") SELECT job_id FROM subtree", - (job_id.to_wire(),), - ).fetchall() - if not subtree: + with self._stores.transact() as ctx: + subtree_ids = ctx.jobs.get_subtree_ids(ctx.cur, job_id.to_wire()) + if not subtree_ids: return TxResult() - subtree_ids = [str(row["job_id"]) for row in subtree] - placeholders = ",".join("?" for _ in subtree_ids) - running_rows = cur.execute( - f"SELECT t.task_id, t.current_worker_id AS worker_id, " - f"j.is_reservation_holder, " - f"jc.res_cpu_millicores, jc.res_memory_bytes, jc.res_disk_bytes, jc.res_device_json " - f"FROM tasks t " - f"JOIN jobs j ON j.job_id = t.job_id " - f"{JOB_CONFIG_JOIN} " - f"WHERE t.job_id IN ({placeholders}) " - "AND t.state IN (?, ?, ?)", - ( - *subtree_ids, - job_pb2.TASK_STATE_ASSIGNED, - job_pb2.TASK_STATE_BUILDING, - job_pb2.TASK_STATE_RUNNING, - ), - ).fetchall() - tasks_to_kill = {JobName.from_wire(str(row["task_id"])) for row in running_rows} + running_rows = ctx.tasks.query( + ctx.cur, + TaskFilter(job_ids=tuple(subtree_ids), states=ACTIVE_TASK_STATES), + projection=TaskProjection.WITH_JOB_CONFIG, + ) + tasks_to_kill = {row.task_id for row in running_rows} task_kill_workers = { - JobName.from_wire(str(row["task_id"])): WorkerId(str(row["worker_id"])) + row.task_id: WorkerId(str(row.current_worker_id)) for row in running_rows - if row["worker_id"] is not None + if row.current_worker_id is not None } # Decommit resources for each active task on its assigned worker. # cancel_job marks tasks as KILLED, but apply_heartbeat skips @@ -1376,45 +723,19 @@ def cancel_job(self, job_id: JobName, reason: str) -> TxResult: # heartbeat decommit path never fires for cancelled tasks. # Direct-provider tasks have NULL worker_id — skip decommit for them. for row in running_rows: - if row["worker_id"] is not None and not int(row["is_reservation_holder"]): + if row.current_worker_id is not None and not row.is_reservation_holder: resources = resource_spec_from_scalars( - int(row["res_cpu_millicores"]), - int(row["res_memory_bytes"]), - int(row["res_disk_bytes"]), - row["res_device_json"], + row.resources.cpu_millicores, + row.resources.memory_bytes, + row.resources.disk_bytes, + row.resources.device_json, ) - _decommit_worker_resources(cur, str(row["worker_id"]), resources) + ctx.workers.decommit_resources(ctx.cur, str(row.current_worker_id), resources) now_ms = Timestamp.now().epoch_ms() - task_terminal_placeholders = ",".join("?" for _ in TERMINAL_TASK_STATES) - cur.execute( - f"UPDATE tasks SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?), " - f"current_worker_id = NULL, current_worker_address = NULL " - f"WHERE job_id IN ({placeholders}) AND state NOT IN ({task_terminal_placeholders})", - ( - job_pb2.TASK_STATE_KILLED, - reason, - now_ms, - *subtree_ids, - *TERMINAL_TASK_STATES, - ), - ) - # Deliberately excludes JOB_STATE_WORKER_FAILED from the guard set: - # worker-failed jobs should still be cancellable (transitioned to KILLED). - cancel_guard_states = TERMINAL_JOB_STATES - {job_pb2.JOB_STATE_WORKER_FAILED} - cancel_guard_placeholders = ",".join("?" for _ in cancel_guard_states) - cur.execute( - f"UPDATE jobs SET state = ?, error = ?, finished_at_ms = COALESCE(finished_at_ms, ?) " - f"WHERE job_id IN ({placeholders}) AND state NOT IN ({cancel_guard_placeholders})", - ( - job_pb2.JOB_STATE_KILLED, - reason, - now_ms, - *subtree_ids, - *cancel_guard_states, - ), - ) - self._db.endpoints.remove_by_job_ids(cur, [JobName.from_wire(jid) for jid in subtree_ids]) - self._record_transaction(cur, "cancel_job", [("job_cancelled", job_id.to_wire(), {"reason": reason})]) + ctx.tasks.bulk_cancel(ctx.cur, subtree_ids, reason, now_ms) + ctx.jobs.bulk_cancel(ctx.cur, subtree_ids, reason, now_ms) + ctx.endpoints.remove_by_job_ids(ctx.cur, [JobName.from_wire(jid) for jid in subtree_ids]) + self._record_transaction(ctx.cur, "cancel_job", [("job_cancelled", job_id.to_wire(), {"reason": reason})]) return TxResult(tasks_to_kill=tasks_to_kill, task_kill_workers=task_kill_workers) def register_or_refresh_worker( @@ -1448,87 +769,45 @@ def register_or_refresh_worker( else: device_type = "" device_variant = "" - with self._db.transaction() as cur: - md_device_json = proto_to_json(metadata.device) - cur.execute( - "INSERT INTO workers(" - "worker_id, address, healthy, active, consecutive_failures, last_heartbeat_ms, " - "committed_cpu_millicores, committed_mem_bytes, committed_gpu, committed_tpu, " - "total_cpu_millicores, total_memory_bytes, total_gpu_count, total_tpu_count, " - "device_type, device_variant, slice_id, scale_group, " - "md_hostname, md_ip_address, md_cpu_count, md_memory_bytes, md_disk_bytes, " - "md_tpu_name, md_tpu_worker_hostnames, md_tpu_worker_id, md_tpu_chips_per_host_bounds, " - "md_gpu_count, md_gpu_name, md_gpu_memory_mb, " - "md_gce_instance_name, md_gce_zone, md_git_hash, md_device_json" - ") VALUES (?, ?, 1, 1, 0, ?, 0, 0, 0, 0, ?, ?, ?, ?, ?, ?, ?, ?, " - "?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) " - "ON CONFLICT(worker_id) DO UPDATE SET " - "address=excluded.address, healthy=1, active=1, " - "consecutive_failures=0, last_heartbeat_ms=excluded.last_heartbeat_ms, " - "total_cpu_millicores=excluded.total_cpu_millicores, total_memory_bytes=excluded.total_memory_bytes, " - "total_gpu_count=excluded.total_gpu_count, total_tpu_count=excluded.total_tpu_count, " - "device_type=excluded.device_type, device_variant=excluded.device_variant, " - "slice_id=excluded.slice_id, scale_group=excluded.scale_group, " - "md_hostname=excluded.md_hostname, md_ip_address=excluded.md_ip_address, " - "md_cpu_count=excluded.md_cpu_count, md_memory_bytes=excluded.md_memory_bytes, " - "md_disk_bytes=excluded.md_disk_bytes, md_tpu_name=excluded.md_tpu_name, " - "md_tpu_worker_hostnames=excluded.md_tpu_worker_hostnames, " - "md_tpu_worker_id=excluded.md_tpu_worker_id, " - "md_tpu_chips_per_host_bounds=excluded.md_tpu_chips_per_host_bounds, " - "md_gpu_count=excluded.md_gpu_count, md_gpu_name=excluded.md_gpu_name, " - "md_gpu_memory_mb=excluded.md_gpu_memory_mb, " - "md_gce_instance_name=excluded.md_gce_instance_name, md_gce_zone=excluded.md_gce_zone, " - "md_git_hash=excluded.md_git_hash, md_device_json=excluded.md_device_json", - ( - str(worker_id), - address, - now_ms, - metadata.cpu_count * 1000, - metadata.memory_bytes, - gpu_count, - tpu_count, - device_type, - device_variant, - slice_id, - scale_group, - metadata.hostname, - metadata.ip_address, - metadata.cpu_count, - metadata.memory_bytes, - metadata.disk_bytes, - metadata.tpu_name, - metadata.tpu_worker_hostnames, - metadata.tpu_worker_id, - metadata.tpu_chips_per_host_bounds, - metadata.gpu_count, - metadata.gpu_name, - metadata.gpu_memory_mb, - metadata.gce_instance_name, - metadata.gce_zone, - metadata.git_hash, - md_device_json, + with self._stores.transact() as ctx: + ctx.workers.upsert( + ctx.cur, + WorkerUpsert( + worker_id=str(worker_id), + address=address, + now_ms=now_ms, + total_cpu_millicores=metadata.cpu_count * 1000, + total_memory_bytes=metadata.memory_bytes, + total_gpu_count=gpu_count, + total_tpu_count=tpu_count, + device_type=device_type, + device_variant=device_variant, + slice_id=slice_id, + scale_group=scale_group, + metadata=WorkerMetadata( + hostname=metadata.hostname, + ip_address=metadata.ip_address, + cpu_count=metadata.cpu_count, + memory_bytes=metadata.memory_bytes, + disk_bytes=metadata.disk_bytes, + tpu_name=metadata.tpu_name, + tpu_worker_hostnames=metadata.tpu_worker_hostnames, + tpu_worker_id=metadata.tpu_worker_id, + tpu_chips_per_host_bounds=metadata.tpu_chips_per_host_bounds, + gpu_count=metadata.gpu_count, + gpu_name=metadata.gpu_name, + gpu_memory_mb=metadata.gpu_memory_mb, + gce_instance_name=metadata.gce_instance_name, + gce_zone=metadata.gce_zone, + git_hash=metadata.git_hash, + device_json=proto_to_json(metadata.device), + ), + attributes=attrs, ), ) - cur.execute("DELETE FROM worker_attributes WHERE worker_id = ?", (str(worker_id),)) - for key, value_type, str_value, int_value, float_value in attrs: - cur.execute( - "INSERT INTO worker_attributes(worker_id, key, value_type, str_value, int_value, float_value) " - "VALUES (?, ?, ?, ?, ?, ?)", - (str(worker_id), key, value_type, str_value, int_value, float_value), - ) self._record_transaction( - cur, "register_worker", [("worker_registered", str(worker_id), {"address": address})] + ctx.cur, "register_worker", [("worker_registered", str(worker_id), {"address": address})] ) - # Update in-memory attribute cache so scheduling sees the new worker immediately. - attr_dict: dict[str, AttributeValue] = {} - for key, value_type, str_value, int_value, float_value in attrs: - if value_type == "int": - attr_dict[key] = AttributeValue(int(int_value)) - elif value_type == "float": - attr_dict[key] = AttributeValue(float(float_value)) - else: - attr_dict[key] = AttributeValue(str(str_value or "")) - self._db.set_worker_attributes(worker_id, attr_dict) return TxResult() def register_worker( @@ -1555,20 +834,13 @@ def queue_assignments(self, assignments: list[Assignment]) -> AssignmentResult: accepted: list[Assignment] = [] rejected: list[Assignment] = [] has_real_dispatch = False - with self._db.transaction() as cur: + with self._stores.transact() as ctx: now_ms = Timestamp.now().epoch_ms() job_cache: dict[str, JobDetailRow] = {} jobs_to_update: set[str] = set() for assignment in assignments: - task_row = cur.execute( - f"SELECT {TASK_DETAIL_PROJECTION.select_clause()} " "FROM tasks t WHERE t.task_id = ?", - (assignment.task_id.to_wire(),), - ).fetchone() - worker_row = cur.execute( - "SELECT worker_id, address, active, healthy " - "FROM workers WHERE worker_id = ? AND active = 1 AND healthy = 1", - (str(assignment.worker_id),), - ).fetchone() + task_row = ctx.tasks.get_for_assignment(ctx.cur, assignment.task_id.to_wire()) + worker_row = ctx.workers.get_healthy_active(ctx.cur, str(assignment.worker_id)) if task_row is None or worker_row is None: rejected.append(assignment) continue @@ -1578,56 +850,34 @@ def queue_assignments(self, assignments: list[Assignment]) -> AssignmentResult: continue job_id_wire = task.job_id.to_wire() if job_id_wire not in job_cache: - job_row = cur.execute( - f"SELECT {JOB_DETAIL_PROJECTION.select_clause()} " - f"FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ?", - (job_id_wire,), - ).fetchone() - if job_row is None: - rejected.append(assignment) - continue - decoded_job = JOB_DETAIL_PROJECTION.decode_one([job_row]) + decoded_job = ctx.jobs.get_job_detail(ctx.cur, job_id_wire) if decoded_job is None: rejected.append(assignment) continue job_cache[job_id_wire] = decoded_job job = job_cache[job_id_wire] attempt_id = int(task_row["current_attempt_id"]) + 1 - _assign_task( - cur, - assignment.task_id.to_wire(), - str(assignment.worker_id), - str(worker_row["address"]), - attempt_id, - now_ms, + ctx.tasks.assign_to_worker( + ctx.cur, + WorkerAssignment( + task_id=assignment.task_id.to_wire(), + attempt_id=attempt_id, + worker_id=str(assignment.worker_id), + worker_address=str(worker_row["address"]), + now_ms=now_ms, + ), ) if not job.is_reservation_holder: resources = resource_spec_from_scalars( - job.res_cpu_millicores, - job.res_memory_bytes, - job.res_disk_bytes, - job.res_device_json, - ) - cur.execute( - "UPDATE workers SET committed_cpu_millicores = committed_cpu_millicores + ?, " - "committed_mem_bytes = committed_mem_bytes + ?, committed_gpu = committed_gpu + ?, " - "committed_tpu = committed_tpu + ? WHERE worker_id = ?", - ( - int(resources.cpu_millicores), - int(resources.memory_bytes), - int(get_gpu_count(resources.device)), - int(get_tpu_count(resources.device)), - str(assignment.worker_id), - ), + job.resources.cpu_millicores, + job.resources.memory_bytes, + job.resources.disk_bytes, + job.resources.device_json, ) + ctx.workers.commit_resources(ctx.cur, str(assignment.worker_id), resources) entrypoint = proto_from_json(job.entrypoint_json, job_pb2.RuntimeEntrypoint) - # Load inline workdir files from the job_workdir_files table. - wf_rows = cur.execute( - "SELECT filename, data FROM job_workdir_files WHERE job_id = ?", - (job_id_wire,), - ).fetchall() - for wf_row in wf_rows: - entrypoint.workdir_files[wf_row["filename"]] = bytes(wf_row["data"]) + for fn, data in ctx.jobs.get_workdir_files(ctx.cur, job_id_wire).items(): + entrypoint.workdir_files[fn] = data run_request = job_pb2.RunTaskRequest( task_id=assignment.task_id.to_wire(), num_tasks=job.num_tasks, @@ -1640,38 +890,33 @@ def queue_assignments(self, assignments: list[Assignment]) -> AssignmentResult: constraints=[c.to_proto() for c in constraints_from_json(job.constraints_json)], task_image=job.task_image, ) - enqueue_run_dispatch(cur, str(assignment.worker_id), run_request.SerializeToString(), now_ms) + ctx.dispatch.enqueue_run(ctx.cur, str(assignment.worker_id), run_request.SerializeToString(), now_ms) has_real_dispatch = True - cur.execute( - "INSERT INTO worker_task_history(worker_id, task_id, assigned_at_ms) VALUES (?, ?, ?)", - (str(assignment.worker_id), assignment.task_id.to_wire(), now_ms), + ctx.workers.record_worker_task_history( + ctx.cur, str(assignment.worker_id), assignment.task_id.to_wire(), now_ms ) jobs_to_update.add(job_id_wire) accepted.append(assignment) for job_id_wire in jobs_to_update: - cur.execute( - "UPDATE jobs SET state = CASE WHEN state = ? THEN ? ELSE state END, " - "started_at_ms = COALESCE(started_at_ms, ?) WHERE job_id = ?", - (job_pb2.JOB_STATE_PENDING, job_pb2.JOB_STATE_RUNNING, now_ms, job_id_wire), - ) + ctx.jobs.start_if_pending(ctx.cur, job_id_wire, now_ms) if accepted or rejected: actions = [("assignment_queued", a.task_id.to_wire(), {"worker_id": str(a.worker_id)}) for a in accepted] - self._record_transaction(cur, "queue_assignments", actions) + self._record_transaction(ctx.cur, "queue_assignments", actions) return AssignmentResult( tasks_to_kill=set(), has_real_dispatch=has_real_dispatch, accepted=accepted, rejected=rejected ) - def _update_worker_health(self, cur: TransactionCursor, req: HeartbeatApplyRequest, now_ms: int) -> bool: + def _update_worker_health(self, ctx: ControllerStore, req: HeartbeatApplyRequest, now_ms: int) -> bool: """Update worker health, resource snapshot, and history. Returns False if the worker doesn't exist (caller should bail). """ - existing = _batch_worker_health(cur, [req], now_ms) + existing = ctx.workers.update_health_batch(ctx.cur, [req], now_ms) return str(req.worker_id) in existing def _apply_task_transitions( self, - cur: TransactionCursor, + ctx: ControllerStore, req: HeartbeatApplyRequest, now_ms: int, ) -> TxResult: @@ -1685,84 +930,69 @@ def _apply_task_transitions( task_kill_workers: dict[JobName, WorkerId] = {} cascaded_jobs: set[JobName] = set() jobs_to_recompute: set[JobName] = set() - # Cache job_config rows keyed by job_id wire format. - job_config_cache: dict[str, dict | None] = {} for update in req.updates: - task_row = cur.execute("SELECT * FROM tasks WHERE task_id = ?", (update.task_id.to_wire(),)).fetchone() - if task_row is None: + snapshot = ctx.tasks.get_task(ctx.cur, update.task_id) + if snapshot is None: continue - task = TASK_DETAIL_PROJECTION.decode_one([task_row]) - if task_row_is_finished(task) or update.new_state in ( + if snapshot.state in TERMINAL_TASK_STATES or update.new_state in ( job_pb2.TASK_STATE_UNSPECIFIED, job_pb2.TASK_STATE_PENDING, ): continue - if update.attempt_id != int(task_row["current_attempt_id"]): - stale = cur.execute( - "SELECT state FROM task_attempts WHERE task_id = ? AND attempt_id = ?", - (update.task_id.to_wire(), update.attempt_id), - ).fetchone() - if stale is not None and int(stale["state"]) not in TERMINAL_TASK_STATES: + if update.attempt_id != snapshot.attempt_id: + stale_state = ctx.tasks.get_attempt_state(ctx.cur, update.task_id.to_wire(), update.attempt_id) + if stale_state is not None and stale_state not in TERMINAL_TASK_STATES: logger.error( "Stale attempt precondition violation: task=%s reported=%d current=%d stale_state=%s", update.task_id, update.attempt_id, - int(task_row["current_attempt_id"]), - int(stale["state"]), + snapshot.attempt_id, + int(stale_state), ) continue - prior_state = int(task_row["state"]) + prior_state = int(snapshot.state) # Fast path: task already in the reported state with no new data to apply. has_new_data = update.error is not None or update.exit_code is not None or update.resource_usage is not None if update.new_state == prior_state and not has_new_data: continue - attempt_row = cur.execute( - "SELECT * FROM task_attempts WHERE task_id = ? AND attempt_id = ?", - (update.task_id.to_wire(), update.attempt_id), - ).fetchone() - if attempt_row is None: - continue # The attempt is already terminal (e.g. preempted, killed) but the task has # been rolled back to PENDING for retry and current_attempt_id still points # at the dead attempt. Reviving it would produce an inconsistent row where # state contradicts finished_at_ms/error. - if int(attempt_row["state"]) in TERMINAL_TASK_STATES: + if snapshot.attempt_state in TERMINAL_TASK_STATES: logger.debug( "Dropping late update for terminal attempt: task=%s attempt=%d attempt_state=%d reported=%d", update.task_id, update.attempt_id, - int(attempt_row["state"]), + int(snapshot.attempt_state), int(update.new_state), ) continue - worker_id = attempt_row["worker_id"] + if update.resource_usage is not None: - ru = update.resource_usage - cur.execute( - "INSERT INTO task_resource_history" - "(task_id, attempt_id, cpu_millicores, memory_mb, disk_mb, memory_peak_mb, timestamp_ms) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - ( - update.task_id.to_wire(), - update.attempt_id, - ru.cpu_millicores, - ru.memory_mb, - ru.disk_mb, - ru.memory_peak_mb, - now_ms, - ), + ctx.tasks.insert_resource_usage( + ctx.cur, + update.task_id.to_wire(), + update.attempt_id, + update.resource_usage, + now_ms, ) - terminal_ms: int | None = None - started_ms: int | None = None + + if update.container_id is not None: + ctx.tasks.update_container_id(ctx.cur, update.task_id.to_wire(), update.container_id) + + # --- Inline retry logic (PR 2 will extract to resolve_transition) --- task_state = prior_state task_error = update.error task_exit = update.exit_code - failure_count = int(task_row["failure_count"]) - preemption_count = int(task_row["preemption_count"]) + failure_count = snapshot.failure_count + preemption_count = snapshot.preemption_count + started_ms: int | None = None + is_terminal_update = False if update.new_state == job_pb2.TASK_STATE_RUNNING: started_ms = now_ms @@ -1776,7 +1006,7 @@ def _apply_task_transitions( job_pb2.TASK_STATE_UNSCHEDULABLE, job_pb2.TASK_STATE_SUCCEEDED, ): - terminal_ms = now_ms + is_terminal_update = True task_state = int(update.new_state) if update.new_state == job_pb2.TASK_STATE_SUCCEEDED and task_exit is None: task_exit = 0 @@ -1788,143 +1018,108 @@ def _apply_task_transitions( preemption_count += 1 if update.new_state == job_pb2.TASK_STATE_WORKER_FAILED and prior_state == job_pb2.TASK_STATE_ASSIGNED: task_state = job_pb2.TASK_STATE_PENDING - terminal_ms = None - if update.new_state == job_pb2.TASK_STATE_FAILED and failure_count <= int( - task_row["max_retries_failure"] - ): + if update.new_state == job_pb2.TASK_STATE_FAILED and failure_count <= snapshot.max_retries_failure: task_state = job_pb2.TASK_STATE_PENDING - terminal_ms = None if ( update.new_state == job_pb2.TASK_STATE_WORKER_FAILED - and preemption_count <= int(task_row["max_retries_preemption"]) + and preemption_count <= snapshot.max_retries_preemption and prior_state in EXECUTING_TASK_STATES ): task_state = job_pb2.TASK_STATE_PENDING - terminal_ms = None - # An attempt is terminal whenever the update itself is terminal, even - # if the TASK rolls back to PENDING for a retry. terminal_ms above - # tracks the task's finished_at_ms; the attempt needs its own stamp. - attempt_terminal_ms = now_ms if int(update.new_state) in TERMINAL_TASK_STATES else None - - cur.execute( - "UPDATE task_attempts SET state = ?, started_at_ms = COALESCE(started_at_ms, ?), " - "finished_at_ms = COALESCE(finished_at_ms, ?), exit_code = COALESCE(?, exit_code), " - "error = COALESCE(?, error) WHERE task_id = ? AND attempt_id = ?", - ( - int(update.new_state), - started_ms, - attempt_terminal_ms, - task_exit, - update.error, - update.task_id.to_wire(), - update.attempt_id, - ), - ) - # Clear denormalized worker columns when task leaves active state. + # --- Apply writes through store methods --- if task_state in ACTIVE_TASK_STATES: - cur.execute( - "UPDATE tasks SET state = ?, error = COALESCE(?, error), exit_code = COALESCE(?, exit_code), " - "started_at_ms = COALESCE(started_at_ms, ?), finished_at_ms = ?, " - "failure_count = ?, preemption_count = ? " - "WHERE task_id = ?", - ( - task_state, - task_error, - task_exit, - started_ms, - terminal_ms, - failure_count, - preemption_count, - update.task_id.to_wire(), + ctx.tasks.update_active( + ctx.cur, + ActiveStateUpdate( + task_id=update.task_id.to_wire(), + attempt_id=update.attempt_id, + state=task_state, + error=task_error, + exit_code=task_exit, + started_ms=started_ms, + failure_count=failure_count, + preemption_count=preemption_count, ), ) - else: - cur.execute( - "UPDATE tasks SET state = ?, error = COALESCE(?, error), exit_code = COALESCE(?, exit_code), " - "started_at_ms = COALESCE(started_at_ms, ?), finished_at_ms = ?, " - "failure_count = ?, preemption_count = ?, " - "current_worker_id = NULL, current_worker_address = NULL " - "WHERE task_id = ?", - ( - task_state, - task_error, - task_exit, - started_ms, - terminal_ms, - failure_count, - preemption_count, - update.task_id.to_wire(), + elif task_state == job_pb2.TASK_STATE_PENDING: + ctx.tasks.requeue( + ctx.cur, + TaskRetry( + task_id=update.task_id.to_wire(), + finalize=AttemptFinalizer.build( + update.task_id.to_wire(), update.attempt_id, int(update.new_state), now_ms + ), + worker_id=snapshot.worker_id, + resources=snapshot.resources, + failure_count=failure_count, + preemption_count=preemption_count, + ), + ) + elif is_terminal_update: + ctx.tasks.terminate( + ctx.cur, + TaskTermination( + task_id=update.task_id.to_wire(), + state=task_state, + now_ms=now_ms, + error=task_error, + finalize=AttemptFinalizer( + task_id=update.task_id.to_wire(), + attempt_id=update.attempt_id, + attempt_state=task_state, + now_ms=now_ms, + error=task_error, + exit_code=task_exit, + ), + worker_id=snapshot.worker_id, + resources=snapshot.resources, + failure_count=failure_count, + preemption_count=preemption_count, ), ) - - # Fetch and cache job_config row (avoids re-querying per task in same job). - job_id_wire = task.job_id.to_wire() - if job_id_wire not in job_config_cache: - jc_row = cur.execute("SELECT * FROM job_config WHERE job_id = ?", (job_id_wire,)).fetchone() - job_config_cache[job_id_wire] = dict(jc_row) if jc_row is not None else None - jc = job_config_cache[job_id_wire] - - if worker_id is not None and task_state not in ACTIVE_TASK_STATES: - if jc is not None: - resources = resource_spec_from_scalars( - int(jc["res_cpu_millicores"]), - int(jc["res_memory_bytes"]), - int(jc["res_disk_bytes"]), - jc["res_device_json"], - ) - _decommit_worker_resources(cur, str(worker_id), resources) - - if update.new_state in TERMINAL_TASK_STATES: - delete_task_endpoints(cur, self._db.endpoints, update.task_id.to_wire()) # Coscheduled jobs: a terminal host failure should cascade to siblings. - if jc is not None and task_state in FAILURE_TASK_STATES: - has_cosched = bool(int(jc["has_coscheduling"])) - siblings = _find_coscheduled_siblings(cur, task.job_id, update.task_id, has_cosched) - resources = resource_spec_from_scalars( - int(jc["res_cpu_millicores"]), - int(jc["res_memory_bytes"]), - int(jc["res_disk_bytes"]), - jc["res_device_json"], - ) - cascade_kill, cascade_workers = _terminate_coscheduled_siblings( - cur, self._db.endpoints, siblings, update.task_id, resources, now_ms + if task_state in FAILURE_TASK_STATES and snapshot.has_coscheduling: + siblings = ctx.tasks.find_coscheduled_siblings( + ctx.cur, snapshot.job_id, update.task_id, snapshot.has_coscheduling ) - tasks_to_kill.update(cascade_kill) - task_kill_workers.update(cascade_workers) + if siblings and snapshot.resources is not None: + kill_result = ctx.tasks.terminate_coscheduled_siblings( + ctx.cur, siblings, update.task_id, snapshot.resources, now_ms + ) + tasks_to_kill.update(kill_result.tasks_to_kill) + task_kill_workers.update(kill_result.task_kill_workers) # Mark job for recomputation (deduplicated, done after the task loop). if task_state != prior_state: - jobs_to_recompute.add(task.job_id) + jobs_to_recompute.add(snapshot.job_id) # Recompute job states once per job instead of once per task. for job_id in jobs_to_recompute: if job_id in cascaded_jobs: continue - new_job_state = self._recompute_job_state(cur, job_id) - if new_job_state in TERMINAL_JOB_STATES: - final_tasks_to_kill, final_task_kill_workers = _finalize_terminal_job( - cur, self._db.endpoints, job_id, new_job_state, now_ms - ) - tasks_to_kill.update(final_tasks_to_kill) - task_kill_workers.update(final_task_kill_workers) + new_job_state = ctx.jobs.recompute_state(ctx.cur, job_id) + if new_job_state is not None and new_job_state in TERMINAL_JOB_STATES: + kill_result = _finalize_terminal_job(ctx, job_id, new_job_state, now_ms) + tasks_to_kill.update(kill_result.tasks_to_kill) + task_kill_workers.update(kill_result.task_kill_workers) cascaded_jobs.add(job_id) if tasks_to_kill or cascaded_jobs: actions: list[tuple[str, str, dict[str, object]]] = [("heartbeat_applied", str(req.worker_id), {})] for job_id in cascaded_jobs: actions.append(("job_terminated", job_id.to_wire(), {})) - self._record_transaction(cur, "apply_task_updates", actions) + self._record_transaction(ctx.cur, "apply_task_updates", actions) return TxResult(tasks_to_kill=tasks_to_kill, task_kill_workers=task_kill_workers) def apply_task_updates(self, req: HeartbeatApplyRequest) -> TxResult: """Apply a batch of worker task updates atomically.""" - with self._db.transaction() as cur: + with self._stores.transact() as ctx: now_ms = Timestamp.now().epoch_ms() - if not self._update_worker_health(cur, req, now_ms): + if not self._update_worker_health(ctx, req, now_ms): return TxResult() - result = self._apply_task_transitions(cur, req, now_ms) + result = self._apply_task_transitions(ctx, req, now_ms) return result @@ -1944,11 +1139,11 @@ def apply_heartbeats_batch(self, requests: list[HeartbeatApplyRequest]) -> list[ _empty = HeartbeatApplyResult(tasks_to_kill=set(), action=HeartbeatAction.OK) results: list[HeartbeatApplyResult] = [_empty] * len(requests) - with self._db.transaction() as cur: + with self._stores.transact() as ctx: now_ms = Timestamp.now().epoch_ms() # ── Batch worker health updates ─────────────────────────────── - existing_workers = _batch_worker_health(cur, requests, now_ms) + existing_workers = ctx.workers.update_health_batch(ctx.cur, requests, now_ms) # ── Bulk-fetch task rows for classification ─────────────────── all_task_ids: list[str] = [] @@ -1962,7 +1157,8 @@ def apply_heartbeats_batch(self, requests: list[HeartbeatApplyRequest]) -> list[ ): all_task_ids.append(update.task_id.to_wire()) - task_row_map = _bulk_fetch_tasks(cur, all_task_ids) + task_rows = ctx.tasks.query(ctx.cur, TaskFilter(task_ids=tuple(all_task_ids))) + task_row_map = {t.task_id.to_wire(): t for t in task_rows} # ── Classify and split ──────────────────────────────────────── task_history_params: list[tuple[str, int, int, int, int, int, int]] = [] @@ -1976,22 +1172,20 @@ def apply_heartbeats_batch(self, requests: list[HeartbeatApplyRequest]) -> list[ transition_updates: list[TaskUpdate] = [] for update in req.updates: task_id_wire = update.task_id.to_wire() - task_row = task_row_map.get(task_id_wire) - if task_row is None: + task = task_row_map.get(task_id_wire) + if task is None: continue - prior_state = int(task_row["state"]) - is_state_change = update.new_state != prior_state + is_state_change = update.new_state != task.state has_terminal_data = update.error is not None or update.exit_code is not None if is_state_change or has_terminal_data: transition_updates.append(update) else: # Steady-state: check finished / stale attempt before writing. - task = self._db.decode_task(task_row) if task_row_is_finished(task): continue - if update.attempt_id != int(task_row["current_attempt_id"]): + if update.attempt_id != task.current_attempt_id: continue if update.resource_usage is not None: u = update.resource_usage @@ -2021,16 +1215,11 @@ def apply_heartbeats_batch(self, requests: list[HeartbeatApplyRequest]) -> list[ # ── Pass 2a: batch task resource history writes ───────────────── if task_history_params: - cur.executemany( - "INSERT INTO task_resource_history" - "(task_id, attempt_id, cpu_millicores, memory_mb, disk_mb, memory_peak_mb, timestamp_ms) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - task_history_params, - ) + ctx.tasks.insert_resource_usage_batch(ctx.cur, task_history_params) # ── Pass 2b: transitions via existing state machine ─────────── for req_idx, treq in transition_entries: - tx_result = self._apply_task_transitions(cur, treq, now_ms) + tx_result = self._apply_task_transitions(ctx, treq, now_ms) results[req_idx] = HeartbeatApplyResult( tasks_to_kill=tx_result.tasks_to_kill, action=HeartbeatAction.OK, @@ -2044,7 +1233,7 @@ def apply_heartbeat(self, req: HeartbeatApplyRequest) -> HeartbeatApplyResult: def _remove_failed_worker( self, - cur: TransactionCursor, + ctx: ControllerStore, worker_id: WorkerId, error: str, *, @@ -2053,81 +1242,86 @@ def _remove_failed_worker( """Remove a definitively failed worker and cascade its task state.""" tasks_to_kill: set[JobName] = set() task_kill_workers: dict[JobName, WorkerId] = {} - task_rows = cur.execute( - "SELECT t.task_id, t.current_attempt_id, t.state, t.preemption_count, t.max_retries_preemption, " - "j.is_reservation_holder " - "FROM tasks t " - "JOIN jobs j ON j.job_id = t.job_id " - "WHERE t.current_worker_id = ? AND t.state IN (?, ?, ?)", - (str(worker_id), *ACTIVE_TASK_STATES), - ).fetchall() + task_rows = ctx.tasks.query( + ctx.cur, + TaskFilter(worker_id=worker_id, states=ACTIVE_TASK_STATES), + projection=TaskProjection.WITH_JOB, + ) for task_row in task_rows: - tid = str(task_row["task_id"]) - prior_state = int(task_row["state"]) - is_reservation_holder = bool(int(task_row["is_reservation_holder"])) + tid = task_row.task_id.to_wire() + prior_state = task_row.state + is_reservation_holder = task_row.is_reservation_holder if is_reservation_holder: new_task_state = job_pb2.TASK_STATE_PENDING - preemption_count = int(task_row["preemption_count"]) + preemption_count = task_row.preemption_count else: new_task_state, preemption_count = _resolve_task_failure_state( prior_state, - int(task_row["preemption_count"]), - int(task_row["max_retries_preemption"]), + task_row.preemption_count, + task_row.max_retries_preemption, job_pb2.TASK_STATE_WORKER_FAILED, ) if is_reservation_holder: - cur.execute( - "DELETE FROM task_attempts WHERE task_id = ? AND attempt_id = ?", - (tid, int(task_row["current_attempt_id"])), - ) - cur.execute( - "UPDATE tasks SET state = ?, current_attempt_id = -1, started_at_ms = NULL, " - "finished_at_ms = NULL, error = NULL, preemption_count = 0, " - "current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", - (new_task_state, tid), + ctx.tasks.delete_attempt(ctx.cur, tid, task_row.current_attempt_id) + ctx.tasks.reset_reservation_holder(ctx.cur, tid, new_task_state) + elif new_task_state == job_pb2.TASK_STATE_PENDING: + ctx.tasks.requeue( + ctx.cur, + TaskRetry( + task_id=tid, + finalize=AttemptFinalizer.build( + tid, task_row.current_attempt_id, job_pb2.TASK_STATE_WORKER_FAILED, now_ms + ), + preemption_count=preemption_count, + ), ) else: - _terminate_task( - cur, - self._db.endpoints, - tid, - int(task_row["current_attempt_id"]), - new_task_state, - f"Worker {worker_id} failed: {error}", - now_ms, - attempt_state=job_pb2.TASK_STATE_WORKER_FAILED, - preemption_count=preemption_count, + worker_fail_error = f"Worker {worker_id} failed: {error}" + ctx.tasks.terminate( + ctx.cur, + TaskTermination( + task_id=tid, + state=new_task_state, + now_ms=now_ms, + error=worker_fail_error, + finalize=AttemptFinalizer.build( + tid, + task_row.current_attempt_id, + job_pb2.TASK_STATE_WORKER_FAILED, + now_ms, + error=worker_fail_error, + ), + preemption_count=preemption_count, + ), ) - task_id = JobName.from_wire(tid) + task_id = task_row.task_id parent_job_id, _ = task_id.require_task() - new_job_state = self._recompute_job_state(cur, parent_job_id) + new_job_state = ctx.jobs.recompute_state(ctx.cur, parent_job_id) if new_job_state is not None and new_job_state in TERMINAL_JOB_STATES: - cascaded_tasks_to_kill, cascaded_task_kill_workers = _cascade_terminal_job( - cur, self._db.endpoints, parent_job_id, now_ms, f"Worker {worker_id} failed" - ) - tasks_to_kill.update(cascaded_tasks_to_kill) - task_kill_workers.update(cascaded_task_kill_workers) + kill_result = _finalize_terminal_job(ctx, parent_job_id, new_job_state, now_ms) + tasks_to_kill.update(kill_result.tasks_to_kill) + task_kill_workers.update(kill_result.task_kill_workers) elif new_task_state == job_pb2.TASK_STATE_PENDING: - policy = _resolve_preemption_policy(cur, parent_job_id) + policy = ctx.jobs.get_preemption_policy(ctx.cur, parent_job_id) if policy == job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN: - child_tasks_to_kill, child_task_kill_workers = _cascade_children( - cur, - self._db.endpoints, + child_result = ctx.jobs.cascade_children( + ctx.cur, + ctx.tasks, parent_job_id, - now_ms, "Parent task preempted", + now_ms, exclude_reservation_holders=True, ) - tasks_to_kill.update(child_tasks_to_kill) - task_kill_workers.update(child_task_kill_workers) + tasks_to_kill.update(child_result.tasks_to_kill) + task_kill_workers.update(child_result.task_kill_workers) if new_task_state == job_pb2.TASK_STATE_WORKER_FAILED: tasks_to_kill.add(task_id) - _remove_worker(cur, str(worker_id)) + ctx.workers.remove(ctx.cur, str(worker_id)) return TxResult(tasks_to_kill=tasks_to_kill, task_kill_workers=task_kill_workers) def _record_heartbeat_failure( self, - cur: TransactionCursor, + ctx: ControllerStore, worker_id: WorkerId, error: str, drained_dispatch: DispatchBatch, @@ -2138,10 +1332,7 @@ def _record_heartbeat_failure( """Apply a heartbeat failure inside an existing transaction.""" tasks_to_kill: set[JobName] = set() task_kill_workers: dict[JobName, WorkerId] = {} - row = cur.execute( - "SELECT consecutive_failures, last_heartbeat_ms FROM workers WHERE worker_id = ? AND active = 1", - (str(worker_id),), - ).fetchone() + row = ctx.workers.get_active_row(ctx.cur, str(worker_id)) if row is None: return HeartbeatFailureResult( worker_removed=True, @@ -2150,24 +1341,20 @@ def _record_heartbeat_failure( ) now_ms = now_ms or Timestamp.now().epoch_ms() - last_heartbeat_ms = row["last_heartbeat_ms"] - last_heartbeat_age_ms = None if last_heartbeat_ms is None else max(0, now_ms - int(last_heartbeat_ms)) - failures = int(row["consecutive_failures"]) + 1 - cur.execute( - "UPDATE workers SET consecutive_failures = ?, healthy = CASE WHEN ? >= ? THEN 0 ELSE healthy END " - "WHERE worker_id = ?", - (failures, failures, self._heartbeat_failure_threshold, str(worker_id)), - ) + last_heartbeat_ms = row.last_heartbeat_ms + last_heartbeat_age_ms = None if last_heartbeat_ms is None else max(0, now_ms - last_heartbeat_ms) + failures = row.consecutive_failures + 1 + ctx.workers.record_heartbeat_failure(ctx.cur, worker_id, failures, self._heartbeat_failure_threshold) should_remove = force_remove or failures >= self._heartbeat_failure_threshold if should_remove: - removal = self._remove_failed_worker(cur, worker_id, error, now_ms=now_ms) + removal = self._remove_failed_worker(ctx, worker_id, error, now_ms=now_ms) tasks_to_kill.update(removal.tasks_to_kill) task_kill_workers.update(removal.task_kill_workers) else: for req in drained_dispatch.tasks_to_run: - enqueue_run_dispatch(cur, str(worker_id), req.SerializeToString(), now_ms) + ctx.dispatch.enqueue_run(ctx.cur, str(worker_id), req.SerializeToString(), now_ms) for task_id in drained_dispatch.tasks_to_kill: - enqueue_kill_dispatch(cur, str(worker_id), task_id, now_ms) + ctx.dispatch.enqueue_kill(ctx.cur, str(worker_id), task_id, now_ms) action = HeartbeatAction.WORKER_FAILED if should_remove else HeartbeatAction.TRANSIENT_FAILURE return HeartbeatFailureResult( tasks_to_kill=tasks_to_kill, @@ -2188,21 +1375,19 @@ def record_heartbeat_failure( force_remove: bool = False, ) -> TxResult: """Record heartbeat failure and requeue/flush drained dispatches.""" - with self._db.transaction() as cur: + with self._stores.transact() as ctx: result = self._record_heartbeat_failure( - cur, + ctx, worker_id, error, drained_dispatch, force_remove=force_remove, ) self._record_transaction( - cur, + ctx.cur, "heartbeat_failure", [("worker_heartbeat_failed", str(worker_id), {"error": error})], ) - if result.worker_removed: - self._db.remove_worker_from_attr_cache(worker_id) return TxResult(tasks_to_kill=result.tasks_to_kill, task_kill_workers=result.task_kill_workers) def fail_heartbeat_for_worker( @@ -2213,21 +1398,19 @@ def fail_heartbeat_for_worker( *, force_remove: bool = False, ) -> HeartbeatFailureResult: - with self._db.transaction() as cur: + with self._stores.transact() as ctx: result = self._record_heartbeat_failure( - cur, + ctx, worker_id, error, snapshot, force_remove=force_remove, ) self._record_transaction( - cur, + ctx.cur, "heartbeat_failure", [("worker_heartbeat_failed", str(worker_id), {"error": error})], ) - if result.worker_removed: - self._db.remove_worker_from_attr_cache(worker_id) return result def fail_heartbeats_batch( @@ -2262,11 +1445,11 @@ def fail_heartbeats_batch( for chunk_start in range(0, len(failures), chunk_size): chunk = failures[chunk_start : chunk_start + chunk_size] chunk_actions: list[tuple[str, str, dict[str, object]]] = [] - with self._db.transaction() as cur: + with self._stores.transact() as ctx: now_ms = Timestamp.now().epoch_ms() for snapshot, error in chunk: result = self._record_heartbeat_failure( - cur, + ctx, snapshot.worker_id, error, snapshot, @@ -2280,14 +1463,12 @@ def fail_heartbeats_batch( if result.worker_removed: removed_workers.append((snapshot.worker_id, snapshot.worker_address)) self._record_transaction( - cur, + ctx.cur, "heartbeat_failures_batch", chunk_actions, payload={"count": len(chunk_actions)}, ) - for worker_id, _ in removed_workers: - self._db.remove_worker_from_attr_cache(worker_id) return WorkerFailureBatchResult( tasks_to_kill=all_tasks_to_kill, task_kill_workers=all_task_kill_workers, @@ -2297,23 +1478,23 @@ def fail_heartbeats_batch( def mark_task_unschedulable(self, task_id: JobName, reason: str) -> TxResult: """Mark a task as unschedulable using the task transition engine.""" - with self._db.transaction() as cur: - row = cur.execute("SELECT job_id FROM tasks WHERE task_id = ?", (task_id.to_wire(),)).fetchone() - if row is None: + with self._stores.transact() as ctx: + job_id_wire = ctx.tasks.get_job_id(ctx.cur, task_id.to_wire()) + if job_id_wire is None: return TxResult() now_ms = Timestamp.now().epoch_ms() - _terminate_task( - cur, - self._db.endpoints, - task_id.to_wire(), - None, - job_pb2.TASK_STATE_UNSCHEDULABLE, - reason, - now_ms, + ctx.tasks.terminate( + ctx.cur, + TaskTermination( + task_id=task_id.to_wire(), + state=job_pb2.TASK_STATE_UNSCHEDULABLE, + now_ms=now_ms, + error=reason, + ), ) - self._recompute_job_state(cur, JobName.from_wire(str(row["job_id"]))) + ctx.jobs.recompute_state(ctx.cur, JobName.from_wire(job_id_wire)) self._record_transaction( - cur, "mark_task_unschedulable", [("task_unschedulable", task_id.to_wire(), {"reason": reason})] + ctx.cur, "mark_task_unschedulable", [("task_unschedulable", task_id.to_wire(), {"reason": reason})] ) return TxResult() @@ -2325,85 +1506,85 @@ def preempt_task(self, task_id: JobName, reason: str) -> TxResult: """ tasks_to_kill: set[JobName] = set() task_kill_workers: dict[JobName, WorkerId] = {} - with self._db.transaction() as cur: - row = cur.execute( - "SELECT t.task_id, t.job_id, t.state, t.current_attempt_id, " - "t.preemption_count, t.max_retries_preemption, " - "jc.res_cpu_millicores, jc.res_memory_bytes, jc.res_disk_bytes, jc.res_device_json " - f"FROM tasks t JOIN jobs j ON j.job_id = t.job_id {JOB_CONFIG_JOIN} " - "WHERE t.task_id = ?", - (task_id.to_wire(),), - ).fetchone() - if row is None: + with self._stores.transact() as ctx: + preempt_rows = ctx.tasks.query( + ctx.cur, + TaskFilter(task_ids=(task_id.to_wire(),)), + projection=TaskProjection.WITH_JOB_CONFIG, + ) + if not preempt_rows: return TxResult() + row = preempt_rows[0] - prior_state = int(row["state"]) + prior_state = row.state if prior_state not in ACTIVE_TASK_STATES: return TxResult() now_ms = Timestamp.now().epoch_ms() new_state, preemption_count = _resolve_task_failure_state( prior_state, - int(row["preemption_count"]), - int(row["max_retries_preemption"]), + row.preemption_count, + row.max_retries_preemption, job_pb2.TASK_STATE_PREEMPTED, ) # Fetch worker_id from the attempt for resource decommit. - attempt_row = cur.execute( - "SELECT worker_id FROM task_attempts WHERE task_id = ? AND attempt_id = ?", - (task_id.to_wire(), int(row["current_attempt_id"])), - ).fetchone() - attempt_worker_id = str(attempt_row["worker_id"]) if attempt_row and attempt_row["worker_id"] else None + attempt_worker_id = ctx.tasks.get_attempt_worker(ctx.cur, task_id.to_wire(), row.current_attempt_id) attempt_resources = None - if attempt_worker_id is not None: + if attempt_worker_id is not None and row.resources is not None: attempt_resources = resource_spec_from_scalars( - int(row["res_cpu_millicores"]), - int(row["res_memory_bytes"]), - int(row["res_disk_bytes"]), - row["res_device_json"], + row.resources.cpu_millicores, + row.resources.memory_bytes, + row.resources.disk_bytes, + row.resources.device_json, ) - _terminate_task( - cur, - self._db.endpoints, - task_id.to_wire(), - int(row["current_attempt_id"]), - new_state, - reason, - now_ms, - attempt_state=job_pb2.TASK_STATE_PREEMPTED, - worker_id=attempt_worker_id, - resources=attempt_resources, - preemption_count=preemption_count, + ctx.tasks.terminate( + ctx.cur, + TaskTermination( + task_id=task_id.to_wire(), + state=new_state, + now_ms=now_ms, + error=reason, + finalize=AttemptFinalizer.build( + task_id.to_wire(), + row.current_attempt_id, + job_pb2.TASK_STATE_PREEMPTED, + now_ms, + error=reason, + ), + worker_id=attempt_worker_id, + resources=attempt_resources, + preemption_count=preemption_count, + ), ) # Recompute job state and cascade if terminal - job_id = JobName.from_wire(str(row["job_id"])) - new_job_state = self._recompute_job_state(cur, job_id) + job_id = row.job_id + new_job_state = ctx.jobs.recompute_state(ctx.cur, job_id) if new_job_state is not None and new_job_state in TERMINAL_JOB_STATES: - cascade_kills, cascade_workers = _finalize_terminal_job( - cur, self._db.endpoints, job_id, new_job_state, now_ms - ) - tasks_to_kill.update(cascade_kills) - task_kill_workers.update(cascade_workers) + kill_result = _finalize_terminal_job(ctx, job_id, new_job_state, now_ms) + tasks_to_kill.update(kill_result.tasks_to_kill) + task_kill_workers.update(kill_result.task_kill_workers) elif new_state == job_pb2.TASK_STATE_PENDING: - policy = _resolve_preemption_policy(cur, job_id) + policy = ctx.jobs.get_preemption_policy(ctx.cur, job_id) if policy == job_pb2.JOB_PREEMPTION_POLICY_TERMINATE_CHILDREN: - child_kills, child_workers = _cascade_children( - cur, - self._db.endpoints, + child_result = ctx.jobs.cascade_children( + ctx.cur, + ctx.tasks, job_id, - now_ms, reason, + now_ms, exclude_reservation_holders=True, ) - tasks_to_kill.update(child_kills) - task_kill_workers.update(child_workers) + tasks_to_kill.update(child_result.tasks_to_kill) + task_kill_workers.update(child_result.task_kill_workers) if new_state == job_pb2.TASK_STATE_PREEMPTED: tasks_to_kill.add(task_id) - self._record_transaction(cur, "preempt_task", [("task_preempted", task_id.to_wire(), {"reason": reason})]) + self._record_transaction( + ctx.cur, "preempt_task", [("task_preempted", task_id.to_wire(), {"reason": reason})] + ) return TxResult(tasks_to_kill=tasks_to_kill, task_kill_workers=task_kill_workers) @@ -2419,36 +1600,29 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes """ if not task_ids: return TxResult() - with self._db.transaction() as cur: + with self._stores.transact() as ctx: wires = [tid.to_wire() for tid in task_ids] - placeholders = ",".join("?" for _ in wires) - rows = cur.execute( - f"SELECT t.task_id, t.job_id, t.current_worker_id AS worker_id, t.current_attempt_id, " - f"t.failure_count, j.is_reservation_holder, " - f"jc.res_cpu_millicores, jc.res_memory_bytes, jc.res_disk_bytes, jc.res_device_json, " - f"jc.has_coscheduling " - f"FROM tasks t JOIN jobs j ON j.job_id = t.job_id {JOB_CONFIG_JOIN} " - f"WHERE t.task_id IN ({placeholders}) AND t.state IN (?, ?)", - (*wires, *EXECUTING_TASK_STATES), - ).fetchall() + rows = ctx.tasks.query( + ctx.cur, + TaskFilter(task_ids=tuple(wires), states=EXECUTING_TASK_STATES), + projection=TaskProjection.WITH_JOB_CONFIG, + ) # -- Phase 1: read all state before any mutations. -- now_ms = Timestamp.now().epoch_ms() - job_row_cache: dict[str, dict] = {} - # Collect directly-timed-out task wires for dedup against siblings. + job_row_cache: dict[str, TaskDetailRow] = {} direct_task_wires: set[str] = set() - # Per-job list of siblings to cascade (collected across all timed-out tasks). - siblings_by_job: dict[str, list[_CoscheduledSibling]] = {} + siblings_by_job: dict[str, list[SiblingSnapshot]] = {} for row in rows: - task_id_wire = str(row["task_id"]) + task_id_wire = row.task_id.to_wire() direct_task_wires.add(task_id_wire) - job_id_wire = str(row["job_id"]) + job_id_wire = row.job_id.to_wire() if job_id_wire not in job_row_cache: - job_row_cache[job_id_wire] = dict(row) - has_cosched = bool(int(row["has_coscheduling"])) - tid = JobName.from_wire(task_id_wire) - siblings = _find_coscheduled_siblings(cur, JobName.from_wire(job_id_wire), tid, has_cosched) + job_row_cache[job_id_wire] = row + has_cosched = row.has_coscheduling + tid = row.task_id + siblings = ctx.tasks.find_coscheduled_siblings(ctx.cur, row.job_id, tid, has_cosched) if siblings: existing = siblings_by_job.get(job_id_wire, []) existing.extend(siblings) @@ -2459,7 +1633,7 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes # trigger tasks within the same job. for job_id_wire, siblings in siblings_by_job.items(): seen: set[str] = set() - deduped: list[_CoscheduledSibling] = [] + deduped: list[SiblingSnapshot] = [] for sib in siblings: if sib.task_id not in direct_task_wires and sib.task_id not in seen: seen.add(sib.task_id) @@ -2472,35 +1646,40 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes jobs_to_update: set[str] = set() for row in rows: - task_id_wire = str(row["task_id"]) - tid = JobName.from_wire(task_id_wire) - job_id_wire = str(row["job_id"]) - worker_id_str = row["worker_id"] + task_id_wire = row.task_id.to_wire() + tid = row.task_id + job_id_wire = row.job_id.to_wire() tasks_to_kill.add(tid) decommit_worker = None decommit_resources = None - if worker_id_str is not None: - task_kill_workers[tid] = WorkerId(str(worker_id_str)) - if not int(row["is_reservation_holder"]): - decommit_worker = str(worker_id_str) + if row.current_worker_id is not None: + task_kill_workers[tid] = WorkerId(str(row.current_worker_id)) + if not row.is_reservation_holder and row.resources is not None: + decommit_worker = str(row.current_worker_id) decommit_resources = resource_spec_from_scalars( - int(row["res_cpu_millicores"]), - int(row["res_memory_bytes"]), - int(row["res_disk_bytes"]), - row["res_device_json"], + row.resources.cpu_millicores, + row.resources.memory_bytes, + row.resources.disk_bytes, + row.resources.device_json, ) - attempt_id = row["current_attempt_id"] - _terminate_task( - cur, - self._db.endpoints, - task_id_wire, - int(attempt_id) if attempt_id is not None else None, - job_pb2.TASK_STATE_FAILED, - reason, - now_ms, - worker_id=decommit_worker, - resources=decommit_resources, - failure_count=int(row["failure_count"]) + 1, + ctx.tasks.terminate( + ctx.cur, + TaskTermination( + task_id=task_id_wire, + state=job_pb2.TASK_STATE_FAILED, + now_ms=now_ms, + error=reason, + finalize=AttemptFinalizer.build( + task_id_wire, + row.current_attempt_id, + job_pb2.TASK_STATE_FAILED, + now_ms, + error=reason, + ), + worker_id=decommit_worker, + resources=decommit_resources, + failure_count=row.failure_count + 1, + ), ) jobs_to_update.add(job_id_wire) @@ -2509,31 +1688,29 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes if not siblings: continue jc_row = job_row_cache[job_id_wire] + assert jc_row.resources is not None job_resources = resource_spec_from_scalars( - int(jc_row["res_cpu_millicores"]), - int(jc_row["res_memory_bytes"]), - int(jc_row["res_disk_bytes"]), - jc_row["res_device_json"], + jc_row.resources.cpu_millicores, + jc_row.resources.memory_bytes, + jc_row.resources.disk_bytes, + jc_row.resources.device_json, ) - # Pick the first direct-timeout task in this job as the "cause" for the error message. - cause_tid = next(JobName.from_wire(str(r["task_id"])) for r in rows if str(r["job_id"]) == job_id_wire) - cascade_kill, cascade_workers = _terminate_coscheduled_siblings( - cur, self._db.endpoints, siblings, cause_tid, job_resources, now_ms + cause_tid = next(r.task_id for r in rows if r.job_id.to_wire() == job_id_wire) + cascade_result = ctx.tasks.terminate_coscheduled_siblings( + ctx.cur, siblings, cause_tid, job_resources, now_ms ) - tasks_to_kill.update(cascade_kill) - task_kill_workers.update(cascade_workers) + tasks_to_kill.update(cascade_result.tasks_to_kill) + task_kill_workers.update(cascade_result.task_kill_workers) jobs_to_update.add(job_id_wire) for job_wire in jobs_to_update: - new_job_state = self._recompute_job_state(cur, JobName.from_wire(job_wire)) + new_job_state = ctx.jobs.recompute_state(ctx.cur, JobName.from_wire(job_wire)) if new_job_state in TERMINAL_JOB_STATES: - final_kill, final_workers = _finalize_terminal_job( - cur, self._db.endpoints, JobName.from_wire(job_wire), new_job_state, now_ms - ) - tasks_to_kill.update(final_kill) - task_kill_workers.update(final_workers) + final_result = _finalize_terminal_job(ctx, JobName.from_wire(job_wire), new_job_state, now_ms) + tasks_to_kill.update(final_result.tasks_to_kill) + task_kill_workers.update(final_result.task_kill_workers) self._record_transaction( - cur, + ctx.cur, "cancel_tasks_for_timeout", [("task_timeout", tid.to_wire(), {"reason": reason}) for tid in tasks_to_kill], ) @@ -2541,55 +1718,36 @@ def cancel_tasks_for_timeout(self, task_ids: set[JobName], reason: str) -> TxRes def drain_dispatch(self, worker_id: WorkerId) -> DispatchBatch | None: """Drain buffered dispatches and snapshot worker running tasks.""" - with self._db.transaction() as cur: - worker_row = cur.execute( - "SELECT worker_id, address FROM workers " "WHERE worker_id = ? AND active = 1 AND healthy = 1", - (str(worker_id),), - ).fetchone() + with self._stores.transact() as ctx: + worker_row = ctx.workers.get_healthy_active(ctx.cur, str(worker_id)) if worker_row is None: return None - dispatch_rows = cur.execute( - "SELECT id, kind, payload_proto, task_id FROM dispatch_queue WHERE worker_id = ? ORDER BY id ASC", - (str(worker_id),), - ).fetchall() - if dispatch_rows: - cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (str(worker_id),)) - running_rows_raw = cur.execute( - "SELECT t.task_id, t.current_attempt_id, t.job_id " - "FROM tasks t " - "WHERE t.current_worker_id = ? AND t.state IN (?, ?, ?) " - "ORDER BY t.task_id ASC", - (str(worker_id), *ACTIVE_TASK_STATES), - ).fetchall() - running_job_ids = {str(row["job_id"]) for row in running_rows_raw} - if running_job_ids: - holder_placeholders = ",".join("?" for _ in running_job_ids) - holder_rows = cur.execute( - f"SELECT job_id FROM jobs WHERE job_id IN ({holder_placeholders}) AND is_reservation_holder = 1", - tuple(running_job_ids), - ).fetchall() - holder_ids = {str(r["job_id"]) for r in holder_rows} - else: - holder_ids = set() - running_rows = [r for r in running_rows_raw if str(r["job_id"]) not in holder_ids] + drained = ctx.dispatch.drain_for_worker(ctx.cur, str(worker_id)) + running_rows_raw = ctx.tasks.query( + ctx.cur, + TaskFilter(worker_id=worker_id, states=ACTIVE_TASK_STATES), + ) + running_job_ids = {t.job_id.to_wire() for t in running_rows_raw} + holder_ids = ctx.jobs.get_reservation_holder_ids(ctx.cur, running_job_ids) + running_rows = [t for t in running_rows_raw if t.job_id.to_wire() not in holder_ids] tasks_to_run: list[job_pb2.RunTaskRequest] = [] tasks_to_kill: list[str] = [] - for row in dispatch_rows: - if str(row["kind"]) == "run" and row["payload_proto"] is not None: + for kind, payload_proto, task_id in drained: + if kind == "run" and payload_proto is not None: req = job_pb2.RunTaskRequest() - req.ParseFromString(bytes(row["payload_proto"])) + req.ParseFromString(bytes(payload_proto)) tasks_to_run.append(req) - elif row["task_id"] is not None: - tasks_to_kill.append(str(row["task_id"])) + elif task_id is not None: + tasks_to_kill.append(str(task_id)) return DispatchBatch( worker_id=WorkerId(str(worker_row["worker_id"])), worker_address=str(worker_row["address"]), running_tasks=[ RunningTaskEntry( - task_id=JobName.from_wire(str(row["task_id"])), - attempt_id=int(row["current_attempt_id"]), + task_id=t.task_id, + attempt_id=t.current_attempt_id, ) - for row in running_rows + for t in running_rows ], tasks_to_run=tasks_to_run, tasks_to_kill=tasks_to_kill, @@ -2632,23 +1790,10 @@ def drain_dispatch_all(self) -> list[DispatchBatch]: running_rows = [row for row in running_rows if str(row["job_id"]) not in reservation_holder_ids] # -- Phase 2: write lock only for dispatch_queue drain -- - placeholders = ",".join("?" for _ in worker_id_set) - with self._db.transaction() as cur: - dispatch_rows = cur.execute( - f"SELECT worker_id, id, kind, payload_proto, task_id FROM dispatch_queue " - f"WHERE worker_id IN ({placeholders}) ORDER BY id ASC", - tuple(worker_id_set), - ).fetchall() - if dispatch_rows: - cur.execute( - f"DELETE FROM dispatch_queue WHERE worker_id IN ({placeholders})", - tuple(worker_id_set), - ) + with self._stores.transact() as ctx: + dispatch_by_worker = ctx.dispatch.drain_for_workers(ctx.cur, list(worker_id_set)) # -- Phase 3: build results (pure Python, no lock) -- - dispatch_by_worker: dict[str, list[Any]] = defaultdict(list) - for row in dispatch_rows: - dispatch_by_worker[str(row["worker_id"])].append(row) running_by_worker: dict[str, list[Any]] = defaultdict(list) for row in running_rows: @@ -2662,13 +1807,13 @@ def drain_dispatch_all(self) -> list[DispatchBatch]: tasks_to_run: list[job_pb2.RunTaskRequest] = [] tasks_to_kill: list[str] = [] - for row in w_dispatch: - if str(row["kind"]) == "run" and row["payload_proto"] is not None: + for kind, payload_proto, task_id in w_dispatch: + if kind == "run" and payload_proto is not None: req = job_pb2.RunTaskRequest() - req.ParseFromString(bytes(row["payload_proto"])) + req.ParseFromString(bytes(payload_proto)) tasks_to_run.append(req) - elif row["task_id"] is not None: - tasks_to_kill.append(str(row["task_id"])) + elif task_id is not None: + tasks_to_kill.append(str(task_id)) batches.append( DispatchBatch( @@ -2690,12 +1835,12 @@ def drain_dispatch_all(self) -> list[DispatchBatch]: def requeue_dispatch(self, batch: DispatchBatch) -> None: """Re-queue drained dispatch payloads for later delivery.""" - with self._db.transaction() as cur: + with self._stores.transact() as ctx: now_ms = Timestamp.now().epoch_ms() for req in batch.tasks_to_run: - enqueue_run_dispatch(cur, str(batch.worker_id), req.SerializeToString(), now_ms) + ctx.dispatch.enqueue_run(ctx.cur, str(batch.worker_id), req.SerializeToString(), now_ms) for task_id in batch.tasks_to_kill: - enqueue_kill_dispatch(cur, str(batch.worker_id), task_id, now_ms) + ctx.dispatch.enqueue_kill(ctx.cur, str(batch.worker_id), task_id, now_ms) def remove_finished_job(self, job_id: JobName) -> bool: """Remove a finished job and its tasks from state. @@ -2709,11 +1854,10 @@ def remove_finished_job(self, job_id: JobName) -> bool: Returns: True if the job was removed, False if it doesn't exist or is not finished """ - with self._db.transaction() as cur: - row = cur.execute("SELECT state FROM jobs WHERE job_id = ?", (job_id.to_wire(),)).fetchone() - if row is None: + with self._stores.transact() as ctx: + state = ctx.jobs.get_state(ctx.cur, job_id) + if state is None: return False - state = int(row["state"]) if state not in ( job_pb2.JOB_STATE_SUCCEEDED, job_pb2.JOB_STATE_FAILED, @@ -2721,164 +1865,61 @@ def remove_finished_job(self, job_id: JobName) -> bool: job_pb2.JOB_STATE_UNSCHEDULABLE, ): return False - cur.execute("DELETE FROM jobs WHERE job_id = ?", (job_id.to_wire(),)) - self._record_transaction(cur, "remove_finished_job", [("job_removed", job_id.to_wire(), {"state": state})]) + ctx.jobs.delete_job(ctx.cur, job_id.to_wire()) + self._record_transaction( + ctx.cur, "remove_finished_job", [("job_removed", job_id.to_wire(), {"state": state})] + ) return True def remove_worker(self, worker_id: WorkerId) -> WorkerDetailRow | None: - with self._db.transaction() as cur: - row = cur.execute("SELECT * FROM workers WHERE worker_id = ?", (str(worker_id),)).fetchone() + with self._stores.transact() as ctx: + row = ctx.workers.get_row(ctx.cur, str(worker_id)) if row is None: return None - _remove_worker(cur, str(worker_id)) - self._record_transaction(cur, "remove_worker", [("worker_removed", str(worker_id), {})]) - self._db.remove_worker_from_attr_cache(worker_id) - return WORKER_DETAIL_PROJECTION.decode_one([row]) - - def _prune_per_worker_history( - self, - table: str, - retention: int, - order_by: str = "id DESC", - ) -> int: - """Trim a per-worker history table to *retention* rows per worker. - - Used by the background prune thread. The NOT IN subquery per worker is - expensive; batching all workers in a single transaction amortizes the - overhead versus running it on every assign/heartbeat. - """ - with self._db.transaction() as cur: - rows = cur.execute( - f"SELECT worker_id, COUNT(*) as cnt FROM {table} GROUP BY worker_id HAVING cnt > ?", - (retention,), - ).fetchall() - total_deleted = 0 - for row in rows: - wid = row["worker_id"] - cur.execute( - f"DELETE FROM {table} " - "WHERE worker_id = ? " - f"AND id NOT IN (" - f" SELECT id FROM {table} " - " WHERE worker_id = ? " - f" ORDER BY {order_by} LIMIT ?" - ")", - (wid, wid, retention), - ) - total_deleted += cur.rowcount - if total_deleted > 0: - logger.info("Pruned %d %s rows", total_deleted, table) - return total_deleted + ctx.workers.remove(ctx.cur, str(worker_id)) + self._record_transaction(ctx.cur, "remove_worker", [("worker_removed", str(worker_id), {})]) + return row def prune_worker_task_history(self) -> int: """Trim worker_task_history to WORKER_TASK_HISTORY_RETENTION rows per worker.""" - return self._prune_per_worker_history( - "worker_task_history", - WORKER_TASK_HISTORY_RETENTION, - order_by="assigned_at_ms DESC, id DESC", - ) + with self._stores.transact() as ctx: + return self._stores.workers.prune_task_history(ctx.cur, WORKER_TASK_HISTORY_RETENTION) def prune_worker_resource_history(self) -> int: """Trim worker_resource_history to WORKER_RESOURCE_HISTORY_RETENTION rows per worker.""" - return self._prune_per_worker_history( - "worker_resource_history", - WORKER_RESOURCE_HISTORY_RETENTION, - ) + with self._stores.transact() as ctx: + return self._stores.workers.prune_resource_history(ctx.cur, WORKER_RESOURCE_HISTORY_RETENTION) def prune_task_resource_history(self) -> int: - """Two-pass prune: - - 1. Evict all history for tasks that have been in a terminal state - longer than TASK_RESOURCE_HISTORY_TERMINAL_TTL. Dashboards read - peak memory from tasks.peak_memory_mb after termination; the - per-sample rows are dead weight and are ~85% of the table on - prod. - 2. Logarithmic downsampling for anything that remains: when a - (task, attempt) exceeds 2*N rows, thin the older half by deleting - every other row so older data grows exponentially sparser. - - Deletes are chunked so the writer lock releases between chunks. - """ now_ms = Timestamp.now().epoch_ms() ttl_cutoff_ms = now_ms - TASK_RESOURCE_HISTORY_TERMINAL_TTL.to_ms() - terminal_placeholders = ",".join("?" for _ in TERMINAL_TASK_STATES) + terminal_placeholders = sql_placeholders(len(TERMINAL_TASK_STATES)) - with self._db.read_snapshot() as snap: + with self._stores.read() as rctx: terminal_ids = [ str(r["task_id"]) - for r in snap.fetchall( + for r in rctx.cur.execute( f"SELECT task_id FROM tasks " f"WHERE state IN ({terminal_placeholders}) " f"AND finished_at_ms IS NOT NULL AND finished_at_ms < ?", (*TERMINAL_TASK_STATES, ttl_cutoff_ms), - ) + ).fetchall() ] evicted_terminal = 0 for chunk_start in range(0, len(terminal_ids), TASK_RESOURCE_HISTORY_DELETE_CHUNK): chunk = terminal_ids[chunk_start : chunk_start + TASK_RESOURCE_HISTORY_DELETE_CHUNK] - ph = ",".join("?" * len(chunk)) - with self._db.transaction() as cur: - cur.execute(f"DELETE FROM task_resource_history WHERE task_id IN ({ph})", tuple(chunk)) - evicted_terminal += cur.rowcount - - threshold = TASK_RESOURCE_HISTORY_RETENTION * 2 - with self._db.transaction() as cur: - overflows = cur.execute( - "SELECT task_id, attempt_id, COUNT(*) as cnt " - "FROM task_resource_history " - "GROUP BY task_id, attempt_id HAVING cnt > ?", - (threshold,), - ).fetchall() - ids_to_delete: list[int] = [] - for row in overflows: - tid, aid = row["task_id"], row["attempt_id"] - # Load all IDs into Python for index-based thinning. - # Bounded by 2*N + heartbeats-per-prune-cycle (~160 rows max at N=50). - all_ids = [ - r["id"] - for r in cur.execute( - "SELECT id FROM task_resource_history " "WHERE task_id = ? AND attempt_id = ? ORDER BY id ASC", - (tid, aid), - ).fetchall() - ] - # Keep the newest N rows untouched; thin the older portion by 2x. - older = all_ids[: len(all_ids) - TASK_RESOURCE_HISTORY_RETENTION] - ids_to_delete.extend(older[1::2]) - - total_deleted = 0 - for chunk_start in range(0, len(ids_to_delete), 900): - chunk = ids_to_delete[chunk_start : chunk_start + 900] - ph = ",".join("?" * len(chunk)) - cur.execute(f"DELETE FROM task_resource_history WHERE id IN ({ph})", tuple(chunk)) - total_deleted += cur.rowcount + ph = sql_placeholders(len(chunk)) + with self._stores.transact() as ctx: + ctx.cur.execute(f"DELETE FROM task_resource_history WHERE task_id IN ({ph})", tuple(chunk)) + evicted_terminal += ctx.cur.rowcount + + with self._stores.transact() as ctx: + total_deleted = self._stores.tasks.prune_task_resource_history(ctx.cur, TASK_RESOURCE_HISTORY_RETENTION) if evicted_terminal > 0: logger.info("Evicted %d task_resource_history rows (terminal TTL)", evicted_terminal) - if total_deleted > 0: - logger.info("Pruned %d task_resource_history rows (log downsampling)", total_deleted) return evicted_terminal + total_deleted - def _batch_delete( - self, - sql: str, - params: tuple[object, ...], - stopped: Callable[[], bool], - pause_between_s: float, - ) -> int: - """Delete rows in batches, sleeping between transactions. - - Returns the total number of rows deleted. - """ - total = 0 - while not stopped(): - with self._db.transaction() as cur: - batch = cur.execute(sql, params).rowcount - if batch == 0: - break - total += batch - time.sleep(pause_between_s) - return total - def prune_old_data( self, *, @@ -2908,52 +1949,42 @@ def prune_old_data( worker_cutoff_ms = now_ms - worker_retention.to_ms() txn_cutoff_ms = now_ms - txn_action_retention.to_ms() - terminal_states = tuple(TERMINAL_JOB_STATES) - placeholders = ",".join("?" * len(terminal_states)) - def _stopped() -> bool: return stop_event is not None and stop_event.is_set() - # 1. Jobs: one at a time (CASCADE to tasks → attempts, endpoints) + # 1. Jobs: one at a time (CASCADE to tasks -> attempts, endpoints) jobs_deleted = 0 while not _stopped(): - with self._db.read_snapshot() as snap: - row = snap.fetchone( - f"SELECT job_id FROM jobs WHERE state IN ({placeholders})" - " AND finished_at_ms IS NOT NULL AND finished_at_ms < ? LIMIT 1", - (*terminal_states, job_cutoff_ms), - ) - if row is None: + with self._stores.read() as read_ctx: + job_ids = self._stores.jobs.get_finished_jobs_before(read_ctx.cur, job_cutoff_ms) + if not job_ids: break - job_id = row["job_id"] - with self._db.transaction() as cur: + job_id = job_ids[0] + with self._stores.transact() as ctx: # Invalidate endpoint cache BEFORE the CASCADE so the registry # drops rows SQLite is about to delete for us. - self._db.endpoints.remove_by_job_ids(cur, [JobName.from_wire(str(job_id))]) - cur.execute("DELETE FROM jobs WHERE job_id = ?", (job_id,)) - self._record_transaction(cur, "prune_old_data", [("job_pruned", str(job_id), {})]) + ctx.endpoints.remove_by_job_ids(ctx.cur, [JobName.from_wire(job_id)]) + ctx.jobs.delete_job(ctx.cur, job_id) + self._record_transaction(ctx.cur, "prune_old_data", [("job_pruned", job_id, {})]) jobs_deleted += 1 time.sleep(pause_between_s) # 2. Workers: one at a time (CASCADE to attributes, task_history, resource_history) workers_deleted = 0 while not _stopped(): - with self._db.read_snapshot() as snap: - row = snap.fetchone( - "SELECT worker_id FROM workers WHERE (active = 0 OR healthy = 0) AND last_heartbeat_ms < ? LIMIT 1", - (worker_cutoff_ms,), - ) - if row is None: + with self._stores.read() as read_ctx: + worker_id = self._stores.workers.get_inactive_worker_before(read_ctx.cur, worker_cutoff_ms) + if worker_id is None: break - worker_id = row["worker_id"] - with self._db.transaction() as cur: - _remove_worker(cur, str(worker_id)) - self._record_transaction(cur, "prune_old_data", [("worker_pruned", str(worker_id), {})]) + with self._stores.transact() as ctx: + ctx.workers.remove(ctx.cur, worker_id) + self._record_transaction(ctx.cur, "prune_old_data", [("worker_pruned", worker_id, {})]) workers_deleted += 1 time.sleep(pause_between_s) # 3. txn_actions: batch of 1000 per transaction (no CASCADE) - txn_actions_deleted = self._batch_delete( + txn_actions_deleted = batch_delete( + self._db, "DELETE FROM txn_actions WHERE rowid IN " "(SELECT rowid FROM txn_actions WHERE created_at_ms < ? LIMIT 1000)", (txn_cutoff_ms,), @@ -2964,7 +1995,8 @@ def _stopped() -> bool: # 4. Task profiles: batch of 1000 per transaction profile_cutoff_ms = now_ms - profile_retention.to_ms() # 4a. Delete stale profiles by age. - profiles_deleted = self._batch_delete( + profiles_deleted = batch_delete( + self._db, "DELETE FROM profiles.task_profiles WHERE rowid IN " "(SELECT rowid FROM profiles.task_profiles WHERE captured_at_ms < ? LIMIT 1000)", (profile_cutoff_ms,), @@ -2972,7 +2004,8 @@ def _stopped() -> bool: pause_between_s, ) # 4b. Delete orphan profiles whose task no longer exists. - profiles_deleted += self._batch_delete( + profiles_deleted += batch_delete( + self._db, "DELETE FROM profiles.task_profiles WHERE rowid IN " "(SELECT p.rowid FROM profiles.task_profiles p" " LEFT JOIN tasks t ON p.task_id = t.task_id" @@ -3010,8 +2043,10 @@ def buffer_dispatch(self, worker_id: WorkerId, task_request: job_pb2.RunTaskRequ Called by the scheduling thread after committing resources via TaskAssignedEvent. The dispatch will be delivered when begin_heartbeat() drains the buffer. """ - with self._db.transaction() as cur: - enqueue_run_dispatch(cur, str(worker_id), task_request.SerializeToString(), Timestamp.now().epoch_ms()) + with self._stores.transact() as ctx: + ctx.dispatch.enqueue_run( + ctx.cur, str(worker_id), task_request.SerializeToString(), Timestamp.now().epoch_ms() + ) def buffer_kill(self, worker_id: WorkerId, task_id: str) -> None: """Buffer a task kill for the next heartbeat. @@ -3019,8 +2054,8 @@ def buffer_kill(self, worker_id: WorkerId, task_id: str) -> None: Called when a task needs to be terminated on a worker. The kill will be delivered when begin_heartbeat() drains the buffer. """ - with self._db.transaction() as cur: - enqueue_kill_dispatch(cur, str(worker_id), task_id, Timestamp.now().epoch_ms()) + with self._stores.transact() as ctx: + ctx.dispatch.enqueue_kill(ctx.cur, str(worker_id), task_id, Timestamp.now().epoch_ms()) def begin_heartbeat(self, worker_id: WorkerId) -> DispatchBatch | None: """Drain dispatch for a worker and snapshot expected running attempts.""" @@ -3182,47 +2217,12 @@ def add_endpoint(self, endpoint: EndpointRow) -> bool: Returns True if the endpoint was inserted, False if the task is already terminal (to prevent orphaned endpoints that would never be cleaned up). """ - with self._db.transaction() as cur: - return self._db.endpoints.add(cur, endpoint) + with self._stores.transact() as ctx: + return ctx.endpoints.add(ctx.cur, endpoint) def remove_endpoint(self, endpoint_id: str) -> EndpointRow | None: - with self._db.transaction() as cur: - return self._db.endpoints.remove(cur, endpoint_id) - - # --------------------------------------------------------------------- - # Test-only SQL mutation helpers - # --------------------------------------------------------------------- - - def set_worker_health_for_test(self, worker_id: WorkerId, healthy: bool) -> None: - """Test helper: set worker health in DB.""" - self._db.execute( - "UPDATE workers SET healthy = ?, consecutive_failures = ? WHERE worker_id = ?", - (1 if healthy else 0, 0 if healthy else 1, str(worker_id)), - ) - - def set_worker_attribute_for_test(self, worker_id: WorkerId, key: str, value: AttributeValue) -> None: - """Test helper: upsert one worker attribute in DB.""" - str_value = int_value = float_value = None - value_type = "str" - if isinstance(value.value, int): - value_type = "int" - int_value = int(value.value) - elif isinstance(value.value, float): - value_type = "float" - float_value = float(value.value) - else: - str_value = str(value.value) - - self._db.execute( - "INSERT INTO worker_attributes(worker_id, key, value_type, str_value, int_value, float_value) " - "VALUES (?, ?, ?, ?, ?, ?) " - "ON CONFLICT(worker_id, key) DO UPDATE SET " - "value_type=excluded.value_type, " - "str_value=excluded.str_value, " - "int_value=excluded.int_value, " - "float_value=excluded.float_value", - (str(worker_id), key, value_type, str_value, int_value, float_value), - ) + with self._stores.transact() as ctx: + return ctx.endpoints.remove(ctx.cur, endpoint_id) # ========================================================================= # Direct provider methods @@ -3240,7 +2240,7 @@ def drain_for_direct_provider( - Already ASSIGNED/BUILDING/RUNNING tasks with NULL worker_id -> running_tasks - Kill entries with NULL worker_id -> tasks_to_kill (deleted from queue) """ - with self._db.transaction() as cur: + with self._stores.transact() as ctx: now_ms = Timestamp.now().epoch_ms() newly_promoted: set[str] = set() @@ -3249,16 +2249,7 @@ def drain_for_direct_provider( if max_promotions <= 0: pending_rows = [] else: - pending_rows = cur.execute( - "SELECT t.task_id, t.job_id, t.current_attempt_id, j.num_tasks, j.is_reservation_holder, " - "jc.res_cpu_millicores, jc.res_memory_bytes, jc.res_disk_bytes, jc.res_device_json, " - "jc.entrypoint_json, jc.environment_json, jc.bundle_id, jc.ports_json, " - "jc.constraints_json, jc.task_image, jc.timeout_ms " - f"FROM tasks t JOIN jobs j ON j.job_id = t.job_id {JOB_CONFIG_JOIN} " - "WHERE t.state = ? AND j.is_reservation_holder = 0 " - "LIMIT ?", - (job_pb2.TASK_STATE_PENDING, max_promotions), - ).fetchall() + pending_rows = ctx.tasks.get_pending_for_direct_provider(ctx.cur, max_promotions) for row in pending_rows: task_id = str(row["task_id"]) @@ -3270,17 +2261,19 @@ def drain_for_direct_provider( row["res_device_json"], ) - _assign_task(cur, task_id, None, None, attempt_id, now_ms) + ctx.tasks.assign_direct( + ctx.cur, + DirectAssignment( + task_id=task_id, + attempt_id=attempt_id, + now_ms=now_ms, + ), + ) entrypoint = proto_from_json(str(row["entrypoint_json"]), job_pb2.RuntimeEntrypoint) - # Load inline workdir files from the job_workdir_files table. job_id_wire = str(row["job_id"]) - wf_rows = cur.execute( - "SELECT filename, data FROM job_workdir_files WHERE job_id = ?", - (job_id_wire,), - ).fetchall() - for wf_row in wf_rows: - entrypoint.workdir_files[wf_row["filename"]] = bytes(wf_row["data"]) + for fn, data in ctx.jobs.get_workdir_files(ctx.cur, job_id_wire).items(): + entrypoint.workdir_files[fn] = data run_req = job_pb2.RunTaskRequest( task_id=task_id, @@ -3302,31 +2295,21 @@ def drain_for_direct_provider( newly_promoted.add(task_id) # Snapshot already-running tasks with NULL worker_id (excluding newly promoted). - active_states = tuple(sorted(ACTIVE_TASK_STATES)) - placeholders = ",".join("?" * len(active_states)) - running_rows = cur.execute( - "SELECT t.task_id, t.current_attempt_id " - "FROM tasks t " - f"WHERE t.current_worker_id IS NULL AND t.state IN ({placeholders}) " - "ORDER BY t.task_id ASC", - active_states, - ).fetchall() + running_rows = ctx.tasks.query( + ctx.cur, + TaskFilter(worker_is_null=True, states=ACTIVE_TASK_STATES), + ) running_tasks = [ RunningTaskEntry( - task_id=JobName.from_wire(str(row["task_id"])), - attempt_id=int(row["current_attempt_id"]), + task_id=t.task_id, + attempt_id=t.current_attempt_id, ) - for row in running_rows - if str(row["task_id"]) not in newly_promoted + for t in running_rows + if t.task_id.to_wire() not in newly_promoted ] # Drain kill entries with NULL worker_id. - kill_rows = cur.execute( - "SELECT task_id FROM dispatch_queue WHERE worker_id IS NULL AND kind = 'kill'", - ).fetchall() - tasks_to_kill = [str(row["task_id"]) for row in kill_rows if row["task_id"] is not None] - if kill_rows: - cur.execute("DELETE FROM dispatch_queue WHERE worker_id IS NULL AND kind = 'kill'") + tasks_to_kill = ctx.dispatch.drain_direct_kills(ctx.cur) return DirectProviderBatch( tasks_to_run=tasks_to_run, @@ -3337,225 +2320,20 @@ def drain_for_direct_provider( def apply_direct_provider_updates(self, updates: list[TaskUpdate]) -> TxResult: """Apply a batch of task state updates from a KubernetesProvider. - Same state machine as apply_task_updates but without worker lookup, - health updates, or resource decommit (no committed resources tracked). + Delegates to _apply_task_transitions with a synthetic HeartbeatApplyRequest. + Direct-provider tasks have worker_id=None in their snapshots, so + TaskStore.terminate/requeue correctly skip resource decommit. """ - tasks_to_kill: set[JobName] = set() - task_kill_workers: dict[JobName, WorkerId] = {} - - with self._db.transaction() as cur: + if not updates: + return TxResult() + with self._stores.transact() as ctx: now_ms = Timestamp.now().epoch_ms() - cascaded_jobs: set[JobName] = set() - - for update in updates: - task_row = cur.execute("SELECT * FROM tasks WHERE task_id = ?", (update.task_id.to_wire(),)).fetchone() - if task_row is None: - continue - task = TASK_DETAIL_PROJECTION.decode_one([task_row]) - if task_row_is_finished(task) or update.new_state in ( - job_pb2.TASK_STATE_UNSPECIFIED, - job_pb2.TASK_STATE_PENDING, - ): - continue - if update.attempt_id != int(task_row["current_attempt_id"]): - stale = cur.execute( - "SELECT state FROM task_attempts WHERE task_id = ? AND attempt_id = ?", - (update.task_id.to_wire(), update.attempt_id), - ).fetchone() - if stale is not None and int(stale["state"]) not in TERMINAL_TASK_STATES: - logger.error( - "Stale attempt precondition violation: task=%s reported=%d current=%d stale_state=%s", - update.task_id, - update.attempt_id, - int(task_row["current_attempt_id"]), - int(stale["state"]), - ) - continue - attempt_row = cur.execute( - "SELECT * FROM task_attempts WHERE task_id = ? AND attempt_id = ?", - (update.task_id.to_wire(), update.attempt_id), - ).fetchone() - if attempt_row is None: - continue - # See _apply_task_transitions for rationale: the current attempt may - # be terminal while the task is retrying in PENDING; late reports - # must not revive it. - if int(attempt_row["state"]) in TERMINAL_TASK_STATES: - logger.debug( - "Dropping late update for terminal attempt: task=%s attempt=%d attempt_state=%d reported=%d", - update.task_id, - update.attempt_id, - int(attempt_row["state"]), - int(update.new_state), - ) - continue - - if update.resource_usage is not None: - ru = update.resource_usage - cur.execute( - "INSERT INTO task_resource_history" - "(task_id, attempt_id, cpu_millicores, memory_mb, disk_mb, memory_peak_mb, timestamp_ms) " - "VALUES (?, ?, ?, ?, ?, ?, ?)", - ( - update.task_id.to_wire(), - update.attempt_id, - ru.cpu_millicores, - ru.memory_mb, - ru.disk_mb, - ru.memory_peak_mb, - now_ms, - ), - ) - if update.container_id is not None: - cur.execute( - "UPDATE tasks SET container_id = ? WHERE task_id = ?", - (update.container_id, update.task_id.to_wire()), - ) - - terminal_ms: int | None = None - started_ms: int | None = None - task_state = int(task_row["state"]) - task_error = update.error - task_exit = update.exit_code - failure_count = int(task_row["failure_count"]) - preemption_count = int(task_row["preemption_count"]) - - if update.new_state == job_pb2.TASK_STATE_RUNNING: - started_ms = now_ms - task_state = job_pb2.TASK_STATE_RUNNING - elif update.new_state == job_pb2.TASK_STATE_BUILDING: - task_state = job_pb2.TASK_STATE_BUILDING - elif update.new_state in ( - job_pb2.TASK_STATE_FAILED, - job_pb2.TASK_STATE_WORKER_FAILED, - job_pb2.TASK_STATE_KILLED, - job_pb2.TASK_STATE_UNSCHEDULABLE, - job_pb2.TASK_STATE_SUCCEEDED, - ): - terminal_ms = now_ms - task_state = int(update.new_state) - if update.new_state == job_pb2.TASK_STATE_SUCCEEDED and task_exit is None: - task_exit = 0 - if update.new_state == job_pb2.TASK_STATE_UNSCHEDULABLE and task_error is None: - task_error = "Scheduling timeout exceeded" - if update.new_state == job_pb2.TASK_STATE_FAILED: - failure_count += 1 - if ( - update.new_state == job_pb2.TASK_STATE_WORKER_FAILED - and int(task_row["state"]) in EXECUTING_TASK_STATES - ): - preemption_count += 1 - # WORKER_FAILED while still ASSIGNED -> retry immediately as PENDING - if ( - update.new_state == job_pb2.TASK_STATE_WORKER_FAILED - and int(task_row["state"]) == job_pb2.TASK_STATE_ASSIGNED - ): - task_state = job_pb2.TASK_STATE_PENDING - terminal_ms = None - if update.new_state == job_pb2.TASK_STATE_FAILED and failure_count <= int( - task_row["max_retries_failure"] - ): - task_state = job_pb2.TASK_STATE_PENDING - terminal_ms = None - if ( - update.new_state == job_pb2.TASK_STATE_WORKER_FAILED - and preemption_count <= int(task_row["max_retries_preemption"]) - and int(task_row["state"]) in EXECUTING_TASK_STATES - ): - task_state = job_pb2.TASK_STATE_PENDING - terminal_ms = None - - # An attempt is terminal whenever the update itself is terminal, even - # if the TASK rolls back to PENDING for a retry. - attempt_terminal_ms = now_ms if int(update.new_state) in TERMINAL_TASK_STATES else None - - cur.execute( - "UPDATE task_attempts SET state = ?, started_at_ms = COALESCE(started_at_ms, ?), " - "finished_at_ms = COALESCE(finished_at_ms, ?), exit_code = COALESCE(?, exit_code), " - "error = COALESCE(?, error) WHERE task_id = ? AND attempt_id = ?", - ( - int(update.new_state), - started_ms, - attempt_terminal_ms, - task_exit, - update.error, - update.task_id.to_wire(), - update.attempt_id, - ), - ) - if task_state in ACTIVE_TASK_STATES: - cur.execute( - "UPDATE tasks SET state = ?, error = COALESCE(?, error), exit_code = COALESCE(?, exit_code), " - "started_at_ms = COALESCE(started_at_ms, ?), finished_at_ms = ?, " - "failure_count = ?, preemption_count = ? " - "WHERE task_id = ?", - ( - task_state, - task_error, - task_exit, - started_ms, - terminal_ms, - failure_count, - preemption_count, - update.task_id.to_wire(), - ), - ) - else: - cur.execute( - "UPDATE tasks SET state = ?, error = COALESCE(?, error), exit_code = COALESCE(?, exit_code), " - "started_at_ms = COALESCE(started_at_ms, ?), finished_at_ms = ?, " - "failure_count = ?, preemption_count = ?, " - "current_worker_id = NULL, current_worker_address = NULL " - "WHERE task_id = ?", - ( - task_state, - task_error, - task_exit, - started_ms, - terminal_ms, - failure_count, - preemption_count, - update.task_id.to_wire(), - ), - ) - jc_row = cur.execute("SELECT * FROM job_config WHERE job_id = ?", (task.job_id.to_wire(),)).fetchone() - - if update.new_state in TERMINAL_TASK_STATES: - delete_task_endpoints(cur, self._db.endpoints, update.task_id.to_wire()) - - # Coscheduled sibling cascade. - if jc_row is not None and task_state in FAILURE_TASK_STATES: - has_cosched = bool(int(jc_row["has_coscheduling"])) - siblings = _find_coscheduled_siblings(cur, task.job_id, update.task_id, has_cosched) - job_resources = resource_spec_from_scalars( - int(jc_row["res_cpu_millicores"]), - int(jc_row["res_memory_bytes"]), - int(jc_row["res_disk_bytes"]), - jc_row["res_device_json"], - ) - cascade_kill, cascade_workers = _terminate_coscheduled_siblings( - cur, self._db.endpoints, siblings, update.task_id, job_resources, now_ms - ) - tasks_to_kill.update(cascade_kill) - task_kill_workers.update(cascade_workers) - - if task.job_id not in cascaded_jobs: - new_job_state = self._recompute_job_state(cur, task.job_id) - if new_job_state in TERMINAL_JOB_STATES: - final_tasks_to_kill, final_task_kill_workers = _finalize_terminal_job( - cur, self._db.endpoints, task.job_id, new_job_state, now_ms - ) - tasks_to_kill.update(final_tasks_to_kill) - task_kill_workers.update(final_task_kill_workers) - cascaded_jobs.add(task.job_id) - - if tasks_to_kill or cascaded_jobs: - actions: list[tuple[str, str, dict[str, object]]] = [("direct_provider_updates_applied", "direct", {})] - for job_id in cascaded_jobs: - actions.append(("job_terminated", job_id.to_wire(), {})) - self._record_transaction(cur, "apply_direct_provider_updates", actions) - - return TxResult(tasks_to_kill=tasks_to_kill, task_kill_workers=task_kill_workers) + req = HeartbeatApplyRequest( + worker_id=WorkerId("direct"), + worker_resource_snapshot=None, + updates=updates, + ) + return self._apply_task_transitions(ctx, req, now_ms) def buffer_direct_kill(self, task_id: str) -> None: """Buffer a kill request for a direct-provider task. @@ -3563,50 +2341,5 @@ def buffer_direct_kill(self, task_id: str) -> None: Inserts a kill entry into dispatch_queue with worker_id=NULL. Drained by drain_for_direct_provider(). """ - with self._db.transaction() as cur: - enqueue_kill_dispatch(cur, None, task_id, Timestamp.now().epoch_ms()) - - # ========================================================================= - # Test helpers - # ========================================================================= - - def set_worker_consecutive_failures_for_test(self, worker_id: WorkerId, consecutive_failures: int) -> None: - """Test helper: set worker consecutive failure count in DB.""" - self._db.execute( - "UPDATE workers SET consecutive_failures = ? WHERE worker_id = ?", - (consecutive_failures, str(worker_id)), - ) - - def set_task_state_for_test( - self, - task_id: JobName, - state: int, - *, - error: str | None = None, - exit_code: int | None = None, - ) -> None: - """Test helper: set task state directly in DB.""" - if state in ACTIVE_TASK_STATES: - self._db.execute( - "UPDATE tasks SET state = ?, error = ?, exit_code = ? WHERE task_id = ?", - (state, error, exit_code, task_id.to_wire()), - ) - else: - self._db.execute( - "UPDATE tasks SET state = ?, error = ?, exit_code = ?, " - "current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", - (state, error, exit_code, task_id.to_wire()), - ) - - def create_attempt_for_test(self, task_id: JobName, worker_id: WorkerId) -> int: - """Test helper: append a new task_attempt without finalizing prior attempt.""" - task = self._db.fetchone("SELECT current_attempt_id FROM tasks WHERE task_id = ?", (task_id.to_wire(),)) - if task is None: - raise ValueError(f"unknown task: {task_id}") - worker_row = self._db.fetchone("SELECT address FROM workers WHERE worker_id = ?", (str(worker_id),)) - worker_address = str(worker_row["address"]) if worker_row is not None else str(worker_id) - next_attempt_id = int(task["current_attempt_id"]) + 1 - now_ms = Timestamp.now().epoch_ms() - with self._db.transaction() as cur: - _assign_task(cur, task_id.to_wire(), str(worker_id), worker_address, next_attempt_id, now_ms) - return next_attempt_id + with self._stores.transact() as ctx: + ctx.dispatch.enqueue_kill(ctx.cur, None, task_id, Timestamp.now().epoch_ms()) diff --git a/lib/iris/tests/cluster/conftest.py b/lib/iris/tests/cluster/conftest.py index 013013e710..327eae1ac8 100644 --- a/lib/iris/tests/cluster/conftest.py +++ b/lib/iris/tests/cluster/conftest.py @@ -23,6 +23,7 @@ WORKER_DETAIL_PROJECTION, ) from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.store import ControllerStores from iris.cluster.controller.transitions import ( Assignment, ControllerTransitions, @@ -391,7 +392,8 @@ def _drive_gcp(self, task_id: JobName, new_state: int) -> None: def _make_k8s_harness(tmp_path) -> ServiceTestHarness: db = ControllerDB(db_dir=tmp_path / "k8s_db") - state = ControllerTransitions(db=db) + stores = ControllerStores.from_db(db) + state = ControllerTransitions(stores=stores) k8s = InMemoryK8sService() k8s.add_node_pool( @@ -411,9 +413,10 @@ def _make_k8s_harness(tmp_path) -> ServiceTestHarness: ctrl.has_direct_provider = True ctrl.provider = k8s_provider + k8s_stores = ControllerStores.from_db(db) service = ControllerServiceImpl( state, - db, + k8s_stores, controller=ctrl, bundle_store=BundleStore(storage_dir=str(tmp_path / "k8s_bundles")), log_service=LogServiceImpl(), @@ -431,14 +434,15 @@ def _make_k8s_harness(tmp_path) -> ServiceTestHarness: def _make_gcp_harness(tmp_path) -> ServiceTestHarness: db = ControllerDB(db_dir=tmp_path / "gcp_db") - state = ControllerTransitions(db=db) + stores = ControllerStores.from_db(db) + state = ControllerTransitions(stores=stores) ctrl = _HarnessController() ctrl.has_direct_provider = False service = ControllerServiceImpl( state, - db, + stores, controller=ctrl, bundle_store=BundleStore(storage_dir=str(tmp_path / "gcp_bundles")), log_service=LogServiceImpl(), diff --git a/lib/iris/tests/cluster/controller/_testing.py b/lib/iris/tests/cluster/controller/_testing.py new file mode 100644 index 0000000000..cd1079f790 --- /dev/null +++ b/lib/iris/tests/cluster/controller/_testing.py @@ -0,0 +1,103 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Test-only DB mutation helpers for controller tests. + +These functions reach directly into the SQLite DB to set state that is +difficult or impossible to reach through normal controller transitions. +They live here (not in production code) to keep transitions.py clean. +""" + +from iris.cluster.constraints import AttributeValue +from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.schema import ACTIVE_TASK_STATES +from iris.cluster.controller.store import ControllerStores, WorkerAssignment +from iris.cluster.types import JobName, WorkerId +from rigging.timing import Timestamp + + +def set_worker_health(db: ControllerDB, worker_id: WorkerId, *, healthy: bool) -> None: + """Set worker health directly in DB.""" + db.execute( + "UPDATE workers SET healthy = ?, consecutive_failures = ? WHERE worker_id = ?", + (1 if healthy else 0, 0 if healthy else 1, str(worker_id)), + ) + + +def set_worker_attribute(db: ControllerDB, worker_id: WorkerId, key: str, value: AttributeValue) -> None: + """Upsert one worker attribute directly in DB.""" + str_value = int_value = float_value = None + value_type = "str" + if isinstance(value.value, int): + value_type = "int" + int_value = int(value.value) + elif isinstance(value.value, float): + value_type = "float" + float_value = float(value.value) + else: + str_value = str(value.value) + + db.execute( + "INSERT INTO worker_attributes(worker_id, key, value_type, str_value, int_value, float_value) " + "VALUES (?, ?, ?, ?, ?, ?) " + "ON CONFLICT(worker_id, key) DO UPDATE SET " + "value_type=excluded.value_type, " + "str_value=excluded.str_value, " + "int_value=excluded.int_value, " + "float_value=excluded.float_value", + (str(worker_id), key, value_type, str_value, int_value, float_value), + ) + + +def set_worker_consecutive_failures(db: ControllerDB, worker_id: WorkerId, consecutive_failures: int) -> None: + """Set worker consecutive failure count directly in DB.""" + db.execute( + "UPDATE workers SET consecutive_failures = ? WHERE worker_id = ?", + (consecutive_failures, str(worker_id)), + ) + + +def set_task_state( + db: ControllerDB, + task_id: JobName, + state: int, + *, + error: str | None = None, + exit_code: int | None = None, +) -> None: + """Set task state directly in DB.""" + if state in ACTIVE_TASK_STATES: + db.execute( + "UPDATE tasks SET state = ?, error = ?, exit_code = ? WHERE task_id = ?", + (state, error, exit_code, task_id.to_wire()), + ) + else: + db.execute( + "UPDATE tasks SET state = ?, error = ?, exit_code = ?, " + "current_worker_id = NULL, current_worker_address = NULL WHERE task_id = ?", + (state, error, exit_code, task_id.to_wire()), + ) + + +def create_attempt(stores: ControllerStores, task_id: JobName, worker_id: WorkerId) -> int: + """Append a new task_attempt without finalizing the prior attempt.""" + db = stores.db + task = db.fetchone("SELECT current_attempt_id FROM tasks WHERE task_id = ?", (task_id.to_wire(),)) + if task is None: + raise ValueError(f"unknown task: {task_id}") + worker_row = db.fetchone("SELECT address FROM workers WHERE worker_id = ?", (str(worker_id),)) + worker_address = str(worker_row["address"]) if worker_row is not None else str(worker_id) + next_attempt_id = int(task["current_attempt_id"]) + 1 + now_ms = Timestamp.now().epoch_ms() + with stores.transact() as ctx: + ctx.tasks.assign_to_worker( + ctx.cur, + WorkerAssignment( + task_id=task_id.to_wire(), + attempt_id=next_attempt_id, + worker_id=str(worker_id), + worker_address=worker_address, + now_ms=now_ms, + ), + ) + return next_attempt_id diff --git a/lib/iris/tests/cluster/controller/conftest.py b/lib/iris/tests/cluster/controller/conftest.py index a06c6190ea..92775dff23 100644 --- a/lib/iris/tests/cluster/controller/conftest.py +++ b/lib/iris/tests/cluster/controller/conftest.py @@ -28,18 +28,18 @@ ) from iris.cluster.controller.autoscaler import Autoscaler from iris.cluster.controller.autoscaler.models import DemandEntry -from iris.cluster.controller.db import ( - ACTIVE_TASK_STATES, - ControllerDB, +from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import ( + ControllerStores, _decode_attribute_rows, task_row_can_be_scheduled, task_row_is_finished, ) from iris.cluster.controller.schema import ( + ACTIVE_TASK_STATES, ATTEMPT_PROJECTION, JOB_CONFIG_JOIN, JOB_DETAIL_PROJECTION, - JOB_SCHEDULING_PROJECTION, TASK_DETAIL_PROJECTION, WORKER_ROW_PROJECTION, JobDetailRow, @@ -69,6 +69,7 @@ from iris.rpc import controller_pb2 from iris.time_proto import duration_to_proto from rigging.timing import Duration, Timestamp +from tests.cluster.controller._testing import set_task_state from tests.cluster.providers.conftest import make_mock_platform check_task_can_be_scheduled = task_row_can_be_scheduled @@ -150,7 +151,7 @@ def controller_service(state, log_service, mock_controller, tmp_path) -> Control """ControllerServiceImpl with fresh DB, log service, and mock controller.""" return ControllerServiceImpl( state, - state._db, + state._stores, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, @@ -168,7 +169,8 @@ def make_controller_state(**kwargs): tmp = Path(tempfile.mkdtemp(prefix="iris_test_")) try: db = ControllerDB(db_dir=tmp) - yield ControllerTransitions(db=db, **kwargs) + stores = ControllerStores.from_db(db) + yield ControllerTransitions(stores=stores, **kwargs) db.close() finally: shutil.rmtree(tmp, ignore_errors=True) @@ -244,9 +246,9 @@ def query_job(state: ControllerTransitions, job_id: JobName) -> JobDetailRow | N def query_job_row(state: ControllerTransitions, job_id: JobName): - """Query a job as a JobSchedulingRow (scheduling projection with resources/constraints).""" + """Query a job as a JobDetailRow.""" with state._db.snapshot() as q: - return JOB_SCHEDULING_PROJECTION.decode_one( + return JOB_DETAIL_PROJECTION.decode_one( q.fetchall( f"SELECT j.*, jc.* FROM jobs j {JOB_CONFIG_JOIN} WHERE j.job_id = ? LIMIT 1", (job_id.to_wire(),), @@ -545,12 +547,7 @@ def transition_task( current_attempt = task.attempts[-1] if task.attempts else None worker_id = current_attempt.worker_id if current_attempt is not None else task.current_worker_id if worker_id is None: - state.set_task_state_for_test( - task_id, - new_state, - error=error, - exit_code=exit_code, - ) + set_task_state(state._db, task_id, new_state, error=error, exit_code=exit_code) return state return state.apply_task_updates( HeartbeatApplyRequest( diff --git a/lib/iris/tests/cluster/controller/test_api_keys.py b/lib/iris/tests/cluster/controller/test_api_keys.py index 8b9586f2c9..f9cfcd18ba 100644 --- a/lib/iris/tests/cluster/controller/test_api_keys.py +++ b/lib/iris/tests/cluster/controller/test_api_keys.py @@ -23,6 +23,7 @@ from rigging.timing import Timestamp from iris.cluster.controller.db import ControllerDB from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.store import ControllerStores from iris.cluster.controller.transitions import ControllerTransitions from iris.log_server.server import LogServiceImpl from iris.rpc import config_pb2 @@ -39,7 +40,8 @@ def db(tmp_path): def _make_service(db, auth=None): """Create a ControllerServiceImpl with minimal dependencies for API key tests.""" - state = ControllerTransitions(db=db) + stores = ControllerStores.from_db(db) + state = ControllerTransitions(stores=stores) controller_mock = Mock() controller_mock.wake = Mock() @@ -51,7 +53,7 @@ def _make_service(db, auth=None): return ControllerServiceImpl( state, - db, + stores, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(db.db_path.parent / "bundles")), log_service=LogServiceImpl(), diff --git a/lib/iris/tests/cluster/controller/test_auth.py b/lib/iris/tests/cluster/controller/test_auth.py index 1e40c58d4d..80c1024c8e 100644 --- a/lib/iris/tests/cluster/controller/test_auth.py +++ b/lib/iris/tests/cluster/controller/test_auth.py @@ -23,6 +23,7 @@ from iris.cluster.controller.dashboard import ControllerDashboard from iris.cluster.controller.db import ControllerDB from iris.cluster.controller.service import ControllerServiceImpl +from iris.cluster.controller.store import ControllerStores from iris.cluster.controller.transitions import ControllerTransitions from iris.rpc.auth import SESSION_COOKIE, StaticTokenVerifier, hash_token, resolve_auth from rigging.timing import Timestamp @@ -44,7 +45,8 @@ def db(tmp_path): @pytest.fixture def state(db, tmp_path): - s = ControllerTransitions(db=db) + stores = ControllerStores.from_db(db) + s = ControllerTransitions(stores=stores) yield s @@ -57,7 +59,7 @@ def service(state, tmp_path): controller_mock.has_direct_provider = False return ControllerServiceImpl( state, - state._db, + state._stores, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=LogServiceImpl(), diff --git a/lib/iris/tests/cluster/controller/test_budgets.py b/lib/iris/tests/cluster/controller/test_budgets.py index 7bc332a8d1..ef9836aa34 100644 --- a/lib/iris/tests/cluster/controller/test_budgets.py +++ b/lib/iris/tests/cluster/controller/test_budgets.py @@ -6,7 +6,8 @@ from pathlib import Path import pytest -from iris.cluster.controller.db import ControllerDB, UserBudget +from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import UserBudget from rigging.timing import Timestamp diff --git a/lib/iris/tests/cluster/controller/test_dashboard.py b/lib/iris/tests/cluster/controller/test_dashboard.py index 57fe8b3e6f..1e99681d0a 100644 --- a/lib/iris/tests/cluster/controller/test_dashboard.py +++ b/lib/iris/tests/cluster/controller/test_dashboard.py @@ -18,9 +18,6 @@ from iris.cluster.controller.codec import constraints_from_json, resource_spec_from_scalars from iris.cluster.controller.dashboard import ControllerDashboard from iris.log_server.server import LogServiceImpl -from iris.cluster.controller.db import ( - healthy_active_workers_with_attributes, -) from iris.cluster.controller.schema import ( JOB_CONFIG_JOIN, JOB_DETAIL_PROJECTION, @@ -142,7 +139,10 @@ def _get_job_scheduling_diagnostics(job_wire_id): return None req = JobRequirements( resources=resource_spec_from_scalars( - job.res_cpu_millicores, job.res_memory_bytes, job.res_disk_bytes, job.res_device_json + job.resources.cpu_millicores, + job.resources.memory_bytes, + job.resources.disk_bytes, + job.resources.device_json, ), constraints=constraints_from_json(job.constraints_json), is_coscheduled=job.has_coscheduling, @@ -150,7 +150,8 @@ def _get_job_scheduling_diagnostics(job_wire_id): ) tasks = _query_tasks_with_attempts(state, job.job_id) schedulable_task_id = next((t.task_id for t in tasks if check_task_can_be_scheduled(t)), None) - workers = healthy_active_workers_with_attributes(state._db) + with state._stores.read() as ctx: + workers = state._stores.workers.healthy_active_with_attributes(ctx.cur) context = _create_scheduling_context(workers) return scheduler.get_job_scheduling_diagnostics(req, context, schedulable_task_id, num_tasks=len(tasks)) @@ -170,7 +171,7 @@ def service(state, scheduler, tmp_path): log_service = LogServiceImpl() return ControllerServiceImpl( state, - state._db, + state._stores, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, @@ -190,7 +191,7 @@ def service_with_autoscaler(state, scheduler, mock_autoscaler, tmp_path): log_service = LogServiceImpl() return ControllerServiceImpl( state, - state._db, + state._stores, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, @@ -1046,7 +1047,7 @@ def test_auth_config_kubernetes_provider_kind(state, scheduler, tmp_path): log_service = LogServiceImpl() svc = ControllerServiceImpl( state, - state._db, + state._stores, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, @@ -1078,7 +1079,7 @@ def _make_k8s_dashboard_client(state, scheduler, tmp_path): log_service = LogServiceImpl() svc = ControllerServiceImpl( state, - state._db, + state._stores, controller=controller_mock, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=log_service, diff --git a/lib/iris/tests/cluster/controller/test_db.py b/lib/iris/tests/cluster/controller/test_db.py index 130685a93e..00ecccdcd7 100644 --- a/lib/iris/tests/cluster/controller/test_db.py +++ b/lib/iris/tests/cluster/controller/test_db.py @@ -13,6 +13,7 @@ Row, TransactionCursor, ) +from iris.cluster.controller.store import ControllerStores @pytest.fixture @@ -288,10 +289,10 @@ def test_replace_from_reattaches_profiles_db(tmp_path: Path) -> None: """replace_from() must re-attach the profiles DB so profile tables remain accessible.""" from rigging.timing import Timestamp - from iris.cluster.controller.db import get_task_profiles, insert_task_profile - db = ControllerDB(db_dir=tmp_path) - insert_task_profile(db, "task-1", b"profile-data", Timestamp.now()) + stores = ControllerStores.from_db(db) + with stores.transact() as ctx: + stores.tasks.insert_task_profile(ctx.cur, "task-1", b"profile-data", Timestamp.now()) backup_dir = tmp_path / "backup" backup_dir.mkdir() @@ -312,7 +313,8 @@ def test_replace_from_reattaches_profiles_db(tmp_path: Path) -> None: db.replace_from(str(backup_dir)) - profiles = get_task_profiles(db, "task-1") + with stores.read() as ctx: + profiles = stores.tasks.get_task_profiles(ctx.cur, "task-1") assert len(profiles) == 1 db.close() @@ -494,3 +496,42 @@ def test_backfill_attempt_finished_at_migration(tmp_path: Path) -> None: assert out[("/u/E", 0)] == 7200 assert out[("/u/E", 1)] is None conn.close() + + +def test_controller_stores_read_scope(tmp_path: Path) -> None: + db = ControllerDB(db_dir=tmp_path) + stores = ControllerStores.from_db(db) + with stores.read() as ctx: + rows = ctx.cur.execute("SELECT 1 AS one").fetchall() + assert [tuple(r) for r in rows] == [(1,)] + # Read-only store method works inside a read scope. + assert ctx.endpoints.query() == [] + + +def test_controller_stores_read_rejects_writes(tmp_path: Path) -> None: + """A read() scope must refuse any mutation of the underlying database. + + The read pool hands out read-only SQLite connections, so any write issued + through ``ctx.cur`` fails with ``OperationalError`` before it can reach + disk. This is what lets callers treat ``read()`` as a strictly non-mutating + context without auditing every downstream store method they pass the + cursor into. + """ + db = ControllerDB(db_dir=tmp_path) + stores = ControllerStores.from_db(db) + + # Seed a row via a real write transaction so we have something to mutate. + with stores.transact() as ctx: + ctx.cur.execute("CREATE TABLE _rollback_probe(id INTEGER PRIMARY KEY, v TEXT)") + ctx.cur.execute("INSERT INTO _rollback_probe(id, v) VALUES (1, 'committed')") + + with stores.read() as rctx: + with pytest.raises(sqlite3.OperationalError, match="readonly"): + rctx.cur.execute("INSERT INTO _rollback_probe(id, v) VALUES (2, 'scratch')") + with pytest.raises(sqlite3.OperationalError, match="readonly"): + rctx.cur.execute("UPDATE _rollback_probe SET v = 'dirty' WHERE id = 1") + + # Confirm the seed row is untouched. + with stores.read() as rctx: + rows = [tuple(r) for r in rctx.cur.execute("SELECT id, v FROM _rollback_probe ORDER BY id").fetchall()] + assert rows == [(1, "committed")] diff --git a/lib/iris/tests/cluster/controller/test_endpoint_registry.py b/lib/iris/tests/cluster/controller/test_endpoint_registry.py index e3e2b73e0b..469e9e4875 100644 --- a/lib/iris/tests/cluster/controller/test_endpoint_registry.py +++ b/lib/iris/tests/cluster/controller/test_endpoint_registry.py @@ -1,7 +1,7 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -"""Tests for EndpointRegistry — the in-memory cache over the ``endpoints`` table.""" +"""Tests for EndpointStore — the in-memory cache over the ``endpoints`` table.""" from __future__ import annotations @@ -9,9 +9,8 @@ import pytest -from iris.cluster.controller.db import EndpointQuery -from iris.cluster.controller.endpoint_registry import EndpointRegistry from iris.cluster.controller.schema import ENDPOINT_PROJECTION, EndpointRow +from iris.cluster.controller.store import EndpointQuery, EndpointStore from iris.cluster.types import JobName from iris.rpc import job_pb2 from rigging.timing import Timestamp @@ -69,13 +68,15 @@ def _make_row(endpoint_id: str, name: str, task_id: JobName, *, address: str = " # --- Load / add / remove ---------------------------------------------------- -def test_registry_loads_existing_rows_on_startup(state): - """On construction, the registry should contain every row in the ``endpoints`` table.""" +def test_store_loads_existing_rows_on_startup(state): + """On construction, the store should contain every row in the ``endpoints`` table.""" tasks = submit_job(state, "j", make_job_request("j")) with state._db.transaction() as cur: - assert state._db.endpoints.add(cur, _make_row("e1", "svc", tasks[0].task_id)) + assert state._stores.endpoints.add(cur, _make_row("e1", "svc", tasks[0].task_id)) - fresh = EndpointRegistry(state._db) + fresh = EndpointStore() + with state._db.read_snapshot() as snap: + fresh._load_all(snap) rows = fresh.query() assert [r.endpoint_id for r in rows] == ["e1"] @@ -85,12 +86,12 @@ def test_add_updates_memory_after_commit(state): t = tasks[0].task_id with state._db.transaction() as cur: - assert state._db.endpoints.add(cur, _make_row("e1", "alpha", t)) + assert state._stores.endpoints.add(cur, _make_row("e1", "alpha", t)) # Not yet committed; memory should not reflect the insert. - assert state._db.endpoints.get("e1") is None + assert state._stores.endpoints.get("e1") is None - assert state._db.endpoints.get("e1") is not None - assert [r.endpoint_id for r in state._db.endpoints.query()] == ["e1"] + assert state._stores.endpoints.get("e1") is not None + assert [r.endpoint_id for r in state._stores.endpoints.query()] == ["e1"] def test_rollback_leaves_memory_untouched(state): @@ -102,12 +103,12 @@ class BoomError(RuntimeError): with pytest.raises(BoomError): with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row("e1", "alpha", t)) + state._stores.endpoints.add(cur, _make_row("e1", "alpha", t)) raise BoomError # DB rolled back → memory must NOT see the insert. - assert state._db.endpoints.get("e1") is None - assert state._db.endpoints.query() == [] + assert state._stores.endpoints.get("e1") is None + assert state._stores.endpoints.query() == [] def test_add_rejects_terminal_task(state): @@ -121,22 +122,22 @@ def test_add_rejects_terminal_task(state): ) with state._db.transaction() as cur: - assert state._db.endpoints.add(cur, _make_row("e1", "alpha", task_id)) is False + assert state._stores.endpoints.add(cur, _make_row("e1", "alpha", task_id)) is False - assert state._db.endpoints.get("e1") is None + assert state._stores.endpoints.get("e1") is None def test_remove_drops_endpoint_by_id(state): tasks = submit_job(state, "j", make_job_request("j")) t = tasks[0].task_id with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row("e1", "alpha", t)) - state._db.endpoints.add(cur, _make_row("e2", "beta", t)) + state._stores.endpoints.add(cur, _make_row("e1", "alpha", t)) + state._stores.endpoints.add(cur, _make_row("e2", "beta", t)) with state._db.transaction() as cur: - removed = state._db.endpoints.remove(cur, "e1") + removed = state._stores.endpoints.remove(cur, "e1") assert removed is not None and removed.endpoint_id == "e1" - assert {r.endpoint_id for r in state._db.endpoints.query()} == {"e2"} + assert {r.endpoint_id for r in state._stores.endpoints.query()} == {"e2"} def test_remove_by_task_drops_all_task_endpoints(state): @@ -144,15 +145,15 @@ def test_remove_by_task_drops_all_task_endpoints(state): t1, t2 = tasks[0].task_id, tasks[1].task_id with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row("e1", "alpha", t1)) - state._db.endpoints.add(cur, _make_row("e2", "beta", t1)) - state._db.endpoints.add(cur, _make_row("e3", "gamma", t2)) + state._stores.endpoints.add(cur, _make_row("e1", "alpha", t1)) + state._stores.endpoints.add(cur, _make_row("e2", "beta", t1)) + state._stores.endpoints.add(cur, _make_row("e3", "gamma", t2)) with state._db.transaction() as cur: - removed = state._db.endpoints.remove_by_task(cur, t1) + removed = state._stores.endpoints.remove_by_task(cur, t1) assert set(removed) == {"e1", "e2"} - assert {r.endpoint_id for r in state._db.endpoints.query()} == {"e3"} + assert {r.endpoint_id for r in state._stores.endpoints.query()} == {"e3"} def test_remove_by_job_ids_drops_subtree(state): @@ -163,14 +164,14 @@ def test_remove_by_job_ids_drops_subtree(state): t2 = tasks_b[0].task_id with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row("e1", "alpha", t1)) - state._db.endpoints.add(cur, _make_row("e2", "beta", t2)) + state._stores.endpoints.add(cur, _make_row("e1", "alpha", t1)) + state._stores.endpoints.add(cur, _make_row("e2", "beta", t2)) with state._db.transaction() as cur: - removed = state._db.endpoints.remove_by_job_ids(cur, [ja]) + removed = state._stores.endpoints.remove_by_job_ids(cur, [ja]) assert removed == ["e1"] - assert [r.endpoint_id for r in state._db.endpoints.query()] == ["e2"] + assert [r.endpoint_id for r in state._stores.endpoints.query()] == ["e2"] # --- Query semantics -------------------------------------------------------- @@ -193,51 +194,51 @@ def populated(state): ] with state._db.transaction() as cur: for r in rows: - state._db.endpoints.add(cur, r) + state._stores.endpoints.add(cur, r) return state, rows, (t0, t1, t2) def test_query_by_exact_name(populated): state, _, _ = populated - ids = {r.endpoint_id for r in state._db.endpoints.query(EndpointQuery(exact_name="alpha/svc"))} + ids = {r.endpoint_id for r in state._stores.endpoints.query(EndpointQuery(exact_name="alpha/svc"))} assert ids == {"e1"} def test_query_by_prefix(populated): state, _, _ = populated - ids = {r.endpoint_id for r in state._db.endpoints.query(EndpointQuery(name_prefix="alpha/"))} + ids = {r.endpoint_id for r in state._stores.endpoints.query(EndpointQuery(name_prefix="alpha/"))} assert ids == {"e1", "e2"} def test_query_by_task_ids(populated): state, _, (t0, _, t2) = populated - ids = {r.endpoint_id for r in state._db.endpoints.query(EndpointQuery(task_ids=(t0, t2)))} + ids = {r.endpoint_id for r in state._stores.endpoints.query(EndpointQuery(task_ids=(t0, t2)))} assert ids == {"e1", "e2", "e4"} def test_query_by_endpoint_ids(populated): state, _, _ = populated - ids = {r.endpoint_id for r in state._db.endpoints.query(EndpointQuery(endpoint_ids=("e2", "e3")))} + ids = {r.endpoint_id for r in state._stores.endpoints.query(EndpointQuery(endpoint_ids=("e2", "e3")))} assert ids == {"e2", "e3"} def test_query_limit(populated): state, _, _ = populated - rows = state._db.endpoints.query(EndpointQuery(limit=2)) + rows = state._stores.endpoints.query(EndpointQuery(limit=2)) assert len(rows) == 2 def test_query_empty_matches_all(populated): state, rows, _ = populated - assert {r.endpoint_id for r in state._db.endpoints.query()} == {r.endpoint_id for r in rows} + assert {r.endpoint_id for r in state._stores.endpoints.query()} == {r.endpoint_id for r in rows} def test_resolve_returns_address_for_exact_name(populated): state, _, _ = populated - row = state._db.endpoints.resolve("alpha/svc") + row = state._stores.endpoints.resolve("alpha/svc") assert row is not None assert row.endpoint_id == "e1" - assert state._db.endpoints.resolve("nope") is None + assert state._stores.endpoints.resolve("nope") is None # --- Parity with the legacy SQL builder ------------------------------------- @@ -262,7 +263,7 @@ def test_registry_parity_with_legacy_sql(populated, build_query): sql, params = _endpoint_query_sql_legacy(query) with state._db.read_snapshot() as q: expected_ids = sorted(r.endpoint_id for r in ENDPOINT_PROJECTION.decode(q.fetchall(sql, tuple(params)))) - actual_ids = sorted(r.endpoint_id for r in state._db.endpoints.query(query)) + actual_ids = sorted(r.endpoint_id for r in state._stores.endpoints.query(query)) # For LIMIT queries, both sides just need to be a valid subset of matching rows. if query.limit is not None: @@ -290,9 +291,9 @@ def writer(): eid = f"e{i % len(task_ids)}" name = f"svc-{i % len(task_ids)}" with state._db.transaction() as cur: - state._db.endpoints.add(cur, _make_row(eid, name, t)) + state._stores.endpoints.add(cur, _make_row(eid, name, t)) with state._db.transaction() as cur: - state._db.endpoints.remove(cur, eid) + state._stores.endpoints.remove(cur, eid) i += 1 except Exception as exc: errors.append(f"writer: {exc!r}") @@ -300,7 +301,7 @@ def writer(): def reader(): try: while not stop.is_set(): - snapshot = state._db.endpoints.query() + snapshot = state._stores.endpoints.query() # Verify the snapshot itself is internally consistent: every # endpoint_id in the result set is unique (no duplicates from # a torn index). @@ -311,11 +312,11 @@ def reader(): # present in a subsequent get() — the writer may remove it # between the two calls (TOCTOU). for row in snapshot: - state._db.endpoints.get(row.endpoint_id) + state._stores.endpoints.get(row.endpoint_id) for i in range(len(task_ids)): - state._db.endpoints.query(EndpointQuery(name_prefix=f"svc-{i}")) - state._db.endpoints.query(EndpointQuery(exact_name=f"svc-{i}")) - state._db.endpoints.query(EndpointQuery(task_ids=(task_ids[i],))) + state._stores.endpoints.query(EndpointQuery(name_prefix=f"svc-{i}")) + state._stores.endpoints.query(EndpointQuery(exact_name=f"svc-{i}")) + state._stores.endpoints.query(EndpointQuery(task_ids=(task_ids[i],))) except Exception as exc: errors.append(f"reader: {exc!r}") diff --git a/lib/iris/tests/cluster/controller/test_heartbeat.py b/lib/iris/tests/cluster/controller/test_heartbeat.py index 5976c30265..38dfe83c2f 100644 --- a/lib/iris/tests/cluster/controller/test_heartbeat.py +++ b/lib/iris/tests/cluster/controller/test_heartbeat.py @@ -14,6 +14,7 @@ TASK_DETAIL_PROJECTION, WORKER_DETAIL_PROJECTION, ) +from iris.cluster.controller.store import ControllerStores from tests.cluster.controller.conftest import FakeProvider from iris.cluster.controller.transitions import ( Assignment, @@ -35,7 +36,8 @@ @pytest.fixture def state(tmp_path): db = ControllerDB(db_dir=tmp_path) - s = ControllerTransitions(db=db) + stores = ControllerStores.from_db(db) + s = ControllerTransitions(stores=stores) yield s db.close() @@ -130,7 +132,8 @@ def test_fail_heartbeat_below_threshold(state, worker_metadata): def test_fail_heartbeat_at_threshold(tmp_path, worker_metadata): """RPC failures at threshold return WORKER_FAILED and prune the worker.""" db = ControllerDB(db_dir=tmp_path) - state = ControllerTransitions(db=db, heartbeat_failure_threshold=3) + stores = ControllerStores.from_db(db) + state = ControllerTransitions(stores=stores, heartbeat_failure_threshold=3) _register_worker(state, "worker1", worker_metadata) snapshot = _make_snapshot("worker1") @@ -168,7 +171,8 @@ def test_unhealthy_worker_cascades_to_tasks(tmp_path): use heartbeat_failure_threshold=1 to trigger removal on the first unhealthy report. """ db = ControllerDB(db_dir=tmp_path) - state = ControllerTransitions(db=db, heartbeat_failure_threshold=1) + stores = ControllerStores.from_db(db) + state = ControllerTransitions(stores=stores, heartbeat_failure_threshold=1) worker_metadata = job_pb2.WorkerMetadata( hostname="test-host", ip_address="192.168.1.1", diff --git a/lib/iris/tests/cluster/controller/test_reservation.py b/lib/iris/tests/cluster/controller/test_reservation.py index cc758b2cf0..d151e5890c 100644 --- a/lib/iris/tests/cluster/controller/test_reservation.py +++ b/lib/iris/tests/cluster/controller/test_reservation.py @@ -28,7 +28,7 @@ ) from iris.cluster.controller.scheduler import JobRequirements, Scheduler, SchedulingContext -from iris.cluster.controller.db import task_row_can_be_scheduled +from iris.cluster.controller.store import task_row_can_be_scheduled from iris.cluster.controller.schema import WorkerRow from iris.cluster.controller.transitions import ( HEARTBEAT_FAILURE_THRESHOLD, @@ -51,6 +51,11 @@ from iris.rpc import job_pb2 from iris.rpc import controller_pb2 from rigging.timing import Timestamp +from tests.cluster.controller._testing import ( + set_task_state as _set_task_state, + set_worker_attribute as _set_worker_attribute, + set_worker_health as _set_worker_health, +) from tests.cluster.controller.conftest import ( FakeProvider, hydrate_worker_attributes as _with_attrs, @@ -405,7 +410,7 @@ def test_claim_skips_unhealthy_worker(): ctrl = _make_controller() _register_worker(ctrl.state, "w1") # Mark worker unhealthy - ctrl.state.set_worker_health_for_test(WorkerId("w1"), False) + _set_worker_health(ctrl.state._db, WorkerId("w1"), healthy=False) req = _make_job_request_with_reservation( reservation_entries=[_make_reservation_entry()], @@ -903,7 +908,7 @@ def test_region_constraint_injected_from_claimed_workers(): ctrl = _make_controller() w1 = _register_worker(ctrl.state, "w1") # Set region attribute on worker - ctrl.state.set_worker_attribute_for_test(w1, WellKnownAttribute.REGION, AttributeValue("us-central1")) + _set_worker_attribute(ctrl.state._db, w1, WellKnownAttribute.REGION, AttributeValue("us-central1")) req = _make_job_request_with_reservation(reservation_entries=[_make_reservation_entry()]) jid = _submit_job(ctrl.state, "j1", req) @@ -912,7 +917,7 @@ def test_region_constraint_injected_from_claimed_workers(): result = _reservation_region_constraints( jid.to_wire(), ctrl.reservation_claims, - ctrl._db, + ctrl._stores, [], ) @@ -926,7 +931,7 @@ def test_region_constraint_not_injected_when_already_present(): """Existing region constraint prevents injection.""" ctrl = _make_controller() w1 = _register_worker(ctrl.state, "w1") - ctrl.state.set_worker_attribute_for_test(w1, WellKnownAttribute.REGION, AttributeValue("us-central1")) + _set_worker_attribute(ctrl.state._db, w1, WellKnownAttribute.REGION, AttributeValue("us-central1")) req = _make_job_request_with_reservation(reservation_entries=[_make_reservation_entry()]) jid = _submit_job(ctrl.state, "j1", req) @@ -936,7 +941,7 @@ def test_region_constraint_not_injected_when_already_present(): result = _reservation_region_constraints( jid.to_wire(), ctrl.reservation_claims, - ctrl._db, + ctrl._stores, [existing], ) @@ -956,7 +961,7 @@ def test_region_constraint_not_injected_when_no_region_attr(): result = _reservation_region_constraints( jid.to_wire(), ctrl.reservation_claims, - ctrl._db, + ctrl._stores, [], ) @@ -968,8 +973,8 @@ def test_region_constraint_multiple_regions(): ctrl = _make_controller() w1 = _register_worker(ctrl.state, "w1") w2 = _register_worker(ctrl.state, "w2") - ctrl.state.set_worker_attribute_for_test(w1, WellKnownAttribute.REGION, AttributeValue("us-central1")) - ctrl.state.set_worker_attribute_for_test(w2, WellKnownAttribute.REGION, AttributeValue("us-east1")) + _set_worker_attribute(ctrl.state._db, w1, WellKnownAttribute.REGION, AttributeValue("us-central1")) + _set_worker_attribute(ctrl.state._db, w2, WellKnownAttribute.REGION, AttributeValue("us-east1")) req = _make_job_request_with_reservation( reservation_entries=[_make_reservation_entry(), _make_reservation_entry()], @@ -980,7 +985,7 @@ def test_region_constraint_multiple_regions(): result = _reservation_region_constraints( jid.to_wire(), ctrl.reservation_claims, - ctrl._db, + ctrl._stores, [], ) @@ -994,7 +999,7 @@ def test_no_injection_for_non_reservation_job(): """No claims for this job → constraints returned unchanged.""" ctrl = _make_controller() w1 = _register_worker(ctrl.state, "w1") - ctrl.state.set_worker_attribute_for_test(w1, WellKnownAttribute.REGION, AttributeValue("us-central1")) + _set_worker_attribute(ctrl.state._db, w1, WellKnownAttribute.REGION, AttributeValue("us-central1")) # Claim w1 for a different job req = _make_job_request_with_reservation(reservation_entries=[_make_reservation_entry()]) @@ -1004,7 +1009,7 @@ def test_no_injection_for_non_reservation_job(): result = _reservation_region_constraints( "/test-user/unrelated-job", ctrl.reservation_claims, - ctrl._db, + ctrl._stores, [], ) @@ -1525,7 +1530,7 @@ def test_holder_task_removed_from_worker_when_parent_cancelled_all_tasks_already # (loop body never executes, so the old code never reached _cancel_child_jobs). parent_task_ref = _query_task(state, parent_task.task_id) assert parent_task_ref is not None - state.set_task_state_for_test(parent_task.task_id, job_pb2.TASK_STATE_KILLED) + _set_task_state(state._db, parent_task.task_id, job_pb2.TASK_STATE_KILLED) # Fire JobCancelledEvent. All parent tasks are now terminal so the loop # skips them. Only the explicit _cancel_child_jobs call at the end of diff --git a/lib/iris/tests/cluster/controller/test_scheduler.py b/lib/iris/tests/cluster/controller/test_scheduler.py index d6da0ee222..5e7c9dc80f 100644 --- a/lib/iris/tests/cluster/controller/test_scheduler.py +++ b/lib/iris/tests/cluster/controller/test_scheduler.py @@ -12,9 +12,7 @@ from iris.cluster.constraints import WellKnownAttribute from iris.cluster.controller.codec import constraints_from_json, resource_spec_from_scalars -from iris.cluster.controller.db import ( - _decode_attribute_rows, -) +from iris.cluster.controller.store import _decode_attribute_rows from iris.cluster.controller.scheduler import ( JobRequirements, Scheduler, @@ -30,6 +28,7 @@ from rigging.timing import Duration, Timestamp from tests.cluster.conftest import eq_constraint, in_constraint +from ._testing import set_worker_health as _set_worker_health from .conftest import ( building_counts as _building_counts, check_task_can_be_scheduled, @@ -51,7 +50,7 @@ def _job_requirements_from_job(job) -> JobRequirements: return JobRequirements( resources=resource_spec_from_scalars( - job.res_cpu_millicores, job.res_memory_bytes, job.res_disk_bytes, job.res_device_json + job.resources.cpu_millicores, job.resources.memory_bytes, job.resources.disk_bytes, job.resources.device_json ), constraints=constraints_from_json(job.constraints_json), is_coscheduled=job.has_coscheduling, @@ -355,7 +354,7 @@ def test_scheduler_skips_unhealthy_workers(scheduler, state): register_worker(state, "w1", "addr1", make_worker_metadata()) register_worker(state, "w2", "addr2", make_worker_metadata()) # Mark second worker as unhealthy - state.set_worker_health_for_test(WorkerId("w2"), False) + _set_worker_health(state._db, WorkerId("w2"), healthy=False) submit_job(state, "j1", make_job_request()) diff --git a/lib/iris/tests/cluster/controller/test_service.py b/lib/iris/tests/cluster/controller/test_service.py index c810259205..50764f3bc7 100644 --- a/lib/iris/tests/cluster/controller/test_service.py +++ b/lib/iris/tests/cluster/controller/test_service.py @@ -650,7 +650,7 @@ def test_terminate_job_rejected_for_non_owner(state, mock_controller, tmp_path): auth_service = ControllerServiceImpl( state, - state._db, + state._stores, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles_owner")), log_service=LogServiceImpl(), @@ -681,7 +681,7 @@ def test_launch_child_job_rejected_for_non_owner(state, mock_controller, tmp_pat auth_service = ControllerServiceImpl( state, - state._db, + state._stores, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles_child")), log_service=LogServiceImpl(), @@ -1120,13 +1120,14 @@ def test_register_requires_worker_role(state, mock_controller, tmp_path): from iris.rpc.auth import _verified_identity, VerifiedIdentity db = state._db + stores = state._stores now = Timestamp.now() db.ensure_user("alice", now, role="user") auth = ControllerAuth(provider="static") service = ControllerServiceImpl( state, - db, + stores, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=LogServiceImpl(), @@ -1156,13 +1157,14 @@ def test_register_allows_worker_role(state, mock_controller, tmp_path): from iris.rpc.auth import _verified_identity, VerifiedIdentity db = state._db + stores = state._stores now = Timestamp.now() db.ensure_user("system:worker", now, role="worker") auth = ControllerAuth(provider="static") service = ControllerServiceImpl( state, - db, + stores, controller=mock_controller, bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=LogServiceImpl(), diff --git a/lib/iris/tests/cluster/controller/test_task_profiles.py b/lib/iris/tests/cluster/controller/test_task_profiles.py index 336dab204b..12f5cf3ae9 100644 --- a/lib/iris/tests/cluster/controller/test_task_profiles.py +++ b/lib/iris/tests/cluster/controller/test_task_profiles.py @@ -7,18 +7,34 @@ import pytest -from iris.cluster.controller.db import ( - ControllerDB, - get_task_profiles, - insert_task_profile, -) +from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import ControllerStores from iris.rpc import job_pb2 from rigging.timing import Timestamp @pytest.fixture -def db(tmp_path: Path) -> ControllerDB: - return ControllerDB(db_dir=tmp_path) +def db(tmp_path: Path) -> tuple[ControllerDB, ControllerStores]: + db_instance = ControllerDB(db_dir=tmp_path) + stores = ControllerStores.from_db(db_instance) + return db_instance, stores + + +def _insert_profile( + stores: ControllerStores, + task_id: str, + data: bytes, + when: Timestamp, + *, + profile_kind: str = "cpu", +) -> None: + with stores.transact() as ctx: + stores.tasks.insert_task_profile(ctx.cur, task_id, data, when, profile_kind=profile_kind) + + +def _get_profiles(stores: ControllerStores, task_id: str, *, profile_kind: str | None = None): + with stores.read() as ctx: + return stores.tasks.get_task_profiles(ctx.cur, task_id, profile_kind=profile_kind) def _ensure_task(db: ControllerDB, task_id: str) -> None: @@ -50,36 +66,39 @@ def _ensure_task(db: ControllerDB, task_id: str) -> None: ) -def test_insert_and_retrieve_profile(db: ControllerDB) -> None: +def test_insert_and_retrieve_profile(db: tuple[ControllerDB, ControllerStores]) -> None: + db_instance, stores = db now = Timestamp.now() - _ensure_task(db, "task-1") - insert_task_profile(db, "task-1", b"profile-data-here", now) + _ensure_task(db_instance, "task-1") + _insert_profile(stores, "task-1", b"profile-data-here", now) - profiles = get_task_profiles(db, "task-1") + profiles = _get_profiles(stores, "task-1") assert len(profiles) == 1 assert profiles[0][0] == b"profile-data-here" assert profiles[0][1].epoch_ms() == now.epoch_ms() assert profiles[0][2] == "cpu" -def test_profiles_ordered_newest_first(db: ControllerDB) -> None: - _ensure_task(db, "task-1") +def test_profiles_ordered_newest_first(db: tuple[ControllerDB, ControllerStores]) -> None: + db_instance, stores = db + _ensure_task(db_instance, "task-1") for i in range(3): - insert_task_profile(db, "task-1", f"profile-{i}".encode(), Timestamp.from_ms(1000 + i)) + _insert_profile(stores, "task-1", f"profile-{i}".encode(), Timestamp.from_ms(1000 + i)) - profiles = get_task_profiles(db, "task-1") + profiles = _get_profiles(stores, "task-1") assert len(profiles) == 3 assert profiles[0][0] == b"profile-2" assert profiles[2][0] == b"profile-0" -def test_cap_at_ten_profiles(db: ControllerDB) -> None: +def test_cap_at_ten_profiles(db: tuple[ControllerDB, ControllerStores]) -> None: """The DB trigger should evict oldest profiles when count exceeds 10.""" - _ensure_task(db, "task-1") + db_instance, stores = db + _ensure_task(db_instance, "task-1") for i in range(15): - insert_task_profile(db, "task-1", f"profile-{i}".encode(), Timestamp.from_ms(1000 + i)) + _insert_profile(stores, "task-1", f"profile-{i}".encode(), Timestamp.from_ms(1000 + i)) - profiles = get_task_profiles(db, "task-1") + profiles = _get_profiles(stores, "task-1") assert len(profiles) == 10 # Newest 10 should be kept (profiles 5..14) data_values = [p[0] for p in profiles] @@ -87,34 +106,37 @@ def test_cap_at_ten_profiles(db: ControllerDB) -> None: assert data_values[-1] == b"profile-5" -def test_cap_is_per_task(db: ControllerDB) -> None: +def test_cap_is_per_task(db: tuple[ControllerDB, ControllerStores]) -> None: """Profiles for different tasks are capped independently.""" - _ensure_task(db, "task-a") - _ensure_task(db, "task-b") + db_instance, stores = db + _ensure_task(db_instance, "task-a") + _ensure_task(db_instance, "task-b") for i in range(12): - insert_task_profile(db, "task-a", f"a-{i}".encode(), Timestamp.from_ms(1000 + i)) - insert_task_profile(db, "task-b", f"b-{i}".encode(), Timestamp.from_ms(1000 + i)) + _insert_profile(stores, "task-a", f"a-{i}".encode(), Timestamp.from_ms(1000 + i)) + _insert_profile(stores, "task-b", f"b-{i}".encode(), Timestamp.from_ms(1000 + i)) - assert len(get_task_profiles(db, "task-a")) == 10 - assert len(get_task_profiles(db, "task-b")) == 10 + assert len(_get_profiles(stores, "task-a")) == 10 + assert len(_get_profiles(stores, "task-b")) == 10 -def test_empty_profiles(db: ControllerDB) -> None: - profiles = get_task_profiles(db, "nonexistent") +def test_empty_profiles(db: tuple[ControllerDB, ControllerStores]) -> None: + _db, stores = db + profiles = _get_profiles(stores, "nonexistent") assert profiles == [] -def test_cap_is_per_task_and_kind(db: ControllerDB) -> None: +def test_cap_is_per_task_and_kind(db: tuple[ControllerDB, ControllerStores]) -> None: """Cap trigger retains 10 per (task_id, profile_kind).""" - _ensure_task(db, "task-1") + db_instance, stores = db + _ensure_task(db_instance, "task-1") for i in range(12): - insert_task_profile(db, "task-1", f"cpu-{i}".encode(), Timestamp.from_ms(1000 + i), profile_kind="cpu") - insert_task_profile(db, "task-1", f"mem-{i}".encode(), Timestamp.from_ms(1000 + i), profile_kind="memory") + _insert_profile(stores, "task-1", f"cpu-{i}".encode(), Timestamp.from_ms(1000 + i), profile_kind="cpu") + _insert_profile(stores, "task-1", f"mem-{i}".encode(), Timestamp.from_ms(1000 + i), profile_kind="memory") - cpu_profiles = get_task_profiles(db, "task-1", profile_kind="cpu") - mem_profiles = get_task_profiles(db, "task-1", profile_kind="memory") + cpu_profiles = _get_profiles(stores, "task-1", profile_kind="cpu") + mem_profiles = _get_profiles(stores, "task-1", profile_kind="memory") assert len(cpu_profiles) == 10 assert len(mem_profiles) == 10 # Total should be 20 (10 cpu + 10 memory) - all_profiles = get_task_profiles(db, "task-1") + all_profiles = _get_profiles(stores, "task-1") assert len(all_profiles) == 20 diff --git a/lib/iris/tests/cluster/controller/test_task_resource_history.py b/lib/iris/tests/cluster/controller/test_task_resource_history.py index bfb3354eac..4bcd8586ef 100644 --- a/lib/iris/tests/cluster/controller/test_task_resource_history.py +++ b/lib/iris/tests/cluster/controller/test_task_resource_history.py @@ -5,6 +5,7 @@ import pytest from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import ControllerStores from iris.cluster.controller.transitions import ( Assignment, ControllerTransitions, @@ -21,7 +22,8 @@ @pytest.fixture def state(tmp_path): db = ControllerDB(db_dir=tmp_path) - s = ControllerTransitions(db=db) + stores = ControllerStores.from_db(db) + s = ControllerTransitions(stores=stores) yield s db.close() diff --git a/lib/iris/tests/cluster/controller/test_task_store_query.py b/lib/iris/tests/cluster/controller/test_task_store_query.py new file mode 100644 index 0000000000..0afb274752 --- /dev/null +++ b/lib/iris/tests/cluster/controller/test_task_store_query.py @@ -0,0 +1,246 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Focused tests for TaskStore.query() + TaskFilter. + +Exercises the filter combinations actually produced by transitions.py after the +store-layer consolidation: task_ids, job_ids, worker_id / worker_is_null, +states, limit, and the with_job / with_job_config join variants. +""" + +from __future__ import annotations + +import pytest +from iris.cluster.controller.schema import ACTIVE_TASK_STATES, EXECUTING_TASK_STATES, TaskDetailRow +from iris.cluster.controller.store import TaskFilter, TaskProjection +from iris.cluster.types import JobName, WorkerId +from iris.rpc import job_pb2 + +from .conftest import ( + ControllerTestHarness, + make_job_request, + submit_job, +) + +# --- Dataclass validation --------------------------------------------------- + + +def test_task_filter_rejects_worker_id_plus_is_null() -> None: + with pytest.raises(ValueError, match="mutually exclusive"): + TaskFilter(worker_id=WorkerId("w1"), worker_is_null=True) + + +# --- Empty-list short-circuits --------------------------------------------- + + +def test_query_empty_task_ids_short_circuits(state) -> None: + """Empty task_ids tuple should return [] without executing SQL.""" + with state._stores.transact() as ctx: + assert ctx.tasks.query(ctx.cur, TaskFilter(task_ids=())) == [] + + +def test_query_empty_job_ids_short_circuits(state) -> None: + with state._stores.transact() as ctx: + assert ctx.tasks.query(ctx.cur, TaskFilter(job_ids=())) == [] + + +# --- No-join path returns TaskDetailRow ------------------------------------ + + +def test_query_no_filter_returns_task_detail_rows(state) -> None: + tasks = submit_job(state, "j", make_job_request("j", replicas=3)) + assert len(tasks) == 3 + + with state._stores.transact() as ctx: + rows = ctx.tasks.query(ctx.cur, TaskFilter()) + + assert len(rows) == 3 + assert all(isinstance(r, TaskDetailRow) for r in rows) + assert {r.task_id for r in rows} == {t.task_id for t in tasks} + + +def test_query_by_task_ids_returns_subset(state) -> None: + tasks = submit_job(state, "j", make_job_request("j", replicas=3)) + wanted = (tasks[0].task_id.to_wire(), tasks[2].task_id.to_wire()) + + with state._stores.transact() as ctx: + rows = ctx.tasks.query(ctx.cur, TaskFilter(task_ids=wanted)) + + assert {r.task_id.to_wire() for r in rows} == set(wanted) + + +def test_query_by_job_ids(state) -> None: + ja = submit_job(state, "ja", make_job_request("ja", replicas=2)) + jb = submit_job(state, "jb", make_job_request("jb", replicas=1)) + _ = jb # silence unused warning — we filter to ja only + + with state._stores.transact() as ctx: + rows = ctx.tasks.query( + ctx.cur, + TaskFilter(job_ids=(ja[0].job_id.to_wire(),)), + ) + + assert {r.task_id for r in rows} == {t.task_id for t in ja} + + +def test_query_limit_truncates(state) -> None: + submit_job(state, "j", make_job_request("j", replicas=5)) + + with state._stores.transact() as ctx: + rows = ctx.tasks.query(ctx.cur, TaskFilter(limit=2)) + + assert len(rows) == 2 + + +# --- States filter --------------------------------------------------------- + + +def test_query_states_filter(state) -> None: + """Filter tasks to only the PENDING state — all tasks start PENDING.""" + submit_job(state, "j", make_job_request("j", replicas=2)) + + with state._stores.transact() as ctx: + pending = ctx.tasks.query( + ctx.cur, + TaskFilter(states=frozenset({job_pb2.TASK_STATE_PENDING})), + ) + active = ctx.tasks.query( + ctx.cur, + TaskFilter(states=ACTIVE_TASK_STATES), + ) + + assert len(pending) == 2 + assert active == [] + + +# --- worker_id / worker_is_null -------------------------------------------- + + +def test_query_worker_id_and_worker_is_null(state) -> None: + """After dispatching one task, that task has current_worker_id set; the + remainder are still NULL.""" + harness = ControllerTestHarness(state) + wid = harness.add_worker("w1", cpu=10) + tasks = harness.submit("j", replicas=3) + harness.dispatch(tasks[0], wid) + + with state._stores.transact() as ctx: + on_worker = ctx.tasks.query(ctx.cur, TaskFilter(worker_id=wid)) + unassigned = ctx.tasks.query(ctx.cur, TaskFilter(worker_is_null=True)) + + assert {r.task_id for r in on_worker} == {tasks[0].task_id} + assert {r.task_id for r in unassigned} == {tasks[1].task_id, tasks[2].task_id} + + +def test_query_worker_and_state_combination(state) -> None: + """AND of worker_id + EXECUTING_TASK_STATES matches the migrated + get_active_with_resources / cancel_tasks_for_timeout call path.""" + harness = ControllerTestHarness(state) + wid = harness.add_worker("w1", cpu=10) + tasks = harness.submit("j", replicas=2) + harness.dispatch(tasks[0], wid) + # tasks[0] is now RUNNING after dispatch; tasks[1] is still PENDING. + + with state._stores.transact() as ctx: + rows = ctx.tasks.query( + ctx.cur, + TaskFilter(worker_id=wid, states=EXECUTING_TASK_STATES), + ) + + assert {r.task_id for r in rows} == {tasks[0].task_id} + + +# --- Joined variants ---------------------------------------------------------- + + +def test_query_with_job_returns_typed_rows(state) -> None: + """query WITH_JOB returns TaskDetailRow with is_reservation_holder/num_tasks populated.""" + submit_job(state, "j", make_job_request("j", replicas=1)) + + with state._stores.transact() as ctx: + rows = ctx.tasks.query(ctx.cur, TaskFilter(), projection=TaskProjection.WITH_JOB) + + assert len(rows) == 1 + row = rows[0] + assert isinstance(row, TaskDetailRow) + assert isinstance(row.is_reservation_holder, bool) + assert isinstance(row.num_tasks, int) + assert row.current_worker_id is None # task not yet dispatched + # Resource fields are not populated at this projection level. + assert row.resources is None + + +def test_query_with_job_config_exposes_resource_columns(state) -> None: + """query WITH_JOB_CONFIG returns TaskDetailRow with resource columns populated.""" + submit_job(state, "j", make_job_request("j", cpu=2, memory_bytes=2 * 1024**3)) + + with state._stores.transact() as ctx: + rows = ctx.tasks.query(ctx.cur, TaskFilter(), projection=TaskProjection.WITH_JOB_CONFIG) + + assert len(rows) == 1 + row = rows[0] + assert isinstance(row, TaskDetailRow) + assert row.resources is not None + assert row.resources.cpu_millicores == 2000 + assert row.resources.memory_bytes == 2 * 1024**3 + assert isinstance(row.resources.disk_bytes, int) + assert row.timeout_ms is None # not set in make_job_request default + + +# --- Chunking ------------------------------------------------------------- + + +def test_query_chunking_across_id_in_cap(state) -> None: + """Supplying more task_ids than the SQLite host-param chunk size still + returns the full matching set — chunk results are concatenated.""" + # submit 3 real tasks; pad task_ids with 950 bogus wires to force two chunks. + # The chunker splits at 900 per batch (safely under SQLite's 999 cap). + tasks = submit_job(state, "j", make_job_request("j", replicas=3)) + real = [t.task_id.to_wire() for t in tasks] + padding = tuple(f"/bogus/task-{i}" for i in range(950)) + all_ids = tuple(real) + padding + + with state._stores.transact() as ctx: + rows = ctx.tasks.query(ctx.cur, TaskFilter(task_ids=all_ids)) + + assert {r.task_id.to_wire() for r in rows} == set(real) + + +def test_query_chunking_respects_limit(state) -> None: + """When ``limit`` is set, the chunked loop stops once the limit is reached.""" + tasks = submit_job(state, "j", make_job_request("j", replicas=5)) + ids = tuple(t.task_id.to_wire() for t in tasks) + tuple(f"/bogus/{i}" for i in range(1000)) + + with state._stores.transact() as ctx: + rows = ctx.tasks.query(ctx.cur, TaskFilter(task_ids=ids, limit=2)) + + assert len(rows) == 2 + + +# --- Ordering ------------------------------------------------------------- + + +def test_query_orders_by_task_id_ascending(state) -> None: + """ORDER BY t.task_id ASC is stable across call sites that diff results.""" + tasks = submit_job(state, "j", make_job_request("j", replicas=4)) + expected_order = sorted(t.task_id.to_wire() for t in tasks) + + with state._stores.transact() as ctx: + rows = ctx.tasks.query(ctx.cur, TaskFilter()) + + assert [r.task_id.to_wire() for r in rows] == expected_order + + +# --- Parity with JobName-typed row field ----------------------------------- + + +def test_query_rows_decode_job_name_fields(state) -> None: + """Non-joined rows go through TASK_DETAIL_PROJECTION.decode → typed + fields like ``job_id`` come back as JobName, not str.""" + submit_job(state, "j", make_job_request("j", replicas=1)) + + with state._stores.transact() as ctx: + [row] = ctx.tasks.query(ctx.cur, TaskFilter()) + + assert isinstance(row.job_id, JobName) + assert isinstance(row.task_id, JobName) diff --git a/lib/iris/tests/cluster/controller/test_transitions.py b/lib/iris/tests/cluster/controller/test_transitions.py index cd36d74ea0..f59ba1433c 100644 --- a/lib/iris/tests/cluster/controller/test_transitions.py +++ b/lib/iris/tests/cluster/controller/test_transitions.py @@ -16,11 +16,8 @@ from iris.cluster.controller.codec import constraints_from_json, resource_spec_from_scalars from iris.cluster.controller.autoscaler.models import DemandEntry from iris.cluster.controller.controller import compute_demand_entries -from iris.cluster.controller.db import ( - ControllerDB, - EndpointQuery, - attempt_is_terminal, -) +from iris.cluster.controller.db import ControllerDB +from iris.cluster.controller.store import attempt_is_terminal from iris.cluster.controller.schema import ( ATTEMPT_PROJECTION, JOB_DETAIL_PROJECTION, @@ -28,6 +25,7 @@ WORKER_DETAIL_PROJECTION, EndpointRow, ) +from iris.cluster.controller.store import EndpointQuery from iris.cluster.controller.scheduler import JobRequirements, Scheduler from iris.cluster.controller.transitions import ( Assignment, @@ -45,6 +43,10 @@ from iris.rpc import logging_pb2 from rigging.timing import Duration, Timestamp +from ._testing import ( + create_attempt as _create_attempt, + set_worker_consecutive_failures as _set_worker_consecutive_failures, +) from .conftest import ( building_counts as _building_counts, check_task_can_be_scheduled, @@ -91,7 +93,7 @@ def _queued_dispatch( def _endpoints(state: ControllerTransitions, query: EndpointQuery = EndpointQuery()) -> list[EndpointRow]: - rows = state._db.endpoints.query(query) + rows = state._stores.endpoints.query(query) # Mirror the original helper's ordering (registered_at DESC, endpoint_id ASC). return sorted(rows, key=lambda r: (-r.registered_at.epoch_ms(), r.endpoint_id)) @@ -107,10 +109,10 @@ def _build_scheduling_context(scheduler: Scheduler, state: ControllerTransitions job = _query_job(state, job_id) if job: resources = resource_spec_from_scalars( - job.res_cpu_millicores, - job.res_memory_bytes, - job.res_disk_bytes, - job.res_device_json, + job.resources.cpu_millicores, + job.resources.memory_bytes, + job.resources.disk_bytes, + job.resources.device_json, ) jobs[job_id] = JobRequirements( resources=resources, @@ -361,7 +363,7 @@ def test_cancelled_job_tasks_excluded_from_demand(harness): assert not check_task_can_be_scheduled(harness.query_task(task.task_id)) assert len(_schedulable_tasks(harness.state)) == 0 - assert len(compute_demand_entries(harness.state._db)) == 0 + assert len(compute_demand_entries(harness.state._stores)) == 0 # ============================================================================= @@ -1331,7 +1333,7 @@ def test_stale_attempt_error_log_for_non_terminal(state, caplog): # Manually create a second attempt without properly terminating the first. # This simulates a scenario where the controller created a new attempt # but the old one is still non-terminal (a precondition violation). - state.create_attempt_for_test(task.task_id, worker_id) + _create_attempt(state._stores, task.task_id, worker_id) assert _query_task(state, task.task_id).current_attempt_id == 1 # The old attempt (0) is still in RUNNING state (non-terminal) with state._db.snapshot() as q: @@ -1428,7 +1430,7 @@ def test_compute_demand_entries_counts_coscheduled_job_once(state): req.coscheduling.group_by = WellKnownAttribute.TPU_NAME submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) assert len(demand) == 1 assert demand[0].normalized.device_type == DeviceType.TPU assert demand[0].normalized.device_variants == frozenset({"v5litepod-16"}) @@ -1452,7 +1454,7 @@ def test_compute_demand_entries_counts_non_coscheduled_tasks_individually(state) # No coscheduling set submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) assert len(demand) == 4 for entry in demand: assert entry.normalized.device_type == DeviceType.TPU @@ -1493,7 +1495,7 @@ def test_compute_demand_entries_mixed_coscheduled_and_regular(state): ) submit_job(state, "j2", regular_req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) assert len(demand) == 3 coscheduled = [entry for entry in demand if entry.coschedule_group_id == "/test-user/j1"] regular = [entry for entry in demand if entry.coschedule_group_id is None] @@ -1550,7 +1552,7 @@ def test_compute_demand_entries_separates_by_preemptible_constraint(state): ) submit_job(state, "j2", on_demand_req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) assert len(demand) == 2 by_preemptible = {d.normalized.preemptible: d for d in demand} @@ -1576,7 +1578,7 @@ def test_compute_demand_entries_no_preemptible_constraint_gives_none(state): ) submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) assert len(demand) == 1 assert demand[0].normalized.preemptible is None @@ -1602,7 +1604,7 @@ def test_compute_demand_entries_extracts_required_region(state): ) submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) assert len(demand) == 1 assert demand[0].normalized.required_regions == frozenset({"us-west4"}) assert demand[0].invalid_reason is None @@ -1634,7 +1636,7 @@ def test_compute_demand_entries_marks_invalid_on_conflicting_region_constraints( ) submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) assert len(demand) == 1 assert demand[0].invalid_reason is not None @@ -1717,7 +1719,7 @@ def test_demand_reservation_all_tasks_generate_demand(state): ) submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) synthetic_demand = [d for d in demand if _is_synthetic_demand(state, d)] real_demand = [d for d in demand if not _is_synthetic_demand(state, d)] @@ -1734,7 +1736,7 @@ def test_demand_reservation_excess_tasks(state): ) submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) synthetic_demand = [d for d in demand if _is_synthetic_demand(state, d)] real_demand = [d for d in demand if not _is_synthetic_demand(state, d)] @@ -1758,7 +1760,7 @@ def test_demand_reservation_holder_uses_entry_resources(state): ) submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) synthetic_demand = [d for d in demand if _is_synthetic_demand(state, d)] real_demand = [d for d in demand if not _is_synthetic_demand(state, d)] @@ -1793,7 +1795,7 @@ def test_demand_reservation_mixed_jobs(state): ) submit_job(state, "a100-job", a100_req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) synthetic_demand = [d for d in demand if _is_synthetic_demand(state, d)] real_demand = [d for d in demand if not _is_synthetic_demand(state, d)] @@ -1821,7 +1823,7 @@ def test_demand_no_reservation_passes_all_tasks(state): ) submit_job(state, "j1", req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) assert len(demand) == 3 for d in demand: assert not _is_synthetic_demand(state, d) @@ -1852,7 +1854,7 @@ def test_demand_reservation_independent_per_job(state): ) submit_job(state, "job-b", job_b_req) - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) synthetic_demand = [d for d in demand if _is_synthetic_demand(state, d)] real_demand = [d for d in demand if not _is_synthetic_demand(state, d)] @@ -2039,7 +2041,7 @@ def test_fail_heartbeat_clears_dispatch_when_worker_fails(state): assert not queued_kill # Simulate repeated failures up to threshold - state.set_worker_consecutive_failures_for_test(worker_id, HEARTBEAT_FAILURE_THRESHOLD - 1) + _set_worker_consecutive_failures(state._db, worker_id, HEARTBEAT_FAILURE_THRESHOLD - 1) # This fail_heartbeat should trigger worker failure state.fail_heartbeat(snapshot, "Connection refused") @@ -2442,7 +2444,7 @@ def test_demand_excludes_building_limited_tasks(state): # Now w1 has 2 building tasks (at limit), but has plenty of CPU/memory. # The pending task from j1 should be building-limited, not truly unschedulable. workers = healthy_active_workers(state) - demand = compute_demand_entries(state._db, scheduler, workers) + demand = compute_demand_entries(state._stores, scheduler, workers) task_demand = [d for d in demand if not _is_synthetic_demand(state, d)] assert len(task_demand) == 0, "Building-limited task should not generate demand" @@ -2469,7 +2471,7 @@ def test_demand_includes_truly_unschedulable_tasks(state): submit_job(state, "j1", req) workers = healthy_active_workers(state) - demand = compute_demand_entries(state._db, scheduler, workers) + demand = compute_demand_entries(state._stores, scheduler, workers) task_demand = [d for d in demand if not _is_synthetic_demand(state, d)] assert len(task_demand) == 1, "Task with no matching device should generate demand" @@ -2496,7 +2498,7 @@ def test_demand_includes_resource_exhausted_tasks(state): submit_job(state, "j1", req) workers = healthy_active_workers(state) - demand = compute_demand_entries(state._db, scheduler, workers) + demand = compute_demand_entries(state._stores, scheduler, workers) task_demand = [d for d in demand if not _is_synthetic_demand(state, d)] assert len(task_demand) == 1, "Task exceeding worker CPU should generate demand" @@ -2523,7 +2525,7 @@ def test_demand_holders_absorbed_by_dry_run(state): submit_job(state, "j1", req) workers = healthy_active_workers(state) - demand = compute_demand_entries(state._db, scheduler, workers) + demand = compute_demand_entries(state._stores, scheduler, workers) # Worker fits 1 task (holder or real). 3 remaining generate demand. assert len(demand) == 3 @@ -2551,7 +2553,7 @@ def test_demand_absorbs_capacity_before_emitting(state): submit_job(state, "j1", req) workers = healthy_active_workers(state) - demand = compute_demand_entries(state._db, scheduler, workers) + demand = compute_demand_entries(state._stores, scheduler, workers) task_demand = [d for d in demand if not _is_synthetic_demand(state, d)] assert len(task_demand) == 1, "Only 1 of 3 tasks should generate demand (2 absorbed)" @@ -2573,7 +2575,7 @@ def test_demand_no_workers_falls_back_to_all_pending(state): submit_job(state, "j1", req) # No scheduler, no workers -> all tasks become demand - demand = compute_demand_entries(state._db) + demand = compute_demand_entries(state._stores) task_demand = [d for d in demand if not _is_synthetic_demand(state, d)] assert len(task_demand) == 3 @@ -2617,7 +2619,7 @@ def test_demand_building_limited_with_multiple_workers(state): submit_job(state, "pending-job", req) workers = healthy_active_workers(state) - demand = compute_demand_entries(state._db, scheduler, workers) + demand = compute_demand_entries(state._stores, scheduler, workers) task_demand = [d for d in demand if not _is_synthetic_demand(state, d)] assert len(task_demand) == 0, "All workers at building limit -> no demand" @@ -2672,7 +2674,7 @@ def test_demand_mixed_building_limited_and_unschedulable(state): submit_job(state, "a100-job", a100_req) workers = healthy_active_workers(state) - demand = compute_demand_entries(state._db, scheduler, workers) + demand = compute_demand_entries(state._stores, scheduler, workers) task_demand = [d for d in demand if not _is_synthetic_demand(state, d)] assert len(task_demand) == 1 @@ -2823,7 +2825,10 @@ def test_snapshot_round_trip_preserves_reservation_holder(state): checkpoint_path = Path(tmpdir) / "controller.sqlite3" state._db.backup_to(checkpoint_path) restored_db = ControllerDB(db_dir=Path(tmpdir)) - restored_state = ControllerTransitions(db=restored_db) + from iris.cluster.controller.store import ControllerStores + + restored_stores = ControllerStores.from_db(restored_db) + restored_state = ControllerTransitions(stores=restored_stores) restored_holder = _query_job(restored_state, holder_job_id) assert restored_holder is not None diff --git a/lib/iris/tests/test_budget.py b/lib/iris/tests/test_budget.py index 6032e51fcc..437253b2d9 100644 --- a/lib/iris/tests/test_budget.py +++ b/lib/iris/tests/test_budget.py @@ -257,7 +257,7 @@ def service(state, tmp_path) -> ControllerServiceImpl: priority-band authorization triggers (see launch_job band check).""" return ControllerServiceImpl( state, - state._db, + state._stores, controller=MockController(), bundle_store=BundleStore(storage_dir=str(tmp_path / "bundles")), log_service=LogServiceImpl(), diff --git a/pyproject.toml b/pyproject.toml index ffb64d4be3..3c758da59b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,13 @@ extend-exclude = ["scripts/"] [tool.ruff.lint] select = ["A", "B", "E", "F", "I", "NPY", "RUF", "UP", "W"] ignore = ["F722", "B008", "UP015", "A005", "I001", "E741"] +# Treat field-factory helpers like `dataclasses.field` so RUF009 doesn't flag +# their use as dataclass defaults. +extend-safe-fixes = [] + +[tool.ruff.lint.flake8-bugbear] +# Functions that return a `dataclasses.field(...)` and are safe as defaults. +extend-immutable-calls = ["iris.cluster.controller.schema.pcolumn"] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["E402", "F401"]