Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion lib/iris/src/iris/cluster/providers/gcp/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@
_RPC_CODE_RESOURCE_EXHAUSTED = 8


def _recommended_tpu_operation_timeout(accelerator_type: str) -> float:
"""Return an LRO timeout sized for TPU topology."""
topology = next((topology for topology in TPU_TOPOLOGIES if topology.name == accelerator_type), None)
if topology is None:
return _OPERATION_TIMEOUT
if topology.vm_count >= 256:
return 1800.0
if topology.vm_count >= 64:
return 900.0
return _OPERATION_TIMEOUT


# ============================================================================
# Data types
# ============================================================================
Expand Down Expand Up @@ -566,7 +578,7 @@ def tpu_create(self, request: TpuCreateRequest) -> TpuInfo:
data = resp.json()
op_name = data.get("name", "")
if op_name and "/operations/" in op_name:
self._wait_tpu_operation(op_name)
self._wait_tpu_operation(op_name, timeout=_recommended_tpu_operation_timeout(request.accelerator_type))

tpu_data = self._tpu_get(request.name, request.zone)
return _parse_tpu_info(tpu_data, request.zone)
Expand Down
103 changes: 98 additions & 5 deletions lib/iris/src/iris/cluster/providers/gcp/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
generate_slice_suffix,
)
from iris.cluster.service_mode import ServiceMode
from iris.cluster.types import get_tpu_topology
from iris.cluster.worker.env_probe import construct_worker_id
from iris.cluster.providers.remote_exec import GceRemoteExec
from iris.rpc import config_pb2
Expand Down Expand Up @@ -85,8 +86,59 @@ def _run():
# pd-ssd provides ~6000 IOPS vs ~38 on pd-standard, critical for controller DB
DEFAULT_BOOT_DISK_TYPE = "pd-ssd"
DEFAULT_TPU_CLOUD_READY_TIMEOUT = 600.0
DEFAULT_TPU_BOOTSTRAP_TIMEOUT = 600.0
RESERVED_TPU_ASSIGN_TIMEOUT = 4 * 60 * 60.0
RESERVED_TPU_PROVISION_TIMEOUT = 2 * 60 * 60.0
TPU_BOOTSTRAP_PROGRESS_LOG_INTERVAL = 60.0
TPU_BOOTSTRAP_DIAGNOSTIC_LIMIT = 8


def _recommended_tpu_cloud_ready_timeout(worker_count: int) -> float:
"""Return a cloud READY timeout sized for TPU pod startup."""
if worker_count >= 256:
return 1800.0
if worker_count >= 64:
return 900.0
return DEFAULT_TPU_CLOUD_READY_TIMEOUT


def _recommended_tpu_bootstrap_timeout(worker_count: int) -> float:
"""Return a worker health timeout sized for TPU pod bootstrap."""
if worker_count >= 256:
return 1800.0
if worker_count >= 64:
return 900.0
return DEFAULT_TPU_BOOTSTRAP_TIMEOUT


def _format_probe_error(error: BaseException) -> str:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just str(error) ?

message = str(error).strip()
if message:
return f"{type(error).__name__}: {message}"
return type(error).__name__


def _summarize_missing_workers(
missing_workers: list[str],
last_probe_errors: dict[str, str],
limit: int = TPU_BOOTSTRAP_DIAGNOSTIC_LIMIT,
) -> str:
"""Summarize missing TPU workers and their last probe errors for logs."""
if not missing_workers:
return "none"

rendered: list[str] = []
for worker_id in missing_workers[:limit]:
probe_error = last_probe_errors.get(worker_id)
if probe_error:
rendered.append(f"{worker_id} ({probe_error})")
else:
rendered.append(worker_id)

if len(missing_workers) > limit:
rendered.append(f"... +{len(missing_workers) - limit} more")

return ", ".join(rendered)


