Skip to content

Commit 8f348f3

Browse files
scripts: restart and wait on nodes in parallel
For the non-interactive ALL_AT_ONCE strategy, restart every node's pod and then run the post-restart health/metric waits concurrently instead of node-by-node, which was the main source of slow rollouts. Add ServiceRestarter.restart_all driving two parallel phases (restart, then wait) via run_in_parallel, gated by a new --max-parallelism flag (default 16). ONE_BY_ONE and NO_RESTART stay sequential since they prompt the user between nodes. Thread the flag through all four entry scripts. Also fix _wait_for_pod_to_satisfy_condition to return True on success (it returned None, so callers always logged a spurious failure). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 0270f55 commit 8f348f3

7 files changed

Lines changed: 248 additions & 19 deletions

scripts/prod/restarter_lib.py

Lines changed: 87 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22

3+
import signal
34
import sys
45
from abc import ABC, abstractmethod
56
from time import sleep
@@ -13,10 +14,11 @@
1314
get_namespace_args,
1415
print_colored,
1516
print_error,
17+
run_in_parallel,
1618
run_kubectl_command,
1719
wait_until_y_or_n,
1820
)
19-
from metrics_lib import MetricConditionGater
21+
from metrics_lib import MetricConditionGater, terminate_all_port_forwards
2022

2123

