Skip to content

Commit a4cfe27

Browse files
committed
[iris] Adaptive rolling worker restart with observation window
Replace the simple one-at-a-time worker-restart with progressive batch sizing (1, 2, 4, ... up to --max-batch). Each batch waits for workers to become healthy, then observes for --observation-window seconds checking for heartbeat failures before advancing. Aborts immediately on any failure. --worker-id is now repeatable to target specific workers. Also removes the unrecognized disable-project-excludes-heuristics key from pyrefly config.
1 parent 37cc8a3 commit a4cfe27

3 files changed

Lines changed: 139 additions & 60 deletions

File tree

lib/iris/src/iris/cli/cluster.py

Lines changed: 131 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -761,70 +761,151 @@ def controller_restart(ctx, skip_checkpoint: bool, checkpoint_timeout: int):
761761

762762

763763
@controller.command("worker-restart")
764-
@click.option("--worker-id", default=None, help="Specific worker to restart (default: all)")
765-
@click.option("--timeout", type=int, default=120, help="Max seconds to wait per worker restart")
764+
@click.option("--worker-id", multiple=True, help="Worker(s) to restart (repeatable; default: all)")
765+
@click.option("--timeout", type=int, default=120, help="Max seconds to wait per worker to become healthy")
766+
@click.option("--max-batch", type=int, default=64, help="Maximum workers to restart concurrently")
767+
@click.option(
768+
"--observation-window",
769+
type=int,
770+
default=60,
771+
help="Seconds to observe restarted workers for failures before advancing",
772+
)
766773
@click.pass_context
767-
def worker_restart(ctx, worker_id: str | None, timeout: int):
768-
"""Rolling restart of workers without disrupting running tasks.
769-
770-
Restarts workers one at a time, waiting for each to re-register before
771-
proceeding. Running Docker containers are preserved and adopted by the
772-
new worker process.
774+
def worker_restart(
775+
ctx,
776+
worker_id: tuple[str, ...],
777+
timeout: int,
778+
max_batch: int,
779+
observation_window: int,
780+
):
781+
"""Rolling restart of workers with adaptive batch sizing.
782+
783+
Restarts workers in progressively larger batches (1, 2, 4, ... up to
784+
--max-batch). After each batch, waits for workers to become healthy, then
785+
observes them for --observation-window seconds to catch post-restart
786+
failures. Aborts immediately if any worker fails to come back healthy or
787+
develops failures during observation.
788+
789+
Running Docker containers are preserved and adopted by the new worker
790+
process, so tasks are not disrupted.
773791
"""
774792
controller_url = require_controller_url(ctx)
775793

776794
with rpc_client(controller_url) as client:
777-
# Get current workers
778795
workers_resp = client.list_workers(controller_pb2.Controller.ListWorkersRequest())
779-
workers = workers_resp.workers
796+
all_workers = workers_resp.workers
780797

781798
if worker_id:
782-
workers = [w for w in workers if w.worker_id == worker_id]
783-
if not workers:
784-
click.echo(f"Worker {worker_id} not found", err=True)
799+
requested = set(worker_id)
800+
workers = [w for w in all_workers if w.worker_id in requested]
801+
missing = requested - {w.worker_id for w in workers}
802+
if missing:
803+
click.echo(f"Workers not found: {', '.join(sorted(missing))}", err=True)
785804
raise SystemExit(1)
805+
else:
806+
workers = list(all_workers)
786807

787808
if not workers:
788809
click.echo("No workers to restart")
789810
return
790811

791-
click.echo(f"Restarting {len(workers)} worker(s) (timeout={timeout}s per worker)")
812+
worker_ids = [w.worker_id for w in workers]
813+
total = len(worker_ids)
814+
click.echo(
815+
f"Restarting {total} worker(s) "
816+
f"(timeout={timeout}s, observation={observation_window}s, max_batch={max_batch})"
817+
)
792818

793819
succeeded = 0
794-
failed = 0
795-
796-
for worker in workers:
797-
wid = worker.worker_id
798-
click.echo(f"\nRestarting worker {wid}...")
799-
800-
resp = client.restart_worker(
801-
controller_pb2.Controller.RestartWorkerRequest(worker_id=wid),
802-
timeout_ms=timeout * 1000,
803-
)
804-
805-
if not resp.accepted:
806-
click.echo(f" Failed: {resp.error}", err=True)
807-
failed += 1
808-
continue
809-
810-
# Poll until the worker re-registers as healthy
811-
def _worker_healthy(target_id: str = wid) -> bool:
812-
try:
813-
resp = client.list_workers(controller_pb2.Controller.ListWorkersRequest())
814-
return any(w.worker_id == target_id and w.healthy for w in resp.workers)
815-
except Exception:
816-
return False
817-
818-
reregistered = ExponentialBackoff(initial=5.0, maximum=5.0, jitter=0.0).wait_until(
819-
_worker_healthy,
820-
timeout=Duration.from_seconds(timeout),
821-
)
822-
823-
if reregistered:
824-
click.echo(f" Worker {wid} restarted successfully")
825-
succeeded += 1
826-
else:
827-
click.echo(f" Worker {wid} did not re-register within {timeout}s", err=True)
828-
failed += 1
820+
batch_size = 1
821+
offset = 0
822+
823+
while offset < total:
824+
batch = worker_ids[offset : offset + batch_size]
825+
click.echo(f"\n--- Batch of {len(batch)} (workers {offset + 1}-{offset + len(batch)} of {total}) ---")
826+
827+
# Issue restart RPCs for the batch
828+
for wid in batch:
829+
click.echo(f" Restarting {wid}...")
830+
resp = client.restart_worker(
831+
controller_pb2.Controller.RestartWorkerRequest(worker_id=wid),
832+
timeout_ms=timeout * 1000,
833+
)
834+
if not resp.accepted:
835+
click.echo(f" ABORT: restart rejected for {wid}: {resp.error}", err=True)
836+
_print_summary(succeeded, total - succeeded, offset)
837+
raise SystemExit(1)
838+
839+
# Wait for all workers in the batch to become healthy
840+
click.echo(f" Waiting for {len(batch)} worker(s) to become healthy...")
841+
unhealthy = _wait_for_workers_healthy(client, set(batch), timeout)
842+
if unhealthy:
843+
click.echo(
844+
f" ABORT: workers did not become healthy within {timeout}s: " f"{', '.join(sorted(unhealthy))}",
845+
err=True,
846+
)
847+
_print_summary(succeeded, total - succeeded, offset)
848+
raise SystemExit(1)
849+
850+
click.echo(f" All {len(batch)} worker(s) healthy. Observing for {observation_window}s...")
851+
time.sleep(observation_window)
852+
853+
# Re-check health after observation window
854+
failed_workers = _check_worker_health(client, set(batch))
855+
if failed_workers:
856+
click.echo(
857+
f" ABORT: workers developed failures during observation: "
858+
f"{', '.join(f'{wid} ({msg})' for wid, msg in sorted(failed_workers))}",
859+
err=True,
860+
)
861+
_print_summary(succeeded, total - succeeded, offset)
862+
raise SystemExit(1)
863+
864+
succeeded += len(batch)
865+
offset += len(batch)
866+
click.echo(f" Batch OK ({succeeded}/{total} complete)")
867+
868+
# Double batch size for next round, capped at max_batch
869+
batch_size = min(batch_size * 2, max_batch)
870+
871+
click.echo(f"\nDone: {succeeded}/{total} workers restarted successfully")
872+
873+
874+
def _wait_for_workers_healthy(client, worker_ids: set[str], timeout: int) -> set[str]:
875+
"""Poll until all workers in the set are healthy. Returns IDs that failed to become healthy."""
876+
remaining = set(worker_ids)
877+
backoff = ExponentialBackoff(initial=5.0, maximum=5.0, jitter=0.0)
878+
879+
def _all_healthy() -> bool:
880+
try:
881+
resp = client.list_workers(controller_pb2.Controller.ListWorkersRequest())
882+
for w in resp.workers:
883+
if w.worker_id in remaining and w.healthy:
884+
remaining.discard(w.worker_id)
885+
except Exception:
886+
pass
887+
return len(remaining) == 0
888+
889+
backoff.wait_until(_all_healthy, timeout=Duration.from_seconds(timeout))
890+
return remaining
891+
892+
893+
def _check_worker_health(client, worker_ids: set[str]) -> list[tuple[str, str]]:
894+
"""Check that all workers are still healthy. Returns list of (worker_id, problem) for failures."""
895+
failures: list[tuple[str, str]] = []
896+
try:
897+
resp = client.list_workers(controller_pb2.Controller.ListWorkersRequest())
898+
by_id = {w.worker_id: w for w in resp.workers}
899+
for wid in worker_ids:
900+
w = by_id.get(wid)
901+
if w is None:
902+
failures.append((wid, "disappeared"))
903+
elif not w.healthy:
904+
failures.append((wid, w.status_message or f"{w.consecutive_failures} consecutive failures"))
905+
except Exception as e:
906+
failures.append(("(rpc)", str(e)))
907+
return failures
908+
829909

830-
click.echo(f"\nDone: {succeeded} succeeded, {failed} failed")
910+
def _print_summary(succeeded: int, remaining: int, offset: int):
911+
click.echo(f"\nSummary: {succeeded} succeeded, {remaining} remaining (aborted at worker {offset + 1})")

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,6 @@ disable-search-path-heuristics = true
127127
# which would auto-detect workspace members as site packages and exclude them
128128
skip-interpreter-query = true
129129
use-ignore-files = false
130-
# Worktrees often live under hidden parent directories (e.g. `.codex/worktrees/*`).
131-
# Disable pyrefly's built-in exclude heuristics so hidden ancestors do not exclude
132-
# the entire project when running from those paths.
133-
disable-project-excludes-heuristics = true
134130

135131
# Exclude non-production code from type checking
136132
project-excludes = [

uv.lock

Lines changed: 8 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)