Skip to content

Commit e55789b

Browse files
authored
[Iris] Handle reserved TPU queue timeouts explicitly (#4764)
Split reserved TPU bootstrap into queued-resource assignment, queued-resource provisioning, cloud readiness, and worker health phases. Increase reserved provisioning tolerance, and cancel queued resources immediately when bootstrap fails so abandoned reservations do not produce zombie workers.
1 parent 61599e3 commit e55789b

File tree

2 files changed

+70
-19
lines changed

2 files changed

+70
-19
lines changed

lib/iris/src/iris/cluster/providers/gcp/handles.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,11 @@ def terminate(self, *, wait: bool = False) -> None:
374374
logger.info("Terminating TPU (async): %s", self._slice_id)
375375
self._gcp_service.tpu_delete(self._slice_id, self._zone)
376376

377+
def cleanup_bootstrap_failure(self) -> None:
378+
"""Clean up provider state after bootstrap fails."""
379+
if self.is_queued_resource:
380+
self.terminate()
381+
377382

378383
class GcpVmSliceHandle:
379384
"""Handle to a single-VM GCE-backed slice."""
@@ -481,3 +486,6 @@ def _describe_cloud(self) -> SliceStatus:
481486
def terminate(self, *, wait: bool = False) -> None:
482487
logger.info("Terminating VM slice: %s (vm=%s)", self._slice_id, self._vm_name)
483488
self._gcp_service.vm_delete(self._vm_name, self._zone, wait=wait)
489+
490+
def cleanup_bootstrap_failure(self) -> None:
491+
"""Clean up provider state after bootstrap fails."""

lib/iris/src/iris/cluster/providers/gcp/workers.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ def _run():
6666
bootstrap_fn()
6767
except Exception as e:
6868
logger.error("Bootstrap failed for slice %s: %s", handle.slice_id, e)
69+
try:
70+
handle.cleanup_bootstrap_failure()
71+
except Exception as cleanup_error:
72+
logger.warning(
73+
"Failed bootstrap cleanup for slice %s: %s",
74+
handle.slice_id,
75+
cleanup_error,
76+
)
6977
with handle._bootstrap_lock:
7078
handle._bootstrap_state = CloudSliceState.FAILED
7179

@@ -76,6 +84,50 @@ def _run():
7684
DEFAULT_BOOT_DISK_SIZE_GB = 50
7785
# pd-ssd provides ~6000 IOPS vs ~38 on pd-standard, critical for controller DB
7886
DEFAULT_BOOT_DISK_TYPE = "pd-ssd"
87+
DEFAULT_TPU_CLOUD_READY_TIMEOUT = 600.0
88+
RESERVED_TPU_ASSIGN_TIMEOUT = 4 * 60 * 60.0
89+
RESERVED_TPU_PROVISION_TIMEOUT = 2 * 60 * 60.0
90+
91+
92+
def _wait_for_queued_resource_activation(
93+
gcp_service: GcpService,
94+
handle: GcpSliceHandle,
95+
poll_interval: float,
96+
) -> None:
97+
"""Wait for a reserved TPU queued resource to be assigned and provisioned."""
98+
assign_deadline = Deadline.from_now(Duration.from_seconds(RESERVED_TPU_ASSIGN_TIMEOUT))
99+
provision_deadline: Deadline | None = None
100+
101+
while True:
102+
qr = gcp_service.queued_resource_describe(handle.slice_id, handle.zone)
103+
if qr is None:
104+
raise InfraError(f"Queued resource {handle.slice_id} not found")
105+
if qr.state == "ACTIVE":
106+
logger.info("Queued resource %s is ACTIVE, proceeding to TPU bootstrap", handle.slice_id)
107+
return
108+
if qr.state in ("FAILED", "SUSPENDED", "DELETING"):
109+
raise InfraError(f"Queued resource {handle.slice_id} entered state {qr.state}")
110+
111+
if qr.state == "PROVISIONING":
112+
if provision_deadline is None:
113+
logger.info(
114+
"Queued resource %s entered PROVISIONING; allowing up to %ss for ACTIVE",
115+
handle.slice_id,
116+
RESERVED_TPU_PROVISION_TIMEOUT,
117+
)
118+
provision_deadline = Deadline.from_now(Duration.from_seconds(RESERVED_TPU_PROVISION_TIMEOUT))
119+
elif provision_deadline.expired():
120+
raise InfraError(
121+
f"Queued resource {handle.slice_id} did not become ACTIVE "
122+
f"within {RESERVED_TPU_PROVISION_TIMEOUT}s after entering PROVISIONING"
123+
)
124+
elif assign_deadline.expired():
125+
raise InfraError(
126+
f"Queued resource {handle.slice_id} did not enter PROVISIONING " f"within {RESERVED_TPU_ASSIGN_TIMEOUT}s"
127+
)
128+
129+
logger.info("Queued resource %s is %s, waiting...", handle.slice_id, qr.state)
130+
time.sleep(poll_interval)
79131

80132

81133
def _gcp_instance_metadata(
@@ -681,7 +733,7 @@ def _run_tpu_bootstrap(
681733
handle: GcpSliceHandle,
682734
worker_config: config_pb2.WorkerConfig,
683735
poll_interval: float = 10.0,
684-
cloud_ready_timeout: float = 600.0,
736+
cloud_ready_timeout: float | None = None,
685737
bootstrap_timeout: float = 600.0,
686738
queued_resource_poll_interval: float = 60.0,
687739
) -> None:
@@ -692,28 +744,19 @@ def _run_tpu_bootstrap(
692744
Phase 2: Poll worker health endpoints until all respond healthy.
693745
On timeout: query Cloud Logging for [iris-init] entries for diagnostics.
694746
"""
695-
# Single deadline covers Phase 0 (queued resource wait) + Phase 1 (cloud READY).
696-
cloud_deadline = Deadline.from_now(Duration.from_seconds(cloud_ready_timeout))
747+
effective_cloud_ready_timeout = cloud_ready_timeout
748+
if effective_cloud_ready_timeout is None:
749+
effective_cloud_ready_timeout = DEFAULT_TPU_CLOUD_READY_TIMEOUT
697750

698751
# Phase 0: If this is a queued resource (reserved TPU), wait for ACTIVE
699752
# before polling the TPU VM state. The queued resource may sit in QUEUED
700753
# or PROVISIONING for an extended period.
701754
if handle.is_queued_resource:
702-
while not cloud_deadline.expired():
703-
qr = gcp_service.queued_resource_describe(handle.slice_id, handle.zone)
704-
if qr is None:
705-
raise InfraError(f"Queued resource {handle.slice_id} not found")
706-
if qr.state == "ACTIVE":
707-
logger.info("Queued resource %s is ACTIVE, proceeding to TPU bootstrap", handle.slice_id)
708-
break
709-
if qr.state in ("FAILED", "SUSPENDED"):
710-
raise InfraError(f"Queued resource {handle.slice_id} entered state {qr.state}")
711-
logger.info("Queued resource %s is %s, waiting...", handle.slice_id, qr.state)
712-
time.sleep(queued_resource_poll_interval)
713-
else:
714-
raise InfraError(
715-
f"Queued resource {handle.slice_id} did not become ACTIVE " f"within {cloud_ready_timeout}s"
716-
)
755+
_wait_for_queued_resource_activation(gcp_service, handle, queued_resource_poll_interval)
756+
757+
# Phase 1: once the QR is ACTIVE (or immediately for non-queued TPUs),
758+
# wait for the TPU VM to reach READY with all worker IPs.
759+
cloud_deadline = Deadline.from_now(Duration.from_seconds(effective_cloud_ready_timeout))
717760

718761
while not cloud_deadline.expired():
719762
cloud_status = handle._describe_cloud()
@@ -731,7 +774,7 @@ def _run_tpu_bootstrap(
731774
)
732775
time.sleep(poll_interval)
733776
else:
734-
raise InfraError(f"Slice {handle.slice_id} did not reach cloud READY within {cloud_ready_timeout}s")
777+
raise InfraError(f"Slice {handle.slice_id} did not reach cloud READY within {effective_cloud_ready_timeout}s")
735778

736779
workers = cloud_status.workers
737780
worker_addrs = [(w.worker_id, w.internal_address) for w in workers]

0 commit comments

Comments
 (0)