2224
def _get_pod_names(
@@ -67,7 +69,11 @@ def _restart_pod(
6769
kubectl_args.extend(get_namespace_args(namespace, cluster))
6870

6971
try:
70-
run_kubectl_command(kubectl_args, capture_output=False)
72+
# Capture (rather than stream) so output stays grouped per node when restarts run
73+
# in parallel; echo it through print_colored which honors the per-node buffer.
74+
result = run_kubectl_command(kubectl_args, capture_output=True)
75+
if result.stdout:
76+
print_colored(result.stdout.rstrip())
7177
print_colored(f"Restarted {pod} for node {index}")
7278
except Exception as e:
7379
print_error(f"Failed restarting {pod} for node {index}: {e}")
@@ -77,6 +83,18 @@ def _restart_pod(
7783
def restart_service(self, instance_index: int) -> bool:
7884
"""Restart service for a specific instance. If returns False, the restart process should be aborted."""
7985

86+
def restart_all(self, max_parallelism: int) -> None:
87+
"""Restart all instances.
88+
89+
Default: sequential, one instance at a time, aborting if any `restart_service` returns
90+
False. Subclasses that have no inter-node ordering dependency override this to run in
91+
parallel. `max_parallelism` is ignored by this sequential default.
92+
"""
93+
for instance_index in range(self.namespace_and_instruction_args.size()):
94+
if not self.restart_service(instance_index):
95+
print_colored("\nAborting restart process.")
96+
sys.exit(1)
97+
8098
# from_restart_strategy is a static method that returns the appropriate ServiceRestarter based on the restart strategy.
8199
@staticmethod
82100
def from_restart_strategy(
@@ -96,13 +114,15 @@ def from_restart_strategy(
96114
service,
97115
check_between_restarts,
98116
RestartPodOnlyRestarter(namespace_and_instruction_args, service),
117+
parallel=False,
99118
)
100119
elif restart_strategy == RestartStrategy.ALL_AT_ONCE:
101120
return ChecksBetweenRestartsCompositeRestarter(
102121
namespace_and_instruction_args,
103122
service,
104123
lambda instance_index: True,
105124
RestartPodOnlyRestarter(namespace_and_instruction_args, service),
125+
parallel=True,
106126
)
107127
elif restart_strategy == RestartStrategy.NO_RESTART:
108128
assert (
@@ -142,10 +162,18 @@ def __init__(
142162
service: Service,
143163
check_between_restarts: Callable[[int], bool],
144164
base_service_restarter: ServiceRestarter,
165+
parallel: bool = False,
145166
):
146167
super().__init__(namespace_and_instruction_args, service)
147168
self.check_between_restarts = check_between_restarts
148169
self.base_service_restarter = base_service_restarter
170+
# When True there is no inter-node ordering dependency (e.g. ALL_AT_ONCE), so restart_all
171+
# restarts every node and then runs the post-restart checks concurrently. When False
172+
# (interactive ONE_BY_ONE / NO_RESTART) restart_all stays sequential.
173+
self.parallel = parallel
174+
175+
def _label(self, instance_index: int) -> str:
176+
return self.namespace_and_instruction_args.get_namespace(instance_index)
149177

150178
def restart_service(self, instance_index: int) -> bool:
151179
"""Call the base restarter on each instance one by one, running the check_between_restarts in between each."""
@@ -156,6 +184,34 @@ def restart_service(self, instance_index: int) -> bool:
156184
print_colored(f"{instructions} ", Colors.YELLOW)
157185
return self.check_between_restarts(instance_index)
158186

187+
def restart_all(self, max_parallelism: int) -> None:
188+
if not self.parallel:
189+
super().restart_all(max_parallelism)
190+
return
191+
192+
indices = list(range(self.namespace_and_instruction_args.size()))
193+
194+
# Phase 1: restart every node's pod concurrently (pod deletes have no ordering dependency).
195+
print_colored(f"\nRestarting {len(indices)} node(s) in parallel...", Colors.YELLOW)
196+
run_in_parallel(
197+
indices,
198+
self.base_service_restarter.restart_service,
199+
max_parallelism,
200+
self._label,
201+
)
202+
203+
for instance_index in indices:
204+
instructions = self.namespace_and_instruction_args.get_instruction(instance_index)
205+
if instructions is not None:
206+
print_colored(f"[{self._label(instance_index)}] {instructions}", Colors.YELLOW)
207+
208+
# Phase 2: run post-restart checks (if any) concurrently.
209+
self._wait_all(indices, max_parallelism)
210+
211+
def _wait_all(self, indices: list[int], max_parallelism: int) -> None:
212+
"""Run post-restart checks for all nodes concurrently. No-op when there is nothing to wait
213+
for (overridden by restarters that gate on metrics)."""
214+
159215

160216
class NoOpServiceRestarter(ServiceRestarter):
161217
"""No-op service restarter."""
@@ -177,11 +233,16 @@ def __init__(
177233
):
178234
self.metrics = metrics
179235
self.metrics_port = metrics_port
236+
# ALL_AT_ONCE has no inter-node ordering dependency: restart every node, then wait for all
237+
# conditions concurrently. ONE_BY_ONE / NO_RESTART stay sequential (they prompt the user
238+
# between nodes).
239+
parallel = restart_strategy == RestartStrategy.ALL_AT_ONCE
180240
if restart_strategy == RestartStrategy.ONE_BY_ONE:
181241
check_function = self._check_between_each_restart
182242
base_restarter = RestartPodOnlyRestarter(namespace_and_instruction_args, service)
183243
elif restart_strategy == RestartStrategy.ALL_AT_ONCE:
184-
check_function = self._check_all_only_after_last_restart
244+
# check_function is unused in the parallel path (restart_all drives the phases directly).
245+
check_function = lambda instance_index: True
185246
base_restarter = RestartPodOnlyRestarter(namespace_and_instruction_args, service)
186247
elif restart_strategy == RestartStrategy.NO_RESTART:
187248
check_function = self._check_between_each_restart
@@ -190,7 +251,9 @@ def __init__(
190251
print_error(f"Invalid restart strategy: {restart_strategy} for WaitOnMetricRestarter.")
191252
sys.exit(1)
192253

193-
super().__init__(namespace_and_instruction_args, service, check_function, base_restarter)
254+
super().__init__(
255+
namespace_and_instruction_args, service, check_function, base_restarter, parallel
256+
)
194257

195258
def _check_between_each_restart(self, instance_index: int) -> bool:
196259
if not self._wait_for_pod_to_satisfy_condition(instance_index):
@@ -200,16 +263,21 @@ def _check_between_each_restart(self, instance_index: int) -> bool:
200263
return True
201264
return wait_until_y_or_n(f"Do you want to restart the next pod?")
202265

203-
def _check_all_only_after_last_restart(self, instance_index: int) -> bool:
204-
# Restart all nodes without waiting for confirmation.
205-
if instance_index < self.namespace_and_instruction_args.size() - 1:
206-
return True
266+
def _wait_all(self, indices: list[int], max_parallelism: int) -> None:
267+
# gate() starts a kubectl port-forward per node on a worker thread, which cannot install
268+
# signal handlers; install one here (main thread) so Ctrl-C tears all of them down.
269+
def signal_handler(signum, frame):
270+
terminate_all_port_forwards()
271+
sys.exit(0)
207272

208-
# After the last node has been restarted, wait for all pods to satisfy the condition.
209-
for instance_index in range(self.namespace_and_instruction_args.size()):
210-
if not self._wait_for_pod_to_satisfy_condition(instance_index):
211-
print_error(f"Failed waiting for condition(s) for Pod {instance_index}.")
212-
return True
273+
signal.signal(signal.SIGINT, signal_handler)
274+
signal.signal(signal.SIGTERM, signal_handler)
275+
276+
run_in_parallel(indices, self._wait_for_index, max_parallelism, self._label)
277+
278+
def _wait_for_index(self, instance_index: int) -> None:
279+
if not self._wait_for_pod_to_satisfy_condition(instance_index):
280+
print_error(f"Failed waiting for condition(s) for Pod {instance_index}.")
213281

214282
def _wait_for_pod_to_satisfy_condition(self, instance_index: int) -> bool:
215283
# The sleep is to prevent the case where we get the pod name of the old pod we just deleted
@@ -234,6 +302,7 @@ def _wait_for_pod_to_satisfy_condition(self, instance_index: int) -> bool:
234302
self.metrics_port,
235303
)
236304
metric_condition_gater.gate()
305+
return True
237306

238307
@staticmethod
239308
def _wait_for_pods_to_be_ready(
@@ -263,7 +332,11 @@ def _wait_for_pods_to_be_ready(
263332
f"{wait_timeout}s",
264333
]
265334
kubectl_args.extend(get_namespace_args(namespace, cluster))
266-
result = run_kubectl_command(kubectl_args, capture_output=False)
335+
# Capture (rather than stream) so output stays grouped per node under parallel
336+
# waits; progress is surfaced by run_in_parallel's heartbeat instead.
337+
result = run_kubectl_command(kubectl_args, capture_output=True)
338+
if result.stdout:
339+
print_colored(result.stdout.rstrip())
267340

268341
if result.returncode != 0:
269342
print_colored(

scripts/prod/set_node_revert_mode.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def set_revert_mode(
3939
restarter: ServiceRestarter,
4040
should_revert: bool,
4141
revert_up_to_block: int,
42+
max_parallelism: int,
4243
):
4344
config_overrides = {
4445
"revert_config.should_revert": should_revert,
@@ -49,6 +50,7 @@ def set_revert_mode(
4950
namespace_and_instruction_args,
5051
Service.Core,
5152
restarter,
53+
max_parallelism,
5254
)
5355

5456

@@ -57,6 +59,7 @@ def enable_revert_mode(
5759
context_list: Optional[list[str]],
5860
project_name: Optional[str],
5961
revert_up_to_block: int,
62+
max_parallelism: int,
6063
):
6164
print_colored(
6265
f"Enabling revert mode (reverting up to and including block {revert_up_to_block})",
@@ -93,12 +96,15 @@ def enable_revert_mode(
9396
8082,
9497
RestartStrategy.ALL_AT_ONCE,
9598
)
96-
set_revert_mode(namespace_and_instruction_args, restarter, True, revert_up_to_block)
99+
set_revert_mode(
100+
namespace_and_instruction_args, restarter, True, revert_up_to_block, max_parallelism
101+
)
97102

98103

99104
def disable_revert_mode(
100105
namespace_list: list[str],
101106
context_list: Optional[list[str]],
107+
max_parallelism: int,
102108
):
103109
print_colored("Disabling revert mode", Colors.YELLOW)
104110
namespace_and_instruction_args = NamespaceAndInstructionArgs(namespace_list, context_list)
@@ -111,6 +117,7 @@ def disable_revert_mode(
111117
False,
112118
# Setting to max block to max u64 to disable revert.
113119
2**64 - 1,
120+
max_parallelism,
114121
)
115122

116123

@@ -210,12 +217,14 @@ def main():
210217
context_list,
211218
args.project_name,
212219
revert_up_to_block,
220+
args.max_parallelism,
213221
)
214222

215223
if should_disable_revert:
216224
disable_revert_mode(
217225
namespace_list,
218226
context_list,
227+
args.max_parallelism,
219228
)
220229

221230

scripts/prod/take_nodes_out_of_observer_mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def main():
133133
namespace_and_instruction_args,
134134
Service.Core,
135135
restarter,
136+
args.max_parallelism,
136137
)
137138

138139

0 commit comments

Comments
 (0)