Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/iris/src/iris/cluster/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,8 @@ def _sync_all_execution_units(self) -> None:
return

# Sync with the execution backend (ThreadPoolExecutor inside provider).
results = self._provider.sync(batches)
with slow_log(logger, "provider sync (RPC dispatch)", threshold_ms=1_000):
results = self._provider.sync(batches)

acc = _SyncFailureAccumulator()
with slow_log(logger, "provider sync (apply results)", threshold_ms=500):
Expand Down
221 changes: 136 additions & 85 deletions lib/iris/src/iris/cluster/controller/worker_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

"""WorkerProvider: TaskProvider backed by worker daemons via heartbeat RPC."""

import asyncio
import logging
import threading
from concurrent.futures import ThreadPoolExecutor
from collections.abc import Coroutine
from dataclasses import dataclass, field
from time import monotonic, sleep
from typing import Protocol
from time import monotonic
from typing import Any, Protocol, TypeVar

from iris.chaos import chaos
from iris.cluster.controller.provider import ProviderError
Expand All @@ -20,13 +21,15 @@
from iris.cluster.types import JobName, WorkerId
from iris.rpc import job_pb2
from iris.rpc import worker_pb2
from iris.rpc.worker_connect import WorkerServiceClientSync
from iris.rpc.worker_connect import WorkerServiceClient
from rigging.timing import Duration

logger = logging.getLogger(__name__)

DEFAULT_WORKER_RPC_TIMEOUT = Duration.from_seconds(30.0)
_SLOW_HEARTBEAT_RPC_LOG_THRESHOLD_MS = 10_000
DEFAULT_WORKER_RPC_TIMEOUT = Duration.from_seconds(10.0)
_SLOW_HEARTBEAT_RPC_LOG_THRESHOLD_MS = 5_000

T = TypeVar("T")


