Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 56 additions & 29 deletions lib/iris/scripts/benchmark_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,14 @@
_read_reservation_claims,
_schedulable_tasks,
)
from iris.cluster.controller.db import (
ACTIVE_TASK_STATES,
ControllerDB,
EndpointQuery,
healthy_active_workers_with_attributes,
running_tasks_by_worker,
tasks_for_job_with_attempts,
)
from iris.cluster.controller.db import ControllerDB
from iris.cluster.controller.store import ControllerStores, EndpointQuery
from iris.cluster.controller.schema import (
ACTIVE_TASK_STATES,
JOB_CONFIG_JOIN,
JOB_DETAIL_PROJECTION,
)
from iris.cluster.types import TERMINAL_JOB_STATES
from iris.cluster.controller.service import (
USER_JOB_STATES,
_descendant_jobs,
Expand All @@ -80,7 +76,7 @@
ReservationClaim,
TaskUpdate,
)
from iris.cluster.types import TERMINAL_JOB_STATES, JobName, WorkerId
from iris.cluster.types import JobName, WorkerId
from iris.rpc import job_pb2
from iris.rpc import controller_pb2
from rigging.timing import Timestamp
Expand Down Expand Up @@ -195,6 +191,7 @@ def bench(

def benchmark_scheduling(db: ControllerDB) -> None:
"""Benchmark scheduling-loop queries."""
stores = ControllerStores.from_db(db)
# Create pending work so scheduling queries have realistic load.
# Pick up to 50 running jobs and revert their first few tasks to PENDING.
with db.read_snapshot() as snap:
Expand All @@ -217,18 +214,25 @@ def benchmark_scheduling(db: ControllerDB) -> None:

bench("_schedulable_tasks", lambda: _schedulable_tasks(db))

bench(
"healthy_active_workers_with_attributes",
lambda: healthy_active_workers_with_attributes(db),
)
def _bench_healthy_active():
with stores.read() as ctx:
stores.workers.healthy_active_with_attributes(ctx.cur)

workers = healthy_active_workers_with_attributes(db)
bench("healthy_active_workers_with_attributes", _bench_healthy_active)

with stores.read() as ctx:
workers = stores.workers.healthy_active_with_attributes(ctx.cur)
bench("_building_counts", lambda: _building_counts(db, workers))

tasks = _schedulable_tasks(db)
job_ids = {t.job_id for t in tasks}

def _bench_jobs_by_id():
with stores.read() as ctx:
_jobs_by_id(stores, ctx.cur, job_ids)

if job_ids:
bench("_jobs_by_id", lambda: _jobs_by_id(db, job_ids))
bench("_jobs_by_id", _bench_jobs_by_id)
else:
print(" _jobs_by_id (skipped, no pending jobs)")

Expand All @@ -255,7 +259,8 @@ def benchmark_scheduling(db: ControllerDB) -> None:

# --- Write-path benchmarks (use a lightweight clone) ---
write_db = clone_db(db)
write_txns = ControllerTransitions(write_db)
write_stores = ControllerStores.from_db(write_db)
write_txns = ControllerTransitions(stores=write_stores)

try:
# queue_assignments: the main write-lock holder in scheduling.
Expand Down Expand Up @@ -385,6 +390,7 @@ def _reset_prune():

def benchmark_dashboard(db: ControllerDB) -> None:
"""Benchmark dashboard/service queries."""
stores = ControllerStores.from_db(db)

def _bench_jobs_in_states(db):
placeholders = ",".join("?" for _ in USER_JOB_STATES)
Expand Down Expand Up @@ -445,10 +451,16 @@ def _bench_jobs_in_states(db):

bench("_worker_roster", lambda: _worker_roster(db))

workers = healthy_active_workers_with_attributes(db)
with stores.read() as ctx:
workers = stores.workers.healthy_active_with_attributes(ctx.cur)
worker_ids = {w.worker_id for w in workers}

def _bench_running_by_worker():
with stores.read() as ctx:
stores.tasks.running_tasks_by_worker(ctx.cur, worker_ids)

if worker_ids:
bench("running_tasks_by_worker", lambda: running_tasks_by_worker(db, worker_ids))
bench("running_tasks_by_worker", _bench_running_by_worker)
else:
print(" running_tasks_by_worker (skipped, no workers)")

Expand Down Expand Up @@ -479,11 +491,12 @@ def _bench_jobs_in_states(db):
if sample_job:
bench("_read_job", lambda: _read_job(db, sample_job.job_id))

def _bench_tasks_for_job():
with stores.read() as ctx:
stores.tasks.tasks_for_job_with_attempts(ctx.cur, sample_job.job_id)

if sample_job:
bench(
"tasks_for_job_with_attempts",
lambda: tasks_for_job_with_attempts(db, sample_job.job_id),
)
bench("tasks_for_job_with_attempts", _bench_tasks_for_job)

if sample_job:
sample_tasks_for_read = _tasks_for_listing(db, job_id=sample_job.job_id)
Expand All @@ -509,7 +522,9 @@ def _list_jobs_full(db):

def benchmark_heartbeat(db: ControllerDB) -> None:
"""Benchmark heartbeat/provider-sync queries."""
workers = healthy_active_workers_with_attributes(db)
stores = ControllerStores.from_db(db)
with stores.read() as ctx:
workers = stores.workers.healthy_active_with_attributes(ctx.cur)
worker_ids = {w.worker_id for w in workers}

if not workers:
Expand Down Expand Up @@ -543,9 +558,13 @@ def _all_workers_running_tasks():

bench(f"drain_dispatch ({len(workers)} workers)", _all_workers_running_tasks)

bench("running_tasks_by_worker", lambda: running_tasks_by_worker(db, worker_ids))
def _bench_running_by_worker_heartbeat():
with stores.read() as ctx:
stores.tasks.running_tasks_by_worker(ctx.cur, worker_ids)

bench("running_tasks_by_worker", _bench_running_by_worker_heartbeat)

transitions = ControllerTransitions(db)
transitions = ControllerTransitions(stores=stores)
bench(
f"drain_dispatch_all ({len(workers)} workers)",
lambda: transitions.drain_dispatch_all(),
Expand Down Expand Up @@ -597,7 +616,8 @@ def _all_workers_running_tasks():
)

hb_db = clone_db(db)
hb_transitions = ControllerTransitions(hb_db)
hb_stores = ControllerStores.from_db(hb_db)
hb_transitions = ControllerTransitions(stores=hb_stores)

try:
bench(
Expand Down Expand Up @@ -628,6 +648,13 @@ def _all_workers_running_tasks():
shutil.rmtree(hb_db._db_dir, ignore_errors=True)


def _healthy_workers(db: ControllerDB) -> list[Any]:
"""Bench-local shim over the post-refactor store API."""
stores = ControllerStores.from_db(db)
with stores.read() as ctx:
return list(stores.workers.healthy_active_with_attributes(ctx.cur))


def _active_task_sample(db: ControllerDB, limit: int) -> list[tuple[JobName, int]]:
"""Return up to ``limit`` (task_id, current_attempt_id) pairs for non-terminal tasks.

Expand Down Expand Up @@ -1006,7 +1033,7 @@ def _reset_fail(saved_w=saved_workers, saved_t=saved_tasks):
# Contention: run an add_endpoint burst concurrently with an
# apply_heartbeats_batch call on two Python threads sharing the clone
# DB. SQLite serializes writers, so this measures write-lock wait.
workers = healthy_active_workers_with_attributes(write_db)
workers = _healthy_workers(write_db)
if workers and len(sample) >= 200:
active_states = tuple(ACTIVE_TASK_STATES)
running_by_worker: dict[str, list[tuple[str, int]]] = {}
Expand Down Expand Up @@ -1164,7 +1191,7 @@ def _burst_100():

# (c) burst 100 under concurrent apply_heartbeats_batch contention.
active_states = tuple(ACTIVE_TASK_STATES)
workers = healthy_active_workers_with_attributes(write_db)
workers = _healthy_workers(write_db)
running_tasks_per_worker: dict[str, list[tuple[str, int]]] = {}
for w in workers:
wid = str(w.worker_id)
Expand Down Expand Up @@ -1231,7 +1258,7 @@ def _build_heartbeat_requests(db: ControllerDB) -> list[HeartbeatApplyRequest]:
one HeartbeatApplyRequest per active worker, with one RUNNING
resource-usage update per task currently assigned to that worker.
"""
workers = healthy_active_workers_with_attributes(db)
workers = _healthy_workers(db)
active_states = tuple(ACTIVE_TASK_STATES)
snapshot_proto = job_pb2.WorkerResourceSnapshot()
usage = job_pb2.ResourceUsage(cpu_millicores=1000, memory_mb=1024)
Expand Down
8 changes: 4 additions & 4 deletions lib/iris/src/iris/cluster/controller/actor_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

from iris.cluster.controller.db import ControllerDB
from iris.cluster.controller.store import ControllerStores

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,8 +48,8 @@
class ActorProxy:
"""Forwards ActorService RPCs to actors resolved from the endpoint registry."""

def __init__(self, db: ControllerDB):
self._db = db
def __init__(self, stores: ControllerStores):
self._stores = stores
self._client = httpx.AsyncClient(timeout=PROXY_TIMEOUT_SECONDS)

async def close(self) -> None:
Expand Down Expand Up @@ -98,7 +98,7 @@ async def handle(self, request: Request) -> Response:

def _resolve_endpoint(self, name: str) -> str | None:
"""Resolve an endpoint name to an address via the in-memory registry."""
row = self._db.endpoints.resolve(name)
row = self._stores.endpoints.resolve(name)
if row is None:
return None
return row.address
19 changes: 15 additions & 4 deletions lib/iris/src/iris/cluster/controller/budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,28 @@

from collections import defaultdict
from dataclasses import dataclass
from typing import Generic, TypeVar
from typing import Any, Generic, Protocol, TypeVar
from collections.abc import Callable

import json

from iris.cluster.controller.db import ACTIVE_TASK_STATES, QuerySnapshot
from iris.cluster.controller.schema import ACTIVE_TASK_STATES
from iris.cluster.types import JobName
from iris.rpc import job_pb2

T = TypeVar("T")


class SnapshotReader(Protocol):
"""Interface budget functions need from QuerySnapshot."""

def raw(
self,
sql: str,
params: tuple = ...,
decoders: dict[str, Callable] | None = None,
) -> list[Any]: ...


def _accel_from_device_json(device_json: str | None) -> int:
"""Count GPU + TPU accelerators from a device JSON column."""
if not device_json:
Expand Down Expand Up @@ -62,7 +73,7 @@ def resource_value(cpu_millicores: int, memory_bytes: int, accelerator_count: in
return 1000 * accelerator_count + ram_gb + 5 * cpu_cores


def compute_user_spend(snapshot: QuerySnapshot) -> dict[str, int]:
def compute_user_spend(snapshot: SnapshotReader) -> dict[str, int]:
"""Compute per-user budget spend from active tasks.

Joins tasks (in ASSIGNED/BUILDING/RUNNING states) with job_config to get
Expand Down
Loading
Loading