Skip to content

Commit 2acba3a

Browse files
committed
[iris] in-memory worker liveness, slim ListJobs, drop SnapshotView
Move worker last_heartbeat_ms/healthy/active/consecutive_failures/committed_* out of the workers SQLite columns into in-memory WorkerHealthTracker / WorkerCommitTracker, eliminating a per-heartbeat writer transaction that was bloating the WAL and starving dashboard reads. Migration 0042 drops the now dormant columns; ListJobs no longer materializes a 26k-row snapshot, slims JobRow 22->13 fields, and serves directly from indexed SQL. Fixes a 212s ListJobs symptom seen on prod.
1 parent d13db14 commit 2acba3a

41 files changed

Lines changed: 981 additions & 991 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

lib/iris/src/iris/cli/bug_report.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from iris.cluster.types import JobName
1919
from iris.rpc import controller_pb2, job_pb2
2020
from iris.rpc.auth import AuthTokenInjector, TokenProvider
21+
from iris.rpc.compression import IRIS_RPC_COMPRESSIONS
2122
from iris.rpc.controller_connect import ControllerServiceClientSync
2223
from iris.rpc.proto_utils import format_resources, job_state_friendly, task_state_friendly
2324
from iris.time_proto import timestamp_from_proto
@@ -119,7 +120,13 @@ def gather_bug_report(
119120
) -> BugReport:
120121
"""Gather all diagnostic data for a job into a BugReport."""
121122
interceptors = [AuthTokenInjector(token_provider)] if token_provider else []
122-
client = ControllerServiceClientSync(controller_url, timeout_ms=30000, interceptors=interceptors)
123+
client = ControllerServiceClientSync(
124+
controller_url,
125+
timeout_ms=30000,
126+
interceptors=interceptors,
127+
accept_compression=IRIS_RPC_COMPRESSIONS,
128+
send_compression=IRIS_RPC_COMPRESSIONS[0],
129+
)
123130
log_client = LogClient.connect(controller_url, timeout_ms=30000, interceptors=interceptors)
124131
try:
125132
return _gather(client, log_client, job_id, tail=tail)

