diff --git a/Makefile b/Makefile index cbaccceb2c..f81248e7bd 100644 --- a/Makefile +++ b/Makefile @@ -51,16 +51,6 @@ test: export HF_HUB_TOKEN=$HF_TOKEN RAY_ADDRESS= PYTHONPATH=tests:. pytest tests --durations=0 -n 4 --tb=no -v -# Target to configure GCP registry cleanup policy for all standard regions -CLUSTER_REPOS = us-central2 us-central1 europe-west4 us-west4 us-east5 us-east1 -default_registry_name = marin -configure_gcp_registry_all: - @echo "Configuring GCP registry cleanup policy for all standard regions..." - $(foreach region,$(CLUSTER_REPOS), \ - python infra/configure_gcp_registry.py $(default_registry_name) --region=$(region) ; \ - ) - @echo "Cleanup policy configured for all regions." - # stuff for setting up locally install_uv: diff --git a/lib/iris/src/iris/cli/bug_report.py b/lib/iris/src/iris/cli/bug_report.py index ec334fbacb..9fed580187 100644 --- a/lib/iris/src/iris/cli/bug_report.py +++ b/lib/iris/src/iris/cli/bug_report.py @@ -18,6 +18,7 @@ from iris.cluster.types import JobName from iris.rpc import controller_pb2, job_pb2 from iris.rpc.auth import AuthTokenInjector, TokenProvider +from iris.rpc.compression import IRIS_RPC_COMPRESSIONS from iris.rpc.controller_connect import ControllerServiceClientSync from iris.rpc.proto_utils import format_resources, job_state_friendly, task_state_friendly from iris.time_proto import timestamp_from_proto @@ -119,7 +120,13 @@ def gather_bug_report( ) -> BugReport: """Gather all diagnostic data for a job into a BugReport.""" interceptors = [AuthTokenInjector(token_provider)] if token_provider else [] - client = ControllerServiceClientSync(controller_url, timeout_ms=30000, interceptors=interceptors) + client = ControllerServiceClientSync( + controller_url, + timeout_ms=30000, + interceptors=interceptors, + accept_compression=IRIS_RPC_COMPRESSIONS, + send_compression=None, + ) log_client = LogClient.connect(controller_url, timeout_ms=30000, interceptors=interceptors) try: return _gather(client, log_client, job_id, tail=tail) diff --git a/lib/iris/src/iris/cli/job.py b/lib/iris/src/iris/cli/job.py index bb3c7e2f4a..a28641d093 100644 --- a/lib/iris/src/iris/cli/job.py +++ b/lib/iris/src/iris/cli/job.py @@ -52,7 +52,6 @@ from iris.rpc.auth import TokenProvider from iris.rpc.proto_utils import ( PRIORITY_BAND_NAMES, - format_resources, job_state_friendly, priority_band_value, task_state_friendly, @@ -1008,7 +1007,6 @@ def list_jobs(ctx, state: str | None, prefix: str | None, json_output: bool) -> click.echo("No jobs found.") return - # Build table rows rows: list[list[str]] = [] has_reasons = False @@ -1016,23 +1014,19 @@ def list_jobs(ctx, state: str | None, prefix: str | None, json_output: bool) -> job_id = j.job_id state_name = job_state_friendly(j.state) submitted = timestamp_from_proto(j.submitted_at).as_formatted_date() if j.submitted_at.epoch_ms else "-" - resources = format_resources(j.resources) if j.HasField("resources") else "-" - # Show error for failed jobs, pending_reason for pending/unschedulable reason = j.error or j.pending_reason or "" if reason: has_reasons = True - # Truncate long reasons reason = (reason[:60] + "...") if len(reason) > 63 else reason - rows.append([job_id, state_name, resources, submitted, reason]) + rows.append([job_id, state_name, submitted, reason]) - # Build headers - only include REASON column if there are any reasons if has_reasons: - headers = ["JOB ID", "STATE", "RESOURCES", "SUBMITTED", "REASON"] + headers = ["JOB ID", "STATE", "SUBMITTED", "REASON"] else: - headers = ["JOB ID", "STATE", "RESOURCES", "SUBMITTED"] - rows = [row[:4] for row in rows] + headers = ["JOB ID", "STATE", "SUBMITTED"] + rows = [row[:3] for row in rows] click.echo(tabulate(rows, headers=headers, tablefmt="plain")) diff --git a/lib/iris/src/iris/cli/main.py b/lib/iris/src/iris/cli/main.py index e23d790347..bf8fcdaa7b 100644 --- a/lib/iris/src/iris/cli/main.py +++ b/lib/iris/src/iris/cli/main.py @@ -18,6 +18,7 @@ from iris.rpc import config_pb2, job_pb2 from iris.rpc import controller_pb2 as _controller_pb2 from iris.rpc.auth import AuthTokenInjector, GcpAccessTokenProvider, StaticTokenProvider, TokenProvider +from iris.rpc.compression import IRIS_RPC_COMPRESSIONS from iris.rpc.controller_connect import ControllerServiceClientSync from iris.rpc.proto_utils import PRIORITY_BAND_NAMES, priority_band_name, priority_band_value @@ -124,7 +125,13 @@ def rpc_client( ) -> ControllerServiceClientSync: """Create an RPC client with optional auth. Use as a context manager: ``with rpc_client(url) as c:``.""" interceptors = [AuthTokenInjector(token_provider)] if token_provider else [] - return ControllerServiceClientSync(address, timeout_ms=timeout_ms, interceptors=interceptors) + return ControllerServiceClientSync( + address, + timeout_ms=timeout_ms, + interceptors=interceptors, + accept_compression=IRIS_RPC_COMPRESSIONS, + send_compression=None, + ) def require_controller_url(ctx: click.Context) -> str: diff --git a/lib/iris/src/iris/client/resolver.py b/lib/iris/src/iris/client/resolver.py index c123fa0ac9..16d95697ce 100644 --- a/lib/iris/src/iris/client/resolver.py +++ b/lib/iris/src/iris/client/resolver.py @@ -8,6 +8,7 @@ from iris.actor.resolver import ResolvedEndpoint, ResolveResult from iris.cluster.types import Namespace from iris.rpc import controller_pb2 +from iris.rpc.compression import IRIS_RPC_COMPRESSIONS from iris.rpc.controller_connect import ControllerServiceClientSync @@ -54,6 +55,8 @@ def __init__( self._client = ControllerServiceClientSync( address=self._address, timeout_ms=int(timeout * 1000), + accept_compression=IRIS_RPC_COMPRESSIONS, + send_compression=None, ) def _namespace_prefix(self) -> str: diff --git a/lib/iris/src/iris/cluster/client/remote_client.py b/lib/iris/src/iris/cluster/client/remote_client.py index ed24ce856b..a2ed4b6bc0 100644 --- a/lib/iris/src/iris/cluster/client/remote_client.py +++ b/lib/iris/src/iris/cluster/client/remote_client.py @@ -21,6 +21,7 @@ from iris.cluster.runtime.entrypoint import build_runtime_entrypoint from iris.cluster.types import Entrypoint, EnvironmentSpec, JobName, TaskAttempt, adjust_tpu_replicas, is_job_finished from iris.rpc import controller_pb2, job_pb2 +from iris.rpc.compression import IRIS_RPC_COMPRESSIONS from iris.rpc.controller_connect import ControllerServiceClientSync from iris.rpc.errors import call_with_retry, format_connect_error, poll_with_retries from iris.time_proto import duration_to_proto @@ -35,14 +36,12 @@ # Upper bound on GetJobState polling cadence for long-running jobs. The loop # ramps 100ms -> 1s within a handful of polls (factor=1.5 in ExponentialBackoff) -# and then caps here, so long jobs cost ~1 state RPC / 30s instead of hammering -# the controller at the old ~2s ceiling. +# and then caps here, so long jobs cost ~1 state RPC / 30s. MAX_STATE_POLL_INTERVAL = 30.0 # Floor on the backoff cap. ``ExponentialBackoff`` requires ``maximum >= initial`` -# (currently 100ms), so we clamp the caller-supplied ``poll_interval`` up to this -# value before handing it to the backoff. Callers asking for a sub-100ms cap end -# up polling at 100ms instead of crashing with ValueError. +# (currently 100ms), so callers asking for a sub-100ms cap are clamped to this +# value before being handed to the backoff. MIN_STATE_POLL_INTERVAL = 0.1 @@ -78,6 +77,8 @@ def __init__( address=controller_address, timeout_ms=timeout_ms, interceptors=interceptors, + accept_compression=IRIS_RPC_COMPRESSIONS, + send_compression=None, ) self._log_client = LogClient.connect( controller_address, diff --git a/lib/iris/src/iris/cluster/controller/autoscaler/recovery.py b/lib/iris/src/iris/cluster/controller/autoscaler/recovery.py index 348ec27a93..3bf32af175 100644 --- a/lib/iris/src/iris/cluster/controller/autoscaler/recovery.py +++ b/lib/iris/src/iris/cluster/controller/autoscaler/recovery.py @@ -57,8 +57,10 @@ def load_autoscaler_checkpoint(db: ControllerDB) -> AutoscalerCheckpoint: "last_active_ms": decode_timestamp_ms, }, ) + # Failed workers have their DB row deleted (WorkerStore.remove), so + # surviving rows with a slice are by definition the live tracked set. tracked_rows = snapshot.raw( - "SELECT worker_id, slice_id, scale_group, address FROM workers WHERE slice_id != '' AND active = 1", + "SELECT worker_id, slice_id, scale_group, address FROM workers WHERE slice_id != ''", ) slices_by_group: dict[str, list[SliceSnapshot]] = {} diff --git a/lib/iris/src/iris/cluster/controller/controller.py b/lib/iris/src/iris/cluster/controller/controller.py index 2003c28b29..d4f2df5ba6 100644 --- a/lib/iris/src/iris/cluster/controller/controller.py +++ b/lib/iris/src/iris/cluster/controller/controller.py @@ -67,6 +67,7 @@ from iris.cluster.controller.dashboard import ControllerDashboard from iris.cluster.controller.db import ( ControllerDB, + SchedulableWorker, healthy_active_workers_with_attributes, insert_task_profile, job_scheduling_deadline, @@ -80,7 +81,6 @@ Scheduler, SchedulingContext, WorkerCapacity, - WorkerSnapshot, ) from iris.cluster.controller.schema import ( ATTEMPT_PROJECTION, @@ -96,7 +96,6 @@ TaskDetailRow, TaskRow, WorkerDetailRow, - WorkerRow, proto_decoder, tasks_with_attempts, ) @@ -201,7 +200,7 @@ class _SchedulingStateRead: """Snapshot of pending tasks and workers read at the start of a scheduling cycle.""" pending_tasks: list[TaskRow] - workers: list[WorkerRow] + workers: list[SchedulableWorker] state_read_ms: int @@ -245,7 +244,7 @@ def job_requirements_from_job(job: JobSchedulingRow) -> JobRequirements: def compute_demand_entries( queries: ControllerDB, scheduler: Scheduler | None = None, - workers: list[WorkerSnapshot] | None = None, + workers: list[SchedulableWorker] | None = None, reservation_claims: dict[WorkerId, ReservationClaim] | None = None, ) -> list[DemandEntry]: """Compute demand entries for the autoscaler from controller state. @@ -708,7 +707,7 @@ def _tasks_by_ids_with_attempts(queries: ControllerDB, task_ids: set[JobName]) - return {task.task_id: task for task in tasks_with_attempts(tasks, attempts)} -def _building_counts(queries: ControllerDB, workers: list[WorkerRow]) -> dict[WorkerId, int]: +def _building_counts(queries: ControllerDB, workers: list[SchedulableWorker]) -> dict[WorkerId, int]: """Count tasks in BUILDING or ASSIGNED state per worker, excluding reservation-holder jobs.""" if not workers: return {} @@ -763,7 +762,7 @@ def _task_worker_mapping(queries: ControllerDB, task_ids: set[JobName]) -> dict[ def _worker_matches_reservation_entry( - worker: WorkerRow, + worker: SchedulableWorker, res_entry: job_pb2.ReservationEntry, ) -> bool: """Check if a worker is eligible for a reservation entry. @@ -785,9 +784,9 @@ def _worker_matches_reservation_entry( def _inject_reservation_taints( - workers: list[WorkerRow], + workers: list[SchedulableWorker], claims: dict[WorkerId, ReservationClaim], -) -> list[WorkerRow]: +) -> list[SchedulableWorker]: """Create modified worker copies with reservation taints and prioritization. Claimed workers receive a ``reservation-job`` attribute set to the claiming @@ -800,8 +799,8 @@ def _inject_reservation_taints( if not claims: return workers - claimed: list[WorkerRow] = [] - unclaimed: list[WorkerRow] = [] + claimed: list[SchedulableWorker] = [] + unclaimed: list[SchedulableWorker] = [] for worker in workers: claim = claims.get(worker.worker_id) if claim is not None: @@ -881,6 +880,7 @@ def _reservation_region_constraints( job_id_wire: str, claims: dict[WorkerId, ReservationClaim], queries: ControllerDB, + health: WorkerHealthTracker, existing_constraints: list[Constraint], ) -> list[Constraint]: """Derive region constraints from claimed reservation workers. @@ -897,7 +897,7 @@ def _reservation_region_constraints( claimed_worker_ids = {worker_id for worker_id, claim in claims.items() if claim.job_id == job_id_wire} workers_by_id = { worker.worker_id: worker - for worker in healthy_active_workers_with_attributes(queries) + for worker in healthy_active_workers_with_attributes(queries, health) if worker.worker_id in claimed_worker_ids } regions: set[str] = set() @@ -1153,7 +1153,8 @@ def __init__( self._db = db else: self._db = ControllerDB(db_dir=config.local_state_dir / "db") - self._store = ControllerStore(self._db) + self._health = WorkerHealthTracker() + self._store = ControllerStore(self._db, health=self._health) # ThreadContainer must be initialized before the log service setup # because _start_local_log_server spawns a uvicorn thread. @@ -1194,7 +1195,6 @@ def __init__( self._log_handler.setFormatter(logging.Formatter("%(asctime)s %(name)s %(message)s")) logging.getLogger("iris").addHandler(self._log_handler) - self._health = WorkerHealthTracker() self._transitions = ControllerTransitions( store=self._store, health=self._health, @@ -1630,13 +1630,13 @@ def _profile_all_running_tasks(self) -> None: Memory profiling via memray is currently disabled because memray attach has been triggering segfaults in target processes. """ - workers = healthy_active_workers_with_attributes(self._db) + workers = healthy_active_workers_with_attributes(self._db, self._health) if not workers: return workers_by_id = {w.worker_id: w for w in workers} tasks_by_worker = running_tasks_by_worker(self._db, set(workers_by_id.keys())) - profile_targets: list[tuple[JobName, WorkerRow]] = [] + profile_targets: list[tuple[JobName, SchedulableWorker]] = [] for worker_id, task_ids in tasks_by_worker.items(): worker = workers_by_id[worker_id] for task_id in task_ids: @@ -1654,7 +1654,7 @@ def _profile_all_running_tasks(self) -> None: def _dispatch_profiles( self, - targets: list[tuple[JobName, WorkerRow]], + targets: list[tuple[JobName, SchedulableWorker]], profile_type: job_pb2.ProfileType, profile_kind: str, duration: int, @@ -1672,7 +1672,7 @@ def _dispatch_profiles( def _capture_one_profile( self, task_id: JobName, - worker: WorkerRow, + worker: SchedulableWorker, profile_type: job_pb2.ProfileType, profile_kind: str, duration: int, @@ -1742,11 +1742,7 @@ def _cleanup_stale_claims(self, claims: dict[WorkerId, ReservationClaim] | None if claims is None: claims = _read_reservation_claims(self._db) persisted = True - with self._db.read_snapshot() as snapshot: - active_worker_ids = { - WorkerId(str(row[0])) - for row in snapshot.fetchall("SELECT w.worker_id FROM workers w WHERE w.active = 1") - } + active_worker_ids = {wid for wid, l in self._health.all().items() if l.active} claimed_job_ids = {JobName.from_wire(claim.job_id) for claim in claims.values()} claimed_jobs = list(_jobs_by_id(self._db, claimed_job_ids).values()) if claimed_job_ids else [] jobs_by_id = {job.job_id.to_wire(): job for job in claimed_jobs} @@ -1778,7 +1774,7 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai persisted = True claimed_entries: set[tuple[str, int]] = {(c.job_id, c.entry_idx) for c in claims.values()} claimed_worker_ids: set[WorkerId] = set(claims.keys()) - all_workers = healthy_active_workers_with_attributes(self._db) + all_workers = healthy_active_workers_with_attributes(self._db, self._health) changed = False reservable_states = ( @@ -1796,8 +1792,6 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai for worker in all_workers: if worker.worker_id in claimed_worker_ids: continue - if not worker.healthy: - continue if not _worker_matches_reservation_entry(worker, res_entry): continue @@ -1916,7 +1910,7 @@ def _read_scheduling_state(self) -> _SchedulingStateRead: timer = Timer() with slow_log(logger, "scheduling state reads", threshold_ms=50): pending_tasks = _schedulable_tasks(self._db) - workers = healthy_active_workers_with_attributes(self._db) + workers = healthy_active_workers_with_attributes(self._db, self._health) return _SchedulingStateRead( pending_tasks=pending_tasks, workers=workers, @@ -2240,7 +2234,7 @@ def _mark_task_unschedulable(self, task: TaskRow) -> None: if result.tasks_to_kill: self.kill_tasks_on_workers(result.tasks_to_kill, result.task_kill_workers) - def create_scheduling_context(self, workers: list[WorkerRow]) -> SchedulingContext: + def create_scheduling_context(self, workers: list[SchedulableWorker]) -> SchedulingContext: """Create a scheduling context for the given workers.""" building_counts = _building_counts(self._db, workers) return self._scheduler.create_scheduling_context( @@ -2378,7 +2372,7 @@ def _stop_tasks_direct( def _get_active_worker_addresses(self) -> list[tuple[WorkerId, str | None]]: """Get healthy active workers as (worker_id, address) tuples for ping.""" - workers = healthy_active_workers_with_attributes(self._db) + workers = healthy_active_workers_with_attributes(self._db, self._health) return [(w.worker_id, w.address) for w in workers] def _run_ping_loop(self, stop_event: threading.Event) -> None: @@ -2406,8 +2400,7 @@ def _run_ping_loop(self, stop_event: threading.Event) -> None: self._health.ping(result.worker_id, healthy=True) live_worker_ids.append(result.worker_id) - with self._store.transaction() as cur: - self._transitions.update_worker_pings(cur, live_worker_ids) + self._transitions.update_worker_pings(live_worker_ids) unhealthy = self._health.workers_over_threshold() if unhealthy: @@ -2534,7 +2527,7 @@ def _run_autoscaler_once(self) -> None: worker_status_map = self._build_worker_status_map() self._autoscaler.refresh(worker_status_map) - workers = healthy_active_workers_with_attributes(self._db) + workers = healthy_active_workers_with_attributes(self._db, self._health) demand_entries = compute_demand_entries( self._db, self._scheduler, @@ -2546,12 +2539,7 @@ def _run_autoscaler_once(self) -> None: def _build_worker_status_map(self) -> WorkerStatusMap: """Build a map of worker_id to worker status for autoscaler idle tracking.""" result: WorkerStatusMap = {} - with self._db.read_snapshot() as snapshot: - rows = snapshot.raw( - "SELECT worker_id FROM workers WHERE active = 1", - decoders={"worker_id": WorkerId}, - ) - worker_ids = {row.worker_id for row in rows} + worker_ids = {wid for wid, l in self._health.all().items() if l.active} running_by_worker = running_tasks_by_worker(self._db, worker_ids) for wid in worker_ids: result[wid] = WorkerStatus( diff --git a/lib/iris/src/iris/cluster/controller/db.py b/lib/iris/src/iris/cluster/controller/db.py index ebbf70ecf9..9b9e169798 100644 --- a/lib/iris/src/iris/cluster/controller/db.py +++ b/lib/iris/src/iris/cluster/controller/db.py @@ -11,7 +11,6 @@ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass, field -from dataclasses import replace as dc_replace from pathlib import Path from threading import Lock, RLock from typing import Any @@ -20,6 +19,7 @@ from iris.cluster.constraints import AttributeValue from iris.cluster.controller.schema import decode_timestamp_ms, decode_worker_id +from iris.cluster.controller.worker_health import WorkerHealthTracker from iris.cluster.types import TERMINAL_TASK_STATES, JobName, WorkerId from iris.rpc import job_pb2 @@ -827,10 +827,7 @@ def get_all_user_budget_limits(self) -> dict[str, int]: def running_tasks_by_worker(db: ControllerDB, worker_ids: set[WorkerId]) -> dict[WorkerId, set[JobName]]: - """Return the set of currently-running task IDs for each worker. - - Uses the denormalized current_worker_id column instead of joining task_attempts. - """ + """Return the set of currently-running task IDs for each worker.""" if not worker_ids: return {} placeholders = ",".join("?" for _ in worker_ids) @@ -919,32 +916,72 @@ def _worker_row_select() -> str: return WORKER_ROW_PROJECTION.select_clause() -def healthy_active_workers_with_attributes(db: ControllerDB) -> list: - """Fetch all healthy, active workers with their attributes populated. +@dataclass(frozen=True, slots=True) +class SchedulableWorker: + """Worker shape consumed by the scheduler. - Returns WorkerRow (scalar-only) so the scheduling loop avoids loading metadata columns. - Uses the in-memory attribute cache to avoid a per-cycle SQL join. + Field names mirror the :class:`scheduler.WorkerSnapshot` protocol so + instances flow into ``Scheduler.create_scheduling_context`` without + an adapter. """ + + worker_id: WorkerId + address: str + total_cpu_millicores: int + total_memory_bytes: int + total_gpu_count: int + total_tpu_count: int + device_type: str + device_variant: str + attributes: dict[str, AttributeValue] + committed_cpu_millicores: int + committed_mem: int + committed_gpu: int + committed_tpu: int + + +def healthy_active_workers_with_attributes( + db: ControllerDB, + health: WorkerHealthTracker, +) -> list[SchedulableWorker]: + """Return healthy + active workers with attributes and committed totals.""" from iris.cluster.controller.schema import WORKER_ROW_PROJECTION + liveness = health.all() + healthy_active = {wid for wid, l in liveness.items() if l.healthy and l.active} + if not healthy_active: + return [] + placeholders = ",".join("?" for _ in healthy_active) with db.read_snapshot() as q: - workers = WORKER_ROW_PROJECTION.decode( - q.fetchall(f"SELECT {_worker_row_select()} FROM workers w WHERE w.healthy = 1 AND w.active = 1"), + rows = WORKER_ROW_PROJECTION.decode( + q.fetchall( + f"SELECT {_worker_row_select()} FROM workers w WHERE w.worker_id IN ({placeholders})", + tuple(str(wid) for wid in healthy_active), + ), ) - if not workers: + if not rows: return [] attrs_by_worker = db.get_worker_attributes() - return [ - dc_replace( - w, - attributes=attrs_by_worker.get(w.worker_id, {}), - available_cpu_millicores=w.total_cpu_millicores - w.committed_cpu_millicores, - available_memory=w.total_memory_bytes - w.committed_mem, - available_gpus=w.total_gpu_count - w.committed_gpu, - available_tpus=w.total_tpu_count - w.committed_tpu, + out: list[SchedulableWorker] = [] + for w in rows: + out.append( + SchedulableWorker( + worker_id=w.worker_id, + address=w.address, + total_cpu_millicores=w.total_cpu_millicores, + total_memory_bytes=w.total_memory_bytes, + total_gpu_count=w.total_gpu_count, + total_tpu_count=w.total_tpu_count, + device_type=w.device_type, + device_variant=w.device_variant, + attributes=attrs_by_worker.get(w.worker_id, {}), + committed_cpu_millicores=w.committed_cpu_millicores, + committed_mem=w.committed_mem, + committed_gpu=w.committed_gpu, + committed_tpu=w.committed_tpu, + ) ) - for w in workers - ] + return out def insert_task_profile( diff --git a/lib/iris/src/iris/cluster/controller/migrations/0004_worker_indexes.py b/lib/iris/src/iris/cluster/controller/migrations/0004_worker_indexes.py index 4f77b0acbf..4a907b0cd2 100644 --- a/lib/iris/src/iris/cluster/controller/migrations/0004_worker_indexes.py +++ b/lib/iris/src/iris/cluster/controller/migrations/0004_worker_indexes.py @@ -4,10 +4,12 @@ import sqlite3 +def _has_column(conn: sqlite3.Connection, table: str, column: str) -> bool: + return column in {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()} + + def migrate(conn: sqlite3.Connection) -> None: - # Originally this migration also rewrote the `trg_txn_log_retention` - # trigger; those statements were removed once migration 0037 dropped the - # `txn_log` / `txn_actions` tables entirely. On DBs that already ran the - # old form the trigger survives until 0037 executes; 0037 is idempotent - # (`DROP TRIGGER IF EXISTS`) so no fixup is needed here. - conn.execute("CREATE INDEX IF NOT EXISTS idx_workers_healthy_active ON workers(healthy, active)") + # ``healthy`` / ``active`` are dropped from the workers table by 0042; on a + # fresh DB the columns are absent at this point so the index is a no-op. + if _has_column(conn, "workers", "healthy") and _has_column(conn, "workers", "active"): + conn.execute("CREATE INDEX IF NOT EXISTS idx_workers_healthy_active ON workers(healthy, active)") diff --git a/lib/iris/src/iris/cluster/controller/migrations/0019_worker_fk_cascade.py b/lib/iris/src/iris/cluster/controller/migrations/0019_worker_fk_cascade.py index 8317685485..d9acae4081 100644 --- a/lib/iris/src/iris/cluster/controller/migrations/0019_worker_fk_cascade.py +++ b/lib/iris/src/iris/cluster/controller/migrations/0019_worker_fk_cascade.py @@ -45,24 +45,29 @@ def migrate(conn: sqlite3.Connection) -> None: "CREATE INDEX IF NOT EXISTS idx_task_attempts_worker_task " "ON task_attempts(worker_id, task_id, attempt_id)" ) - # Recreate the trigger from 0001_init (dropped when the table was rebuilt) - conn.execute( - """ - CREATE TRIGGER IF NOT EXISTS trg_task_attempt_active_worker - BEFORE INSERT ON task_attempts - FOR EACH ROW - WHEN NEW.worker_id IS NOT NULL - BEGIN - SELECT - CASE - WHEN NOT EXISTS( - SELECT 1 FROM workers w - WHERE w.worker_id = NEW.worker_id - AND w.active = 1 - AND w.healthy = 1 - ) - THEN RAISE(ABORT, 'task attempt worker must be active and healthy') + # Recreate the trigger from 0001_init (dropped when the table was rebuilt). + # The trigger references workers.active / workers.healthy, which are dropped + # in 0042. On a fresh DB those columns are absent at this point, so skip + # the trigger entirely; existing DBs created the trigger before 0042 ran. + cols = {row[1] for row in conn.execute("PRAGMA table_info(workers)").fetchall()} + if "active" in cols and "healthy" in cols: + conn.execute( + """ + CREATE TRIGGER IF NOT EXISTS trg_task_attempt_active_worker + BEFORE INSERT ON task_attempts + FOR EACH ROW + WHEN NEW.worker_id IS NOT NULL + BEGIN + SELECT + CASE + WHEN NOT EXISTS( + SELECT 1 FROM workers w + WHERE w.worker_id = NEW.worker_id + AND w.active = 1 + AND w.healthy = 1 + ) + THEN RAISE(ABORT, 'task attempt worker must be active and healthy') + END; END; - END; - """ - ) + """ + ) diff --git a/lib/iris/src/iris/cluster/controller/migrations/0030_backfill_worker_region.py b/lib/iris/src/iris/cluster/controller/migrations/0030_backfill_worker_region.py index ceef6cb6c3..9f35645e64 100644 --- a/lib/iris/src/iris/cluster/controller/migrations/0030_backfill_worker_region.py +++ b/lib/iris/src/iris/cluster/controller/migrations/0030_backfill_worker_region.py @@ -27,10 +27,17 @@ def _zone_of(scale_group: str) -> str | None: return None +def _has_column(conn: sqlite3.Connection, table: str, column: str) -> bool: + return column in {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()} + + def migrate(conn: sqlite3.Connection) -> None: + # ``active`` was a workers column at the time this migration was authored. + # It is dropped in 0042; on fresh DBs the column is absent at this point. + active_predicate = "w.active=1 AND " if _has_column(conn, "workers", "active") else "" rows = conn.execute( "SELECT w.worker_id, w.scale_group FROM workers w " - "WHERE w.active=1 AND w.scale_group != '' " + f"WHERE {active_predicate}w.scale_group != '' " "AND NOT EXISTS (" " SELECT 1 FROM worker_attributes wa " " WHERE wa.worker_id = w.worker_id AND wa.key = 'region'" diff --git a/lib/iris/src/iris/cluster/controller/migrations/0042_drop_workers_dormant_columns.py b/lib/iris/src/iris/cluster/controller/migrations/0042_drop_workers_dormant_columns.py new file mode 100644 index 0000000000..5c55686f14 --- /dev/null +++ b/lib/iris/src/iris/cluster/controller/migrations/0042_drop_workers_dormant_columns.py @@ -0,0 +1,31 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Drop the transient-liveness columns on ``workers`` (now in-memory only). + +Removes ``last_heartbeat_ms``, ``healthy``, ``active``, and +``consecutive_failures`` along with ``idx_workers_healthy_active`` and the +``trg_task_attempt_active_worker`` trigger that referenced them. +""" + +import sqlite3 + + +def _has_column(conn: sqlite3.Connection, table: str, column: str) -> bool: + return column in {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()} + + +_COLUMNS_TO_DROP = ( + "last_heartbeat_ms", + "healthy", + "active", + "consecutive_failures", +) + + +def migrate(conn: sqlite3.Connection) -> None: + conn.execute("DROP INDEX IF EXISTS idx_workers_healthy_active") + conn.execute("DROP TRIGGER IF EXISTS trg_task_attempt_active_worker") + for col in _COLUMNS_TO_DROP: + if _has_column(conn, "workers", col): + conn.execute(f"ALTER TABLE workers DROP COLUMN {col}") diff --git a/lib/iris/src/iris/cluster/controller/scheduler.py b/lib/iris/src/iris/cluster/controller/scheduler.py index 89f389a0be..f0c8e0e3e7 100644 --- a/lib/iris/src/iris/cluster/controller/scheduler.py +++ b/lib/iris/src/iris/cluster/controller/scheduler.py @@ -75,7 +75,6 @@ class WorkerSnapshot(Protocol): total_tpu_count: int committed_tpu: int attributes: dict[str, AttributeValue] - healthy: bool class RejectionKind(StrEnum): @@ -338,7 +337,6 @@ def from_workers( max_building_tasks=max_building_tasks, ) for w in workers - if w.healthy } str_to_wid: dict[str, WorkerId] = {} diff --git a/lib/iris/src/iris/cluster/controller/schema.py b/lib/iris/src/iris/cluster/controller/schema.py index 2bddd2a4ff..7aab54ea51 100644 --- a/lib/iris/src/iris/cluster/controller/schema.py +++ b/lib/iris/src/iris/cluster/controller/schema.py @@ -828,25 +828,6 @@ def generate_full_ddl(tables: Sequence[Table]) -> str: "CREATE INDEX IF NOT EXISTS idx_task_attempts_worker_task" " ON task_attempts(worker_id, task_id, attempt_id)", ), - triggers=( - # From 0001_init - """CREATE TRIGGER IF NOT EXISTS trg_task_attempt_active_worker -BEFORE INSERT ON task_attempts -FOR EACH ROW -WHEN NEW.worker_id IS NOT NULL -BEGIN - SELECT - CASE - WHEN NOT EXISTS( - SELECT 1 FROM workers w - WHERE w.worker_id = NEW.worker_id - AND w.active = 1 - AND w.healthy = 1 - ) - THEN RAISE(ABORT, 'task attempt worker must be active and healthy') - END; -END;""", - ), ) WORKERS = Table( @@ -871,35 +852,6 @@ def generate_full_ddl(tables: Sequence[Table]) -> str: Column("md_gce_zone", "TEXT", "NOT NULL DEFAULT ''", python_type=str, decoder=str, default=""), Column("md_git_hash", "TEXT", "NOT NULL DEFAULT ''", python_type=str, decoder=str, default=""), Column("md_device_json", "TEXT", "NOT NULL DEFAULT '{}'", python_type=str, decoder=str, default="{}"), - Column("healthy", "INTEGER", "NOT NULL CHECK (healthy IN (0, 1))", python_type=bool, decoder=_decode_bool_int), - Column( - "active", - "INTEGER", - "NOT NULL CHECK (active IN (0, 1))", - python_type=bool, - decoder=_decode_bool_int, - default=True, - ), - Column("consecutive_failures", "INTEGER", "NOT NULL", python_type=int, decoder=int), - Column( - "last_heartbeat_ms", - "INTEGER", - "NOT NULL", - python_name="last_heartbeat", - python_type=Timestamp, - decoder=decode_timestamp_ms, - ), - Column("committed_cpu_millicores", "INTEGER", "NOT NULL", python_type=int, decoder=int), - Column( - "committed_mem_bytes", - "INTEGER", - "NOT NULL", - python_name="committed_mem", - python_type=int, - decoder=int, - ), - Column("committed_gpu", "INTEGER", "NOT NULL", python_type=int, decoder=int), - Column("committed_tpu", "INTEGER", "NOT NULL", python_type=int, decoder=int), # Migration 0016 Column("total_cpu_millicores", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=0), Column("total_memory_bytes", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=0), @@ -910,10 +862,26 @@ def generate_full_ddl(tables: Sequence[Table]) -> str: # Migration 0022 Column("slice_id", "TEXT", "NOT NULL DEFAULT ''", python_type=str, decoder=str, default=""), Column("scale_group", "TEXT", "NOT NULL DEFAULT ''", python_type=str, decoder=str, default=""), - ), - indexes=( - # Migration 0004_worker_indexes - "CREATE INDEX IF NOT EXISTS idx_workers_healthy_active ON workers(healthy, active)", + # Committed-resource totals — only the scheduler writes these. + Column( + "committed_cpu_millicores", + "INTEGER", + "NOT NULL DEFAULT 0", + python_type=int, + decoder=int, + default=0, + ), + Column( + "committed_mem_bytes", + "INTEGER", + "NOT NULL DEFAULT 0", + python_name="committed_mem", + python_type=int, + decoder=int, + default=0, + ), + Column("committed_gpu", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=0), + Column("committed_tpu", "INTEGER", "NOT NULL DEFAULT 0", python_type=int, decoder=int, default=0), ), ) @@ -1239,25 +1207,16 @@ class JobRow: job_id: JobName state: int submitted_at: Timestamp - root_submitted_at: Timestamp started_at: Timestamp | None finished_at: Timestamp | None - scheduling_deadline_epoch_ms: int | None error: str | None exit_code: int | None - num_tasks: int - is_reservation_holder: bool - has_reservation: bool name: str depth: int res_cpu_millicores: int res_memory_bytes: int res_disk_bytes: int res_device_json: str | None - has_coscheduling: bool - coscheduling_group_by: str - scheduling_timeout_ms: int | None - max_task_failures: int @dataclass(frozen=True, slots=True) @@ -1374,29 +1333,21 @@ class TaskDetailRow: @dataclass(frozen=True, slots=True) class WorkerRow: - """Worker row for scheduling and health checks.""" + """Durable worker columns: identity, capability, and committed scheduling totals.""" worker_id: WorkerId address: str - healthy: bool - active: bool - consecutive_failures: int - last_heartbeat: Timestamp - committed_cpu_millicores: int - committed_mem: int - committed_gpu: int - committed_tpu: int total_cpu_millicores: int total_memory_bytes: int total_gpu_count: int total_tpu_count: int device_type: str device_variant: str + committed_cpu_millicores: int + committed_mem: int + committed_gpu: int + committed_tpu: int attributes: dict = dataclasses.field(default_factory=dict) - available_cpu_millicores: int = 0 - available_memory: int = 0 - available_gpus: int = 0 - available_tpus: int = 0 @dataclass(frozen=True, slots=True) @@ -1405,20 +1356,16 @@ class WorkerDetailRow: worker_id: WorkerId address: str - healthy: bool - active: bool - consecutive_failures: int - last_heartbeat: Timestamp - committed_cpu_millicores: int - committed_mem: int - committed_gpu: int - committed_tpu: int total_cpu_millicores: int total_memory_bytes: int total_gpu_count: int total_tpu_count: int device_type: str device_variant: str + committed_cpu_millicores: int + committed_mem: int + committed_gpu: int + committed_tpu: int md_hostname: str md_ip_address: str md_cpu_count: int @@ -1436,10 +1383,6 @@ class WorkerDetailRow: md_git_hash: str md_device_json: str attributes: dict = dataclasses.field(default_factory=dict) - available_cpu_millicores: int = 0 - available_memory: int = 0 - available_gpus: int = 0 - available_tpus: int = 0 @dataclass(frozen=True, slots=True) @@ -1528,25 +1471,16 @@ def _job_columns(*names: str) -> tuple[tuple[Column, ...], tuple[str, ...]]: "job_id", "state", "submitted_at_ms", - "root_submitted_at_ms", "started_at_ms", "finished_at_ms", - "scheduling_deadline_epoch_ms", "error", "exit_code", - "num_tasks", - "is_reservation_holder", - "has_reservation", "name", "depth", "res_cpu_millicores", "res_memory_bytes", "res_disk_bytes", "res_device_json", - "has_coscheduling", - "coscheduling_group_by", - "scheduling_timeout_ms", - "max_task_failures", ) JOB_ROW_PROJECTION = Projection( JOBS, @@ -1588,31 +1522,21 @@ def _job_columns(*names: str) -> tuple[tuple[Column, ...], tuple[str, ...]]: column_aliases=_job_sched_aliases, ) -# Worker row for scheduling and health checks. +# ``attributes`` is hydrated post-decode from the ``worker_attributes`` table. WORKER_ROW_PROJECTION = WORKERS.projection( "worker_id", "address", - "healthy", - "active", - "consecutive_failures", - "last_heartbeat_ms", - "committed_cpu_millicores", - "committed_mem_bytes", - "committed_gpu", - "committed_tpu", "total_cpu_millicores", "total_memory_bytes", "total_gpu_count", "total_tpu_count", "device_type", "device_variant", - extra_fields=( - ExtraField("attributes", dict, default_factory=dict), - ExtraField("available_cpu_millicores", int, default=0), - ExtraField("available_memory", int, default=0), - ExtraField("available_gpus", int, default=0), - ExtraField("available_tpus", int, default=0), - ), + "committed_cpu_millicores", + "committed_mem_bytes", + "committed_gpu", + "committed_tpu", + extra_fields=(ExtraField("attributes", dict, default_factory=dict),), row_cls=WorkerRow, ) @@ -1709,20 +1633,16 @@ def _job_columns(*names: str) -> tuple[tuple[Column, ...], tuple[str, ...]]: WORKER_DETAIL_PROJECTION = WORKERS.projection( "worker_id", "address", - "healthy", - "active", - "consecutive_failures", - "last_heartbeat_ms", - "committed_cpu_millicores", - "committed_mem_bytes", - "committed_gpu", - "committed_tpu", "total_cpu_millicores", "total_memory_bytes", "total_gpu_count", "total_tpu_count", "device_type", "device_variant", + "committed_cpu_millicores", + "committed_mem_bytes", + "committed_gpu", + "committed_tpu", "md_hostname", "md_ip_address", "md_cpu_count", @@ -1739,13 +1659,7 @@ def _job_columns(*names: str) -> tuple[tuple[Column, ...], tuple[str, ...]]: "md_gce_zone", "md_git_hash", "md_device_json", - extra_fields=( - ExtraField("attributes", dict, default_factory=dict), - ExtraField("available_cpu_millicores", int, default=0), - ExtraField("available_memory", int, default=0), - ExtraField("available_gpus", int, default=0), - ExtraField("available_tpus", int, default=0), - ), + extra_fields=(ExtraField("attributes", dict, default_factory=dict),), row_cls=WorkerDetailRow, ) diff --git a/lib/iris/src/iris/cluster/controller/service.py b/lib/iris/src/iris/cluster/controller/service.py index fcd038e9bc..4b66f0236c 100644 --- a/lib/iris/src/iris/cluster/controller/service.py +++ b/lib/iris/src/iris/cluster/controller/service.py @@ -14,7 +14,6 @@ import re import secrets import uuid -from collections.abc import Callable, Mapping from dataclasses import dataclass from datetime import date, timedelta from typing import Any, Protocol @@ -53,6 +52,7 @@ ControllerDB, EndpointQuery, QuerySnapshot, + SchedulableWorker, TaskJobSummary, UserStats, attempt_is_worker_failure, @@ -77,15 +77,15 @@ JobRow, TaskDetailRow, WorkerDetailRow, - WorkerRow, tasks_with_attempts, ) -from iris.cluster.controller.stores import ControllerStore, SnapshotView +from iris.cluster.controller.stores import AddEndpointOutcome, ControllerStore from iris.cluster.controller.transitions import ( ControllerTransitions, HeartbeatApplyRequest, task_updates_from_proto, ) +from iris.cluster.controller.worker_health import WorkerLiveness from iris.cluster.log_store_helpers import build_log_source from iris.cluster.process_status import get_process_status from iris.cluster.redaction import redact_request_env_vars @@ -219,10 +219,10 @@ def _active_worker_id(task: TaskDetailRow) -> WorkerId | None: def task_to_proto(task: TaskDetailRow, worker_address: str = "") -> job_pb2.TaskStatus: """Convert a task row to a TaskStatus proto. - Handles attempt conversion and timestamps. ``resource_usage`` is no longer - populated by the controller — per-attempt samples live in the ``iris.task`` - stats namespace. The caller is responsible for resolving worker_address - from worker_id if needed. + Handles attempt conversion and timestamps. Per-attempt resource samples + live in the ``iris.task`` stats namespace and are not populated here. The + caller is responsible for resolving worker_address from worker_id if + needed. """ current_attempt = _current_attempt(task) @@ -269,12 +269,12 @@ def task_to_proto(task: TaskDetailRow, worker_address: str = "") -> job_pb2.Task return proto -def worker_status_message(w: WorkerDetailRow) -> str: +def worker_status_message(liveness: WorkerLiveness) -> str: """Build a human-readable status message for unhealthy workers.""" - if w.healthy: + if liveness.healthy: return "" - age = w.last_heartbeat.age_ms() - return f"Unhealthy (last seen {age // 1000}s ago)" + age_ms = max(0, Timestamp.now().epoch_ms() - liveness.last_heartbeat_ms) + return f"Unhealthy (last seen {age_ms // 1000}s ago)" _WORKER_TARGET_PREFIX = "/system/worker/" @@ -349,14 +349,9 @@ def _read_task_with_attempts(db: ControllerDB, task_id: JobName) -> TaskDetailRo return tasks_with_attempts([task], attempts)[0] -def _read_worker(db: ControllerDB, worker_id: WorkerId) -> WorkerDetailRow | None: - with db.read_snapshot() as q: - return WORKER_DETAIL_PROJECTION.decode_one( - q.fetchall( - f"SELECT {WORKER_DETAIL_PROJECTION.select_clause()} FROM workers w WHERE w.worker_id = ?", - (str(worker_id),), - ) - ) +def _read_worker(store: ControllerStore, worker_id: WorkerId) -> WorkerDetailRow | None: + with store.read_snapshot() as q: + return store.workers.get_detail(q, worker_id) def _job_state(db: ControllerDB, job_id: JobName) -> int | None: @@ -477,14 +472,9 @@ class _WorkerDetail: running_tasks: frozenset[JobName] -def _read_worker_detail(db: ControllerDB, worker_id: WorkerId) -> _WorkerDetail | None: - with db.read_snapshot() as q: - worker = WORKER_DETAIL_PROJECTION.decode_one( - q.fetchall( - f"SELECT {WORKER_DETAIL_PROJECTION.select_clause()} FROM workers w WHERE w.worker_id = ?", - (str(worker_id),), - ), - ) +def _read_worker_detail(store: ControllerStore, worker_id: WorkerId) -> _WorkerDetail | None: + with store.read_snapshot() as q: + worker = store.workers.get_detail(q, worker_id) if worker is None: return None attr_rows = q.fetchall( @@ -508,24 +498,29 @@ def _read_worker_detail(db: ControllerDB, worker_id: WorkerId) -> _WorkerDetail def _tasks_for_listing(db: ControllerDB, *, job_id: JobName) -> list[TaskDetailRow]: + """Load tasks for the list view, attaching only the current attempt. + + The list UI only needs the current attempt's ``started_at`` / + ``finished_at`` and a single ``proto.attempts`` entry. Full history is + fetched separately by ``get_task_status``. + """ + job_wire = job_id.to_wire() with db.read_snapshot() as q: tasks = TASK_DETAIL_PROJECTION.decode( q.fetchall( f"SELECT {TASK_DETAIL_PROJECTION.select_clause()} " "FROM tasks t WHERE t.job_id = ? ORDER BY t.job_id ASC, t.task_index ASC", - (job_id.to_wire(),), + (job_wire,), ), ) - if not tasks: - return [] - task_wires = [t.task_id.to_wire() for t in tasks] - placeholders = ",".join("?" for _ in task_wires) attempts = ATTEMPT_PROJECTION.decode( q.fetchall( f"SELECT {ATTEMPT_PROJECTION.select_clause()} FROM task_attempts ta " - f"WHERE ta.task_id IN ({placeholders}) " - "ORDER BY ta.task_id ASC, ta.attempt_id ASC", - tuple(task_wires), + "WHERE (ta.task_id, ta.attempt_id) IN (" + " SELECT t.task_id, t.current_attempt_id FROM tasks t " + " WHERE t.job_id = ? AND t.current_attempt_id >= 0" + ")", + (job_wire,), ), ) return tasks_with_attempts(tasks, attempts) @@ -546,9 +541,8 @@ def _worker_addresses_for_tasks(db: ControllerDB, tasks: list[TaskDetailRow]) -> return {WorkerId(str(row.worker_id)): row.address for row in rows} -# State display order for sorting (active states first). Used both by the SQL -# path (rendered as a CASE expression) and the Python snapshot path (used as a -# direct dict lookup). Keep the two in sync. +# State display order for sorting (active states first). Rendered into +# ``_STATE_SORT_EXPR`` as a CASE expression for the JOB_SORT_FIELD_STATE path. _STATE_SORT_ORDER: dict[int, int] = { job_pb2.JOB_STATE_RUNNING: 0, job_pb2.JOB_STATE_BUILDING: 1, @@ -580,6 +574,7 @@ def _worker_addresses_for_tasks(db: ControllerDB, tasks: list[TaskDetailRow]) -> def _filter_and_sort_workers( workers: list[WorkerDetailRow], + liveness_by_id: dict[WorkerId, WorkerLiveness], query: controller_pb2.Controller.WorkerQuery, ) -> list[WorkerDetailRow]: """Apply the ``WorkerQuery`` contains filter and sort the cached roster. @@ -599,7 +594,7 @@ def _filter_and_sort_workers( sort_field = query.sort_field or controller_pb2.Controller.WORKER_SORT_FIELD_WORKER_ID descending = query.sort_direction == controller_pb2.Controller.SORT_DIRECTION_DESC if sort_field == controller_pb2.Controller.WORKER_SORT_FIELD_LAST_HEARTBEAT: - workers = sorted(workers, key=lambda w: w.last_heartbeat.epoch_ms(), reverse=descending) + workers = sorted(workers, key=lambda w: liveness_by_id[w.worker_id].last_heartbeat_ms, reverse=descending) elif sort_field == controller_pb2.Controller.WORKER_SORT_FIELD_DEVICE_TYPE: # CPU workers persist with ``device_type == ""``; under ascending sort # they group first (treating CPU as the no-accelerator baseline). @@ -720,65 +715,18 @@ def _query_jobs( return JOB_ROW_PROJECTION.decode(rows), total -def _job_matches_query( - job: JobRow, - query: controller_pb2.Controller.JobQuery, - state_ids: tuple[int, ...], -) -> bool: - """Python equivalent of the WHERE clause in :func:`_query_jobs`. - - Used by the snapshot path; the SQL path uses ``_query_jobs`` directly. - """ - if job.state not in state_ids: - return False - scope = query.scope or controller_pb2.Controller.JOB_QUERY_SCOPE_ALL - if scope == controller_pb2.Controller.JOB_QUERY_SCOPE_ROOTS: - if job.depth != 1: - return False - elif scope == controller_pb2.Controller.JOB_QUERY_SCOPE_CHILDREN: - parent = job.job_id.parent - if parent is None or parent.to_wire() != query.parent_job_id: - return False - if query.name_filter and query.name_filter.lower() not in job.name.lower(): - return False - return True - - -def _job_sort_key( - sort_field: int, - summaries: Mapping[JobName, TaskJobSummary], -) -> Callable[[JobRow], object]: - """Return a key function for Python-side sorting equivalent to ``_SORT_FIELD_TO_SQL``.""" - if sort_field == controller_pb2.Controller.JOB_SORT_FIELD_NAME: - return lambda j: j.name - if sort_field == controller_pb2.Controller.JOB_SORT_FIELD_STATE: - # Same priority order as ``_STATE_SORT_EXPR``; unknown states sink to 99. - return lambda j: _STATE_SORT_ORDER.get(j.state, 99) - if sort_field == controller_pb2.Controller.JOB_SORT_FIELD_FAILURES: - return lambda j: summaries.get(j.job_id, TaskJobSummary(job_id=j.job_id)).failure_count - if sort_field == controller_pb2.Controller.JOB_SORT_FIELD_PREEMPTIONS: - return lambda j: summaries.get(j.job_id, TaskJobSummary(job_id=j.job_id)).preemption_count - # JOB_SORT_FIELD_DATE (default). - return lambda j: j.submitted_at.epoch_ms() - - def _query_from_list_jobs_request( request: controller_pb2.Controller.ListJobsRequest, ) -> controller_pb2.Controller.JobQuery: - """Return the request's ``JobQuery`` with paging clamped to safe bounds. - - The legacy flat fields on ``ListJobsRequest`` were removed in #4573; - callers must now always submit a ``JobQuery``. - """ + """Return the request's ``JobQuery`` with paging clamped to safe bounds.""" query = controller_pb2.Controller.JobQuery() if request.HasField("query"): query.CopyFrom(request.query) - # Clamp paging: 0 (unset) defaults to MAX; explicit values are capped at MAX. - # We no longer support unbounded listing — callers that previously relied on - # limit=0 must paginate. Unbounded queries scale poorly because downstream - # per-page work (_task_summaries_for_jobs, _parent_ids_with_children) grows - # an IN-clause with one placeholder per returned row. + # Clamp paging: 0 (unset) or out-of-range values default to MAX. Unbounded + # listing is not supported because downstream per-page work + # (_task_summaries_for_jobs, _parent_ids_with_children) grows an IN-clause + # with one placeholder per returned row. if query.limit <= 0 or query.limit > MAX_LIST_JOBS_LIMIT: query.limit = MAX_LIST_JOBS_LIMIT if query.offset < 0: @@ -801,7 +749,7 @@ def _parent_ids_with_children(q: QuerySnapshot, job_ids: list[JobName]) -> set[J def _task_summaries_for_jobs(q: QuerySnapshot, job_ids: set[JobName] | None = None) -> dict[JobName, TaskJobSummary]: - """Aggregate task counts per job using SQL GROUP BY instead of Python-side iteration.""" + """Aggregate task counts per job via a SQL GROUP BY.""" if job_ids is not None: placeholders = ",".join("?" for _ in job_ids) where = f"WHERE t.job_id IN ({placeholders})" @@ -837,27 +785,26 @@ def _task_summaries_for_jobs(q: QuerySnapshot, job_ids: set[JobName] | None = No return summaries -def _worker_roster(db: ControllerDB) -> list[WorkerDetailRow]: - with db.read_snapshot() as q: - workers = WORKER_DETAIL_PROJECTION.decode( +def _worker_roster(store: ControllerStore) -> list[WorkerDetailRow]: + with store.read_snapshot() as q: + decoded = WORKER_DETAIL_PROJECTION.decode( q.fetchall(f"SELECT {WORKER_DETAIL_PROJECTION.select_clause()} FROM workers w") ) - # Populate attributes from worker_attributes table. - if workers: - worker_ids = tuple(str(w.worker_id) for w in workers) - placeholders = ",".join("?" for _ in worker_ids) - attr_rows = q.fetchall( - f"SELECT worker_id, key, value_type, str_value, int_value, float_value " - f"FROM worker_attributes WHERE worker_id IN ({placeholders})", - worker_ids, - ) - attrs_by_worker: dict[str, dict[str, str | int | float]] = {} - for row in attr_rows: - wid = str(row["worker_id"]) - key, value = _decode_attribute_value(row) - attrs_by_worker.setdefault(wid, {})[key] = value - workers = [dataclasses.replace(w, attributes=attrs_by_worker.get(str(w.worker_id), {})) for w in workers] - return workers + if not decoded: + return [] + worker_ids = tuple(str(w.worker_id) for w in decoded) + placeholders = ",".join("?" for _ in worker_ids) + attr_rows = q.fetchall( + f"SELECT worker_id, key, value_type, str_value, int_value, float_value " + f"FROM worker_attributes WHERE worker_id IN ({placeholders})", + worker_ids, + ) + attrs_by_worker: dict[str, dict[str, str | int | float]] = {} + for row in attr_rows: + wid = str(row["worker_id"]) + key, value = _decode_attribute_value(row) + attrs_by_worker.setdefault(wid, {})[key] = value + return [dataclasses.replace(w, attributes=attrs_by_worker.get(str(w.worker_id), {})) for w in decoded] def _descendant_jobs(db: ControllerDB, job_id: JobName) -> list[JobDetailRow]: @@ -984,7 +931,7 @@ def kill_tasks_on_workers( task_kill_workers: dict[JobName, WorkerId] | None = None, ) -> None: ... - def create_scheduling_context(self, workers: list[WorkerRow]) -> SchedulingContext: ... + def create_scheduling_context(self, workers: list[SchedulableWorker]) -> SchedulingContext: ... def get_job_scheduling_diagnostics(self, job_wire_id: str) -> str | None: ... @@ -1032,39 +979,6 @@ def _inject_resource_constraints( return new_request -# Dashboard list/aggregate RPCs (ListJobs, ListWorkers) are dominated by polling -# traffic. The data they read is already a few seconds stale by the time the -# browser renders it, so we serve them from periodic in-memory snapshots -# instead of a per-request DB fan-out. Each snapshot is rebuilt at most once -# per ``SnapshotView`` TTL and shared across concurrent readers. - -# How long ListJobs/ListWorkers may serve stale rows. Picked to be short enough -# that a user clicking "kill" then refreshing sees the change within a poll or -# two, and long enough to absorb dashboard fan-out. -DASHBOARD_SNAPSHOT_TTL_S = 2.0 - - -@dataclass(frozen=True, slots=True) -class JobsSnapshot: - """Full job set with per-job task summaries and parent-child flags. - - ListJobs filters/sorts/paginates ``rows`` in Python and looks up the - accompanying ``summaries`` / ``child_parent_ids`` per page. - """ - - rows: tuple[JobRow, ...] - summaries: Mapping[JobName, TaskJobSummary] - child_parent_ids: frozenset[JobName] - - -@dataclass(frozen=True, slots=True) -class WorkersSnapshot: - """Worker roster plus running-task assignments per worker.""" - - workers: tuple[WorkerDetailRow, ...] - running_by_worker: Mapping[WorkerId, frozenset[JobName]] - - class ControllerServiceImpl: """ControllerService RPC implementation. @@ -1097,45 +1011,6 @@ def __init__( self._auth = auth or ControllerAuth() self._system_endpoints: dict[str, str] = system_endpoints or {} self._user_budget_defaults = user_budget_defaults or UserBudgetDefaults() - # Dashboard polling RPCs are served from periodic snapshots instead of - # per-request SQL. ListJobs filters its snapshot in Python; ListWorkers - # and GetAutoscalerStatus share one workers snapshot so back-to-back - # dashboard polls only do one SELECT/JOIN. - self._jobs_snapshot = SnapshotView[JobsSnapshot]( - name="jobs", ttl_s=DASHBOARD_SNAPSHOT_TTL_S, build=self._build_jobs_snapshot - ) - self._workers_snapshot = SnapshotView[WorkersSnapshot]( - name="workers", ttl_s=DASHBOARD_SNAPSHOT_TTL_S, build=self._build_workers_snapshot - ) - - def _build_jobs_snapshot(self) -> JobsSnapshot: - """Materialize every job row plus its task summary and child-presence flag. - - One read snapshot, three queries: SELECT jobs (with the standard - ``JOB_CONFIG_JOIN``), GROUP BY tasks for state counts, DISTINCT on - parent_job_id for parent→child membership. - """ - with self._db.read_snapshot() as q: - rows = JOB_ROW_PROJECTION.decode( - q.fetchall(f"SELECT {JOB_ROW_PROJECTION.select_clause()} FROM jobs j {JOB_CONFIG_JOIN}") - ) - ids = {j.job_id for j in rows} - summaries = _task_summaries_for_jobs(q, ids) - children = _parent_ids_with_children(q, list(ids)) - return JobsSnapshot( - rows=tuple(rows), - summaries=summaries, - child_parent_ids=frozenset(children), - ) - - def _build_workers_snapshot(self) -> WorkersSnapshot: - """Materialize the worker roster plus running-task assignments per worker.""" - workers = _worker_roster(self._db) - running = running_tasks_by_worker(self._db, {w.worker_id for w in workers}) if workers else {} - return WorkersSnapshot( - workers=tuple(workers), - running_by_worker={wid: frozenset(tasks) for wid, tasks in running.items()}, - ) def bundle_zip(self, bundle_id: str) -> bytes: return self._bundle_store.get_zip(bundle_id) @@ -1406,9 +1281,6 @@ def get_job_status( if job.submitted_at: proto_job_status.submitted_at.CopyFrom(timestamp_to_proto(job.submitted_at)) - # Per-task resource samples now live in the ``iris.task`` stats - # namespace; the controller no longer aggregates min/max from a - # local table. Dashboard panels that need this should query stats. reconstructed_request = _reconstruct_launch_job_request(job) return controller_pb2.Controller.GetJobStatusResponse( job=proto_job_status, @@ -1488,8 +1360,6 @@ def _job_to_proto( scaling_prefix = "(scaling up) " if hint.is_scaling_up else "" pending_reason = f"Scheduler: {pending_reason}\n\nAutoscaler: {scaling_prefix}{hint.message}" - resources = _resource_spec_from_job_row(j) - proto_job = job_pb2.JobStatus( job_id=j.job_id.to_wire(), state=j.state, @@ -1498,7 +1368,6 @@ def _job_to_proto( failure_count=task_summary.failure_count if task_summary else 0, preemption_count=task_summary.preemption_count if task_summary else 0, name=job_name, - resources=resources, task_state_counts=task_state_counts, task_count=task_summary.task_count if task_summary else 0, completed_count=task_summary.completed_count if task_summary else 0, @@ -1536,12 +1405,12 @@ def list_jobs( request: controller_pb2.Controller.ListJobsRequest, ctx: Any, ) -> controller_pb2.Controller.ListJobsResponse: - """List jobs with filtering, sorting, and pagination served from snapshot. + """List jobs with filtering, sorting, and pagination. - The dashboard polls this on every refresh cycle; serving it from the - per-process ``_jobs_snapshot`` cuts the per-request DB fan-out (SELECT - + COUNT + GROUP BY tasks + DISTINCT parents) down to one in-memory - scan plus a Python sort. + Served directly from indexed SQL via ``_query_jobs``. Per-page task + summaries and parent->child flags are looked up against the same read + snapshot so the whole RPC observes a single transactionally-consistent + view. """ query = _query_from_list_jobs_request(request) @@ -1554,34 +1423,17 @@ def list_jobs( "query.parent_job_id is required for JOB_QUERY_SCOPE_CHILDREN", ) - snap = self._jobs_snapshot.read() - matched = [j for j in snap.rows if _job_matches_query(j, query, state_ids)] - - sort_field = query.sort_field or controller_pb2.Controller.JOB_SORT_FIELD_DATE - sort_direction = query.sort_direction - if sort_direction == controller_pb2.Controller.SORT_DIRECTION_UNSPECIFIED: - sort_direction = ( - controller_pb2.Controller.SORT_DIRECTION_DESC - if sort_field == controller_pb2.Controller.JOB_SORT_FIELD_DATE - else controller_pb2.Controller.SORT_DIRECTION_ASC - ) - descending = sort_direction == controller_pb2.Controller.SORT_DIRECTION_DESC - matched.sort(key=_job_sort_key(sort_field, snap.summaries), reverse=descending) - - total_count = len(matched) - offset = query.offset - limit = query.limit - page = matched[offset : offset + limit] if limit > 0 else matched[offset:] + with self._db.read_snapshot() as q: + page, total_count = _query_jobs(q, query, state_ids) + page_ids = [j.job_id for j in page] + summaries = _task_summaries_for_jobs(q, set(page_ids)) if page_ids else {} + children = _parent_ids_with_children(q, page_ids) if page_ids else set() has_pending = any(j.state == job_pb2.JOB_STATE_PENDING for j in page) autoscaler_pending_hints = self._get_autoscaler_pending_hints() if has_pending else {} - page_children = {j.job_id for j in page if j.job_id in snap.child_parent_ids} - all_jobs = self._jobs_to_protos( - page, - {jid: snap.summaries[jid] for jid in (j.job_id for j in page) if jid in snap.summaries}, - autoscaler_pending_hints, - has_children=page_children, - ) + all_jobs = self._jobs_to_protos(page, summaries, autoscaler_pending_hints, has_children=children) + limit = query.limit + offset = query.offset has_more = limit > 0 and offset + limit < total_count return controller_pb2.Controller.ListJobsResponse( jobs=all_jobs, @@ -1646,9 +1498,6 @@ def list_tasks( tasks = _tasks_for_listing(self._db, job_id=job_id) worker_addr_by_id = _worker_addresses_for_tasks(self._db, tasks) - # Per-task latest resource usage now lives in the ``iris.task`` stats - # namespace; dashboard list views should query it there instead of - # the controller attaching it to every TaskStatus row. task_statuses = [] for task in tasks: twid = _task_worker_id(task) @@ -1697,7 +1546,6 @@ def register( slice_id=request.slice_id, scale_group=request.scale_group, ) - logger.info("Worker registered: %s at %s", worker_id, request.address) return controller_pb2.Controller.RegisterResponse( worker_id=str(worker_id), @@ -1711,13 +1559,12 @@ def list_workers( ) -> controller_pb2.Controller.ListWorkersResponse: """List workers with their running task counts. - Served from ``_workers_snapshot``. The dashboard polls this together - with ``GetAutoscalerStatus`` (which reads the same snapshot), so - adjacent calls are fused into one rebuild per TTL window. ``running_tasks_by_worker`` - is captured into the snapshot so the per-page fan-out is replaced by - an in-memory dict lookup. ``query.limit == 0`` disables paging - (preserves CLI callers that fetch the whole roster); ``limit > 0`` is - clamped to ``MAX_LIST_WORKERS_LIMIT``. + Served directly from the workers table (cluster size is in the low + thousands at most), with liveness queried from + :class:`~iris.cluster.controller.worker_health.WorkerHealthTracker` and + a single per-page running-task lookup. ``query.limit == 0`` disables + paging (preserves CLI callers that fetch the whole roster); ``limit > 0`` + is clamped to ``MAX_LIST_WORKERS_LIMIT``. """ if self._controller.has_direct_provider: return controller_pb2.Controller.ListWorkersResponse() @@ -1726,8 +1573,9 @@ def list_workers( if request.HasField("query"): query.CopyFrom(request.query) - snap = self._workers_snapshot.read() - filtered = _filter_and_sort_workers(list(snap.workers), query) + workers_all = _worker_roster(self._store) + liveness_by_id = self._store.health.liveness_many(w.worker_id for w in workers_all) + filtered = _filter_and_sort_workers(workers_all, liveness_by_id, query) total_count = len(filtered) offset = max(query.offset, 0) @@ -1741,21 +1589,22 @@ def list_workers( page_rows = filtered[offset:] if offset else filtered has_more = False - workers = [ - controller_pb2.Controller.WorkerHealthStatus( - worker_id=worker.worker_id, - healthy=worker.healthy, - consecutive_failures=worker.consecutive_failures, - last_heartbeat=timestamp_to_proto(worker.last_heartbeat), - running_job_ids=[ - task_id.to_wire() for task_id in snap.running_by_worker.get(worker.worker_id, frozenset()) - ], - address=worker.address, - metadata=_worker_metadata_to_proto(worker), - status_message=worker_status_message(worker), + running = running_tasks_by_worker(self._db, {w.worker_id for w in page_rows}) if page_rows else {} + workers = [] + for worker in page_rows: + liveness = liveness_by_id[worker.worker_id] + workers.append( + controller_pb2.Controller.WorkerHealthStatus( + worker_id=worker.worker_id, + healthy=liveness.healthy, + consecutive_failures=liveness.consecutive_failures, + last_heartbeat=timestamp_to_proto(Timestamp.from_ms(liveness.last_heartbeat_ms)), + running_job_ids=[task_id.to_wire() for task_id in running.get(worker.worker_id, set())], + address=worker.address, + metadata=_worker_metadata_to_proto(worker), + status_message=worker_status_message(liveness), + ) ) - for worker in page_rows - ] return controller_pb2.Controller.ListWorkersResponse( workers=workers, total_count=total_count, @@ -1783,20 +1632,7 @@ def register_endpoint( endpoint_id = request.endpoint_id or str(uuid.uuid4()) task_id = JobName.from_wire(request.task_id) - job_id, _task_index = task_id.require_task() - - if _job_state(self._db, job_id) is None: - raise ConnectError(Code.NOT_FOUND, f"Job {request.task_id} not found") - - task = _read_task_with_attempts(self._db, task_id) - if not task: - raise ConnectError(Code.NOT_FOUND, f"Task {request.task_id} not found") - if request.attempt_id != task.current_attempt_id: - raise ConnectError( - Code.FAILED_PRECONDITION, - f"Stale attempt: task {request.task_id} attempt {request.attempt_id} " - f"!= current {task.current_attempt_id}", - ) + task_id.require_task() endpoint = EndpointRow( endpoint_id=endpoint_id, @@ -1807,9 +1643,19 @@ def register_endpoint( registered_at=Timestamp.now(), ) + # Validation runs inside the writer transaction in + # :meth:`EndpointStore.add`: NOT_FOUND if the task row is missing, + # FAILED_PRECONDITION if the task is terminal or the attempt is stale. with self._store.transaction() as cur: - added = self._transitions.add_endpoint(cur, endpoint) - if not added: + outcome = self._transitions.add_endpoint(cur, endpoint, expected_attempt_id=request.attempt_id) + if outcome is AddEndpointOutcome.NOT_FOUND: + raise ConnectError(Code.NOT_FOUND, f"Task {request.task_id} not found") + if outcome is AddEndpointOutcome.STALE_ATTEMPT: + raise ConnectError( + Code.FAILED_PRECONDITION, + f"Stale attempt for task {request.task_id} (attempt {request.attempt_id})", + ) + if outcome is AddEndpointOutcome.TERMINAL: raise ConnectError( Code.FAILED_PRECONDITION, f"Task {request.task_id} is already terminal; endpoint not registered", @@ -1899,21 +1745,26 @@ def get_autoscaler_status( status = autoscaler.get_status() - # ListWorkers and GetAutoscalerStatus share the workers snapshot — the - # roster + running-task fan-out is built once per TTL window and reused. - snap = self._workers_snapshot.read() - worker_id_to_info: dict[str, tuple[str, bool]] = {w.worker_id: (w.worker_id, w.healthy) for w in snap.workers} + workers = _worker_roster(self._store) + liveness_by_id = self._store.health.liveness_many(w.worker_id for w in workers) + worker_id_to_health: dict[str, bool] = {str(w.worker_id): liveness_by_id[w.worker_id].healthy for w in workers} + + # The vm_ids appearing in the autoscaler status are the only candidates + # for the running-task lookup; restrict to those known to be in the + # roster to keep the IN-clause bounded by visible VMs, not roster size. + vm_ids = {vm.vm_id for group in status.groups for slice_info in group.slices for vm in slice_info.vms} + candidate_ids = {WorkerId(vid) for vid in vm_ids if vid in worker_id_to_health} + running = running_tasks_by_worker(self._db, candidate_ids) if candidate_ids else {} - # Enrich VmInfo objects with worker information by matching vm_id to worker_id for group in status.groups: for slice_info in group.slices: for vm in slice_info.vms: - worker_info = worker_id_to_info.get(vm.vm_id) - if worker_info: - vm.worker_id = worker_info[0] - vm.worker_healthy = worker_info[1] - wid = WorkerId(vm.worker_id) - vm.running_task_count = len(snap.running_by_worker.get(wid, frozenset())) + healthy = worker_id_to_health.get(vm.vm_id) + if healthy is None: + continue + vm.worker_id = vm.vm_id + vm.worker_healthy = healthy + vm.running_task_count = len(running.get(WorkerId(vm.vm_id), set())) return controller_pb2.Controller.GetAutoscalerStatusResponse(status=status) @@ -2052,13 +1903,13 @@ def profile_task( return job_pb2.ProfileTaskResponse(error=str(e)) # /system/worker/: proxy profile to the worker's own process - worker_id = _parse_worker_target(request.target) - if worker_id is not None: - worker = self._transitions.get_worker(WorkerId(worker_id)) + worker_id_str = _parse_worker_target(request.target) + if worker_id_str is not None: + worker = _read_worker(self._store, WorkerId(worker_id_str)) if not worker: - raise ConnectError(Code.NOT_FOUND, f"Worker {worker_id} not found") - if not worker.healthy: - raise ConnectError(Code.UNAVAILABLE, f"Worker {worker_id} is unavailable") + raise ConnectError(Code.NOT_FOUND, f"Worker {worker_id_str} not found") + if not self._store.health.liveness(worker.worker_id).healthy: + raise ConnectError(Code.UNAVAILABLE, f"Worker {worker_id_str} is unavailable") forwarded = job_pb2.ProfileTaskRequest( target="/system/process", duration_seconds=request.duration_seconds, @@ -2093,8 +1944,8 @@ def profile_task( ) raise ConnectError(Code.FAILED_PRECONDITION, f"Task {request.target} not yet assigned to a worker") - worker = _read_worker(self._db, task_worker_id) - if not worker or not worker.healthy: + worker = _read_worker(self._store, task_worker_id) + if not worker or not self._store.health.liveness(task_worker_id).healthy: raise ConnectError(Code.UNAVAILABLE, f"Worker {task_worker_id} is unavailable") timeout_ms = (request.duration_seconds or 10) * 1000 + 30000 @@ -2148,20 +1999,21 @@ def get_worker_status( if not request.id: raise ConnectError(Code.INVALID_ARGUMENT, "id is required") - detail = _read_worker_detail(self._db, WorkerId(str(request.id))) + detail = _read_worker_detail(self._store, WorkerId(str(request.id))) if not detail: raise ConnectError(Code.NOT_FOUND, f"No worker found for '{request.id}'") worker = detail.worker + liveness = self._store.health.liveness(worker.worker_id) worker_health = controller_pb2.Controller.WorkerHealthStatus( worker_id=worker.worker_id, - healthy=worker.healthy, - consecutive_failures=worker.consecutive_failures, - last_heartbeat=timestamp_to_proto(worker.last_heartbeat), + healthy=liveness.healthy, + consecutive_failures=liveness.consecutive_failures, + last_heartbeat=timestamp_to_proto(Timestamp.from_ms(liveness.last_heartbeat_ms)), running_job_ids=[tid.to_wire() for tid in detail.running_tasks], address=worker.address, metadata=_worker_metadata_to_proto(worker), - status_message=worker_status_message(worker), + status_message=worker_status_message(liveness), ) # Worker daemon logs are NOT inlined here — when the worker is @@ -2212,10 +2064,10 @@ def get_process_status( if worker_id is None: raise ConnectError(Code.INVALID_ARGUMENT, f"Invalid target: {target}") - worker = self._transitions.get_worker(WorkerId(worker_id)) + worker = _read_worker(self._store, WorkerId(worker_id)) if not worker: raise ConnectError(Code.NOT_FOUND, f"Worker {worker_id} not found") - if not worker.healthy: + if not self._store.health.liveness(worker.worker_id).healthy: raise ConnectError(Code.UNAVAILABLE, f"Worker {worker_id} is unavailable") try: @@ -2414,8 +2266,8 @@ def exec_in_container( ) raise ConnectError(Code.FAILED_PRECONDITION, f"Task {request.task_id} not assigned to a worker") - worker = _read_worker(self._db, task_worker_id) - if not worker or not worker.healthy: + worker = _read_worker(self._store, task_worker_id) + if not worker or not self._store.health.liveness(task_worker_id).healthy: raise ConnectError(Code.UNAVAILABLE, f"Worker {task_worker_id} is unavailable") # Proxy to worker @@ -2554,12 +2406,10 @@ def get_scheduler_state( ) -> controller_pb2.Controller.GetSchedulerStateResponse: """Return aggregated scheduler state for the dashboard. - The dashboard SchedulerTab + AutoscalerTab only consume rolled-up - counts: per-(band, user, job) for pending and per-(band, user, worker, - job) for running. We aggregate server-side instead of streaming one - proto entry per task. This drops the per-task ``job_config`` lookup, - the running-tasks ``JOIN job_config``, and the resource-value / - preemption / interleave computations the old shape required. + The dashboard SchedulerTab + AutoscalerTab consume rolled-up counts: + per-(band, user, job) for pending and per-(band, user, worker, job) + for running. Aggregation runs server-side and emits one proto entry + per bucket rather than per task. """ require_identity() @@ -2570,9 +2420,8 @@ def get_scheduler_state( user_spend = compute_user_spend(snap) # Pending tasks: full TASK_ROW_PROJECTION so ``task_row_can_be_scheduled`` - # can match the previous handler's "schedulable" filter (excludes - # retry-exhausted PENDING tasks). No ORDER BY — we aggregate, - # not display. + # can apply the schedulable filter (excludes retry-exhausted + # PENDING tasks). No ORDER BY — we aggregate, not display. pending_rows = TASK_ROW_PROJECTION.decode( snap.fetchall( f"SELECT {TASK_ROW_PROJECTION.select_clause()} FROM tasks t WHERE t.state = ?", @@ -2581,7 +2430,7 @@ def get_scheduler_state( ) # Running tasks: only task_id, priority_band, and worker — no - # job_config join (resource_value / coscheduling were dropped). + # job_config join is needed for the rolled-up counts below. running_rows = snap.raw( "SELECT t.task_id, t.priority_band, t.current_worker_id AS worker_id " "FROM tasks t " @@ -2716,10 +2565,12 @@ def set_task_status_text( request: job_pb2.SetTaskStatusTextRequest, _ctx: Any, ) -> job_pb2.SetTaskStatusTextResponse: - """Task pushes a markdown status string to the coordinator.""" + """Task pushes a markdown status string to the coordinator. + + Status text lives entirely in the in-memory ``TaskStore`` dict; the + write is idempotent and stale task IDs are evicted by + ``remove_status_text_by_job_ids`` during pruning. + """ task_id = JobName.from_wire(request.task_id) - task = _read_task_with_attempts(self._db, task_id) - if task is None: - raise ConnectError(Code.NOT_FOUND, f"Task {request.task_id} not found") self._transitions.record_task_status_text(task_id, request.status_text_detail_md, request.status_text_summary_md) return job_pb2.SetTaskStatusTextResponse() diff --git a/lib/iris/src/iris/cluster/controller/stores.py b/lib/iris/src/iris/cluster/controller/stores.py index 9c44ae7822..7c26855ba3 100644 --- a/lib/iris/src/iris/cluster/controller/stores.py +++ b/lib/iris/src/iris/cluster/controller/stores.py @@ -33,8 +33,10 @@ import time from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import dataclass -from threading import Lock, RLock -from typing import Generic, TypeVar +from enum import StrEnum +from threading import RLock + +from rigging.timing import Timestamp from iris.cluster.constraints import AttributeValue from iris.cluster.controller.codec import resource_spec_from_scalars @@ -58,6 +60,7 @@ TaskDetailRow, WorkerDetailRow, ) +from iris.cluster.controller.worker_health import WorkerHealthTracker, WorkerLiveness from iris.cluster.types import TERMINAL_JOB_STATES, TERMINAL_TASK_STATES, JobName, WorkerId, get_gpu_count, get_tpu_count from iris.rpc import job_pb2 @@ -74,80 +77,21 @@ # ============================================================================= -# SnapshotView — TTL-cached, singleflight-rebuilt dataset +# EndpointStore # ============================================================================= -T = TypeVar("T") - - -class SnapshotView(Generic[T]): - """Periodically-refreshed in-memory view of a dataset. - - The dashboard polls a small set of list/aggregate RPCs continuously - (ListJobs, ListWorkers, GetSchedulerState). A per-request DB fan-out on - these is wasted work — the data they read is already a few seconds stale - by the time the browser renders it. ``SnapshotView`` lets callers expose - "the latest known set of jobs / workers / etc." as one Python object that - handlers filter, sort, and paginate locally. +class AddEndpointOutcome(StrEnum): + """Result of :meth:`EndpointStore.add`. - Semantics: - - * ``read()`` returns the cached value if it was built within ``ttl_s``, - otherwise rebuilds and returns the new one. - * Reads are serialized on a single lock. Concurrent callers that arrive - during a rebuild wait on the lock and observe the freshly-built value; - ``build`` runs at most once per TTL window even under contention. - * If ``build`` raises, the exception propagates to the caller and the - cached value is left unchanged. The next reader retries. - - Write-driven invalidation is intentionally not built in. Callers that need - read-your-writes consistency should not use snapshots — they should issue - a live read. + The string values are stable for logging; callers should compare against + the enum members rather than the literal strings. """ - def __init__( - self, - name: str, - ttl_s: float, - build: Callable[[], T], - clock: Callable[[], float] = time.monotonic, - ) -> None: - self._name = name - self._ttl_s = ttl_s - self._build = build - # ``clock`` is injectable so tests can drive TTL expiry deterministically - # without ``time.sleep``. Production uses ``time.monotonic``. - self._clock = clock - self._lock = Lock() - self._value: T | None = None - self._built_at: float = 0.0 - self._force_rebuild: bool = True - - def read(self) -> T: - """Return the latest snapshot, rebuilding if older than TTL.""" - with self._lock: - if self._force_rebuild or (self._clock() - self._built_at) >= self._ttl_s: - self._value = self._build() - self._built_at = self._clock() - self._force_rebuild = False - assert self._value is not None - return self._value - - def invalidate(self) -> None: - """Force the next ``read()`` to rebuild. - - Uses an explicit flag rather than backdating ``_built_at`` so the - contract holds regardless of the clock's origin (e.g. a freshly-booted - host whose ``monotonic()`` is still less than ``ttl_s``). - """ - with self._lock: - self._force_rebuild = True - - -# ============================================================================= -# EndpointStore -# ============================================================================= + OK = "ok" + NOT_FOUND = "not_found" + STALE_ATTEMPT = "stale_attempt" + TERMINAL = "terminal" class EndpointStore: @@ -275,18 +219,36 @@ def all(self) -> list[EndpointRow]: # -- Writes --------------------------------------------------------------- - def add(self, cur: TransactionCursor, endpoint: EndpointRow) -> bool: + def add( + self, + cur: TransactionCursor, + endpoint: EndpointRow, + *, + expected_attempt_id: int | None = None, + ) -> AddEndpointOutcome: """Insert ``endpoint`` into the DB and schedule the memory update. - Returns False (and writes nothing) if the owning task is already - terminal. Otherwise inserts / replaces and schedules a post-commit - hook that updates the in-memory indexes. + All task validation runs inside this transaction so the RPC handler + does not need a separate read snapshot. Returns: + + - ``NOT_FOUND`` if the task row does not exist. + - ``TERMINAL`` if the task is in a terminal state. + - ``STALE_ATTEMPT`` if ``expected_attempt_id`` doesn't match the + task's current attempt. + - ``OK`` after a successful insert; the in-memory index is updated + via a post-commit hook. """ task_id = endpoint.task_id job_id, _ = task_id.require_task() - row = cur.execute("SELECT state FROM tasks WHERE task_id = ?", (task_id.to_wire(),)).fetchone() - if row is not None and int(row["state"]) in TERMINAL_TASK_STATES: - return False + row = cur.execute( + "SELECT state, current_attempt_id FROM tasks WHERE task_id = ?", (task_id.to_wire(),) + ).fetchone() + if row is None: + return AddEndpointOutcome.NOT_FOUND + if int(row["state"]) in TERMINAL_TASK_STATES: + return AddEndpointOutcome.TERMINAL + if expected_attempt_id is not None and int(row["current_attempt_id"]) != int(expected_attempt_id): + return AddEndpointOutcome.STALE_ATTEMPT cur.execute( "INSERT OR REPLACE INTO endpoints(" @@ -311,7 +273,7 @@ def apply() -> None: self._index(endpoint) cur.on_commit(apply) - return True + return AddEndpointOutcome.OK def remove(self, cur: TransactionCursor, endpoint_id: str) -> EndpointRow | None: """Remove a single endpoint by id. Returns the removed row snapshot, if any.""" @@ -525,15 +487,13 @@ class WorkerAttributeParams: class WorkerUpsertParams: """All scalar columns written by a worker registration/refresh. - The upsert leaves ``committed_*`` counters and attributes untouched — - attributes are replaced via :meth:`WorkerStore.replace_attributes` and - resource commitment is tracked incrementally via - :meth:`WorkerStore.add_committed_resources` / ``decommit_resources``. + Liveness state and committed-resource counters live in + :class:`WorkerHealthTracker`. Attributes are replaced via + :meth:`WorkerStore.replace_attributes`. """ worker_id: WorkerId address: str - last_heartbeat_ms: int total_cpu_millicores: int total_memory_bytes: int total_gpu_count: int @@ -560,15 +520,6 @@ class WorkerUpsertParams: md_device_json: str -@dataclass(frozen=True, slots=True) -class ActiveWorkerStatus: - """Minimal row used by the worker-failure path: confirms the worker is - active (non-None return) and reports its last heartbeat timestamp. - """ - - last_heartbeat_ms: int | None - - @dataclass(frozen=True, slots=True) class TaskScope: """Scope predicate for :meth:`TaskStore.list_active`. @@ -1617,17 +1568,27 @@ def bulk_finalize_active( class WorkerStore: - """Workers and worker_attributes.""" + """Workers and worker_attributes. - def __init__(self, db: ControllerDB) -> None: + The ``workers`` row holds durable identity, capability, and committed + scheduling totals. Transient liveness (heartbeat / health / failure + counters) lives in :class:`WorkerHealthTracker` to avoid funneling every + ping through the writer connection. + """ + + def __init__(self, db: ControllerDB, health: WorkerHealthTracker) -> None: self._db = db + self._health = health + + @property + def health(self) -> WorkerHealthTracker: + return self._health def active_healthy_address(self, tx: Tx, worker_id: WorkerId) -> str | None: - row = tx.fetchone( - "SELECT address FROM workers WHERE worker_id = ? AND active = 1 AND healthy = 1", - (str(worker_id),), - ) - return str(row["address"]) if row is not None else None + liveness = self._health.liveness(worker_id) + if not (liveness.healthy and liveness.active): + return None + return self.address(tx, worker_id) def address(self, tx: Tx, worker_id: WorkerId) -> str | None: row = tx.fetchone("SELECT address FROM workers WHERE worker_id = ?", (str(worker_id),)) @@ -1638,33 +1599,42 @@ def get_detail(self, tx: Tx, worker_id: WorkerId) -> WorkerDetailRow | None: f"SELECT {WORKER_DETAIL_PROJECTION.select_clause()} FROM workers w WHERE w.worker_id = ?", (str(worker_id),), ) - return WORKER_DETAIL_PROJECTION.decode_one([row]) if row is not None else None - - def get_active_status(self, tx: Tx, worker_id: WorkerId) -> ActiveWorkerStatus | None: - """Return heartbeat info for an active worker, or None if missing/inactive.""" - row = tx.fetchone( - "SELECT last_heartbeat_ms FROM workers WHERE worker_id = ? AND active = 1", - (str(worker_id),), - ) if row is None: return None - hb = row["last_heartbeat_ms"] - return ActiveWorkerStatus(last_heartbeat_ms=int(hb) if hb is not None else None) + return WORKER_DETAIL_PROJECTION.decode_one([row]) + + def liveness(self, worker_id: WorkerId) -> WorkerLiveness: + return self._health.liveness(worker_id) def list_active_healthy(self, tx: Tx) -> dict[WorkerId, str]: """Return ``{worker_id: address}`` for all active+healthy workers.""" - rows = tx.fetchall("SELECT worker_id, address FROM workers WHERE active = 1 AND healthy = 1") + liveness = self._health.all() + live_ids = [wid for wid, l in liveness.items() if l.healthy and l.active] + if not live_ids: + return {} + placeholders = ",".join("?" for _ in live_ids) + rows = tx.fetchall( + f"SELECT worker_id, address FROM workers WHERE worker_id IN ({placeholders})", + tuple(str(wid) for wid in live_ids), + ) return {WorkerId(str(row["worker_id"])): str(row["address"]) for row in rows} def list_active_by_ids(self, tx: Tx, worker_ids: Iterable[str]) -> list[WorkerDetailRow]: """Return :class:`WorkerDetailRow` for all active workers whose id is in ``worker_ids``.""" - ids = sorted({str(wid) for wid in worker_ids}) + liveness = self._health.all() + ids = sorted( + { + str(wid) + for wid in worker_ids + if (liveness_entry := liveness.get(WorkerId(str(wid)))) is not None and liveness_entry.active + } + ) if not ids: return [] placeholders = ",".join("?" for _ in ids) rows = tx.fetchall( f"SELECT {WORKER_DETAIL_PROJECTION.select_clause()} " - f"FROM workers w WHERE w.active = 1 AND w.worker_id IN ({placeholders})", + f"FROM workers w WHERE w.worker_id IN ({placeholders})", tuple(ids), ) return WORKER_DETAIL_PROJECTION.decode(rows) @@ -1681,29 +1651,26 @@ def filter_existing(self, tx: Tx, worker_ids: Iterable[WorkerId]) -> set[str]: ) return {str(r["worker_id"]) for r in rows} - def upsert(self, cur: TransactionCursor, params: WorkerUpsertParams) -> None: - """Insert a new worker row or refresh every field of an existing one. + def upsert(self, cur: TransactionCursor, params: WorkerUpsertParams, now_ms: int) -> None: + """Insert or refresh durable identity/capability metadata for a worker. - On conflict the row is reset to healthy/active with zero - consecutive_failures (registration re-establishes a worker as good). - ``committed_*`` counters are left untouched because they reflect - concurrent scheduling decisions, not registration metadata. + ``committed_*`` columns are written by the scheduler, not here. A + post-commit hook registers the worker in the liveness tracker so + memory state advances with the DB row. """ cur.execute( "INSERT INTO workers(" - "worker_id, address, healthy, active, consecutive_failures, last_heartbeat_ms, " - "committed_cpu_millicores, committed_mem_bytes, committed_gpu, committed_tpu, " + "worker_id, address, " "total_cpu_millicores, total_memory_bytes, total_gpu_count, total_tpu_count, " "device_type, device_variant, slice_id, scale_group, " "md_hostname, md_ip_address, md_cpu_count, md_memory_bytes, md_disk_bytes, " "md_tpu_name, md_tpu_worker_hostnames, md_tpu_worker_id, md_tpu_chips_per_host_bounds, " "md_gpu_count, md_gpu_name, md_gpu_memory_mb, " "md_gce_instance_name, md_gce_zone, md_git_hash, md_device_json" - ") VALUES (?, ?, 1, 1, 0, ?, 0, 0, 0, 0, ?, ?, ?, ?, ?, ?, ?, ?, " + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, " "?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) " "ON CONFLICT(worker_id) DO UPDATE SET " - "address=excluded.address, healthy=1, active=1, " - "consecutive_failures=0, last_heartbeat_ms=excluded.last_heartbeat_ms, " + "address=excluded.address, " "total_cpu_millicores=excluded.total_cpu_millicores, total_memory_bytes=excluded.total_memory_bytes, " "total_gpu_count=excluded.total_gpu_count, total_tpu_count=excluded.total_tpu_count, " "device_type=excluded.device_type, device_variant=excluded.device_variant, " @@ -1721,7 +1688,6 @@ def upsert(self, cur: TransactionCursor, params: WorkerUpsertParams) -> None: ( str(params.worker_id), params.address, - params.last_heartbeat_ms, params.total_cpu_millicores, params.total_memory_bytes, params.total_gpu_count, @@ -1749,58 +1715,41 @@ def upsert(self, cur: TransactionCursor, params: WorkerUpsertParams) -> None: ), ) - def mark_unhealthy(self, cur: TransactionCursor, worker_id: WorkerId) -> None: - cur.execute("UPDATE workers SET healthy = 0 WHERE worker_id = ?", (str(worker_id),)) + def _register() -> None: + self._health.register(params.worker_id, now_ms=now_ms) - def find_prunable(self, tx: Tx, before_ms: int) -> WorkerId | None: - """Return one inactive-or-unhealthy worker whose heartbeat predates ``before_ms``.""" - row = tx.fetchone( - "SELECT worker_id FROM workers " "WHERE (active = 0 OR healthy = 0) AND last_heartbeat_ms < ? LIMIT 1", - (before_ms,), - ) - return WorkerId(str(row["worker_id"])) if row is not None else None + cur.on_commit(_register) - def set_health_for_test(self, cur: TransactionCursor, worker_id: WorkerId, healthy: bool) -> None: - """Test helper: overwrite ``healthy`` and reset/raise ``consecutive_failures``.""" - cur.execute( - "UPDATE workers SET healthy = ?, consecutive_failures = ? WHERE worker_id = ?", - (1 if healthy else 0, 0 if healthy else 1, str(worker_id)), - ) + def mark_unhealthy(self, worker_id: WorkerId) -> None: + """Flip the worker's in-memory health verdict to false.""" + self._health.mark_unhealthy(worker_id) - def set_consecutive_failures_for_test(self, cur: TransactionCursor, worker_id: WorkerId, count: int) -> None: - """Test helper: overwrite ``consecutive_failures`` directly.""" - cur.execute( - "UPDATE workers SET consecutive_failures = ? WHERE worker_id = ?", - (count, str(worker_id)), - ) + def find_prunable(self, before_ms: int) -> WorkerId | None: + """Return one tracker-known worker that is unhealthy/inactive with a stale heartbeat. - def apply_snapshots( - self, - cur: TransactionCursor, - worker_ids: Sequence[WorkerId], - now_ms: int, - *, - reset_health: bool, - ) -> None: - """Bump ``last_heartbeat_ms`` for every worker. - - Per-tick host utilization is no longer cached on the ``workers`` row — - workers emit those samples directly to the ``iris.worker`` stats - namespace. - - ``reset_health=True`` also clears ``healthy``/``active``/ - ``consecutive_failures`` because a successful heartbeat proves - recovery. Ping path passes ``False`` — the ping loop tracks failures - in-memory and removes workers via ``fail_workers_batch``. + Every persisted ``workers`` row has a tracker entry by construction + (seeded at boot/restore, registered on commit of ``upsert``, removed + on commit of :meth:`remove`), so scanning the tracker is sufficient. + """ + for worker_id, l in self._health.all().items(): + if (not l.healthy or not l.active) and l.last_heartbeat_ms < before_ms: + return worker_id + return None + + def heartbeat(self, worker_ids: Sequence[WorkerId], now_ms: int, *, reset_health: bool) -> None: + """Record a heartbeat / ping batch in the in-memory tracker. + + ``reset_health=True`` is the heartbeat path: a successful heartbeat + proves the worker recovered, so ``healthy``/``active`` flip back on + and the consecutive failure counter resets. ``reset_health=False`` is + the ping success path, which only bumps ``last_heartbeat_ms``. """ if not worker_ids: return - - health_prefix = "healthy = 1, active = 1, consecutive_failures = 0, " if reset_health else "" - cur.executemany( - f"UPDATE workers SET {health_prefix}last_heartbeat_ms = ? WHERE worker_id = ?", - [(now_ms, str(wid)) for wid in worker_ids], - ) + if reset_health: + self._health.heartbeat(worker_ids, now_ms) + else: + self._health.bump_heartbeat(worker_ids, now_ms) def add_committed_resources( self, @@ -1809,9 +1758,12 @@ def add_committed_resources( resources: job_pb2.ResourceSpecProto, ) -> None: cur.execute( - "UPDATE workers SET committed_cpu_millicores = committed_cpu_millicores + ?, " - "committed_mem_bytes = committed_mem_bytes + ?, committed_gpu = committed_gpu + ?, " - "committed_tpu = committed_tpu + ? WHERE worker_id = ?", + "UPDATE workers SET " + "committed_cpu_millicores = committed_cpu_millicores + ?, " + "committed_mem_bytes = committed_mem_bytes + ?, " + "committed_gpu = committed_gpu + ?, " + "committed_tpu = committed_tpu + ? " + "WHERE worker_id = ?", ( int(resources.cpu_millicores), int(resources.memory_bytes), @@ -1828,9 +1780,11 @@ def decommit_resources( resources: job_pb2.ResourceSpecProto, ) -> None: cur.execute( - "UPDATE workers SET committed_cpu_millicores = MAX(0, committed_cpu_millicores - ?), " - "committed_mem_bytes = MAX(0, committed_mem_bytes - ?), " - "committed_gpu = MAX(0, committed_gpu - ?), committed_tpu = MAX(0, committed_tpu - ?) " + "UPDATE workers SET " + "committed_cpu_millicores = MAX(0, committed_cpu_millicores - ?), " + "committed_mem_bytes = MAX(0, committed_mem_bytes - ?), " + "committed_gpu = MAX(0, committed_gpu - ?), " + "committed_tpu = MAX(0, committed_tpu - ?) " "WHERE worker_id = ?", ( int(resources.cpu_millicores), @@ -1841,6 +1795,14 @@ def decommit_resources( ), ) + def set_health_for_test(self, worker_id: WorkerId, healthy: bool) -> None: + """Test helper: overwrite the in-memory health verdict.""" + self._health.set_health_for_test(worker_id, healthy) + + def set_consecutive_failures_for_test(self, worker_id: WorkerId, count: int) -> None: + """Test helper: overwrite the in-memory consecutive failure count.""" + self._health.set_consecutive_failures_for_test(worker_id, count) + def replace_attributes( self, cur: TransactionCursor, @@ -1884,6 +1846,8 @@ def remove(self, cur: TransactionCursor, worker_id: WorkerId) -> None: cur.execute("DELETE FROM dispatch_queue WHERE worker_id = ?", (str(worker_id),)) cur.execute("DELETE FROM workers WHERE worker_id = ?", (str(worker_id),)) + cur.on_commit(lambda: self._health.forget(worker_id)) + class DispatchQueueStore: """The dispatch_queue table.""" @@ -1951,19 +1915,40 @@ def next_submission_ms(self, cur: TransactionCursor, submitted_ms: int) -> int: class ControllerStore: """Bundle of per-entity stores with direct access to transactions/snapshots.""" - def __init__(self, db: ControllerDB) -> None: + def __init__(self, db: ControllerDB, health: WorkerHealthTracker | None = None) -> None: self._db = db + self._health = health or WorkerHealthTracker() self.jobs = JobStore(db) self.tasks = TaskStore(db) self.attempts = TaskAttemptStore(db) - self.workers = WorkerStore(db) + self.workers = WorkerStore(db, self._health) self.endpoints = EndpointStore(db) self.dispatch = DispatchQueueStore(db) self.reservations = ReservationStore(db) + self._seed_liveness_from_workers() # Caches reload after a checkpoint restore via db.replace_from(). The # hook fires only in that flow; normal startup loads caches in the # store constructors above. db.register_reopen_hook(self.endpoints._load_all) + db.register_reopen_hook(self._seed_liveness_from_workers) + + def _seed_liveness_from_workers(self) -> None: + """Mark every persisted worker healthy so the scheduler sees them before they ping back. + + Workers that fail to ping within the heartbeat window are timed out + by the ping loop. ``find_prunable`` relies on this seed to maintain + the invariant that every ``workers`` row has a tracker entry. + """ + now_ms = Timestamp.now().epoch_ms() + with self._db.read_snapshot() as q: + rows = q.fetchall("SELECT worker_id FROM workers") + worker_ids = [WorkerId(str(row["worker_id"])) for row in rows] + if worker_ids: + self._health.heartbeat(worker_ids, now_ms) + + @property + def health(self) -> WorkerHealthTracker: + return self._health def transaction(self): return self._db.transaction() diff --git a/lib/iris/src/iris/cluster/controller/transitions.py b/lib/iris/src/iris/cluster/controller/transitions.py index cdbf1515f8..97773a4208 100644 --- a/lib/iris/src/iris/cluster/controller/transitions.py +++ b/lib/iris/src/iris/cluster/controller/transitions.py @@ -44,6 +44,7 @@ ) from iris.cluster.controller.stores import ( ActiveTaskRow, + AddEndpointOutcome, ControllerStore, EndpointStore, JobConfigInsertParams, @@ -1165,7 +1166,6 @@ def register_or_refresh_worker( WorkerUpsertParams( worker_id=worker_id, address=address, - last_heartbeat_ms=now_ms, total_cpu_millicores=metadata.cpu_count * 1000, total_memory_bytes=metadata.memory_bytes, total_gpu_count=gpu_count, @@ -1191,6 +1191,7 @@ def register_or_refresh_worker( md_git_hash=metadata.git_hash, md_device_json=proto_to_json(metadata.device), ), + now_ms=now_ms, ) self._store.workers.replace_attributes(cur, worker_id, attrs) # Update in-memory attribute cache only after commit so a rolled-back tx @@ -1319,19 +1320,14 @@ def queue_assignments( ) def _update_worker_health(self, cur: TransactionCursor, req: HeartbeatApplyRequest, now_ms: int) -> bool: - """Update worker health, resource snapshot, and history. + """Update worker health in the in-memory tracker. Returns False if the worker doesn't exist (caller should bail). """ existing = self._store.workers.filter_existing(cur, [req.worker_id]) if str(req.worker_id) not in existing: return False - self._store.workers.apply_snapshots( - cur, - [req.worker_id], - now_ms, - reset_health=True, - ) + self._store.workers.heartbeat([req.worker_id], now_ms, reset_health=True) return True def _apply_task_transitions( @@ -1554,7 +1550,7 @@ def _apply_task_transitions( if task_state != prior_state: jobs_to_recompute.add(task.job_id) - # Recompute job states once per job instead of once per task. + # Recompute job states once per job (deduplicated above). for job_id in jobs_to_recompute: if job_id in cascaded_jobs: continue @@ -1601,8 +1597,7 @@ def apply_heartbeats_batch(self, cur: TransactionCursor, requests: list[Heartbea # ── Batch worker health updates ─────────────────────────────── existing_workers = self._store.workers.filter_existing(cur, [req.worker_id for req in requests]) - self._store.workers.apply_snapshots( - cur, + self._store.workers.heartbeat( [req.worker_id for req in requests if str(req.worker_id) in existing_workers], now_ms, reset_health=True, @@ -1774,13 +1769,14 @@ def _record_worker_failure( now_ms: int | None = None, ) -> WorkerFailureResult: """Remove a failed worker inside an existing transaction.""" - status = self._store.workers.get_active_status(cur, worker_id) - if status is None: + liveness = self._store.workers.liveness(worker_id) + if not liveness.active: return WorkerFailureResult(worker_removed=True) now_ms = now_ms or Timestamp.now().epoch_ms() - last_contact_age_ms = None if status.last_heartbeat_ms is None else max(0, now_ms - status.last_heartbeat_ms) - self._store.workers.mark_unhealthy(cur, worker_id) + last_hb = liveness.last_heartbeat_ms + last_contact_age_ms = None if not last_hb else max(0, now_ms - last_hb) + self._store.workers.mark_unhealthy(worker_id) removal = self._remove_failed_worker(cur, worker_id, error, now_ms=now_ms) return WorkerFailureResult( tasks_to_kill=removal.tasks_to_kill, @@ -1800,7 +1796,7 @@ def fail_workers( Each ``(worker_id, worker_address, reason)`` tuple triggers a worker-removal transaction. Chunks commit between themselves so the SQLite writer is released and other RPCs (register, apply_heartbeats_batch, - ...) can interleave instead of stalling behind a zone-wide failure. + ...) can interleave during a zone-wide failure. """ if not failures: return WorkerFailureBatchResult() @@ -2163,11 +2159,10 @@ def _stopped() -> bool: jobs_deleted += 1 time.sleep(pause_between_s) - # 2. Workers: one at a time (CASCADE to attributes) + # 2. Workers: one at a time (CASCADE to attributes). workers_deleted = 0 while not _stopped(): - with self._store.read_snapshot() as snap: - worker_id = self._store.workers.find_prunable(snap, worker_cutoff_ms) + worker_id = self._store.workers.find_prunable(worker_cutoff_ms) if worker_id is None: break with self._store.transaction() as cur: @@ -2208,25 +2203,19 @@ def _stopped() -> bool: # Split Heartbeat Helpers # ========================================================================= - def update_worker_pings( - self, - cur: TransactionCursor, - worker_ids: Iterable[WorkerId], - ) -> None: - """Apply a batch of Ping RPC results within the caller's transaction. - - Bumps ``last_heartbeat_ms`` for each successfully-pinged worker. - Per-tick host utilization is no longer persisted in the controller DB — - workers emit it directly to the ``iris.worker`` stats namespace. - Does not touch healthy/active/consecutive_failures — the ping loop - tracks failures in-memory and uses ``fail_workers_batch`` to remove - workers past threshold. + def update_worker_pings(self, worker_ids: Iterable[WorkerId]) -> None: + """Apply a batch of Ping RPC results into the in-memory health tracker. + + Bumps ``last_heartbeat_ms`` for each successfully-pinged worker. Does + not touch healthy/active/consecutive_failures — the ping loop tracks + failures via :meth:`WorkerHealthTracker.ping` and reaps workers via + :meth:`fail_workers_batch`. """ ids = list(worker_ids) if not ids: return now_ms = Timestamp.now().epoch_ms() - self._store.workers.apply_snapshots(cur, ids, now_ms, reset_health=False) + self._store.workers.heartbeat(ids, now_ms, reset_health=False) def get_running_tasks_for_poll( self, @@ -2317,13 +2306,19 @@ def record_task_status_text(self, task_id: JobName, detail_md: str, summary_md: # --- Endpoint Management --- - def add_endpoint(self, cur: TransactionCursor, endpoint: EndpointRow) -> bool: + def add_endpoint( + self, + cur: TransactionCursor, + endpoint: EndpointRow, + *, + expected_attempt_id: int | None = None, + ) -> AddEndpointOutcome: """Add an endpoint row through the store's endpoint cache. - Returns True if the endpoint was inserted, False if the task is already - terminal (to prevent orphaned endpoints that would never be cleaned up). + Validation (existence, terminal-state, stale-attempt) runs inside the + write transaction so the RPC handler can drop its precheck reads. """ - return self._store.endpoints.add(cur, endpoint) + return self._store.endpoints.add(cur, endpoint, expected_attempt_id=expected_attempt_id) def remove_endpoint(self, cur: TransactionCursor, endpoint_id: str) -> EndpointRow | None: return self._store.endpoints.remove(cur, endpoint_id) @@ -2333,9 +2328,8 @@ def remove_endpoint(self, cur: TransactionCursor, endpoint_id: str) -> EndpointR # --------------------------------------------------------------------- def set_worker_health_for_test(self, worker_id: WorkerId, healthy: bool) -> None: - """Test helper: set worker health in DB.""" - with self._store.transaction() as cur: - self._store.workers.set_health_for_test(cur, worker_id, healthy) + """Test helper: set worker health in the in-memory tracker.""" + self._store.workers.set_health_for_test(worker_id, healthy) def set_worker_attribute_for_test(self, worker_id: WorkerId, key: str, value: AttributeValue) -> None: """Test helper: upsert one worker attribute in DB.""" @@ -2642,9 +2636,8 @@ def buffer_direct_kill(self, cur: TransactionCursor, task_id: str) -> None: # ========================================================================= def set_worker_consecutive_failures_for_test(self, worker_id: WorkerId, consecutive_failures: int) -> None: - """Test helper: set worker consecutive failure count in DB.""" - with self._store.transaction() as cur: - self._store.workers.set_consecutive_failures_for_test(cur, worker_id, consecutive_failures) + """Test helper: set worker consecutive failure count in the in-memory tracker.""" + self._store.workers.set_consecutive_failures_for_test(worker_id, consecutive_failures) def set_task_state_for_test( self, diff --git a/lib/iris/src/iris/cluster/controller/worker_health.py b/lib/iris/src/iris/cluster/controller/worker_health.py index 0e48955e4b..4b11d239e3 100644 --- a/lib/iris/src/iris/cluster/controller/worker_health.py +++ b/lib/iris/src/iris/cluster/controller/worker_health.py @@ -1,23 +1,24 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -"""In-memory worker health tracking with ping-based decay. +"""In-memory worker liveness tracking. -Tracks two independent failure modes per worker: +Per-worker signals: -- Consecutive ping failures: incremented by any failed ping or heartbeat RPC, - reset to zero by any successful ping. Ten consecutive failures trip the - termination threshold. -- Build failures: monotonic counter for BUILDING→FAILED transitions. Ten build - failures trip the termination threshold independently. +- ``last_heartbeat_ms``: bumped on each successful heartbeat / ping. +- ``healthy`` / ``active``: liveness verdict; flipped to false when the worker + is marked unhealthy or removed. +- ``consecutive_failures``: incremented by failed ping/heartbeat RPCs, reset + on success. ``PING_FAILURE_THRESHOLD`` consecutive failures trip + termination. +- ``build_failures``: monotonic counter for BUILDING→FAILED transitions. + ``BUILD_FAILURE_THRESHOLD`` build failures trip termination independently. -The ping-based decay means no clock management is needed: healthy pings -naturally reset the failure count, and the tracker needs no time parameters. - -Lives entirely in memory. A failing worker recurs within one ping cycle, -so losing evidence on controller restart doesn't meaningfully delay termination. +Thread-safe: written from ping/heartbeat threads, read from the reaper, +scheduler, and RPC handler threads. """ +import dataclasses import logging import threading from collections.abc import Iterable @@ -32,17 +33,22 @@ @dataclass(slots=True) -class _WorkerState: - consecutive_ping_failures: int = 0 +class WorkerLiveness: + """Public snapshot of a worker's transient liveness state. + + Mutated in place by the tracker under its lock during heartbeat/ping + updates. Readers receive copies via :meth:`WorkerHealthTracker.liveness`. + """ + + healthy: bool = False + active: bool = False + consecutive_failures: int = 0 + last_heartbeat_ms: int = 0 build_failures: int = 0 class WorkerHealthTracker: - """Tracks per-worker failure counts for termination decisions. - - Thread-safe: written from ping/heartbeat and task-update threads, - read from the reaper thread. - """ + """In-memory source of truth for worker liveness.""" def __init__( self, @@ -55,19 +61,51 @@ def __init__( self._ping_threshold = ping_threshold self._build_threshold = build_threshold self._lock = threading.Lock() - self._states: dict[WorkerId, _WorkerState] = {} + self._states: dict[WorkerId, WorkerLiveness] = {} + + # -- Registration / heartbeat ------------------------------------------- + + def register(self, worker_id: WorkerId, *, now_ms: int) -> None: + """Mark a worker as live with a fresh heartbeat. Resets failure counters.""" + with self._lock: + state = self._states.setdefault(worker_id, WorkerLiveness()) + state.last_heartbeat_ms = now_ms + state.healthy = True + state.active = True + state.consecutive_failures = 0 + + def heartbeat(self, worker_ids: Iterable[WorkerId], now_ms: int) -> None: + """Record a successful heartbeat batch — bumps last_heartbeat_ms and resets health.""" + with self._lock: + for wid in worker_ids: + state = self._states.setdefault(wid, WorkerLiveness()) + state.last_heartbeat_ms = now_ms + state.healthy = True + state.active = True + state.consecutive_failures = 0 + + def bump_heartbeat(self, worker_ids: Iterable[WorkerId], now_ms: int) -> None: + """Record a successful ping batch — bumps last_heartbeat_ms only. + + Does not reset healthy/active/consecutive_failures. The ping path + records failures separately via :meth:`ping`. + """ + with self._lock: + for wid in worker_ids: + state = self._states.setdefault(wid, WorkerLiveness()) + state.last_heartbeat_ms = now_ms def ping(self, worker_id: WorkerId, *, healthy: bool) -> None: """Record a ping outcome. A healthy ping resets the consecutive failure count.""" with self._lock: - state = self._states.setdefault(worker_id, _WorkerState()) + state = self._states.setdefault(worker_id, WorkerLiveness()) if healthy: - state.consecutive_ping_failures = 0 + state.consecutive_failures = 0 else: - state.consecutive_ping_failures += 1 - failures = state.consecutive_ping_failures + state.consecutive_failures += 1 + failures = state.consecutive_failures logger.debug( - "Worker %s ping=%s consecutive_ping_failures=%d", + "Worker %s ping=%s consecutive_failures=%d", worker_id, "ok" if healthy else "fail", failures, @@ -76,20 +114,52 @@ def ping(self, worker_id: WorkerId, *, healthy: bool) -> None: def build_failed(self, worker_id: WorkerId) -> None: """Record a BUILDING→FAILED transition.""" with self._lock: - state = self._states.setdefault(worker_id, _WorkerState()) + state = self._states.setdefault(worker_id, WorkerLiveness()) state.build_failures += 1 failures = state.build_failures logger.debug("Worker %s build_failures=%d", worker_id, failures) + def mark_unhealthy(self, worker_id: WorkerId) -> None: + """Force the worker into the unhealthy verdict (used by failure cascade).""" + with self._lock: + state = self._states.get(worker_id) + if state is None: + return + state.healthy = False + + # -- Reads -------------------------------------------------------------- + + def liveness(self, worker_id: WorkerId) -> WorkerLiveness: + """Return a copy of the worker's current liveness snapshot. + + Returns a default-constructed ``WorkerLiveness`` if the worker isn't + tracked yet. The returned dataclass is a copy — callers may read but + should not mutate. + """ + with self._lock: + state = self._states.get(worker_id) + return WorkerLiveness() if state is None else dataclasses.replace(state) + + def liveness_many(self, worker_ids: Iterable[WorkerId]) -> dict[WorkerId, WorkerLiveness]: + """Return a copy of liveness for each requested worker.""" + with self._lock: + return {wid: dataclasses.replace(self._states.get(wid, WorkerLiveness())) for wid in worker_ids} + + def all(self) -> dict[WorkerId, WorkerLiveness]: + with self._lock: + return {wid: dataclasses.replace(state) for wid, state in self._states.items()} + def workers_over_threshold(self) -> list[WorkerId]: """Return IDs of workers that have exceeded a termination threshold.""" with self._lock: return [ wid for wid, s in self._states.items() - if s.consecutive_ping_failures >= self._ping_threshold or s.build_failures >= self._build_threshold + if s.consecutive_failures >= self._ping_threshold or s.build_failures >= self._build_threshold ] + # -- Eviction ----------------------------------------------------------- + def forget(self, worker_id: WorkerId) -> None: with self._lock: self._states.pop(worker_id, None) @@ -100,6 +170,30 @@ def forget_many(self, worker_ids: Iterable[WorkerId]) -> None: self._states.pop(wid, None) def snapshot(self) -> dict[WorkerId, tuple[int, int]]: - """Current (consecutive_ping_failures, build_failures) per worker (for diagnostics).""" + """Current ``(consecutive_failures, build_failures)`` per worker (for diagnostics).""" + with self._lock: + return {wid: (s.consecutive_failures, s.build_failures) for wid, s in self._states.items()} + + # -- Test helpers ------------------------------------------------------- + + def set_health_for_test(self, worker_id: WorkerId, healthy: bool) -> None: + """Test helper: overwrite the healthy verdict.""" + with self._lock: + state = self._states.setdefault(worker_id, WorkerLiveness()) + state.healthy = healthy + if healthy: + state.consecutive_failures = 0 + else: + state.consecutive_failures = max(state.consecutive_failures, 1) + + def set_consecutive_failures_for_test(self, worker_id: WorkerId, count: int) -> None: + """Test helper: overwrite consecutive_failures directly.""" + with self._lock: + state = self._states.setdefault(worker_id, WorkerLiveness()) + state.consecutive_failures = count + + def set_last_heartbeat_for_test(self, worker_id: WorkerId, last_heartbeat_ms: int) -> None: + """Test helper: backdate the last heartbeat for prune-window tests.""" with self._lock: - return {wid: (s.consecutive_ping_failures, s.build_failures) for wid, s in self._states.items()} + state = self._states.setdefault(worker_id, WorkerLiveness()) + state.last_heartbeat_ms = last_heartbeat_ms diff --git a/lib/iris/src/iris/cluster/controller/worker_provider.py b/lib/iris/src/iris/cluster/controller/worker_provider.py index a794f20907..655c26c9bf 100644 --- a/lib/iris/src/iris/cluster/controller/worker_provider.py +++ b/lib/iris/src/iris/cluster/controller/worker_provider.py @@ -20,6 +20,7 @@ ) from iris.cluster.types import WorkerId from iris.rpc import job_pb2, worker_pb2 +from iris.rpc.compression import IRIS_RPC_COMPRESSIONS from iris.rpc.worker_connect import WorkerServiceClient logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def close(self) -> None: ... class RpcWorkerStubFactory: """Caches async WorkerServiceClient stubs by address so each worker gets - one persistent async HTTP client instead of a new one per RPC.""" + one persistent async HTTP client across RPCs.""" def __init__(self, timeout: Duration = DEFAULT_WORKER_RPC_TIMEOUT) -> None: self._timeout = timeout @@ -66,6 +67,8 @@ def get_stub(self, address: str) -> WorkerServiceClient: stub = WorkerServiceClient( address=f"http://{address}", timeout_ms=self._timeout.to_ms(), + accept_compression=IRIS_RPC_COMPRESSIONS, + send_compression=None, ) self._stubs[address] = stub return stub diff --git a/lib/iris/src/iris/cluster/providers/gcp/service.py b/lib/iris/src/iris/cluster/providers/gcp/service.py index 88a85fc717..fd43b285b2 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/service.py +++ b/lib/iris/src/iris/cluster/providers/gcp/service.py @@ -17,7 +17,7 @@ import google.auth.transport.requests import httpx from google.cloud import tpu_v2alpha1 -from rigging.timing import Timestamp +from rigging.timing import ExponentialBackoff, Timestamp from iris.cluster.providers.gcp.local import LocalSliceHandle from iris.cluster.providers.types import ( @@ -77,7 +77,6 @@ # HTTP/auth constants _REFRESH_MARGIN = 300 # seconds before expiry to refresh token _DEFAULT_TIMEOUT = 120 # seconds -_OPERATION_POLL_INTERVAL = 2 # seconds between operation status polls _OPERATION_TIMEOUT = 600 # seconds to wait for an operation to complete # google.rpc.Code value for RESOURCE_EXHAUSTED. Used to classify LRO failures @@ -493,6 +492,7 @@ def _paginate_raw(self, url: str, params: dict[str, str] | None = None) -> list[ def _wait_zone_operation(self, zone: str, operation_name: str, timeout: float = _OPERATION_TIMEOUT) -> dict: url = f"{_COMPUTE_BASE}/projects/{self._project_id}/zones/{zone}/operations/{operation_name}" deadline = time.monotonic() + timeout + backoff = ExponentialBackoff(initial=1.0, maximum=30.0, factor=1.5) while True: resp = self._client.get(url, headers=self._headers()) self._classify_response(resp) @@ -505,11 +505,12 @@ def _wait_zone_operation(self, zone: str, operation_name: str, timeout: float = return data if time.monotonic() >= deadline: raise InfraError(f"Operation {operation_name} timed out after {timeout}s") - time.sleep(_OPERATION_POLL_INTERVAL) + time.sleep(backoff.next_interval()) def _wait_tpu_operation(self, operation_name: str, timeout: float = _OPERATION_TIMEOUT) -> dict: url = f"{_TPU_BASE}/{operation_name}" deadline = time.monotonic() + timeout + backoff = ExponentialBackoff(initial=1.0, maximum=30.0, factor=1.5) while True: resp = self._client.get(url, headers=self._headers()) self._classify_response(resp) @@ -529,7 +530,7 @@ def _wait_tpu_operation(self, operation_name: str, timeout: float = _OPERATION_T return data if time.monotonic() >= deadline: raise InfraError(f"TPU operation {operation_name} timed out after {timeout}s") - time.sleep(_OPERATION_POLL_INTERVAL) + time.sleep(backoff.next_interval()) # ======================================================================== # Low-level REST helpers diff --git a/lib/iris/src/iris/cluster/providers/gcp/workers.py b/lib/iris/src/iris/cluster/providers/gcp/workers.py index dee2162047..c74b6d9161 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/workers.py +++ b/lib/iris/src/iris/cluster/providers/gcp/workers.py @@ -18,7 +18,7 @@ from collections.abc import Callable from datetime import datetime, timedelta, timezone -from rigging.timing import Deadline, Duration, Timestamp +from rigging.timing import Deadline, Duration, ExponentialBackoff, Timestamp from iris.cluster.providers.gcp.bootstrap import ( build_worker_bootstrap_script, @@ -845,6 +845,7 @@ def _run_tpu_bootstrap( # Phase 1: once the QR is ACTIVE (or immediately for non-queued TPUs), # wait for the TPU VM to reach READY with all worker IPs. cloud_deadline = Deadline.from_now(Duration.from_seconds(effective_cloud_ready_timeout)) + cloud_backoff = ExponentialBackoff(initial=1.0, maximum=30.0, factor=1.5) while not cloud_deadline.expired(): cloud_status = handle._describe_cloud() @@ -860,7 +861,7 @@ def _run_tpu_bootstrap( sum(1 for w in cloud_status.workers if w.internal_address), cloud_status.worker_count, ) - time.sleep(poll_interval) + time.sleep(cloud_backoff.next_interval()) else: raise InfraError(f"Slice {handle.slice_id} did not reach cloud READY within {effective_cloud_ready_timeout}s") diff --git a/lib/iris/src/iris/cluster/worker/worker.py b/lib/iris/src/iris/cluster/worker/worker.py index 972af8edf1..61f6b2475a 100644 --- a/lib/iris/src/iris/cluster/worker/worker.py +++ b/lib/iris/src/iris/cluster/worker/worker.py @@ -49,6 +49,7 @@ from iris.managed_thread import ThreadContainer, get_thread_container from iris.rpc import config_pb2, controller_pb2, job_pb2, worker_pb2 from iris.rpc.auth import AuthTokenInjector, StaticTokenProvider +from iris.rpc.compression import IRIS_RPC_COMPRESSIONS from iris.rpc.controller_connect import ControllerServiceClientSync from iris.time_proto import timestamp_to_proto @@ -271,6 +272,8 @@ def start(self) -> None: address=self._config.controller_address, timeout_ms=10_000, interceptors=interceptors, + accept_compression=IRIS_RPC_COMPRESSIONS, + send_compression=None, ) # Register stats namespaces eagerly. Schema bugs surface here at # startup rather than silently producing empty namespaces. diff --git a/lib/iris/src/iris/rpc/compression.py b/lib/iris/src/iris/rpc/compression.py index 5a92051b72..9d522ea9c9 100644 --- a/lib/iris/src/iris/rpc/compression.py +++ b/lib/iris/src/iris/rpc/compression.py @@ -3,11 +3,11 @@ """Shared compression configuration for iris RPC servers and clients. -zstd is preferred and gzip is the fallback for older peers. Iris RPC traffic -is dominated by log payloads (FetchLogs responses, PushLogs requests); the -gzip path was the top allocator on prod (memray showed gzip.compress at ~66% -of allocated bytes in the finelog server alone). zstd cuts that meaningfully -without giving up interop with gzip-only clients. +Iris RPC traffic is response-dominated (FetchLogs / list RPCs); requests are +small in practice, so clients pass ``send_compression=None`` and only +advertise ``Accept-Encoding`` via this list. Servers negotiate against it. +zstd is listed first as the preferred response encoding; gzip is kept for +interop with older peers. """ from __future__ import annotations @@ -20,6 +20,4 @@ # without the entry points having to remember to import it themselves. from iris.rpc import codecs as _codecs # noqa: F401 -# Order matters only on the client side (the negotiator walks the client's -# Accept-Encoding in order); we keep zstd first here for readability. IRIS_RPC_COMPRESSIONS = (ZstdCompression(), GzipCompression()) diff --git a/lib/iris/tests/cluster/controller/conftest.py b/lib/iris/tests/cluster/controller/conftest.py index cbd7f84e5a..859f6db664 100644 --- a/lib/iris/tests/cluster/controller/conftest.py +++ b/lib/iris/tests/cluster/controller/conftest.py @@ -6,6 +6,7 @@ import shutil import tempfile from contextlib import contextmanager +from dataclasses import dataclass from dataclasses import replace as _replace from pathlib import Path from unittest.mock import MagicMock, Mock @@ -33,10 +34,12 @@ from iris.cluster.controller.db import ( ACTIVE_TASK_STATES, ControllerDB, + SchedulableWorker, _decode_attribute_rows, task_row_can_be_scheduled, task_row_is_finished, ) +from iris.cluster.controller.db import healthy_active_workers_with_attributes as _healthy_active_workers_with_attributes from iris.cluster.controller.provider import ProviderUnsupportedError from iris.cluster.controller.schema import ( ATTEMPT_PROJECTION, @@ -340,11 +343,59 @@ def query_job_row(state: ControllerTransitions, job_id: JobName): ) -def query_worker(state: ControllerTransitions, worker_id: WorkerId) -> WorkerRow | None: +@dataclass(frozen=True, slots=True) +class WorkerView: + """Combined snapshot for tests that read DB row data + liveness in one call.""" + + worker_id: WorkerId + address: str + total_cpu_millicores: int + total_memory_bytes: int + total_gpu_count: int + total_tpu_count: int + device_type: str + device_variant: str + attributes: dict + healthy: bool + active: bool + consecutive_failures: int + last_heartbeat_ms: int + committed_cpu_millicores: int + committed_mem: int + committed_gpu: int + committed_tpu: int + + +def _worker_view(row: WorkerRow, liveness) -> WorkerView: + return WorkerView( + worker_id=row.worker_id, + address=row.address, + total_cpu_millicores=row.total_cpu_millicores, + total_memory_bytes=row.total_memory_bytes, + total_gpu_count=row.total_gpu_count, + total_tpu_count=row.total_tpu_count, + device_type=row.device_type, + device_variant=row.device_variant, + attributes=row.attributes, + healthy=liveness.healthy, + active=liveness.active, + consecutive_failures=liveness.consecutive_failures, + last_heartbeat_ms=liveness.last_heartbeat_ms, + committed_cpu_millicores=row.committed_cpu_millicores, + committed_mem=row.committed_mem, + committed_gpu=row.committed_gpu, + committed_tpu=row.committed_tpu, + ) + + +def query_worker(state: ControllerTransitions, worker_id: WorkerId) -> WorkerView | None: with state._db.read_snapshot() as q: - return WORKER_ROW_PROJECTION.decode_one( + decoded = WORKER_ROW_PROJECTION.decode_one( q.fetchall("SELECT * FROM workers WHERE worker_id = ? LIMIT 1", (str(worker_id),)), ) + if decoded is None: + return None + return _worker_view(decoded, state._store.health.liveness(decoded.worker_id)) def query_tasks_for_job(state: ControllerTransitions, job_id: JobName) -> list[TaskDetailRow]: @@ -404,8 +455,8 @@ def register_worker( slice_id=slice_id, scale_group=scale_group, ) - if not healthy: - state._store.workers.set_health_for_test(cur, wid, healthy=False) + if not healthy: + state._store.workers.set_health_for_test(wid, healthy=False) return wid @@ -583,23 +634,11 @@ def hydrate_worker_attributes(state: ControllerTransitions, workers: list) -> li tuple(worker_ids), ) attrs_by_worker = _decode_attribute_rows(attrs) - return [ - _replace( - w, - attributes=attrs_by_worker.get(w.worker_id, {}), - available_cpu_millicores=w.total_cpu_millicores - w.committed_cpu_millicores, - available_memory=w.total_memory_bytes - w.committed_mem, - available_gpus=w.total_gpu_count - w.committed_gpu, - available_tpus=w.total_tpu_count - w.committed_tpu, - ) - for w in workers - ] + return [_replace(w, attributes=attrs_by_worker.get(w.worker_id, {})) for w in workers] -def healthy_active_workers(state: ControllerTransitions) -> list[WorkerRow]: - with state._db.read_snapshot() as q: - workers = WORKER_ROW_PROJECTION.decode(q.fetchall("SELECT * FROM workers WHERE healthy = 1 AND active = 1")) - return hydrate_worker_attributes(state, workers) +def healthy_active_workers(state: ControllerTransitions) -> list[SchedulableWorker]: + return _healthy_active_workers_with_attributes(state._db, state._store.health) def dispatch_task(state: ControllerTransitions, task: TaskDetailRow, worker_id: WorkerId) -> None: diff --git a/lib/iris/tests/cluster/controller/replay/events.py b/lib/iris/tests/cluster/controller/replay/events.py index 7686a4c9ae..124d5e170e 100644 --- a/lib/iris/tests/cluster/controller/replay/events.py +++ b/lib/iris/tests/cluster/controller/replay/events.py @@ -195,7 +195,7 @@ def apply_event(transitions: ControllerTransitions, event: IrisEvent) -> Any: case RemoveWorker(worker_id): return transitions.remove_worker(cur, worker_id) case UpdateWorkerPings(worker_ids): - return transitions.update_worker_pings(cur, worker_ids) + return transitions.update_worker_pings(worker_ids) case DrainForDirectProvider(max_promotions): return transitions.drain_for_direct_provider(cur, max_promotions) case ApplyDirectProviderUpdates(updates): diff --git a/lib/iris/tests/cluster/controller/replay/golden/cancel_running_job.json b/lib/iris/tests/cluster/controller/replay/golden/cancel_running_job.json index 304bdf1567..82b52e59cc 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/cancel_running_job.json +++ b/lib/iris/tests/cluster/controller/replay/golden/cancel_running_job.json @@ -176,17 +176,13 @@ ], "workers": [ { - "active": 1, "address": "w-cancel:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/coscheduled_failure_retry_bounces_siblings.json b/lib/iris/tests/cluster/controller/replay/golden/coscheduled_failure_retry_bounces_siblings.json index 1f0366a88f..55472b02cc 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/coscheduled_failure_retry_bounces_siblings.json +++ b/lib/iris/tests/cluster/controller/replay/golden/coscheduled_failure_retry_bounces_siblings.json @@ -184,17 +184,13 @@ ], "workers": [ { - "active": 1, "address": "w-cosched-fail-a:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200005, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, @@ -220,17 +216,13 @@ "worker_id": "w-cosched-fail-a" }, { - "active": 1, "address": "w-cosched-fail-b:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200003, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/coscheduled_preempt_retry_bounces_siblings.json b/lib/iris/tests/cluster/controller/replay/golden/coscheduled_preempt_retry_bounces_siblings.json index 4c52b438a4..b950b083c8 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/coscheduled_preempt_retry_bounces_siblings.json +++ b/lib/iris/tests/cluster/controller/replay/golden/coscheduled_preempt_retry_bounces_siblings.json @@ -184,17 +184,13 @@ ], "workers": [ { - "active": 1, "address": "w-cosched-preempt-a:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, @@ -220,17 +216,13 @@ "worker_id": "w-cosched-preempt-a" }, { - "active": 1, "address": "w-cosched-preempt-b:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/coscheduled_timeout.json b/lib/iris/tests/cluster/controller/replay/golden/coscheduled_timeout.json index d3cdd5682f..ff0b20df6e 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/coscheduled_timeout.json +++ b/lib/iris/tests/cluster/controller/replay/golden/coscheduled_timeout.json @@ -184,17 +184,13 @@ ], "workers": [ { - "active": 1, "address": "w-cosched-a:8080", "committed_cpu_millicores": 1000, "committed_gpu": 0, "committed_mem_bytes": 1073741824, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, @@ -220,17 +216,13 @@ "worker_id": "w-cosched-a" }, { - "active": 1, "address": "w-cosched-b:8080", "committed_cpu_millicores": 1000, "committed_gpu": 0, "committed_mem_bytes": 1073741824, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/endpoint_register_remove.json b/lib/iris/tests/cluster/controller/replay/golden/endpoint_register_remove.json index b27184b725..082b459acc 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/endpoint_register_remove.json +++ b/lib/iris/tests/cluster/controller/replay/golden/endpoint_register_remove.json @@ -134,17 +134,13 @@ ], "workers": [ { - "active": 1, "address": "w-endpoint:8080", "committed_cpu_millicores": 1000, "committed_gpu": 0, "committed_mem_bytes": 1073741824, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/preempt_task.json b/lib/iris/tests/cluster/controller/replay/golden/preempt_task.json index abf4f6615f..b53dd339c2 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/preempt_task.json +++ b/lib/iris/tests/cluster/controller/replay/golden/preempt_task.json @@ -134,17 +134,13 @@ ], "workers": [ { - "active": 1, "address": "w-preempt:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/prune_old_data.json b/lib/iris/tests/cluster/controller/replay/golden/prune_old_data.json index 20c85fefe9..0f58a2dd8c 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/prune_old_data.json +++ b/lib/iris/tests/cluster/controller/replay/golden/prune_old_data.json @@ -49,17 +49,13 @@ ], "workers": [ { - "active": 1, "address": "w-prune:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200001, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/register_assign_run_succeed.json b/lib/iris/tests/cluster/controller/replay/golden/register_assign_run_succeed.json index 4bc836750b..2af35d18d9 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/register_assign_run_succeed.json +++ b/lib/iris/tests/cluster/controller/replay/golden/register_assign_run_succeed.json @@ -134,17 +134,13 @@ ], "workers": [ { - "active": 1, "address": "w-happy:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200003, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/replace_reservation_claims.json b/lib/iris/tests/cluster/controller/replay/golden/replace_reservation_claims.json index 3b80f01db1..64152e4efb 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/replace_reservation_claims.json +++ b/lib/iris/tests/cluster/controller/replay/golden/replace_reservation_claims.json @@ -214,17 +214,13 @@ ], "workers": [ { - "active": 1, "address": "w-claim-a:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, @@ -250,17 +246,13 @@ "worker_id": "w-claim-a" }, { - "active": 1, "address": "w-claim-b:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200000, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/replay/golden/task_failure_with_retry.json b/lib/iris/tests/cluster/controller/replay/golden/task_failure_with_retry.json index 4f6891c134..4bfb5831c1 100644 --- a/lib/iris/tests/cluster/controller/replay/golden/task_failure_with_retry.json +++ b/lib/iris/tests/cluster/controller/replay/golden/task_failure_with_retry.json @@ -153,17 +153,13 @@ ], "workers": [ { - "active": 1, "address": "w-retry:8080", "committed_cpu_millicores": 0, "committed_gpu": 0, "committed_mem_bytes": 0, "committed_tpu": 0, - "consecutive_failures": 0, "device_type": "", "device_variant": "", - "healthy": 1, - "last_heartbeat_ms": 1704067200004, "md_cpu_count": 8, "md_device_json": "{\"cpu\": {\"variant\": \"cpu\"}}", "md_disk_bytes": 17179869184, diff --git a/lib/iris/tests/cluster/controller/test_5470_preemption_reassignment.py b/lib/iris/tests/cluster/controller/test_5470_preemption_reassignment.py index 70dae8a6f3..2dbb78539d 100644 --- a/lib/iris/tests/cluster/controller/test_5470_preemption_reassignment.py +++ b/lib/iris/tests/cluster/controller/test_5470_preemption_reassignment.py @@ -103,7 +103,7 @@ def _job_requirements_from_job(job): def _build_context(scheduler, state): pending = _schedulable_tasks(state) - workers = [w for w in healthy_active_workers(state) if w.healthy] + workers = list(healthy_active_workers(state)) bc = _building_counts(state) task_ids = [] jobs = {} @@ -177,8 +177,7 @@ def _assigned_workers_by_job(assignments): def _mark_slice_unhealthy(state, prefix): for i in range(VMS_PER_SLICE): wid = WorkerId(f"{prefix}-w{i}") - with state._store.transaction() as cur: - state._store.workers.set_health_for_test(cur, wid, healthy=False) + state._store.workers.set_health_for_test(wid, healthy=False) @pytest.fixture diff --git a/lib/iris/tests/cluster/controller/test_dashboard.py b/lib/iris/tests/cluster/controller/test_dashboard.py index dfd2465c8a..b8ef852afd 100644 --- a/lib/iris/tests/cluster/controller/test_dashboard.py +++ b/lib/iris/tests/cluster/controller/test_dashboard.py @@ -152,7 +152,7 @@ def _get_job_scheduling_diagnostics(job_wire_id): ) tasks = _query_tasks_with_attempts(state, job.job_id) schedulable_task_id = next((t.task_id for t in tasks if check_task_can_be_scheduled(t)), None) - workers = healthy_active_workers_with_attributes(state._db) + workers = healthy_active_workers_with_attributes(state._db, state._store.health) context = _create_scheduling_context(workers) return scheduler.get_job_scheduling_diagnostics(req, context, schedulable_task_id, num_tasks=len(tasks)) diff --git a/lib/iris/tests/cluster/controller/test_endpoint_store.py b/lib/iris/tests/cluster/controller/test_endpoint_store.py index 9c577a40e1..1df835237a 100644 --- a/lib/iris/tests/cluster/controller/test_endpoint_store.py +++ b/lib/iris/tests/cluster/controller/test_endpoint_store.py @@ -10,7 +10,7 @@ import pytest from iris.cluster.controller.db import EndpointQuery from iris.cluster.controller.schema import ENDPOINT_PROJECTION, EndpointRow -from iris.cluster.controller.stores import EndpointStore +from iris.cluster.controller.stores import AddEndpointOutcome, EndpointStore from iris.cluster.types import JobName from iris.rpc import job_pb2 from rigging.timing import Timestamp @@ -110,7 +110,7 @@ class BoomError(RuntimeError): def test_add_rejects_terminal_task(state): - """Writing an endpoint for a terminal task should return False and not mutate memory.""" + """Writing an endpoint for a terminal task should return TERMINAL and not mutate memory.""" tasks = submit_job(state, "j", make_job_request("j")) task_id = tasks[0].task_id # Drive the task to SUCCEEDED to mark it terminal. @@ -120,7 +120,8 @@ def test_add_rejects_terminal_task(state): ) with state._db.transaction() as cur: - assert state._store.endpoints.add(cur, _make_row("e1", "alpha", task_id)) is False + outcome = state._store.endpoints.add(cur, _make_row("e1", "alpha", task_id)) + assert outcome is AddEndpointOutcome.TERMINAL assert state._store.endpoints.get("e1") is None diff --git a/lib/iris/tests/cluster/controller/test_reservation.py b/lib/iris/tests/cluster/controller/test_reservation.py index 9e7e115ef8..65acc0931e 100644 --- a/lib/iris/tests/cluster/controller/test_reservation.py +++ b/lib/iris/tests/cluster/controller/test_reservation.py @@ -35,9 +35,8 @@ _worker_matches_reservation_entry, job_requirements_from_job, ) -from iris.cluster.controller.db import task_row_can_be_scheduled +from iris.cluster.controller.db import SchedulableWorker, task_row_can_be_scheduled from iris.cluster.controller.scheduler import JobRequirements, Scheduler, SchedulingContext -from iris.cluster.controller.schema import WorkerRow from iris.cluster.controller.transitions import ( RESERVATION_HOLDER_JOB_NAME, Assignment, @@ -134,8 +133,7 @@ def _make_worker( worker_id: str, metadata: job_pb2.WorkerMetadata | None = None, attributes: dict[str, AttributeValue] | None = None, - healthy: bool = True, -) -> WorkerRow: +) -> SchedulableWorker: meta = metadata or _cpu_metadata() # Workers always have device attributes from config (Stage 3). # Merge explicit attributes on top of the device-derived defaults. @@ -148,17 +146,9 @@ def _make_worker( total_mem = meta.memory_bytes total_gpu = meta.gpu_count total_tpu = 1 if meta.tpu_name else 0 - return WorkerRow( + return SchedulableWorker( worker_id=WorkerId(worker_id), address=f"{worker_id}:8080", - healthy=healthy, - active=True, - consecutive_failures=0, - last_heartbeat=Timestamp.now(), - committed_cpu_millicores=0, - committed_mem=0, - committed_gpu=0, - committed_tpu=0, total_cpu_millicores=total_cpu, total_memory_bytes=total_mem, total_gpu_count=total_gpu, @@ -166,10 +156,10 @@ def _make_worker( device_type=dt, device_variant=dv, attributes=default_attrs, - available_cpu_millicores=total_cpu, - available_memory=total_mem, - available_gpus=total_gpu, - available_tpus=total_tpu, + committed_cpu_millicores=0, + committed_mem=0, + committed_gpu=0, + committed_tpu=0, ) @@ -746,7 +736,7 @@ def test_taint_constraint_preserves_existing_constraints(): def _build_context_with_workers( - workers: list[WorkerRow], + workers: list[SchedulableWorker], pending_tasks: list[JobName], jobs: dict[JobName, JobRequirements], ) -> SchedulingContext: @@ -925,6 +915,7 @@ def test_region_constraint_injected_from_claimed_workers(ctrl): jid.to_wire(), ctrl.reservation_claims, ctrl._db, + ctrl.state._store.health, [], ) @@ -948,6 +939,7 @@ def test_region_constraint_not_injected_when_already_present(ctrl): jid.to_wire(), ctrl.reservation_claims, ctrl._db, + ctrl.state._store.health, [existing], ) @@ -967,6 +959,7 @@ def test_region_constraint_not_injected_when_no_region_attr(ctrl): jid.to_wire(), ctrl.reservation_claims, ctrl._db, + ctrl.state._store.health, [], ) @@ -990,6 +983,7 @@ def test_region_constraint_multiple_regions(ctrl): jid.to_wire(), ctrl.reservation_claims, ctrl._db, + ctrl.state._store.health, [], ) @@ -1013,6 +1007,7 @@ def test_no_injection_for_non_reservation_job(ctrl): "/test-user/unrelated-job", ctrl.reservation_claims, ctrl._db, + ctrl.state._store.health, [], ) diff --git a/lib/iris/tests/cluster/controller/test_scheduler.py b/lib/iris/tests/cluster/controller/test_scheduler.py index 4f5da53440..be2dfac796 100644 --- a/lib/iris/tests/cluster/controller/test_scheduler.py +++ b/lib/iris/tests/cluster/controller/test_scheduler.py @@ -125,7 +125,7 @@ def transition_task_to_state(state: ControllerTransitions, task, new_state: int) def _build_context(scheduler, state): pending_tasks = _schedulable_tasks(state) - workers = [w for w in healthy_active_workers(state) if w.healthy] + workers = list(healthy_active_workers(state)) building_counts = _building_counts(state) task_ids = [] diff --git a/lib/iris/tests/cluster/controller/test_service.py b/lib/iris/tests/cluster/controller/test_service.py index 48f79061d0..a8d7e3e9c3 100644 --- a/lib/iris/tests/cluster/controller/test_service.py +++ b/lib/iris/tests/cluster/controller/test_service.py @@ -7,6 +7,8 @@ State changes are verified via RPC calls rather than internal state inspection. """ +import concurrent.futures +import time from datetime import date, timedelta import pytest @@ -1495,3 +1497,86 @@ def test_set_task_status_text_persists_via_store(service): service.set_task_status_text(request, None) assert service._store.tasks.get_status_text_detail(task_id.to_wire()) == detail_text assert service._store.tasks.get_status_text_summary(task_id.to_wire()) == summary_text + + +def test_list_tasks_returns_current_attempt_timing(service, state): + """ListTasks must surface the current attempt's started_at and exactly one attempt entry. + + Regression target: ``_tasks_for_listing`` previously returned tasks with + empty ``attempts``, so the proto's ``started_at`` (read from the current + attempt) was never populated and retry status text was missing on the + dashboard. The fixed version JOINs the current attempt only — at most one + row per task — to avoid the IN-clause blowup on long histories. + """ + request = make_job_request("list-tasks-timing") + service.launch_job(request, None) + job_id = JobName.root("test-user", "list-tasks-timing") + task_id = job_id.task(0) + worker_id = WorkerId("w-list-timing") + _register_worker(state, worker_id) + _assign_and_transition(state, task_id, worker_id, job_pb2.TASK_STATE_RUNNING) + + response = service.list_tasks( + controller_pb2.Controller.ListTasksRequest(job_id=job_id.to_wire()), + None, + ) + + assert len(response.tasks) == 1 + proto = response.tasks[0] + assert proto.task_id == task_id.to_wire() + assert proto.state == job_pb2.TASK_STATE_RUNNING + # current attempt is loaded -> started_at on the proto is populated + assert proto.started_at.epoch_ms > 0 + # exactly one attempt entry — the current one — even if more existed + assert len(proto.attempts) == 1 + assert proto.attempts[0].attempt_id == proto.current_attempt_id + + +# ============================================================================= +# Direct-SQL load: ensure list_jobs scales under concurrent dashboard polling. +# ============================================================================= + + +def test_list_jobs_concurrent_load_p99_under_threshold(service): + """100 concurrent ``list_jobs`` calls against ~1k jobs must hit p99 < 500ms. + + The previous implementation amortized this fan-out behind ``SnapshotView``; + direct SQL replaces that with per-request reads from the 32-slot reader + pool. This test would have failed if the reader pool were starving on the + GIL-unfriendly path or if the EXPLAIN-verified indexes were not being + used. The threshold is generous on purpose — we want to detect regressions + on the order of seconds, not micro-jitter. + """ + for i in range(1000): + service.launch_job(make_job_request(f"load-job-{i:04d}"), None) + + request = controller_pb2.Controller.ListJobsRequest( + query=controller_pb2.Controller.JobQuery(limit=50), + ) + + # Warm one call so the SQLite page cache is populated. + service.list_jobs(request, None) + + n_requests = 100 + latencies_ms: list[float] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as pool: + + def call() -> float: + start = time.perf_counter() + response = service.list_jobs(request, None) + assert response.total_count == 1000 + return (time.perf_counter() - start) * 1000.0 + + futures = [pool.submit(call) for _ in range(n_requests)] + for f in concurrent.futures.as_completed(futures): + latencies_ms.append(f.result()) + + latencies_ms.sort() + p50 = latencies_ms[len(latencies_ms) // 2] + p99 = latencies_ms[int(len(latencies_ms) * 0.99)] + # 500ms is well above expected steady-state (in-process measurement was + # ~100-120ms; production should be cheaper since the test fixture's DB has + # transaction overhead from thousands of launch_job inserts) but tight + # enough that read-pool starvation or accidental N+1 query growth would + # trip it. + assert p99 < 500.0, f"list_jobs concurrent p99 too high: p50={p50:.1f}ms p99={p99:.1f}ms" diff --git a/lib/iris/tests/cluster/controller/test_snapshot_view.py b/lib/iris/tests/cluster/controller/test_snapshot_view.py deleted file mode 100644 index 10f09356de..0000000000 --- a/lib/iris/tests/cluster/controller/test_snapshot_view.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Tests for :class:`iris.cluster.controller.stores.SnapshotView`.""" - -from __future__ import annotations - -import threading - -import pytest -from iris.cluster.controller.stores import SnapshotView - - -class FakeClock: - """Deterministic monotonic clock for tests. - - Drives ``SnapshotView`` TTL expiry without ``time.sleep``: tests advance - the clock manually, so behavior is independent of wall time and CI load. - """ - - def __init__(self, start: float = 0.0) -> None: - self.now = start - - def __call__(self) -> float: - return self.now - - def advance(self, dt: float) -> None: - self.now += dt - - -def test_first_read_calls_build_and_returns_value() -> None: - calls = 0 - - def build() -> str: - nonlocal calls - calls += 1 - return f"v{calls}" - - view = SnapshotView[str](name="t", ttl_s=60.0, build=build, clock=FakeClock()) - assert view.read() == "v1" - assert calls == 1 - - -def test_read_within_ttl_returns_cached_value() -> None: - calls = 0 - - def build() -> int: - nonlocal calls - calls += 1 - return calls - - clock = FakeClock() - view = SnapshotView[int](name="t", ttl_s=60.0, build=build, clock=clock) - assert view.read() == 1 - clock.advance(10.0) - assert view.read() == 1 - clock.advance(10.0) - assert view.read() == 1 - assert calls == 1 - - -def test_read_past_ttl_rebuilds() -> None: - calls = 0 - - def build() -> int: - nonlocal calls - calls += 1 - return calls - - clock = FakeClock() - view = SnapshotView[int](name="t", ttl_s=60.0, build=build, clock=clock) - assert view.read() == 1 - clock.advance(60.0) - assert view.read() == 2 - assert calls == 2 - - -def test_invalidate_forces_rebuild() -> None: - calls = 0 - - def build() -> int: - nonlocal calls - calls += 1 - return calls - - # Clock starts at 0, so a backdate-based invalidate would not force a - # rebuild within the TTL window. The flag-based implementation must. - view = SnapshotView[int](name="t", ttl_s=60.0, build=build, clock=FakeClock()) - assert view.read() == 1 - view.invalidate() - assert view.read() == 2 - - -def test_concurrent_readers_share_one_rebuild() -> None: - """Concurrent past-TTL reads must call ``build`` exactly once. - - The view's lock serializes rebuilds: the first reader runs ``build``; - later readers wait for the lock, then observe the freshly-built value - within TTL and skip rebuild. We assert on call count rather than - instantaneous concurrency so the test does not depend on timing. - """ - calls = 0 - call_lock = threading.Lock() - # Gate the first build so the other 7 threads pile up on the view's lock - # before it completes. This makes the "many threads enter read() - # simultaneously" condition deterministic without sleeping. - release_build = threading.Event() - first_in_build = threading.Event() - - def build() -> int: - nonlocal calls - with call_lock: - calls += 1 - my_call = calls - if my_call == 1: - first_in_build.set() - release_build.wait() - return my_call - - view = SnapshotView[int](name="t", ttl_s=60.0, build=build, clock=FakeClock()) - results: list[int] = [0] * 8 - barrier = threading.Barrier(8) - - def worker(i: int) -> None: - barrier.wait() - results[i] = view.read() - - threads = [threading.Thread(target=worker, args=(i,)) for i in range(8)] - for t in threads: - t.start() - # Block until the first build is in flight. The other 7 threads pile up - # on the view's lock; we then release the build so it returns and they - # observe the freshly-cached value. Timeouts guard against deadlock if - # the view stops serializing rebuilds. - assert first_in_build.wait(timeout=5.0), "first build never entered" - release_build.set() - for t in threads: - t.join(timeout=5.0) - assert not t.is_alive(), "worker thread did not finish" - - # Only the first thread ran ``build``; the rest got the cached value. - assert calls == 1 - assert results == [1] * 8 - - -def test_build_error_propagates_and_next_read_retries() -> None: - calls = 0 - - def build() -> int: - nonlocal calls - calls += 1 - if calls == 1: - raise RuntimeError("transient") - return calls - - view = SnapshotView[int](name="t", ttl_s=60.0, build=build, clock=FakeClock()) - with pytest.raises(RuntimeError, match="transient"): - view.read() - # Cached value is still None, so the next read retries instead of returning stale. - assert view.read() == 2 diff --git a/lib/iris/tests/cluster/controller/test_transitions.py b/lib/iris/tests/cluster/controller/test_transitions.py index d06aad9ea6..fef75fe94d 100644 --- a/lib/iris/tests/cluster/controller/test_transitions.py +++ b/lib/iris/tests/cluster/controller/test_transitions.py @@ -3268,17 +3268,18 @@ def test_prune_evicts_status_text_cache(state): def test_prune_old_inactive_workers(state): - """Inactive workers with stale heartbeats are pruned; active workers are kept.""" + """Inactive workers with stale heartbeats are pruned; active workers are kept. - # Register two workers: one healthy, one that we'll make inactive + Liveness state lives in :class:`WorkerHealthTracker` rather than the + SQLite ``workers`` row, so the test mutates the tracker directly to age + out the stale worker. + """ active_wid = register_worker(state, "active-w", "host:8080", make_worker_metadata()) stale_wid = register_worker(state, "stale-w", "host:8081", make_worker_metadata()) - # Mark the stale worker as unhealthy with an old heartbeat - state._db.execute( - "UPDATE workers SET healthy = 0, last_heartbeat_ms = ? WHERE worker_id = ?", - (1000, str(stale_wid)), - ) + # Mark the stale worker as unhealthy with an old heartbeat in the tracker. + state._store.health.set_health_for_test(stale_wid, healthy=False) + state._store.health.set_last_heartbeat_for_test(stale_wid, last_heartbeat_ms=1000) assert _query_worker(state, active_wid) is not None assert _query_worker(state, stale_wid) is not None diff --git a/lib/iris/tests/cluster/controller/test_worker_health.py b/lib/iris/tests/cluster/controller/test_worker_health.py index c4b9c00881..43a33dad69 100644 --- a/lib/iris/tests/cluster/controller/test_worker_health.py +++ b/lib/iris/tests/cluster/controller/test_worker_health.py @@ -8,7 +8,11 @@ - Build failures (monotonic counter, independent of pings) """ +from pathlib import Path + import pytest +from iris.cluster.controller.db import ControllerDB, healthy_active_workers_with_attributes +from iris.cluster.controller.stores import ControllerStore from iris.cluster.controller.worker_health import WorkerHealthTracker from iris.cluster.types import WorkerId @@ -96,3 +100,32 @@ def test_snapshot_reports_both_counters(tracker: WorkerHealthTracker) -> None: for _ in range(2): tracker.build_failed(wid) assert tracker.snapshot() == {wid: (3, 2)} + + +def test_controller_store_seeds_liveness_from_persisted_workers(tmp_path: Path) -> None: + """A fresh ControllerStore must mark every persisted worker healthy on boot. + + Without this seed (regression target), a controller restart hides every + pre-existing worker from ``healthy_active_workers_with_attributes`` until + the next ping cycle — the scheduler then makes no assignments. + """ + db = ControllerDB(db_dir=tmp_path) + try: + with db.transaction() as cur: + cur.execute("INSERT INTO workers (worker_id, address) VALUES (?, ?)", ("w-seed-1", "10.0.0.1:8080")) + cur.execute("INSERT INTO workers (worker_id, address) VALUES (?, ?)", ("w-seed-2", "10.0.0.2:8080")) + + store = ControllerStore(db) + + liveness_one = store.health.liveness(WorkerId("w-seed-1")) + liveness_two = store.health.liveness(WorkerId("w-seed-2")) + assert liveness_one.healthy and liveness_one.active + assert liveness_two.healthy and liveness_two.active + assert liveness_one.last_heartbeat_ms > 0 + assert liveness_two.last_heartbeat_ms > 0 + + schedulable = healthy_active_workers_with_attributes(db, store.health) + ids = {str(w.worker_id) for w in schedulable} + assert ids == {"w-seed-1", "w-seed-2"} + finally: + db.close() diff --git a/lib/marin/src/marin/mcp/babysitter.py b/lib/marin/src/marin/mcp/babysitter.py index 2391dbe102..73929819d9 100644 --- a/lib/marin/src/marin/mcp/babysitter.py +++ b/lib/marin/src/marin/mcp/babysitter.py @@ -22,6 +22,7 @@ from iris.cluster.types import JobName from iris.rpc import controller_pb2, job_pb2 from iris.rpc.auth import AuthTokenInjector, StaticTokenProvider, TokenProvider +from iris.rpc.compression import IRIS_RPC_COMPRESSIONS from iris.rpc.controller_connect import ControllerServiceClientSync from iris.rpc.proto_utils import job_state_friendly, task_state_friendly from mcp.server.fastmcp import FastMCP @@ -162,7 +163,12 @@ def task_status_to_json(task: job_pb2.TaskStatus) -> dict[str, Any]: def job_status_to_json(job: job_pb2.JobStatus, tasks: Iterable[job_pb2.TaskStatus] = ()) -> dict[str, Any]: - """Serialize Iris job status into stable JSON.""" + """Serialize Iris job status into stable JSON. + + Callers that need per-job ``resources`` / ``ports`` / ``tasks`` / + ``status_message`` should hit ``GetJobStatus`` and use + :func:`_job_summary_payload`. + """ task_payloads = [task_status_to_json(task) for task in tasks] return { "job_id": job.job_id, @@ -174,7 +180,6 @@ def job_status_to_json(job: job_pb2.JobStatus, tasks: Iterable[job_pb2.TaskStatu "started_at_ms": _timestamp_ms(job.started_at), "finished_at_ms": _timestamp_ms(job.finished_at), "duration_ms": _duration_ms(job.started_at, job.finished_at), - "status_message": job.status_message, "pending_reason": job.pending_reason, "failure_count": int(job.failure_count), "preemption_count": int(job.preemption_count), @@ -182,8 +187,6 @@ def job_status_to_json(job: job_pb2.JobStatus, tasks: Iterable[job_pb2.TaskStatu "completed_count": int(job.completed_count), "task_state_counts": dict(job.task_state_counts), "has_children": bool(job.has_children), - "resource_requests": _resource_spec_to_json(job.resources), - "ports": dict(job.ports), "tasks": task_payloads, } @@ -405,6 +408,8 @@ def __init__(self, config: IrisConnectionConfig): config.controller_url, timeout_ms=config.timeout_ms, interceptors=interceptors, + accept_compression=IRIS_RPC_COMPRESSIONS, + send_compression=IRIS_RPC_COMPRESSIONS[0], ) self.logs = LogServiceClientSync( config.controller_url,