11#!/usr/bin/env python3
22
3+ import signal
34import sys
45from abc import ABC , abstractmethod
56from time import sleep
1314 get_namespace_args ,
1415 print_colored ,
1516 print_error ,
17+ run_in_parallel ,
1618 run_kubectl_command ,
1719 wait_until_y_or_n ,
1820)
19- from metrics_lib import MetricConditionGater
21+ from metrics_lib import MetricConditionGater , terminate_all_port_forwards
2022
2123
2224def _get_pod_names (
@@ -67,7 +69,11 @@ def _restart_pod(
6769 kubectl_args .extend (get_namespace_args (namespace , cluster ))
6870
6971 try :
70- run_kubectl_command (kubectl_args , capture_output = False )
72+ # Capture (rather than stream) so output stays grouped per node when restarts run
73+ # in parallel; echo it through print_colored which honors the per-node buffer.
74+ result = run_kubectl_command (kubectl_args , capture_output = True )
75+ if result .stdout :
76+ print_colored (result .stdout .rstrip ())
7177 print_colored (f"Restarted { pod } for node { index } " )
7278 except Exception as e :
7379 print_error (f"Failed restarting { pod } for node { index } : { e } " )
@@ -77,6 +83,18 @@ def _restart_pod(
7783 def restart_service (self , instance_index : int ) -> bool :
7884 """Restart service for a specific instance. If returns False, the restart process should be aborted."""
7985
86+ def restart_all (self , max_parallelism : int ) -> None :
87+ """Restart all instances.
88+
89+ Default: sequential, one instance at a time, aborting if any `restart_service` returns
90+ False. Subclasses that have no inter-node ordering dependency override this to run in
91+ parallel. `max_parallelism` is ignored by this sequential default.
92+ """
93+ for instance_index in range (self .namespace_and_instruction_args .size ()):
94+ if not self .restart_service (instance_index ):
95+ print_colored ("\n Aborting restart process." )
96+ sys .exit (1 )
97+
8098 # from_restart_strategy is a static method that returns the appropriate ServiceRestarter based on the restart strategy.
8199 @staticmethod
82100 def from_restart_strategy (
@@ -96,13 +114,15 @@ def from_restart_strategy(
96114 service ,
97115 check_between_restarts ,
98116 RestartPodOnlyRestarter (namespace_and_instruction_args , service ),
117+ parallel = False ,
99118 )
100119 elif restart_strategy == RestartStrategy .ALL_AT_ONCE :
101120 return ChecksBetweenRestartsCompositeRestarter (
102121 namespace_and_instruction_args ,
103122 service ,
104123 lambda instance_index : True ,
105124 RestartPodOnlyRestarter (namespace_and_instruction_args , service ),
125+ parallel = True ,
106126 )
107127 elif restart_strategy == RestartStrategy .NO_RESTART :
108128 assert (
@@ -142,10 +162,18 @@ def __init__(
142162 service : Service ,
143163 check_between_restarts : Callable [[int ], bool ],
144164 base_service_restarter : ServiceRestarter ,
165+ parallel : bool = False ,
145166 ):
146167 super ().__init__ (namespace_and_instruction_args , service )
147168 self .check_between_restarts = check_between_restarts
148169 self .base_service_restarter = base_service_restarter
170+ # When True there is no inter-node ordering dependency (e.g. ALL_AT_ONCE), so restart_all
171+ # restarts every node and then runs the post-restart checks concurrently. When False
172+ # (interactive ONE_BY_ONE / NO_RESTART) restart_all stays sequential.
173+ self .parallel = parallel
174+
175+ def _label (self , instance_index : int ) -> str :
176+ return self .namespace_and_instruction_args .get_namespace (instance_index )
149177
150178 def restart_service (self , instance_index : int ) -> bool :
151179 """Call the base restarter on each instance one by one, running the check_between_restarts in between each."""
@@ -156,6 +184,34 @@ def restart_service(self, instance_index: int) -> bool:
156184 print_colored (f"{ instructions } " , Colors .YELLOW )
157185 return self .check_between_restarts (instance_index )
158186
187+ def restart_all (self , max_parallelism : int ) -> None :
188+ if not self .parallel :
189+ super ().restart_all (max_parallelism )
190+ return
191+
192+ indices = list (range (self .namespace_and_instruction_args .size ()))
193+
194+ # Phase 1: restart every node's pod concurrently (pod deletes have no ordering dependency).
195+ print_colored (f"\n Restarting { len (indices )} node(s) in parallel..." , Colors .YELLOW )
196+ run_in_parallel (
197+ indices ,
198+ self .base_service_restarter .restart_service ,
199+ max_parallelism ,
200+ self ._label ,
201+ )
202+
203+ for instance_index in indices :
204+ instructions = self .namespace_and_instruction_args .get_instruction (instance_index )
205+ if instructions is not None :
206+ print_colored (f"[{ self ._label (instance_index )} ] { instructions } " , Colors .YELLOW )
207+
208+ # Phase 2: run post-restart checks (if any) concurrently.
209+ self ._wait_all (indices , max_parallelism )
210+
211+ def _wait_all (self , indices : list [int ], max_parallelism : int ) -> None :
212+ """Run post-restart checks for all nodes concurrently. No-op when there is nothing to wait
213+ for (overridden by restarters that gate on metrics)."""
214+
159215
160216class NoOpServiceRestarter (ServiceRestarter ):
161217 """No-op service restarter."""
@@ -177,11 +233,16 @@ def __init__(
177233 ):
178234 self .metrics = metrics
179235 self .metrics_port = metrics_port
236+ # ALL_AT_ONCE has no inter-node ordering dependency: restart every node, then wait for all
237+ # conditions concurrently. ONE_BY_ONE / NO_RESTART stay sequential (they prompt the user
238+ # between nodes).
239+ parallel = restart_strategy == RestartStrategy .ALL_AT_ONCE
180240 if restart_strategy == RestartStrategy .ONE_BY_ONE :
181241 check_function = self ._check_between_each_restart
182242 base_restarter = RestartPodOnlyRestarter (namespace_and_instruction_args , service )
183243 elif restart_strategy == RestartStrategy .ALL_AT_ONCE :
184- check_function = self ._check_all_only_after_last_restart
244+ # check_function is unused in the parallel path (restart_all drives the phases directly).
245+ check_function = lambda instance_index : True
185246 base_restarter = RestartPodOnlyRestarter (namespace_and_instruction_args , service )
186247 elif restart_strategy == RestartStrategy .NO_RESTART :
187248 check_function = self ._check_between_each_restart
@@ -190,7 +251,9 @@ def __init__(
190251 print_error (f"Invalid restart strategy: { restart_strategy } for WaitOnMetricRestarter." )
191252 sys .exit (1 )
192253
193- super ().__init__ (namespace_and_instruction_args , service , check_function , base_restarter )
254+ super ().__init__ (
255+ namespace_and_instruction_args , service , check_function , base_restarter , parallel
256+ )
194257
195258 def _check_between_each_restart (self , instance_index : int ) -> bool :
196259 if not self ._wait_for_pod_to_satisfy_condition (instance_index ):
@@ -200,16 +263,21 @@ def _check_between_each_restart(self, instance_index: int) -> bool:
200263 return True
201264 return wait_until_y_or_n (f"Do you want to restart the next pod?" )
202265
203- def _check_all_only_after_last_restart (self , instance_index : int ) -> bool :
204- # Restart all nodes without waiting for confirmation.
205- if instance_index < self .namespace_and_instruction_args .size () - 1 :
206- return True
266+ def _wait_all (self , indices : list [int ], max_parallelism : int ) -> None :
267+ # gate() starts a kubectl port-forward per node on a worker thread, which cannot install
268+ # signal handlers; install one here (main thread) so Ctrl-C tears all of them down.
269+ def signal_handler (signum , frame ):
270+ terminate_all_port_forwards ()
271+ sys .exit (0 )
207272
208- # After the last node has been restarted, wait for all pods to satisfy the condition.
209- for instance_index in range (self .namespace_and_instruction_args .size ()):
210- if not self ._wait_for_pod_to_satisfy_condition (instance_index ):
211- print_error (f"Failed waiting for condition(s) for Pod { instance_index } ." )
212- return True
273+ signal .signal (signal .SIGINT , signal_handler )
274+ signal .signal (signal .SIGTERM , signal_handler )
275+
276+ run_in_parallel (indices , self ._wait_for_index , max_parallelism , self ._label )
277+
278+ def _wait_for_index (self , instance_index : int ) -> None :
279+ if not self ._wait_for_pod_to_satisfy_condition (instance_index ):
280+ print_error (f"Failed waiting for condition(s) for Pod { instance_index } ." )
213281
214282 def _wait_for_pod_to_satisfy_condition (self , instance_index : int ) -> bool :
215283 # The sleep is to prevent the case where we get the pod name of the old pod we just deleted
@@ -234,6 +302,7 @@ def _wait_for_pod_to_satisfy_condition(self, instance_index: int) -> bool:
234302 self .metrics_port ,
235303 )
236304 metric_condition_gater .gate ()
305+ return True
237306
238307 @staticmethod
239308 def _wait_for_pods_to_be_ready (
@@ -263,7 +332,11 @@ def _wait_for_pods_to_be_ready(
263332 f"{ wait_timeout } s" ,
264333 ]
265334 kubectl_args .extend (get_namespace_args (namespace , cluster ))
266- result = run_kubectl_command (kubectl_args , capture_output = False )
335+ # Capture (rather than stream) so output stays grouped per node under parallel
336+ # waits; progress is surfaced by run_in_parallel's heartbeat instead.
337+ result = run_kubectl_command (kubectl_args , capture_output = True )
338+ if result .stdout :
339+ print_colored (result .stdout .rstrip ())
267340
268341 if result .returncode != 0 :
269342 print_colored (
0 commit comments