Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 127 additions & 4 deletions scripts/prod/common_lib.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#!/usr/bin/env python3

import argparse
import concurrent.futures
import subprocess
import sys
import threading
import time
from enum import Enum
from typing import Optional
from typing import Callable, Optional, TypeVar


class Colors(Enum):
Expand All @@ -17,15 +20,135 @@ class Colors(Enum):
RESET = "\033[0m"


def print_colored(message: str, color: Colors = Colors.RESET, file=sys.stdout) -> None:
"""Print message with color"""
print(f"{color.value}{message}{Colors.RESET.value}", file=file)
# Thread-local output sink. When `run_in_parallel` runs a worker, it sets `buffer` on this so that
# the worker's log lines are captured per-thread instead of interleaving on the shared stdout/stderr.
# When no buffer is set (the common, single-threaded case), logging prints immediately as before.
_output_sink = threading.local()

# Serializes writes to the real stdout/stderr so buffered blocks and heartbeats don't interleave.
_print_lock = threading.Lock()


def print_colored(message: str, color: Colors = Colors.RESET, file=None) -> None:
"""Print message with color.

`file` is resolved to the current `sys.stdout` when None (resolving at call time rather than
binding a default at definition time, so output redirection is honored).

If the current thread has a buffer set on `_output_sink` (i.e. it is a `run_in_parallel`
worker), the formatted line is appended to that buffer instead of being printed, so it can be
flushed as one grouped block when the worker finishes.
"""
if file is None:
file = sys.stdout
formatted = f"{color.value}{message}{Colors.RESET.value}"
buffer = getattr(_output_sink, "buffer", None)
if buffer is not None:
buffer.append((formatted, file))
else:
print(formatted, file=file)


def print_error(message: str) -> None:
print_colored(message, color=Colors.RED, file=sys.stderr)


T = TypeVar("T")
R = TypeVar("R")


def run_in_parallel(
items: list[T],
worker: Callable[[T], R],
max_parallelism: int,
label: Callable[[T], str],
heartbeat_interval_seconds: int = 5,
) -> list[R]:
"""Run `worker(item)` for each item concurrently, capped at `max_parallelism` threads.

Threads (not processes) are used because the work is I/O-bound (kubectl/urllib calls that
release the GIL).

Output: each worker's log lines (emitted via `print_colored`/`print_error`) are buffered and
flushed as one block, prefixed with `label(item)`, when that item finishes — so concurrent
output stays readable. While items are still running, a heartbeat naming the not-yet-done items
is printed every `heartbeat_interval_seconds`.

Errors: a worker that raises (or calls `sys.exit()`, which raises `SystemExit`) is recorded as
a failure for its item; remaining items still run, and once all have settled a summary is
printed and the process exits with code 1. `KeyboardInterrupt` is not treated as an item
failure — it propagates so Ctrl-C aborts the whole run.

Returns the per-item results in the same order as `items`.
"""
if not items:
return []

num_items = len(items)
results: list[Optional[R]] = [None] * num_items
errors: dict[int, BaseException] = {}

def run_one(item: T) -> R:
buffer: list[tuple[str, object]] = []
_output_sink.buffer = buffer
try:
return worker(item)
finally:
# Stop capturing before flushing so the header itself prints to the real stdout.
_output_sink.buffer = None
with _print_lock:
print_colored(f"===== {label(item)} =====", Colors.BLUE)
for text, file in buffer:
print(text, file=file)

with concurrent.futures.ThreadPoolExecutor(
max_workers=min(max_parallelism, num_items)
) as executor:
future_to_index = {
executor.submit(run_one, item): index for index, item in enumerate(items)
}
pending_futures = set(future_to_index.keys())
last_heartbeat = time.monotonic()

while pending_futures:
done_futures, pending_futures = concurrent.futures.wait(
pending_futures,
timeout=heartbeat_interval_seconds,
return_when=concurrent.futures.FIRST_COMPLETED,
)
for future in done_futures:
index = future_to_index[future]
try:
results[index] = future.result()
except KeyboardInterrupt:
# Ctrl-C is not an item failure; let it abort the whole run.
raise
Comment thread
matanl-starkware marked this conversation as resolved.
except BaseException as error:
errors[index] = error
Comment thread
cursor[bot] marked this conversation as resolved.

now = time.monotonic()
if pending_futures and now - last_heartbeat >= heartbeat_interval_seconds:
running_labels = ", ".join(
label(items[future_to_index[future]]) for future in pending_futures
)
num_done = num_items - len(pending_futures)
with _print_lock:
print_colored(
f"[{num_done}/{num_items} done] still waiting on: {running_labels}",
Colors.YELLOW,
)
last_heartbeat = now
Comment thread
matanl-starkware marked this conversation as resolved.

if errors:
with _print_lock:
print_error(f"{len(errors)} of {num_items} parallel operation(s) failed:")
for index in sorted(errors):
print_error(f" - {label(items[index])}: {errors[index]}")
sys.exit(1)

return results


class RestartStrategy(Enum):
"""Strategy for restarting nodes."""

Expand Down
80 changes: 80 additions & 0 deletions scripts/prod/test_run_in_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
"""Unit tests for `run_in_parallel` in common_lib (no kubectl / cluster needed)."""

import time

import pytest
from common_lib import print_colored, run_in_parallel


def _label(item) -> str:
return f"node-{item}"


def test_results_match_input_order_regardless_of_completion_order():
# Earlier items sleep longer, so they finish last — results must still be in input order.
def worker(item: int) -> int:
time.sleep((5 - item) * 0.02)
return item * 10

results = run_in_parallel([0, 1, 2, 3, 4], worker, max_parallelism=4, label=_label)
assert results == [0, 10, 20, 30, 40]


def test_empty_items_returns_empty_list():
calls = []
results = run_in_parallel([], lambda item: calls.append(item), max_parallelism=4, label=_label)
assert results == []
assert calls == []


def test_worker_output_is_buffered_grouped_and_labeled(capsys):
def worker(item: int) -> None:
print_colored(f"line-a from {item}")
print_colored(f"line-b from {item}")

run_in_parallel([0, 1], worker, max_parallelism=2, label=_label)
out = capsys.readouterr().out

# Each node's lines are flushed contiguously after its own header (no interleaving between
# nodes), even though both ran concurrently.
for item in (0, 1):
header_pos = out.index(f"node-{item}")
line_a_pos = out.index(f"line-a from {item}")
line_b_pos = out.index(f"line-b from {item}")
assert header_pos < line_a_pos < line_b_pos
# Nothing from the other node appears between this node's two lines.
other = 1 - item
assert f"from {other}" not in out[line_a_pos:line_b_pos]


def test_heartbeat_lists_still_running_items(capsys):
# One slow item keeps the pool busy long enough for at least one heartbeat (interval 1s).
def worker(item: int) -> int:
if item == 0:
time.sleep(2.5)
return item

run_in_parallel([0, 1], worker, max_parallelism=2, label=_label, heartbeat_interval_seconds=1)
out = capsys.readouterr().out
assert "still waiting on: node-0" in out
assert "done]" in out


def test_failing_worker_is_reported_and_exits_nonzero(capsys):
def worker(item: int) -> int:
if item == 1:
raise ValueError("boom from 1")
return item

with pytest.raises(SystemExit) as exit_info:
run_in_parallel([0, 1, 2], worker, max_parallelism=3, label=_label)

assert exit_info.value.code == 1
err = capsys.readouterr().err
assert "1 of 3 parallel operation(s) failed" in err
assert "node-1: boom from 1" in err


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__, "-v"]))
Loading