@@ -249,63 +249,85 @@ def ping_workers(self, workers: list[tuple[WorkerId, str | None]]) -> list[PingR
249249 """Send Ping RPCs to all workers concurrently. Returns per-worker results."""
250250 if not workers :
251251 return []
252- return asyncio .run (self ._ping_all (workers ))
253252
254- async def _ping_all (
255- self ,
256- workers : list [tuple [WorkerId , str | None ]],
257- ) -> list [PingResult ]:
258- sem = asyncio .Semaphore (self .parallelism )
259- return await asyncio .gather (* (self ._ping_one_safe (sem , wid , addr ) for wid , addr in workers ))
253+ async def _one (sem : asyncio .Semaphore , wid : WorkerId , addr : str | None ) -> PingResult :
254+ async with sem :
255+ if not addr :
256+ return PingResult (worker_id = wid , worker_address = addr , error = f"Worker { wid } has no address" )
257+ try :
258+ stub = self .stub_factory .get_stub (addr )
259+ response = await stub .ping (worker_pb2 .Worker .PingRequest ())
260+ if not response .healthy :
261+ return PingResult (
262+ worker_id = wid ,
263+ worker_address = addr ,
264+ error = f"worker { wid } reported unhealthy: { response .health_error } " ,
265+ )
266+ return PingResult (
267+ worker_id = wid ,
268+ worker_address = addr ,
269+ resource_snapshot = (
270+ response .resource_snapshot if response .resource_snapshot .ByteSize () > 0 else None
271+ ),
272+ healthy = response .healthy ,
273+ health_error = response .health_error ,
274+ )
275+ except Exception as e :
276+ return PingResult (worker_id = wid , worker_address = addr , error = str (e ))
277+
278+ async def _run () -> list [PingResult ]:
279+ sem = asyncio .Semaphore (self .parallelism )
280+ return await asyncio .gather (* (_one (sem , wid , addr ) for wid , addr in workers ))
281+
282+ return asyncio .run (_run ())
260283
261- async def _ping_one_safe (
284+ def start_tasks (
262285 self ,
263- sem : asyncio .Semaphore ,
264- worker_id : WorkerId ,
265- address : str | None ,
266- ) -> PingResult :
267- async with sem :
268- try :
269- return await self ._ping_one (worker_id , address )
270- except Exception as e :
271- return PingResult (worker_id = worker_id , worker_address = address , error = str (e ))
286+ jobs : list [tuple [WorkerId , str , list [job_pb2 .RunTaskRequest ]]],
287+ ) -> list [tuple [WorkerId , worker_pb2 .Worker .StartTasksResponse | None , str | None ]]:
288+ """Send StartTasks RPCs to many workers concurrently."""
289+ if not jobs :
290+ return []
272291
273- async def _ping_one (self , worker_id : WorkerId , address : str | None ) -> PingResult :
274- if not address :
275- raise ProviderError (f"Worker { worker_id } has no address" )
276- stub = self .stub_factory .get_stub (address )
277- response = await stub .ping (worker_pb2 .Worker .PingRequest ())
278- if not response .healthy :
279- raise ProviderError (f"worker { worker_id } reported unhealthy: { response .health_error } " )
280- return PingResult (
281- worker_id = worker_id ,
282- worker_address = address ,
283- resource_snapshot = response .resource_snapshot if response .resource_snapshot .ByteSize () > 0 else None ,
284- healthy = response .healthy ,
285- health_error = response .health_error ,
286- )
292+ async def _one (
293+ sem : asyncio .Semaphore , wid : WorkerId , addr : str , tasks : list [job_pb2 .RunTaskRequest ]
294+ ) -> tuple [WorkerId , worker_pb2 .Worker .StartTasksResponse | None , str | None ]:
295+ async with sem :
296+ try :
297+ stub = self .stub_factory .get_stub (addr )
298+ response = await stub .start_tasks (worker_pb2 .Worker .StartTasksRequest (tasks = tasks ))
299+ return (wid , response , None )
300+ except Exception as e :
301+ return (wid , None , str (e ))
287302
288- def start_tasks (
289- self ,
290- worker_id : WorkerId ,
291- address : str ,
292- tasks : list [job_pb2 .RunTaskRequest ],
293- ) -> worker_pb2 .Worker .StartTasksResponse :
294- """Send StartTasks RPC to a worker."""
295- stub = self .stub_factory .get_stub (address )
296- request = worker_pb2 .Worker .StartTasksRequest (tasks = tasks )
297- return asyncio .run (stub .start_tasks (request ))
303+ async def _run () -> list [tuple [WorkerId , worker_pb2 .Worker .StartTasksResponse | None , str | None ]]:
304+ sem = asyncio .Semaphore (self .parallelism )
305+ return await asyncio .gather (* (_one (sem , wid , addr , tasks ) for wid , addr , tasks in jobs ))
306+
307+ return asyncio .run (_run ())
298308
299309 def stop_tasks (
300310 self ,
301- worker_id : WorkerId ,
302- address : str ,
303- task_ids : list [str ],
304- ) -> None :
305- """Send StopTasks RPC to a worker."""
306- stub = self .stub_factory .get_stub (address )
307- request = worker_pb2 .Worker .StopTasksRequest (task_ids = task_ids )
308- asyncio .run (stub .stop_tasks (request ))
311+ jobs : list [tuple [WorkerId , str , list [str ]]],
312+ ) -> list [tuple [WorkerId , str | None ]]:
313+ """Send StopTasks RPCs to many workers concurrently."""
314+ if not jobs :
315+ return []
316+
317+ async def _one (sem : asyncio .Semaphore , wid : WorkerId , addr : str , ids : list [str ]) -> tuple [WorkerId , str | None ]:
318+ async with sem :
319+ try :
320+ stub = self .stub_factory .get_stub (addr )
321+ await stub .stop_tasks (worker_pb2 .Worker .StopTasksRequest (task_ids = ids ))
322+ return (wid , None )
323+ except Exception as e :
324+ return (wid , str (e ))
325+
326+ async def _run () -> list [tuple [WorkerId , str | None ]]:
327+ sem = asyncio .Semaphore (self .parallelism )
328+ return await asyncio .gather (* (_one (sem , wid , addr , ids ) for wid , addr , ids in jobs ))
329+
330+ return asyncio .run (_run ())
309331
310332 def poll_workers (
311333 self ,
@@ -318,37 +340,28 @@ def poll_workers(
318340 """
319341 if not running :
320342 return []
321- return asyncio .run (self ._poll_all (running , worker_addresses ))
322-
323- async def _poll_all (
324- self ,
325- running : dict [WorkerId , list [RunningTaskEntry ]],
326- worker_addresses : dict [WorkerId , str ],
327- ) -> list [tuple [WorkerId , list [TaskUpdate ] | None , str | None ]]:
328- sem = asyncio .Semaphore (self .parallelism )
329- return await asyncio .gather (
330- * (self ._poll_one_safe (sem , wid , running [wid ], worker_addresses .get (wid )) for wid in running )
331- )
332343
333- async def _poll_one_safe (
334- self ,
335- sem : asyncio .Semaphore ,
336- worker_id : WorkerId ,
337- entries : list [RunningTaskEntry ],
338- address : str | None ,
339- ) -> tuple [WorkerId , list [TaskUpdate ] | None , str | None ]:
340- async with sem :
341- if not address :
342- return (worker_id , None , f"Worker { worker_id } has no address" )
343- try :
344- expected = [
345- job_pb2 .WorkerTaskStatus (task_id = e .task_id .to_wire (), attempt_id = e .attempt_id ) for e in entries
346- ]
347- stub = self .stub_factory .get_stub (address )
348- response = await stub .poll_tasks (worker_pb2 .Worker .PollTasksRequest (expected_tasks = expected ))
349- return (worker_id , task_updates_from_proto (response .tasks ), None )
350- except Exception as e :
351- return (worker_id , None , str (e ))
344+ async def _one (
345+ sem : asyncio .Semaphore , wid : WorkerId , entries : list [RunningTaskEntry ], addr : str | None
346+ ) -> tuple [WorkerId , list [TaskUpdate ] | None , str | None ]:
347+ async with sem :
348+ if not addr :
349+ return (wid , None , f"Worker { wid } has no address" )
350+ try :
351+ expected = [
352+ job_pb2 .WorkerTaskStatus (task_id = e .task_id .to_wire (), attempt_id = e .attempt_id ) for e in entries
353+ ]
354+ stub = self .stub_factory .get_stub (addr )
355+ response = await stub .poll_tasks (worker_pb2 .Worker .PollTasksRequest (expected_tasks = expected ))
356+ return (wid , task_updates_from_proto (response .tasks ), None )
357+ except Exception as e :
358+ return (wid , None , str (e ))
359+
360+ async def _run () -> list [tuple [WorkerId , list [TaskUpdate ] | None , str | None ]]:
361+ sem = asyncio .Semaphore (self .parallelism )
362+ return await asyncio .gather (* (_one (sem , wid , running [wid ], worker_addresses .get (wid )) for wid in running ))
363+
364+ return asyncio .run (_run ())
352365
353366 def close (self ) -> None :
354367 self .stub_factory .close ()
0 commit comments