diff --git a/lib/iris/src/iris/cluster/controller/controller.py b/lib/iris/src/iris/cluster/controller/controller.py index 6cf4e85789..99c51eb203 100644 --- a/lib/iris/src/iris/cluster/controller/controller.py +++ b/lib/iris/src/iris/cluster/controller/controller.py @@ -2166,7 +2166,8 @@ def _sync_all_execution_units(self) -> None: return # Sync with the execution backend (ThreadPoolExecutor inside provider). - results = self._provider.sync(batches) + with slow_log(logger, "provider sync (RPC dispatch)", threshold_ms=5_000): + results = self._provider.sync(batches) acc = _SyncFailureAccumulator() with slow_log(logger, "provider sync (apply results)", threshold_ms=500): diff --git a/lib/iris/src/iris/cluster/controller/worker_provider.py b/lib/iris/src/iris/cluster/controller/worker_provider.py index 1f4afe9d70..84b0647d48 100644 --- a/lib/iris/src/iris/cluster/controller/worker_provider.py +++ b/lib/iris/src/iris/cluster/controller/worker_provider.py @@ -3,11 +3,11 @@ """WorkerProvider: TaskProvider backed by worker daemons via heartbeat RPC.""" +import asyncio import logging import threading -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field -from time import monotonic, sleep +from dataclasses import dataclass +from time import monotonic from typing import Protocol from iris.chaos import chaos @@ -20,13 +20,13 @@ from iris.cluster.types import JobName, WorkerId from iris.rpc import job_pb2 from iris.rpc import worker_pb2 -from iris.rpc.worker_connect import WorkerServiceClientSync +from iris.rpc.worker_connect import WorkerServiceClient from rigging.timing import Duration logger = logging.getLogger(__name__) -DEFAULT_WORKER_RPC_TIMEOUT = Duration.from_seconds(30.0) -_SLOW_HEARTBEAT_RPC_LOG_THRESHOLD_MS = 10_000 +DEFAULT_WORKER_RPC_TIMEOUT = Duration.from_seconds(10.0) +_SLOW_HEARTBEAT_RPC_LOG_THRESHOLD_MS = 5_000 def _heartbeat_rpc_context( @@ -44,31 +44,31 @@ def _heartbeat_rpc_context( class WorkerStubFactory(Protocol): - """Factory for getting cached worker RPC stubs.""" + """Factory for getting cached async worker RPC stubs.""" - def get_stub(self, address: str) -> WorkerServiceClientSync: ... + def get_stub(self, address: str) -> WorkerServiceClient: ... def evict(self, address: str) -> None: ... def close(self) -> None: ... class RpcWorkerStubFactory: - """Caches WorkerServiceClientSync stubs by address so each worker gets - one persistent httpx.Client instead of a new one per RPC.""" + """Caches async WorkerServiceClient stubs by address so each worker gets + one persistent async HTTP client instead of a new one per RPC.""" def __init__(self, timeout: Duration = DEFAULT_WORKER_RPC_TIMEOUT) -> None: self._timeout = timeout - self._stubs: dict[str, WorkerServiceClientSync] = {} + self._stubs: dict[str, WorkerServiceClient] = {} self._lock = threading.Lock() @property def timeout_ms(self) -> int: return self._timeout.to_ms() - def get_stub(self, address: str) -> WorkerServiceClientSync: + def get_stub(self, address: str) -> WorkerServiceClient: with self._lock: stub = self._stubs.get(address) if stub is None: - stub = WorkerServiceClientSync( + stub = WorkerServiceClient( address=f"http://{address}", timeout_ms=self._timeout.to_ms(), ) @@ -77,16 +77,11 @@ def get_stub(self, address: str) -> WorkerServiceClientSync: def evict(self, address: str) -> None: with self._lock: - stub = self._stubs.pop(address, None) - if stub is not None: - stub.close() + self._stubs.pop(address, None) def close(self) -> None: with self._lock: - stubs = list(self._stubs.values()) self._stubs.clear() - for stub in stubs: - stub.close() def _apply_request_from_response( @@ -118,18 +113,17 @@ def _apply_request_from_response( @dataclass class WorkerProvider: - """TaskProvider backed by worker daemons via heartbeat RPC. + """TaskProvider backed by worker daemons via async heartbeat RPC. - Drop-in replacement for the controller's _do_heartbeat_rpc path. Uses a - persistent ThreadPoolExecutor for parallel heartbeat dispatch. + Per round, `sync()` spins up an asyncio event loop via `asyncio.run` + and dispatches per-worker heartbeat RPCs concurrently via + `asyncio.gather`, capped at `parallelism` in-flight requests by a + local semaphore. Cached stubs in the factory keep their pyqwest + connection pools across rounds independently of the Python loop. """ stub_factory: WorkerStubFactory - parallelism: int = 32 - _pool: ThreadPoolExecutor = field(init=False) - - def __post_init__(self) -> None: - self._pool = ThreadPoolExecutor(max_workers=self.parallelism) + parallelism: int = 128 def sync( self, @@ -137,24 +131,34 @@ def sync( ) -> list[tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]]: if not batches: return [] - results: list[tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]] = [] - futures = {self._pool.submit(self._heartbeat_one, b): b for b in batches} - for future in futures: - batch = futures[future] + return asyncio.run(self._sync_all(batches)) + + async def _sync_all( + self, + batches: list[DispatchBatch], + ) -> list[tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]]: + sem = asyncio.Semaphore(self.parallelism) + return await asyncio.gather(*(self._heartbeat_one_safe(sem, b) for b in batches)) + + async def _heartbeat_one_safe( + self, + sem: asyncio.Semaphore, + batch: DispatchBatch, + ) -> tuple[DispatchBatch, HeartbeatApplyRequest | None, str | None]: + async with sem: try: - apply_req = future.result() - results.append((batch, apply_req, None)) + apply_req = await self._heartbeat_one(batch) + return (batch, apply_req, None) except Exception as e: - results.append((batch, None, str(e))) - return results + return (batch, None, str(e)) - def _heartbeat_one(self, batch: DispatchBatch) -> HeartbeatApplyRequest: + async def _heartbeat_one(self, batch: DispatchBatch) -> HeartbeatApplyRequest: """Send heartbeat RPC to one worker and return the apply request.""" started = monotonic() timeout_ms = getattr(self.stub_factory, "timeout_ms", None) if rule := chaos("controller.heartbeat"): - sleep(rule.delay_seconds) + await asyncio.sleep(rule.delay_seconds) raise ProviderError("chaos: heartbeat unavailable") if not batch.worker_address: @@ -165,7 +169,7 @@ def _heartbeat_one(self, batch: DispatchBatch) -> HeartbeatApplyRequest: expected_tasks = [] for entry in batch.running_tasks: if rule := chaos("controller.heartbeat.iteration"): - sleep(rule.delay_seconds) + await asyncio.sleep(rule.delay_seconds) expected_tasks.append( job_pb2.WorkerTaskStatus( task_id=entry.task_id.to_wire(), @@ -178,7 +182,7 @@ def _heartbeat_one(self, batch: DispatchBatch) -> HeartbeatApplyRequest: expected_tasks=expected_tasks, ) try: - response = stub.heartbeat(request) + response = await stub.heartbeat(request) if not response.worker_healthy: health_error = response.health_error or "worker reported unhealthy" @@ -213,7 +217,7 @@ def get_process_status( log_substring=request.log_substring, min_log_level=request.min_log_level, ) - return stub.get_process_status(forwarded, timeout_ms=10000) + return asyncio.run(stub.get_process_status(forwarded, timeout_ms=10000)) def on_worker_failed(self, worker_id: WorkerId, address: str | None) -> None: if address: @@ -226,7 +230,7 @@ def profile_task( timeout_ms: int, ) -> job_pb2.ProfileTaskResponse: stub = self.stub_factory.get_stub(address) - return stub.profile_task(request, timeout_ms=timeout_ms) + return asyncio.run(stub.profile_task(request, timeout_ms=timeout_ms)) def exec_in_container( self, @@ -240,8 +244,7 @@ def exec_in_container( rpc_timeout_ms = 3_600_000 else: rpc_timeout_ms = (timeout_seconds + 5) * 1000 - return stub.exec_in_container(request, timeout_ms=rpc_timeout_ms) + return asyncio.run(stub.exec_in_container(request, timeout_ms=rpc_timeout_ms)) def close(self) -> None: - self._pool.shutdown(wait=False) self.stub_factory.close() diff --git a/lib/iris/tests/cluster/controller/test_heartbeat.py b/lib/iris/tests/cluster/controller/test_heartbeat.py index 4c8a69c21e..5976c30265 100644 --- a/lib/iris/tests/cluster/controller/test_heartbeat.py +++ b/lib/iris/tests/cluster/controller/test_heartbeat.py @@ -294,7 +294,12 @@ class _FakeStub: def __init__(self, response: job_pb2.HeartbeatResponse): self._response = response - def heartbeat(self, request: job_pb2.HeartbeatRequest) -> job_pb2.HeartbeatResponse: + async def heartbeat( + self, + request: job_pb2.HeartbeatRequest, + *, + timeout_ms: int | None = None, + ) -> job_pb2.HeartbeatResponse: return self._response @@ -302,7 +307,12 @@ class _RaisingStub: def __init__(self, exc: Exception): self._exc = exc - def heartbeat(self, request: job_pb2.HeartbeatRequest) -> job_pb2.HeartbeatResponse: + async def heartbeat( + self, + request: job_pb2.HeartbeatRequest, + *, + timeout_ms: int | None = None, + ) -> job_pb2.HeartbeatResponse: raise self._exc @@ -353,7 +363,7 @@ def test_handle_failed_heartbeats_logs_diagnostics(tmp_path, worker_metadata, ca controller.stop() -def test_rpc_worker_stub_factory_uses_longer_default_timeout(monkeypatch): +def test_rpc_worker_stub_factory_default_timeout(monkeypatch): captured: dict[str, object] = {} class _RecordingClient: @@ -364,13 +374,13 @@ def __init__(self, address: str, timeout_ms: int): def close(self) -> None: pass - monkeypatch.setattr(worker_provider_module, "WorkerServiceClientSync", _RecordingClient) + monkeypatch.setattr(worker_provider_module, "WorkerServiceClient", _RecordingClient) factory = RpcWorkerStubFactory() factory.get_stub("host:8080") assert captured["address"] == "http://host:8080" - assert captured["timeout_ms"] == 30_000 + assert captured["timeout_ms"] == 10_000 factory.close()