Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions lib/iris/scripts/benchmark_db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
_worker_roster,
)
from iris.cluster.controller.schema import EndpointRow
from iris.cluster.controller.stores import ControllerStore
from iris.cluster.controller.transitions import (
Assignment,
ControllerTransitions,
Expand Down Expand Up @@ -252,7 +253,8 @@ def benchmark_scheduling(db: ControllerDB) -> None:

# --- Write-path benchmarks (use a lightweight clone) ---
write_db = clone_db(db)
write_txns = ControllerTransitions(write_db)
write_store = ControllerStore(write_db)
write_txns = ControllerTransitions(store=write_store)

try:
# queue_assignments: the main write-lock holder in scheduling.
Expand Down Expand Up @@ -542,7 +544,7 @@ def _all_workers_running_tasks():

bench("running_tasks_by_worker", lambda: running_tasks_by_worker(db, worker_ids))

transitions = ControllerTransitions(db)
transitions = ControllerTransitions(store=ControllerStore(db))
bench(
f"get_running_tasks_for_poll ({len(workers)} workers)",
lambda: transitions.get_running_tasks_for_poll(),
Expand Down Expand Up @@ -594,7 +596,7 @@ def _all_workers_running_tasks():
)

hb_db = clone_db(db)
hb_transitions = ControllerTransitions(hb_db)
hb_transitions = ControllerTransitions(store=ControllerStore(hb_db))

try:
bench(
Expand Down Expand Up @@ -739,14 +741,16 @@ def benchmark_endpoints(db: ControllerDB) -> None:
contention (matches the production Register p95 of 3-4s)
"""
# Read-path queries run against the source DB (cheap, no clone needed).
bench("endpoint_registry.query (all)", lambda: db.endpoints.query())
read_store = ControllerStore(db)
bench("endpoint_store.query (all)", lambda: read_store.endpoints.query())
bench(
"endpoint_registry.query (prefix)",
lambda: db.endpoints.query(EndpointQuery(name_prefix="test")),
"endpoint_store.query (prefix)",
lambda: read_store.endpoints.query(EndpointQuery(name_prefix="test")),
)

write_db = clone_db(db)
write_txns = ControllerTransitions(write_db)
write_store = ControllerStore(write_db)
write_txns = ControllerTransitions(store=write_store)

try:
sample = _active_task_sample(write_db, limit=300)
Expand All @@ -762,7 +766,7 @@ def _do_single():

def _reset_single():
write_db.execute("DELETE FROM endpoints WHERE name LIKE '/bench/endpoint/%'")
write_db.endpoints._load_all()
write_store.endpoints._load_all()

bench("add_endpoint (1 write)", _do_single, reset=_reset_single)

Expand Down Expand Up @@ -793,7 +797,7 @@ def _do_burst_per_txn(tasks=tasks_for_burst):
def _do_burst_one_txn(tasks=tasks_for_burst):
with write_db.transaction() as cur:
for t in tasks:
write_db.endpoints.add(cur, _make_endpoint(t))
write_store.endpoints.add(cur, _make_endpoint(t))

bench(
f"add_endpoint burst x{burst_n} (1 txn)",
Expand Down Expand Up @@ -1365,7 +1369,7 @@ def benchmark_apply_contention(db: ControllerDB) -> None:
]

write_db = clone_db(db)
write_txns = ControllerTransitions(write_db)
write_txns = ControllerTransitions(store=ControllerStore(write_db))
try:
for scenario in scenarios:
_run_apply_under_contention(
Expand Down
12 changes: 6 additions & 6 deletions lib/iris/src/iris/cluster/controller/actor_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

from iris.cluster.controller.db import ControllerDB
from iris.cluster.controller.stores import ControllerStore

logger = logging.getLogger(__name__)

Expand All @@ -46,10 +46,10 @@


class ActorProxy:
"""Forwards ActorService RPCs to actors resolved from the endpoint registry."""
"""Forwards ActorService RPCs to actors resolved from the endpoint store."""

def __init__(self, db: ControllerDB):
self._db = db
def __init__(self, store: ControllerStore):
self._store = store
self._client = httpx.AsyncClient(timeout=PROXY_TIMEOUT_SECONDS)

async def close(self) -> None:
Expand Down Expand Up @@ -97,8 +97,8 @@ async def handle(self, request: Request) -> Response:
)

def _resolve_endpoint(self, name: str) -> str | None:
"""Resolve an endpoint name to an address via the in-memory registry."""
row = self._db.endpoints.resolve(name)
"""Resolve an endpoint name to an address via the in-memory store."""
row = self._store.endpoints.resolve(name)
if row is None:
return None
return row.address
6 changes: 4 additions & 2 deletions lib/iris/src/iris/cluster/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
)
from iris.cluster.controller.auth import ControllerAuth
from iris.cluster.controller.service import ControllerServiceImpl
from iris.cluster.controller.stores import ControllerStore
from iris.cluster.controller.transitions import (
RESERVATION_HOLDER_JOB_NAME,
Assignment,
Expand Down Expand Up @@ -1038,6 +1039,7 @@ def __init__(
self._db = db
else:
self._db = ControllerDB(db_dir=config.local_state_dir / "db")
self._store = ControllerStore(self._db)

# ThreadContainer must be initialized before the log service setup
# because _start_local_log_server spawns a uvicorn thread.
Expand Down Expand Up @@ -1075,7 +1077,7 @@ def __init__(

self._health = WorkerHealthTracker()
self._transitions = ControllerTransitions(
db=self._db,
store=self._store,
health=self._health,
)
self._scheduler = Scheduler()
Expand All @@ -1084,7 +1086,7 @@ def __init__(

self._service = ControllerServiceImpl(
self._transitions,
self._db,
self._store,
controller=self,
bundle_store=self._bundle_store,
log_service=self._remote_log_service,
Expand Down
2 changes: 1 addition & 1 deletion lib/iris/src/iris/cluster/controller/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _create_app(self) -> ASGIApp:

rpc_app = WSGIMiddleware(rpc_wsgi_app)

self._actor_proxy = ActorProxy(self._service._db)
self._actor_proxy = ActorProxy(self._service._store)

@requires_auth
async def _proxy_actor_rpc(request: Request) -> Response:
Expand Down
34 changes: 19 additions & 15 deletions lib/iris/src/iris/cluster/controller/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class TransactionCursor:

Post-commit hooks registered via :meth:`on_commit` run after the wrapping
``ControllerDB.transaction()`` block commits successfully. They are used
by caches (e.g. ``EndpointRegistry``) to update in-memory state atomically
by caches (e.g. ``EndpointStore``) to update in-memory state atomically
with the DB write: rollback suppresses the hook so memory never drifts
from disk.
"""
Expand All @@ -258,6 +258,14 @@ def executescript(self, sql: str) -> sqlite3.Cursor:
"""Raw SQL script escape hatch."""
return self._cursor.executescript(sql)

def fetchall(self, sql: str, params: tuple = ()) -> list[sqlite3.Row]:
"""Execute ``sql`` and return all rows. Mirrors :meth:`QuerySnapshot.fetchall`."""
return list(self._cursor.execute(sql, params).fetchall())

def fetchone(self, sql: str, params: tuple = ()) -> sqlite3.Row | None:
"""Execute ``sql`` and return the first row, or None. Mirrors :meth:`QuerySnapshot.fetchone`."""
return self._cursor.execute(sql, params).fetchone()

def on_commit(self, hook: Callable[[], None]) -> None:
"""Register ``hook`` to run after the transaction commits successfully."""
self._commit_hooks.append(hook)
Expand Down Expand Up @@ -321,19 +329,14 @@ def __init__(self, db_dir: Path):
self._attr_cache: dict[WorkerId, dict[str, AttributeValue]] | None = None
self._attr_cache_lock = Lock()

# Write-through in-memory cache over the ``endpoints`` table. Imported
# locally to break the ``db -> endpoint_registry -> db`` import cycle;
# this is the single exception to "no local imports" (see AGENTS.md).
from iris.cluster.controller.endpoint_registry import EndpointRegistry
# Callables invoked at the end of ``replace_from`` so callers with
# caches over DB contents (e.g. ``ControllerStore``) can reload them
# after a checkpoint restore. Registered via ``register_reopen_hook``.
self._reopen_hooks: list[Callable[[], None]] = []

t0 = time.monotonic()
self._endpoint_registry = EndpointRegistry(self)
logger.info("EndpointRegistry initialized in %.2fs", time.monotonic() - t0)

@property
def endpoints(self) -> EndpointRegistry: # noqa: F821
"""Process-local cache for the ``endpoints`` table; authoritative for reads."""
return self._endpoint_registry
def register_reopen_hook(self, hook: Callable[[], None]) -> None:
"""Register a no-arg callable to run at the end of ``replace_from``."""
self._reopen_hooks.append(hook)

def _populate_attr_cache(self) -> dict[WorkerId, dict[str, AttributeValue]]:
"""Load all worker attributes from the DB into the cache.
Expand Down Expand Up @@ -454,7 +457,7 @@ def transaction(self):

On successful commit, any hooks registered via ``TransactionCursor.on_commit``
fire while the write lock is still held — keeping in-memory caches
(e.g. ``EndpointRegistry``) in sync with the DB without exposing a
(e.g. ``EndpointStore``) in sync with the DB without exposing a
torn snapshot to concurrent readers.
"""
with self._lock:
Expand Down Expand Up @@ -751,7 +754,8 @@ def replace_from(self, source_dir: str | Path) -> None:
self._conn.execute("ATTACH DATABASE ? AS profiles", (str(self._profiles_db_path),))
self._init_read_pool()
self.apply_migrations()
self._endpoint_registry._load_all()
for hook in self._reopen_hooks:
hook()

# SQL-canonical read access is exposed through ``snapshot()`` and typed table
# metadata at module scope. Legacy list/get/count helper methods were removed
Expand Down
Loading
Loading