lib/iris/src/iris/cli/job.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from iris.rpc.auth import TokenProvider
5353
from iris.rpc.proto_utils import (
5454
PRIORITY_BAND_NAMES,
55-
format_resources,
5655
job_state_friendly,
5756
priority_band_value,
5857
task_state_friendly,
@@ -1008,31 +1007,30 @@ def list_jobs(ctx, state: str | None, prefix: str | None, json_output: bool) ->
10081007
click.echo("No jobs found.")
10091008
return
10101009

1011-
# Build table rows
1010+
# Build table rows. The ListJobs response no longer includes the resource
1011+
# spec (it required a per-row proto decode in service of a CLI column most
1012+
# users skim past); call ``iris job status <id>`` for a single job's
1013+
# resources.
10121014
rows: list[list[str]] = []
10131015
has_reasons = False
10141016

10151017
for j in jobs:
10161018
job_id = j.job_id
10171019
state_name = job_state_friendly(j.state)
10181020
submitted = timestamp_from_proto(j.submitted_at).as_formatted_date() if j.submitted_at.epoch_ms else "-"
1019-
resources = format_resources(j.resources) if j.HasField("resources") else "-"
10201021

1021-
# Show error for failed jobs, pending_reason for pending/unschedulable
10221022
reason = j.error or j.pending_reason or ""
10231023
if reason:
10241024
has_reasons = True
1025-
# Truncate long reasons
10261025
reason = (reason[:60] + "...") if len(reason) > 63 else reason
10271026

1028-
rows.append([job_id, state_name, resources, submitted, reason])
1027+
rows.append([job_id, state_name, submitted, reason])
10291028

1030-
# Build headers - only include REASON column if there are any reasons
10311029
if has_reasons:
1032-
headers = ["JOB ID", "STATE", "RESOURCES", "SUBMITTED", "REASON"]
1030+
headers = ["JOB ID", "STATE", "SUBMITTED", "REASON"]
10331031
else:
1034-
headers = ["JOB ID", "STATE", "RESOURCES", "SUBMITTED"]
1035-
rows = [row[:4] for row in rows]
1032+
headers = ["JOB ID", "STATE", "SUBMITTED"]
1033+
rows = [row[:3] for row in rows]
10361034

10371035
click.echo(tabulate(rows, headers=headers, tablefmt="plain"))
10381036

lib/iris/src/iris/cli/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from iris.rpc import config_pb2, job_pb2
1919
from iris.rpc import controller_pb2 as _controller_pb2
2020
from iris.rpc.auth import AuthTokenInjector, GcpAccessTokenProvider, StaticTokenProvider, TokenProvider
21+
from iris.rpc.compression import IRIS_RPC_COMPRESSIONS
2122
from iris.rpc.controller_connect import ControllerServiceClientSync
2223
from iris.rpc.proto_utils import PRIORITY_BAND_NAMES, priority_band_name, priority_band_value
2324

@@ -124,7 +125,13 @@ def rpc_client(
124125
) -> ControllerServiceClientSync:
125126
"""Create an RPC client with optional auth. Use as a context manager: ``with rpc_client(url) as c:``."""
126127
interceptors = [AuthTokenInjector(token_provider)] if token_provider else []
127-
return ControllerServiceClientSync(address, timeout_ms=timeout_ms, interceptors=interceptors)
128+
return ControllerServiceClientSync(
129+
address,
130+
timeout_ms=timeout_ms,
131+
interceptors=interceptors,
132+
accept_compression=IRIS_RPC_COMPRESSIONS,
133+
send_compression=IRIS_RPC_COMPRESSIONS[0],
134+
)
128135

129136

130137
def require_controller_url(ctx: click.Context) -> str:

lib/iris/src/iris/client/resolver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from iris.actor.resolver import ResolvedEndpoint, ResolveResult
99
from iris.cluster.types import Namespace
1010
from iris.rpc import controller_pb2
11+
from iris.rpc.compression import IRIS_RPC_COMPRESSIONS
1112
from iris.rpc.controller_connect import ControllerServiceClientSync
1213

1314

@@ -54,6 +55,8 @@ def __init__(
5455
self._client = ControllerServiceClientSync(
5556
address=self._address,
5657
timeout_ms=int(timeout * 1000),
58+
accept_compression=IRIS_RPC_COMPRESSIONS,
59+
send_compression=IRIS_RPC_COMPRESSIONS[0],
5760
)
5861

5962
def _namespace_prefix(self) -> str:

lib/iris/src/iris/cluster/client/remote_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from iris.cluster.runtime.entrypoint import build_runtime_entrypoint
2222
from iris.cluster.types import Entrypoint, EnvironmentSpec, JobName, TaskAttempt, adjust_tpu_replicas, is_job_finished
2323
from iris.rpc import controller_pb2, job_pb2
24+
from iris.rpc.compression import IRIS_RPC_COMPRESSIONS
2425
from iris.rpc.controller_connect import ControllerServiceClientSync
2526
from iris.rpc.errors import call_with_retry, format_connect_error, poll_with_retries
2627
from iris.time_proto import duration_to_proto
@@ -78,6 +79,8 @@ def __init__(
7879
address=controller_address,
7980
timeout_ms=timeout_ms,
8081
interceptors=interceptors,
82+
accept_compression=IRIS_RPC_COMPRESSIONS,
83+
send_compression=IRIS_RPC_COMPRESSIONS[0],
8184
)
8285
self._log_client = LogClient.connect(
8386
controller_address,

lib/iris/src/iris/cluster/controller/autoscaler/recovery.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ def load_autoscaler_checkpoint(db: ControllerDB) -> AutoscalerCheckpoint:
5757
"last_active_ms": decode_timestamp_ms,
5858
},
5959
)
60+
# Failed workers have their DB row deleted (WorkerStore.remove), so
61+
# surviving rows with a slice are by definition the live tracked set.
6062
tracked_rows = snapshot.raw(
61-
"SELECT worker_id, slice_id, scale_group, address FROM workers WHERE slice_id != '' AND active = 1",
63+
"SELECT worker_id, slice_id, scale_group, address FROM workers WHERE slice_id != ''",
6264
)
6365

6466
slices_by_group: dict[str, list[SliceSnapshot]] = {}

lib/iris/src/iris/cluster/controller/controller.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
TaskUpdate,
115115
log_event,
116116
)
117-
from iris.cluster.controller.worker_health import WorkerHealthTracker
117+
from iris.cluster.controller.worker_health import WorkerCommitTracker, WorkerHealthTracker
118118
from iris.cluster.log_store_helpers import CONTROLLER_LOG_KEY
119119
from iris.cluster.providers.k8s.tasks import K8sTaskProvider
120120
from iris.cluster.providers.types import find_free_port, resolve_external_host
@@ -881,6 +881,8 @@ def _reservation_region_constraints(
881881
job_id_wire: str,
882882
claims: dict[WorkerId, ReservationClaim],
883883
queries: ControllerDB,
884+
health: WorkerHealthTracker,
885+
committed: WorkerCommitTracker,
884886
existing_constraints: list[Constraint],
885887
) -> list[Constraint]:
886888
"""Derive region constraints from claimed reservation workers.
@@ -897,7 +899,7 @@ def _reservation_region_constraints(
897899
claimed_worker_ids = {worker_id for worker_id, claim in claims.items() if claim.job_id == job_id_wire}
898900
workers_by_id = {
899901
worker.worker_id: worker
900-
for worker in healthy_active_workers_with_attributes(queries)
902+
for worker in healthy_active_workers_with_attributes(queries, health, committed)
901903
if worker.worker_id in claimed_worker_ids
902904
}
903905
regions: set[str] = set()
@@ -1153,7 +1155,8 @@ def __init__(
11531155
self._db = db
11541156
else:
11551157
self._db = ControllerDB(db_dir=config.local_state_dir / "db")
1156-
self._store = ControllerStore(self._db)
1158+
self._health = WorkerHealthTracker()
1159+
self._store = ControllerStore(self._db, health=self._health)
11571160

11581161
# ThreadContainer must be initialized before the log service setup
11591162
# because _start_local_log_server spawns a uvicorn thread.
@@ -1194,7 +1197,6 @@ def __init__(
11941197
self._log_handler.setFormatter(logging.Formatter("%(asctime)s %(name)s %(message)s"))
11951198
logging.getLogger("iris").addHandler(self._log_handler)
11961199

1197-
self._health = WorkerHealthTracker()
11981200
self._transitions = ControllerTransitions(
11991201
store=self._store,
12001202
health=self._health,
@@ -1630,7 +1632,7 @@ def _profile_all_running_tasks(self) -> None:
16301632
Memory profiling via memray is currently disabled because memray attach
16311633
has been triggering segfaults in target processes.
16321634
"""
1633-
workers = healthy_active_workers_with_attributes(self._db)
1635+
workers = healthy_active_workers_with_attributes(self._db, self._health, self._store.committed)
16341636
if not workers:
16351637
return
16361638
workers_by_id = {w.worker_id: w for w in workers}
@@ -1742,11 +1744,7 @@ def _cleanup_stale_claims(self, claims: dict[WorkerId, ReservationClaim] | None
17421744
if claims is None:
17431745
claims = _read_reservation_claims(self._db)
17441746
persisted = True
1745-
with self._db.read_snapshot() as snapshot:
1746-
active_worker_ids = {
1747-
WorkerId(str(row[0]))
1748-
for row in snapshot.fetchall("SELECT w.worker_id FROM workers w WHERE w.active = 1")
1749-
}
1747+
active_worker_ids = {wid for wid, l in self._health.all().items() if l.active}
17501748
claimed_job_ids = {JobName.from_wire(claim.job_id) for claim in claims.values()}
17511749
claimed_jobs = list(_jobs_by_id(self._db, claimed_job_ids).values()) if claimed_job_ids else []
17521750
jobs_by_id = {job.job_id.to_wire(): job for job in claimed_jobs}
@@ -1778,7 +1776,7 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai
17781776
persisted = True
17791777
claimed_entries: set[tuple[str, int]] = {(c.job_id, c.entry_idx) for c in claims.values()}
17801778
claimed_worker_ids: set[WorkerId] = set(claims.keys())
1781-
all_workers = healthy_active_workers_with_attributes(self._db)
1779+
all_workers = healthy_active_workers_with_attributes(self._db, self._health, self._store.committed)
17821780
changed = False
17831781

17841782
reservable_states = (
@@ -1916,7 +1914,7 @@ def _read_scheduling_state(self) -> _SchedulingStateRead:
19161914
timer = Timer()
19171915
with slow_log(logger, "scheduling state reads", threshold_ms=50):
19181916
pending_tasks = _schedulable_tasks(self._db)
1919-
workers = healthy_active_workers_with_attributes(self._db)
1917+
workers = healthy_active_workers_with_attributes(self._db, self._health, self._store.committed)
19201918
return _SchedulingStateRead(
19211919
pending_tasks=pending_tasks,
19221920
workers=workers,
@@ -2378,7 +2376,7 @@ def _stop_tasks_direct(
23782376

23792377
def _get_active_worker_addresses(self) -> list[tuple[WorkerId, str | None]]:
23802378
"""Get healthy active workers as (worker_id, address) tuples for ping."""
2381-
workers = healthy_active_workers_with_attributes(self._db)
2379+
workers = healthy_active_workers_with_attributes(self._db, self._health, self._store.committed)
23822380
return [(w.worker_id, w.address) for w in workers]
23832381

23842382
def _run_ping_loop(self, stop_event: threading.Event) -> None:
@@ -2406,8 +2404,7 @@ def _run_ping_loop(self, stop_event: threading.Event) -> None:
24062404
self._health.ping(result.worker_id, healthy=True)
24072405
live_worker_ids.append(result.worker_id)
24082406

2409-
with self._store.transaction() as cur:
2410-
self._transitions.update_worker_pings(cur, live_worker_ids)
2407+
self._transitions.update_worker_pings(live_worker_ids)
24112408

24122409
unhealthy = self._health.workers_over_threshold()
24132410
if unhealthy:
@@ -2534,7 +2531,7 @@ def _run_autoscaler_once(self) -> None:
25342531

25352532
worker_status_map = self._build_worker_status_map()
25362533
self._autoscaler.refresh(worker_status_map)
2537-
workers = healthy_active_workers_with_attributes(self._db)
2534+
workers = healthy_active_workers_with_attributes(self._db, self._health, self._store.committed)
25382535
demand_entries = compute_demand_entries(
25392536
self._db,
25402537
self._scheduler,
@@ -2546,12 +2543,7 @@ def _run_autoscaler_once(self) -> None:
25462543
def _build_worker_status_map(self) -> WorkerStatusMap:
25472544
"""Build a map of worker_id to worker status for autoscaler idle tracking."""
25482545
result: WorkerStatusMap = {}
2549-
with self._db.read_snapshot() as snapshot:
2550-
rows = snapshot.raw(
2551-
"SELECT worker_id FROM workers WHERE active = 1",
2552-
decoders={"worker_id": WorkerId},
2553-
)
2554-
worker_ids = {row.worker_id for row in rows}
2546+
worker_ids = {wid for wid, l in self._health.all().items() if l.active}
25552547
running_by_worker = running_tasks_by_worker(self._db, worker_ids)
25562548
for wid in worker_ids:
25572549
result[wid] = WorkerStatus(

lib/iris/src/iris/cluster/controller/db.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from iris.cluster.constraints import AttributeValue
2222
from iris.cluster.controller.schema import decode_timestamp_ms, decode_worker_id
23+
from iris.cluster.controller.worker_health import WorkerCommitTracker, WorkerHealthTracker
2324
from iris.cluster.types import TERMINAL_TASK_STATES, JobName, WorkerId
2425
from iris.rpc import job_pb2
2526

@@ -919,32 +920,57 @@ def _worker_row_select() -> str:
919920
return WORKER_ROW_PROJECTION.select_clause()
920921

921922

922-
def healthy_active_workers_with_attributes(db: ControllerDB) -> list:
923+
def healthy_active_workers_with_attributes(
924+
db: ControllerDB,
925+
health: WorkerHealthTracker,
926+
committed: WorkerCommitTracker,
927+
) -> list:
923928
"""Fetch all healthy, active workers with their attributes populated.
924929
925930
Returns WorkerRow (scalar-only) so the scheduling loop avoids loading metadata columns.
926-
Uses the in-memory attribute cache to avoid a per-cycle SQL join.
931+
Health/active filtering reads the in-memory tracker; committed-resource
932+
arithmetic reads the in-memory commit tracker.
927933
"""
928934
from iris.cluster.controller.schema import WORKER_ROW_PROJECTION
929935

936+
liveness = health.all()
937+
healthy_active = {wid for wid, l in liveness.items() if l.healthy and l.active}
938+
if not healthy_active:
939+
return []
940+
placeholders = ",".join("?" for _ in healthy_active)
930941
with db.read_snapshot() as q:
931942
workers = WORKER_ROW_PROJECTION.decode(
932-
q.fetchall(f"SELECT {_worker_row_select()} FROM workers w WHERE w.healthy = 1 AND w.active = 1"),
943+
q.fetchall(
944+
f"SELECT {_worker_row_select()} FROM workers w WHERE w.worker_id IN ({placeholders})",
945+
tuple(str(wid) for wid in healthy_active),
946+
),
933947
)
934948
if not workers:
935949
return []
936950
attrs_by_worker = db.get_worker_attributes()
937-
return [
938-
dc_replace(
939-
w,
940-
attributes=attrs_by_worker.get(w.worker_id, {}),
941-
available_cpu_millicores=w.total_cpu_millicores - w.committed_cpu_millicores,
942-
available_memory=w.total_memory_bytes - w.committed_mem,
943-
available_gpus=w.total_gpu_count - w.committed_gpu,
944-
available_tpus=w.total_tpu_count - w.committed_tpu,
951+
hydrated = []
952+
for w in workers:
953+
commit = committed.get(w.worker_id)
954+
l = liveness.get(w.worker_id)
955+
hydrated.append(
956+
dc_replace(
957+
w,
958+
healthy=True,
959+
active=True,
960+
consecutive_failures=l.consecutive_ping_failures if l is not None else 0,
961+
last_heartbeat=Timestamp.from_ms(l.last_heartbeat_ms) if l is not None else w.last_heartbeat,
962+
committed_cpu_millicores=commit.cpu_millicores,
963+
committed_mem=commit.memory_bytes,
964+
committed_gpu=commit.gpu,
965+
committed_tpu=commit.tpu,
966+
attributes=attrs_by_worker.get(w.worker_id, {}),
967+
available_cpu_millicores=w.total_cpu_millicores - commit.cpu_millicores,
968+
available_memory=w.total_memory_bytes - commit.memory_bytes,
969+
available_gpus=w.total_gpu_count - commit.gpu,
970+
available_tpus=w.total_tpu_count - commit.tpu,
971+
)
945972
)
946-
for w in workers
947-
]
973+
return hydrated
948974

949975

950976
def insert_task_profile(

lib/iris/src/iris/cluster/controller/migrations/0004_worker_indexes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,19 @@
44
import sqlite3
55

66

7+
def _has_column(conn: sqlite3.Connection, table: str, column: str) -> bool:
8+
return column in {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()}
9+
10+
711
def migrate(conn: sqlite3.Connection) -> None:
812
# Originally this migration also rewrote the `trg_txn_log_retention`
913
# trigger; those statements were removed once migration 0037 dropped the
1014
# `txn_log` / `txn_actions` tables entirely. On DBs that already ran the
1115
# old form the trigger survives until 0037 executes; 0037 is idempotent
1216
# (`DROP TRIGGER IF EXISTS`) so no fixup is needed here.
13-
conn.execute("CREATE INDEX IF NOT EXISTS idx_workers_healthy_active ON workers(healthy, active)")
17+
#
18+
# ``healthy`` / ``active`` were workers columns when this migration was
19+
# authored. They are dropped in 0042; on a fresh DB the columns are absent
20+
# at this point so the index is a no-op.
21+
if _has_column(conn, "workers", "healthy") and _has_column(conn, "workers", "active"):
22+
conn.execute("CREATE INDEX IF NOT EXISTS idx_workers_healthy_active ON workers(healthy, active)")

0 commit comments

Comments
 (0)