Skip to content

Commit dcd5b6a

Browse files
rjpowerclaude
andcommitted
[iris] Bulk start_tasks/stop_tasks; one event loop per cycle
Was: controller looped over per-worker calls, each spinning up a fresh asyncio event loop via asyncio.run. With N workers that's N loop spin-ups serialized on the controller thread. Now: start_tasks/stop_tasks take a list of (worker, address, payload) jobs and dispatch them concurrently within a single event loop, capped by the same Semaphore(parallelism) the heartbeat path uses. Also flatten ping_workers/poll_workers helpers into local async closures to remove the layer of _one_safe / _all method ping-pong. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 802673c commit dcd5b6a

3 files changed

Lines changed: 125 additions & 110 deletions

File tree

lib/iris/src/iris/cluster/controller/controller.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,29 +2169,32 @@ def _dispatch_assignments_direct(
21692169
for worker_id, address, run_request in result.start_requests:
21702170
by_worker.setdefault((worker_id, address), []).append(run_request)
21712171

2172-
for (worker_id, address), tasks in by_worker.items():
2173-
attempt_by_task = {t.task_id: t.attempt_id for t in tasks}
2174-
try:
2175-
response = self._provider.start_tasks(worker_id, address, tasks)
2176-
for ack in response.acks:
2177-
if not ack.accepted:
2178-
logger.warning("Worker %s rejected task %s: %s", worker_id, ack.task_id, ack.error)
2179-
self._task_update_queue.put(
2180-
HeartbeatApplyRequest(
2181-
worker_id=worker_id,
2182-
worker_resource_snapshot=None,
2183-
updates=[
2184-
TaskUpdate(
2185-
task_id=JobName.from_wire(ack.task_id),
2186-
attempt_id=attempt_by_task.get(ack.task_id, -1),
2187-
new_state=job_pb2.TASK_STATE_WORKER_FAILED,
2188-
error=f"Worker rejected task: {ack.error}",
2189-
)
2190-
],
2191-
)
2172+
attempt_by_worker_task = {
2173+
(worker_id, t.task_id): t.attempt_id for (worker_id, _), tasks in by_worker.items() for t in tasks
2174+
}
2175+
jobs = [(worker_id, address, tasks) for (worker_id, address), tasks in by_worker.items()]
2176+
for worker_id, response, error in self._provider.start_tasks(jobs):
2177+
if error is not None:
2178+
logger.warning("StartTasks RPC failed for worker %s: %s", worker_id, error)
2179+
continue
2180+
assert response is not None
2181+
for ack in response.acks:
2182+
if not ack.accepted:
2183+
logger.warning("Worker %s rejected task %s: %s", worker_id, ack.task_id, ack.error)
2184+
self._task_update_queue.put(
2185+
HeartbeatApplyRequest(
2186+
worker_id=worker_id,
2187+
worker_resource_snapshot=None,
2188+
updates=[
2189+
TaskUpdate(
2190+
task_id=JobName.from_wire(ack.task_id),
2191+
attempt_id=attempt_by_worker_task.get((worker_id, ack.task_id), -1),
2192+
new_state=job_pb2.TASK_STATE_WORKER_FAILED,
2193+
error=f"Worker rejected task: {ack.error}",
2194+
)
2195+
],
21922196
)
2193-
except Exception as e:
2194-
logger.warning("StartTasks RPC failed for worker %s: %s", worker_id, e)
2197+
)
21952198

21962199
def _stop_tasks_direct(
21972200
self,
@@ -2212,11 +2215,10 @@ def _stop_tasks_direct(
22122215
continue
22132216
by_worker.setdefault((worker_id, worker.address), []).append(task_id.to_wire())
22142217

2215-
for (worker_id, address), wids in by_worker.items():
2216-
try:
2217-
self._provider.stop_tasks(worker_id, address, wids)
2218-
except Exception as e:
2219-
logger.warning("StopTasks RPC failed for worker %s: %s", worker_id, e)
2218+
jobs = [(worker_id, address, wids) for (worker_id, address), wids in by_worker.items()]
2219+
for worker_id, error in self._provider.stop_tasks(jobs):
2220+
if error is not None:
2221+
logger.warning("StopTasks RPC failed for worker %s: %s", worker_id, error)
22202222

22212223
def _get_active_worker_addresses(self) -> list[tuple[WorkerId, str | None]]:
22222224
"""Get healthy active workers as (worker_id, address) tuples for ping."""

lib/iris/src/iris/cluster/controller/worker_provider.py

Lines changed: 92 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -249,63 +249,85 @@ def ping_workers(self, workers: list[tuple[WorkerId, str | None]]) -> list[PingR
249249
"""Send Ping RPCs to all workers concurrently. Returns per-worker results."""
250250
if not workers:
251251
return []
252-
return asyncio.run(self._ping_all(workers))
253252

254-
async def _ping_all(
255-
self,
256-
workers: list[tuple[WorkerId, str | None]],
257-
) -> list[PingResult]:
258-
sem = asyncio.Semaphore(self.parallelism)
259-
return await asyncio.gather(*(self._ping_one_safe(sem, wid, addr) for wid, addr in workers))
253+
async def _one(sem: asyncio.Semaphore, wid: WorkerId, addr: str | None) -> PingResult:
254+
async with sem:
255+
if not addr:
256+
return PingResult(worker_id=wid, worker_address=addr, error=f"Worker {wid} has no address")
257+
try:
258+
stub = self.stub_factory.get_stub(addr)
259+
response = await stub.ping(worker_pb2.Worker.PingRequest())
260+
if not response.healthy:
261+
return PingResult(
262+
worker_id=wid,
263+
worker_address=addr,
264+
error=f"worker {wid} reported unhealthy: {response.health_error}",
265+
)
266+
return PingResult(
267+
worker_id=wid,
268+
worker_address=addr,
269+
resource_snapshot=(
270+
response.resource_snapshot if response.resource_snapshot.ByteSize() > 0 else None
271+
),
272+
healthy=response.healthy,
273+
health_error=response.health_error,
274+
)
275+
except Exception as e:
276+
return PingResult(worker_id=wid, worker_address=addr, error=str(e))
277+
278+
async def _run() -> list[PingResult]:
279+
sem = asyncio.Semaphore(self.parallelism)
280+
return await asyncio.gather(*(_one(sem, wid, addr) for wid, addr in workers))
281+
282+
return asyncio.run(_run())
260283

261-
async def _ping_one_safe(
284+
def start_tasks(
262285
self,
263-
sem: asyncio.Semaphore,
264-
worker_id: WorkerId,
265-
address: str | None,
266-
) -> PingResult:
267-
async with sem:
268-
try:
269-
return await self._ping_one(worker_id, address)
270-
except Exception as e:
271-
return PingResult(worker_id=worker_id, worker_address=address, error=str(e))
286+
jobs: list[tuple[WorkerId, str, list[job_pb2.RunTaskRequest]]],
287+
) -> list[tuple[WorkerId, worker_pb2.Worker.StartTasksResponse | None, str | None]]:
288+
"""Send StartTasks RPCs to many workers concurrently."""
289+
if not jobs:
290+
return []
272291

273-
async def _ping_one(self, worker_id: WorkerId, address: str | None) -> PingResult:
274-
if not address:
275-
raise ProviderError(f"Worker {worker_id} has no address")
276-
stub = self.stub_factory.get_stub(address)
277-
response = await stub.ping(worker_pb2.Worker.PingRequest())
278-
if not response.healthy:
279-
raise ProviderError(f"worker {worker_id} reported unhealthy: {response.health_error}")
280-
return PingResult(
281-
worker_id=worker_id,
282-
worker_address=address,
283-
resource_snapshot=response.resource_snapshot if response.resource_snapshot.ByteSize() > 0 else None,
284-
healthy=response.healthy,
285-
health_error=response.health_error,
286-
)
292+
async def _one(
293+
sem: asyncio.Semaphore, wid: WorkerId, addr: str, tasks: list[job_pb2.RunTaskRequest]
294+
) -> tuple[WorkerId, worker_pb2.Worker.StartTasksResponse | None, str | None]:
295+
async with sem:
296+
try:
297+
stub = self.stub_factory.get_stub(addr)
298+
response = await stub.start_tasks(worker_pb2.Worker.StartTasksRequest(tasks=tasks))
299+
return (wid, response, None)
300+
except Exception as e:
301+
return (wid, None, str(e))
287302

288-
def start_tasks(
289-
self,
290-
worker_id: WorkerId,
291-
address: str,
292-
tasks: list[job_pb2.RunTaskRequest],
293-
) -> worker_pb2.Worker.StartTasksResponse:
294-
"""Send StartTasks RPC to a worker."""
295-
stub = self.stub_factory.get_stub(address)
296-
request = worker_pb2.Worker.StartTasksRequest(tasks=tasks)
297-
return asyncio.run(stub.start_tasks(request))
303+
async def _run() -> list[tuple[WorkerId, worker_pb2.Worker.StartTasksResponse | None, str | None]]:
304+
sem = asyncio.Semaphore(self.parallelism)
305+
return await asyncio.gather(*(_one(sem, wid, addr, tasks) for wid, addr, tasks in jobs))
306+
307+
return asyncio.run(_run())
298308

299309
def stop_tasks(
300310
self,
301-
worker_id: WorkerId,
302-
address: str,
303-
task_ids: list[str],
304-
) -> None:
305-
"""Send StopTasks RPC to a worker."""
306-
stub = self.stub_factory.get_stub(address)
307-
request = worker_pb2.Worker.StopTasksRequest(task_ids=task_ids)
308-
asyncio.run(stub.stop_tasks(request))
311+
jobs: list[tuple[WorkerId, str, list[str]]],
312+
) -> list[tuple[WorkerId, str | None]]:
313+
"""Send StopTasks RPCs to many workers concurrently."""
314+
if not jobs:
315+
return []
316+
317+
async def _one(sem: asyncio.Semaphore, wid: WorkerId, addr: str, ids: list[str]) -> tuple[WorkerId, str | None]:
318+
async with sem:
319+
try:
320+
stub = self.stub_factory.get_stub(addr)
321+
await stub.stop_tasks(worker_pb2.Worker.StopTasksRequest(task_ids=ids))
322+
return (wid, None)
323+
except Exception as e:
324+
return (wid, str(e))
325+
326+
async def _run() -> list[tuple[WorkerId, str | None]]:
327+
sem = asyncio.Semaphore(self.parallelism)
328+
return await asyncio.gather(*(_one(sem, wid, addr, ids) for wid, addr, ids in jobs))
329+
330+
return asyncio.run(_run())
309331

310332
def poll_workers(
311333
self,
@@ -318,37 +340,28 @@ def poll_workers(
318340
"""
319341
if not running:
320342
return []
321-
return asyncio.run(self._poll_all(running, worker_addresses))
322-
323-
async def _poll_all(
324-
self,
325-
running: dict[WorkerId, list[RunningTaskEntry]],
326-
worker_addresses: dict[WorkerId, str],
327-
) -> list[tuple[WorkerId, list[TaskUpdate] | None, str | None]]:
328-
sem = asyncio.Semaphore(self.parallelism)
329-
return await asyncio.gather(
330-
*(self._poll_one_safe(sem, wid, running[wid], worker_addresses.get(wid)) for wid in running)
331-
)
332343

333-
async def _poll_one_safe(
334-
self,
335-
sem: asyncio.Semaphore,
336-
worker_id: WorkerId,
337-
entries: list[RunningTaskEntry],
338-
address: str | None,
339-
) -> tuple[WorkerId, list[TaskUpdate] | None, str | None]:
340-
async with sem:
341-
if not address:
342-
return (worker_id, None, f"Worker {worker_id} has no address")
343-
try:
344-
expected = [
345-
job_pb2.WorkerTaskStatus(task_id=e.task_id.to_wire(), attempt_id=e.attempt_id) for e in entries
346-
]
347-
stub = self.stub_factory.get_stub(address)
348-
response = await stub.poll_tasks(worker_pb2.Worker.PollTasksRequest(expected_tasks=expected))
349-
return (worker_id, task_updates_from_proto(response.tasks), None)
350-
except Exception as e:
351-
return (worker_id, None, str(e))
344+
async def _one(
345+
sem: asyncio.Semaphore, wid: WorkerId, entries: list[RunningTaskEntry], addr: str | None
346+
) -> tuple[WorkerId, list[TaskUpdate] | None, str | None]:
347+
async with sem:
348+
if not addr:
349+
return (wid, None, f"Worker {wid} has no address")
350+
try:
351+
expected = [
352+
job_pb2.WorkerTaskStatus(task_id=e.task_id.to_wire(), attempt_id=e.attempt_id) for e in entries
353+
]
354+
stub = self.stub_factory.get_stub(addr)
355+
response = await stub.poll_tasks(worker_pb2.Worker.PollTasksRequest(expected_tasks=expected))
356+
return (wid, task_updates_from_proto(response.tasks), None)
357+
except Exception as e:
358+
return (wid, None, str(e))
359+
360+
async def _run() -> list[tuple[WorkerId, list[TaskUpdate] | None, str | None]]:
361+
sem = asyncio.Semaphore(self.parallelism)
362+
return await asyncio.gather(*(_one(sem, wid, running[wid], worker_addresses.get(wid)) for wid in running))
363+
364+
return asyncio.run(_run())
352365

353366
def close(self) -> None:
354367
self.stub_factory.close()

lib/iris/tests/cluster/controller/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,13 @@ def profile_task(
113113
def ping_workers(self, workers):
114114
return []
115115

116-
def start_tasks(self, worker_id, address, tasks):
116+
def start_tasks(self, jobs):
117117
from iris.rpc import worker_pb2
118118

119-
return worker_pb2.Worker.StartTasksResponse()
119+
return [(wid, worker_pb2.Worker.StartTasksResponse(), None) for wid, _, _ in jobs]
120120

121-
def stop_tasks(self, worker_id, address, task_ids):
122-
return None
121+
def stop_tasks(self, jobs):
122+
return [(wid, None) for wid, _, _ in jobs]
123123

124124
def poll_workers(self, running, worker_addresses):
125125
return []

0 commit comments

Comments
 (0)