def _wait_for_queued_resource_activation(
Expand Down Expand Up @@ -734,7 +786,7 @@ def _run_tpu_bootstrap(
worker_config: config_pb2.WorkerConfig,
poll_interval: float = 10.0,
cloud_ready_timeout: float | None = None,
bootstrap_timeout: float = 600.0,
bootstrap_timeout: float | None = None,
queued_resource_poll_interval: float = 60.0,
) -> None:
"""Monitor TPU startup-script bootstrap via health endpoint polling.
Expand All @@ -744,9 +796,29 @@ def _run_tpu_bootstrap(
Phase 2: Poll worker health endpoints until all respond healthy.
On timeout: query Cloud Logging for [iris-init] entries for diagnostics.
"""
try:
worker_count = get_tpu_topology(handle._accelerator_variant).vm_count
except ValueError as e:
raise InfraError(
f"Unknown TPU topology '{handle._accelerator_variant}' for slice {handle.slice_id}. "
"Cannot size bootstrap timeouts."
) from e

effective_cloud_ready_timeout = cloud_ready_timeout
if effective_cloud_ready_timeout is None:
effective_cloud_ready_timeout = DEFAULT_TPU_CLOUD_READY_TIMEOUT
effective_cloud_ready_timeout = _recommended_tpu_cloud_ready_timeout(worker_count)

effective_bootstrap_timeout = bootstrap_timeout
if effective_bootstrap_timeout is None:
effective_bootstrap_timeout = _recommended_tpu_bootstrap_timeout(worker_count)

logger.info(
"Using TPU bootstrap timeouts for %s: cloud_ready_timeout=%ss bootstrap_timeout=%ss worker_count=%d",
handle.slice_id,
effective_cloud_ready_timeout,
effective_bootstrap_timeout,
worker_count,
)

# Phase 0: If this is a queued resource (reserved TPU), wait for ACTIVE
# before polling the TPU VM state. The queued resource may sit in QUEUED
Expand Down Expand Up @@ -779,7 +851,9 @@ def _run_tpu_bootstrap(
workers = cloud_status.workers
worker_addrs = [(w.worker_id, w.internal_address) for w in workers]
healthy_workers: set[str] = set()
health_deadline = Deadline.from_now(Duration.from_seconds(bootstrap_timeout))
last_probe_errors: dict[str, str] = {}
health_deadline = Deadline.from_now(Duration.from_seconds(effective_bootstrap_timeout))
next_progress_log = time.monotonic() + TPU_BOOTSTRAP_PROGRESS_LOG_INTERVAL

logger.info(
"Polling health endpoints for %d workers in slice %s",
Expand All @@ -798,14 +872,33 @@ def _run_tpu_bootstrap(
)
if resp.status == 200:
healthy_workers.add(worker_id)
last_probe_errors.pop(worker_id, None)
logger.info("Worker %s is healthy", worker_id)
except (urllib.error.URLError, urllib.error.HTTPError, OSError, TimeoutError):
pass
except (urllib.error.URLError, urllib.error.HTTPError, OSError, TimeoutError) as e:
last_probe_errors[worker_id] = _format_probe_error(e)

if len(healthy_workers) == len(worker_addrs):
break
if time.monotonic() >= next_progress_log:
missing_workers = [worker_id for worker_id, _addr in worker_addrs if worker_id not in healthy_workers]
logger.info(
"TPU bootstrap progress for %s: %d/%d workers healthy; missing=%s",
handle.slice_id,
len(healthy_workers),
len(worker_addrs),
_summarize_missing_workers(missing_workers, last_probe_errors),
)
next_progress_log = time.monotonic() + TPU_BOOTSTRAP_PROGRESS_LOG_INTERVAL
time.sleep(poll_interval)
else:
missing_workers = [worker_id for worker_id, _addr in worker_addrs if worker_id not in healthy_workers]
logger.error(
"TPU bootstrap stalled for %s: %d/%d workers healthy; missing=%s",
handle.slice_id,
len(healthy_workers),
len(worker_addrs),
_summarize_missing_workers(missing_workers, last_probe_errors),
)
_fetch_bootstrap_logs(gcp_service, handle)
raise InfraError(
f"TPU slice {handle.slice_id} bootstrap timed out: "
Expand Down
6 changes: 5 additions & 1 deletion lib/iris/tests/cluster/providers/gcp/test_gcp_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import pytest

from iris.cluster.providers.gcp.fake import InMemoryGcpService
from iris.cluster.providers.gcp.service import CloudGcpService, TpuCreateRequest, VmCreateRequest
from iris.cluster.providers.gcp.service import (
CloudGcpService,
TpuCreateRequest,
VmCreateRequest,
)
from iris.rpc import config_pb2
from iris.cluster.providers.types import InfraError, QuotaExhaustedError, ResourceNotFoundError
from iris.cluster.service_mode import ServiceMode
Expand Down
17 changes: 15 additions & 2 deletions lib/iris/tests/cluster/providers/gcp/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from iris.cluster.providers.gcp.workers import (
GcpWorkerProvider,
_run_vm_slice_bootstrap,
_summarize_missing_workers,
_validate_slice_config,
)
from iris.cluster.providers.manual.provider import ManualControllerProvider, ManualWorkerProvider
Expand Down Expand Up @@ -930,12 +931,24 @@ def test_gcp_tpu_slice_os_login_prefers_external_ip_for_direct_ssh():


# =============================================================================
# Section 6: VM Slice Bootstrap Tests
# Section 6: TPU/VM Bootstrap Tests
#
# Tests for _run_vm_slice_bootstrap with split timeouts and health probing.
# Tests for bootstrap timeout sizing, diagnostics, and VM health probing.
# =============================================================================


def test_summarize_missing_workers_includes_probe_errors():
summary = _summarize_missing_workers(
["worker-250", "worker-251", "worker-252"],
{"worker-250": "TimeoutError: timed out", "worker-252": "URLError: refused"},
limit=2,
)

assert "worker-250 (TimeoutError: timed out)" in summary
assert "worker-251" in summary
assert "+1 more" in summary


def _make_vm_slice_for_bootstrap(
gcp_service: InMemoryGcpService,
zone: str = "us-central2-b",
Expand Down
Loading