def _heartbeat_rpc_context(
Expand All @@ -43,32 +46,70 @@ def _heartbeat_rpc_context(
)


class _AsyncLoopThread:
"""Runs an asyncio event loop on a dedicated daemon thread.

Sync callers submit coroutines via `run()` / `submit()`; the loop hosts
the long-lived async httpx clients so their connection pools survive
across heartbeat rounds.
"""

def __init__(self, name: str = "worker-provider-asyncio") -> None:
self._loop = asyncio.new_event_loop()
self._ready = threading.Event()
self._thread = threading.Thread(target=self._run, name=name, daemon=True)
self._thread.start()
self._ready.wait()

def _run(self) -> None:
asyncio.set_event_loop(self._loop)
self._ready.set()
try:
self._loop.run_forever()
finally:
self._loop.close()

@property
def loop(self) -> asyncio.AbstractEventLoop:
return self._loop

def run(self, coro: Coroutine[Any, Any, T], timeout: float | None = None) -> T:
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result(timeout=timeout)

def close(self) -> None:
if not self._loop.is_running():
return
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join(timeout=5.0)


class WorkerStubFactory(Protocol):
"""Factory for getting cached worker RPC stubs."""
"""Factory for getting cached async worker RPC stubs."""

def get_stub(self, address: str) -> WorkerServiceClientSync: ...
def get_stub(self, address: str) -> WorkerServiceClient: ...
def evict(self, address: str) -> None: ...
def close(self) -> None: ...


class RpcWorkerStubFactory:
"""Caches WorkerServiceClientSync stubs by address so each worker gets
one persistent httpx.Client instead of a new one per RPC."""
"""Caches async WorkerServiceClient stubs by address so each worker gets
one persistent async HTTP client instead of a new one per RPC."""

def __init__(self, timeout: Duration = DEFAULT_WORKER_RPC_TIMEOUT) -> None:
self._timeout = timeout
self._stubs: dict[str, WorkerServiceClientSync] = {}
self._stubs: dict[str, WorkerServiceClient] = {}
self._lock = threading.Lock()

@property
def timeout_ms(self) -> int:
return self._timeout.to_ms()

def get_stub(self, address: str) -> WorkerServiceClientSync:
def get_stub(self, address: str) -> WorkerServiceClient:
with self._lock:
stub = self._stubs.get(address)
if stub is None:
stub = WorkerServiceClientSync(
stub = WorkerServiceClient(
address=f"http://{address}",
timeout_ms=self._timeout.to_ms(),
)
Expand All @@ -77,16 +118,11 @@ def get_stub(self, address: str) -> WorkerServiceClientSync:

def evict(self, address: str) -> None:
with self._lock:
stub = self._stubs.pop(address, None)
if stub is not None:
stub.close()
self._stubs.pop(address, None)

def close(self) -> None:
with self._lock:
stubs = list(self._stubs.values())
self._stubs.clear()
Comment on lines 82 to 84
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Close cached worker clients before dropping stub map

RpcWorkerStubFactory.close() now only clears the _stubs dict, so every cached WorkerServiceClient is discarded without calling close(). In a long-running controller, worker churn/failover creates many distinct addresses, and those async clients can retain open HTTP connections/file descriptors until process exit, causing resource leaks and eventually destabilizing heartbeat/control-plane RPCs. The previous implementation explicitly closed each stub, so this is a regression introduced in this commit.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bot comment is not relevant. Evidence:

  ConnectClient.close at connectrpc/_client_async.py:178-181:

  async def close(self) -> None:
      """Close the HTTP client. After closing, the client cannot be used to make requests."""
      if not self._closed:
          self._closed = True

  It only flips a boolean flag — does not close the underlying pyqwest HTTPClient. The pyqwest client is Rust-owned; its connections/sockets are released via Rust's Drop when the Python reference is GC'd, which happens the moment we clear the dict.

  In the sync version, ConnectClientSync.close() did close the httpx sync client. But we've switched to the async variant, whose close() is a no-op.

  So calling it would (a) require spinning up an event loop from sync close(), and (b) do nothing.

for stub in stubs:
stub.close()


def _apply_request_from_response(
Expand Down Expand Up @@ -118,85 +154,100 @@ def _apply_request_from_response(

@dataclass
class WorkerProvider:
"""TaskProvider backed by worker daemons via heartbeat RPC.
"""TaskProvider backed by worker daemons via async heartbeat RPC.

Drop-in replacement for the controller's _do_heartbeat_rpc path. Uses a
persistent ThreadPoolExecutor for parallel heartbeat dispatch.
Runs an asyncio event loop on a dedicated thread and dispatches
per-worker heartbeat RPCs concurrently via `asyncio.gather`, capped at
`parallelism` concurrent in-flight requests by a semaphore.
"""

stub_factory: WorkerStubFactory
parallelism: int = 32
_pool: ThreadPoolExecutor = field(init=False)
parallelism: int = 128
_loop_thread: _AsyncLoopThread = field(init=False)
_semaphore: asyncio.Semaphore = field(init=False)

def __post_init__(self) -> None:
self._pool = ThreadPoolExecutor(max_workers=self.parallelism)
self._loop_thread = _AsyncLoopThread()
self._semaphore = self._loop_thread.run(self._make_semaphore())

async def _make_semaphore(self) -> asyncio.Semaphore:
return asyncio.Semaphore(self.parallelism)

def sync(
self,
batches: list[DispatchBatch],
) -> list[tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]]:
if not batches:
return []
results: list[tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]] = []
futures = {self._pool.submit(self._heartbeat_one, b): b for b in batches}
for future in futures:
batch = futures[future]
try:
apply_req = future.result()
results.append((batch, apply_req, None))
except Exception as e:
results.append((batch, None, str(e)))
return results
return self._loop_thread.run(self._sync_all(batches))

def _heartbeat_one(self, batch: DispatchBatch) -> HeartbeatApplyRequest:
"""Send heartbeat RPC to one worker and return the apply request."""
started = monotonic()
timeout_ms = getattr(self.stub_factory, "timeout_ms", None)

if rule := chaos("controller.heartbeat"):
sleep(rule.delay_seconds)
raise ProviderError("chaos: heartbeat unavailable")

if not batch.worker_address:
raise ProviderError(f"Worker {batch.worker_id} has no address for heartbeat")

stub = self.stub_factory.get_stub(batch.worker_address)

expected_tasks = []
for entry in batch.running_tasks:
if rule := chaos("controller.heartbeat.iteration"):
sleep(rule.delay_seconds)
expected_tasks.append(
job_pb2.WorkerTaskStatus(
task_id=entry.task_id.to_wire(),
attempt_id=entry.attempt_id,
)
)
request = job_pb2.HeartbeatRequest(
tasks_to_run=batch.tasks_to_run,
tasks_to_kill=batch.tasks_to_kill,
expected_tasks=expected_tasks,
)
try:
response = stub.heartbeat(request)
async def _sync_all(
self,
batches: list[DispatchBatch],
) -> list[tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]]:
coros = [self._heartbeat_one_safe(b) for b in batches]
return await asyncio.gather(*coros)

if not response.worker_healthy:
health_error = response.health_error or "worker reported unhealthy"
raise ProviderError(f"worker {batch.worker_id} reported unhealthy: {health_error}")
async def _heartbeat_one_safe(
self,
batch: DispatchBatch,
) -> tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]:
try:
apply_req = await self._heartbeat_one(batch)
return (batch, apply_req, None)
except Exception as e:
return (batch, None, str(e))

elapsed_ms = int((monotonic() - started) * 1000)
if elapsed_ms >= _SLOW_HEARTBEAT_RPC_LOG_THRESHOLD_MS:
logger.warning(
"Slow heartbeat RPC succeeded: %s",
_heartbeat_rpc_context(batch, elapsed_ms=elapsed_ms, timeout_ms=timeout_ms),
async def _heartbeat_one(self, batch: DispatchBatch) -> HeartbeatApplyRequest:
"""Send heartbeat RPC to one worker and return the apply request."""
async with self._semaphore:
started = monotonic()
timeout_ms = getattr(self.stub_factory, "timeout_ms", None)

if rule := chaos("controller.heartbeat"):
await asyncio.sleep(rule.delay_seconds)
raise ProviderError("chaos: heartbeat unavailable")

if not batch.worker_address:
raise ProviderError(f"Worker {batch.worker_id} has no address for heartbeat")

stub = self.stub_factory.get_stub(batch.worker_address)

expected_tasks = []
for entry in batch.running_tasks:
if rule := chaos("controller.heartbeat.iteration"):
await asyncio.sleep(rule.delay_seconds)
expected_tasks.append(
job_pb2.WorkerTaskStatus(
task_id=entry.task_id.to_wire(),
attempt_id=entry.attempt_id,
)
)
return _apply_request_from_response(batch.worker_id, response)
except Exception as e:
elapsed_ms = int((monotonic() - started) * 1000)
context = _heartbeat_rpc_context(batch, elapsed_ms=elapsed_ms, timeout_ms=timeout_ms)
if isinstance(e, ProviderError):
raise ProviderError(f"{e}; {context}") from e
raise ProviderError(f"heartbeat RPC failed: {context}; error={e}") from e
request = job_pb2.HeartbeatRequest(
tasks_to_run=batch.tasks_to_run,
tasks_to_kill=batch.tasks_to_kill,
expected_tasks=expected_tasks,
)
try:
response = await stub.heartbeat(request)

if not response.worker_healthy:
health_error = response.health_error or "worker reported unhealthy"
raise ProviderError(f"worker {batch.worker_id} reported unhealthy: {health_error}")

elapsed_ms = int((monotonic() - started) * 1000)
if elapsed_ms >= _SLOW_HEARTBEAT_RPC_LOG_THRESHOLD_MS:
logger.warning(
"Slow heartbeat RPC succeeded: %s",
_heartbeat_rpc_context(batch, elapsed_ms=elapsed_ms, timeout_ms=timeout_ms),
)
return _apply_request_from_response(batch.worker_id, response)
except Exception as e:
elapsed_ms = int((monotonic() - started) * 1000)
context = _heartbeat_rpc_context(batch, elapsed_ms=elapsed_ms, timeout_ms=timeout_ms)
if isinstance(e, ProviderError):
raise ProviderError(f"{e}; {context}") from e
raise ProviderError(f"heartbeat RPC failed: {context}; error={e}") from e

def get_process_status(
self,
Expand All @@ -213,7 +264,7 @@ def get_process_status(
log_substring=request.log_substring,
min_log_level=request.min_log_level,
)
return stub.get_process_status(forwarded, timeout_ms=10000)
return self._loop_thread.run(stub.get_process_status(forwarded, timeout_ms=10000))

def on_worker_failed(self, worker_id: WorkerId, address: str | None) -> None:
if address:
Expand All @@ -226,7 +277,7 @@ def profile_task(
timeout_ms: int,
) -> job_pb2.ProfileTaskResponse:
stub = self.stub_factory.get_stub(address)
return stub.profile_task(request, timeout_ms=timeout_ms)
return self._loop_thread.run(stub.profile_task(request, timeout_ms=timeout_ms))

def exec_in_container(
self,
Expand All @@ -240,8 +291,8 @@ def exec_in_container(
rpc_timeout_ms = 3_600_000
else:
rpc_timeout_ms = (timeout_seconds + 5) * 1000
return stub.exec_in_container(request, timeout_ms=rpc_timeout_ms)
return self._loop_thread.run(stub.exec_in_container(request, timeout_ms=rpc_timeout_ms))

def close(self) -> None:
self._pool.shutdown(wait=False)
self.stub_factory.close()
self._loop_thread.close()
20 changes: 15 additions & 5 deletions lib/iris/tests/cluster/controller/test_heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,25 @@ class _FakeStub:
def __init__(self, response: job_pb2.HeartbeatResponse):
self._response = response

def heartbeat(self, request: job_pb2.HeartbeatRequest) -> job_pb2.HeartbeatResponse:
async def heartbeat(
self,
request: job_pb2.HeartbeatRequest,
*,
timeout_ms: int | None = None,
) -> job_pb2.HeartbeatResponse:
return self._response


class _RaisingStub:
def __init__(self, exc: Exception):
self._exc = exc

def heartbeat(self, request: job_pb2.HeartbeatRequest) -> job_pb2.HeartbeatResponse:
async def heartbeat(
self,
request: job_pb2.HeartbeatRequest,
*,
timeout_ms: int | None = None,
) -> job_pb2.HeartbeatResponse:
raise self._exc


Expand Down Expand Up @@ -353,7 +363,7 @@ def test_handle_failed_heartbeats_logs_diagnostics(tmp_path, worker_metadata, ca
controller.stop()


def test_rpc_worker_stub_factory_uses_longer_default_timeout(monkeypatch):
def test_rpc_worker_stub_factory_default_timeout(monkeypatch):
captured: dict[str, object] = {}

class _RecordingClient:
Expand All @@ -364,13 +374,13 @@ def __init__(self, address: str, timeout_ms: int):
def close(self) -> None:
pass

monkeypatch.setattr(worker_provider_module, "WorkerServiceClientSync", _RecordingClient)
monkeypatch.setattr(worker_provider_module, "WorkerServiceClient", _RecordingClient)

factory = RpcWorkerStubFactory()
factory.get_stub("host:8080")

assert captured["address"] == "http://host:8080"
assert captured["timeout_ms"] == 30_000
assert captured["timeout_ms"] == 10_000

factory.close()

Expand Down
Loading