@@ -761,70 +761,151 @@ def controller_restart(ctx, skip_checkpoint: bool, checkpoint_timeout: int):
761761
762762
763763@controller .command ("worker-restart" )
764- @click .option ("--worker-id" , default = None , help = "Specific worker to restart (default: all)" )
765- @click .option ("--timeout" , type = int , default = 120 , help = "Max seconds to wait per worker restart" )
764+ @click .option ("--worker-id" , multiple = True , help = "Worker(s) to restart (repeatable; default: all)" )
765+ @click .option ("--timeout" , type = int , default = 120 , help = "Max seconds to wait per worker to become healthy" )
766+ @click .option ("--max-batch" , type = int , default = 64 , help = "Maximum workers to restart concurrently" )
767+ @click .option (
768+ "--observation-window" ,
769+ type = int ,
770+ default = 60 ,
771+ help = "Seconds to observe restarted workers for failures before advancing" ,
772+ )
766773@click .pass_context
767- def worker_restart (ctx , worker_id : str | None , timeout : int ):
768- """Rolling restart of workers without disrupting running tasks.
769-
770- Restarts workers one at a time, waiting for each to re-register before
771- proceeding. Running Docker containers are preserved and adopted by the
772- new worker process.
774+ def worker_restart (
775+ ctx ,
776+ worker_id : tuple [str , ...],
777+ timeout : int ,
778+ max_batch : int ,
779+ observation_window : int ,
780+ ):
781+ """Rolling restart of workers with adaptive batch sizing.
782+
783+ Restarts workers in progressively larger batches (1, 2, 4, ... up to
784+ --max-batch). After each batch, waits for workers to become healthy, then
785+ observes them for --observation-window seconds to catch post-restart
786+ failures. Aborts immediately if any worker fails to come back healthy or
787+ develops failures during observation.
788+
789+ Running Docker containers are preserved and adopted by the new worker
790+ process, so tasks are not disrupted.
773791 """
774792 controller_url = require_controller_url (ctx )
775793
776794 with rpc_client (controller_url ) as client :
777- # Get current workers
778795 workers_resp = client .list_workers (controller_pb2 .Controller .ListWorkersRequest ())
779- workers = workers_resp .workers
796+ all_workers = workers_resp .workers
780797
781798 if worker_id :
782- workers = [w for w in workers if w .worker_id == worker_id ]
783- if not workers :
784- click .echo (f"Worker { worker_id } not found" , err = True )
799+ requested = set (worker_id )
800+ workers = [w for w in all_workers if w .worker_id in requested ]
801+ missing = requested - {w .worker_id for w in workers }
802+ if missing :
803+ click .echo (f"Workers not found: { ', ' .join (sorted (missing ))} " , err = True )
785804 raise SystemExit (1 )
805+ else :
806+ workers = list (all_workers )
786807
787808 if not workers :
788809 click .echo ("No workers to restart" )
789810 return
790811
791- click .echo (f"Restarting { len (workers )} worker(s) (timeout={ timeout } s per worker)" )
812+ worker_ids = [w .worker_id for w in workers ]
813+ total = len (worker_ids )
814+ click .echo (
815+ f"Restarting { total } worker(s) "
816+ f"(timeout={ timeout } s, observation={ observation_window } s, max_batch={ max_batch } )"
817+ )
792818
793819 succeeded = 0
794- failed = 0
795-
796- for worker in workers :
797- wid = worker .worker_id
798- click .echo (f"\n Restarting worker { wid } ..." )
799-
800- resp = client .restart_worker (
801- controller_pb2 .Controller .RestartWorkerRequest (worker_id = wid ),
802- timeout_ms = timeout * 1000 ,
803- )
804-
805- if not resp .accepted :
806- click .echo (f" Failed: { resp .error } " , err = True )
807- failed += 1
808- continue
809-
810- # Poll until the worker re-registers as healthy
811- def _worker_healthy (target_id : str = wid ) -> bool :
812- try :
813- resp = client .list_workers (controller_pb2 .Controller .ListWorkersRequest ())
814- return any (w .worker_id == target_id and w .healthy for w in resp .workers )
815- except Exception :
816- return False
817-
818- reregistered = ExponentialBackoff (initial = 5.0 , maximum = 5.0 , jitter = 0.0 ).wait_until (
819- _worker_healthy ,
820- timeout = Duration .from_seconds (timeout ),
821- )
822-
823- if reregistered :
824- click .echo (f" Worker { wid } restarted successfully" )
825- succeeded += 1
826- else :
827- click .echo (f" Worker { wid } did not re-register within { timeout } s" , err = True )
828- failed += 1
820+ batch_size = 1
821+ offset = 0
822+
823+ while offset < total :
824+ batch = worker_ids [offset : offset + batch_size ]
825+ click .echo (f"\n --- Batch of { len (batch )} (workers { offset + 1 } -{ offset + len (batch )} of { total } ) ---" )
826+
827+ # Issue restart RPCs for the batch
828+ for wid in batch :
829+ click .echo (f" Restarting { wid } ..." )
830+ resp = client .restart_worker (
831+ controller_pb2 .Controller .RestartWorkerRequest (worker_id = wid ),
832+ timeout_ms = timeout * 1000 ,
833+ )
834+ if not resp .accepted :
835+ click .echo (f" ABORT: restart rejected for { wid } : { resp .error } " , err = True )
836+ _print_summary (succeeded , total - succeeded , offset )
837+ raise SystemExit (1 )
838+
839+ # Wait for all workers in the batch to become healthy
840+ click .echo (f" Waiting for { len (batch )} worker(s) to become healthy..." )
841+ unhealthy = _wait_for_workers_healthy (client , set (batch ), timeout )
842+ if unhealthy :
843+ click .echo (
844+ f" ABORT: workers did not become healthy within { timeout } s: " f"{ ', ' .join (sorted (unhealthy ))} " ,
845+ err = True ,
846+ )
847+ _print_summary (succeeded , total - succeeded , offset )
848+ raise SystemExit (1 )
849+
850+ click .echo (f" All { len (batch )} worker(s) healthy. Observing for { observation_window } s..." )
851+ time .sleep (observation_window )
852+
853+ # Re-check health after observation window
854+ failed_workers = _check_worker_health (client , set (batch ))
855+ if failed_workers :
856+ click .echo (
857+ f" ABORT: workers developed failures during observation: "
858+ f"{ ', ' .join (f'{ wid } ({ msg } )' for wid , msg in sorted (failed_workers ))} " ,
859+ err = True ,
860+ )
861+ _print_summary (succeeded , total - succeeded , offset )
862+ raise SystemExit (1 )
863+
864+ succeeded += len (batch )
865+ offset += len (batch )
866+ click .echo (f" Batch OK ({ succeeded } /{ total } complete)" )
867+
868+ # Double batch size for next round, capped at max_batch
869+ batch_size = min (batch_size * 2 , max_batch )
870+
871+ click .echo (f"\n Done: { succeeded } /{ total } workers restarted successfully" )
872+
873+
874+ def _wait_for_workers_healthy (client , worker_ids : set [str ], timeout : int ) -> set [str ]:
875+ """Poll until all workers in the set are healthy. Returns IDs that failed to become healthy."""
876+ remaining = set (worker_ids )
877+ backoff = ExponentialBackoff (initial = 5.0 , maximum = 5.0 , jitter = 0.0 )
878+
879+ def _all_healthy () -> bool :
880+ try :
881+ resp = client .list_workers (controller_pb2 .Controller .ListWorkersRequest ())
882+ for w in resp .workers :
883+ if w .worker_id in remaining and w .healthy :
884+ remaining .discard (w .worker_id )
885+ except Exception :
886+ pass
887+ return len (remaining ) == 0
888+
889+ backoff .wait_until (_all_healthy , timeout = Duration .from_seconds (timeout ))
890+ return remaining
891+
892+
893+ def _check_worker_health (client , worker_ids : set [str ]) -> list [tuple [str , str ]]:
894+ """Check that all workers are still healthy. Returns list of (worker_id, problem) for failures."""
895+ failures : list [tuple [str , str ]] = []
896+ try :
897+ resp = client .list_workers (controller_pb2 .Controller .ListWorkersRequest ())
898+ by_id = {w .worker_id : w for w in resp .workers }
899+ for wid in worker_ids :
900+ w = by_id .get (wid )
901+ if w is None :
902+ failures .append ((wid , "disappeared" ))
903+ elif not w .healthy :
904+ failures .append ((wid , w .status_message or f"{ w .consecutive_failures } consecutive failures" ))
905+ except Exception as e :
906+ failures .append (("(rpc)" , str (e )))
907+ return failures
908+
829909
830- click .echo (f"\n Done: { succeeded } succeeded, { failed } failed" )
910+ def _print_summary (succeeded : int , remaining : int , offset : int ):
911+ click .echo (f"\n Summary: { succeeded } succeeded, { remaining } remaining (aborted at worker { offset + 1 } )" )
0 commit comments