Skip to content

Commit 2f52e4c

Browse files
rjpowerclaude
andauthored
[iris] Reclaim dead cloud slices and purge orphan slice rows at boot (#5720)
list_all_slices now returns (handle, state) pairs across every cloud state, and restore_autoscaler_state partitions on it: live slices feed the autoscaler, dead ones are async-terminated. Discarded checkpoint slices are also deleted from the slices table so SQLite no longer accumulates ghost rows that the autoscaler cannot see. --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 1f577fa commit 2f52e4c

9 files changed

Lines changed: 164 additions & 98 deletions

File tree

lib/iris/src/iris/cluster/controller/autoscaler/recovery.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import logging
9+
import threading
910
from dataclasses import dataclass
1011

1112
from sqlalchemy import select
@@ -20,7 +21,9 @@
2021
from iris.cluster.controller.db import ControllerDB
2122
from iris.cluster.controller.schema import scaling_groups_table, slices_table, workers_table
2223
from iris.cluster.providers.protocols import WorkerInfraProvider
23-
from iris.cluster.providers.types import SliceHandle
24+
from iris.cluster.providers.types import CloudSliceState, SliceHandle
25+
26+
_LIVE_CLOUD_STATES = frozenset({CloudSliceState.CREATING, CloudSliceState.READY, CloudSliceState.REPAIRING})
2427

2528
logger = logging.getLogger(__name__)
2629

@@ -112,10 +115,12 @@ def restore_autoscaler_state(
112115
) -> dict[str, TrackedWorker]:
113116
"""Restore scaling groups and tracked workers from a checkpoint."""
114117

115-
all_cloud_slices = platform.list_all_slices()
116118
cloud_by_group: dict[str, list[SliceHandle]] = {}
117-
for handle in all_cloud_slices:
118-
cloud_by_group.setdefault(handle.scale_group, []).append(handle)
119+
for listed in platform.list_all_slices():
120+
if listed.state not in _LIVE_CLOUD_STATES:
121+
_reclaim_dead_slice(listed.handle, listed.state)
122+
continue
123+
cloud_by_group.setdefault(listed.handle.scale_group, []).append(listed.handle)
119124

120125
for group_snapshot in checkpoint.group_snapshots.values():
121126
group = groups.get(group_snapshot.name)
@@ -135,5 +140,24 @@ def restore_autoscaler_state(
135140
last_scale_up=restore_result.last_scale_up,
136141
last_scale_down=restore_result.last_scale_down,
137142
)
143+
group.purge_persisted_slice_rows(restore_result.discarded_slice_ids)
138144

139145
return restore_tracked_workers(checkpoint.tracked_worker_rows)
146+
147+
148+
def _reclaim_dead_slice(handle: SliceHandle, state: CloudSliceState) -> None:
149+
"""Best-effort terminate of a dead slice in a daemon thread.
150+
151+
Boot recovery must not block on or fail because of a stale cloud resource:
152+
terminate() can hit transient API errors and is not guaranteed to be fast.
153+
Errors are logged; on the next restart the slice will surface again.
154+
"""
155+
logger.info("Reclaiming dead slice %s (state=%s, zone=%s)", handle.slice_id, state, handle.zone)
156+
157+
def _run() -> None:
158+
try:
159+
handle.terminate()
160+
except Exception as e:
161+
logger.warning("Failed to terminate dead slice %s: %s", handle.slice_id, e)
162+
163+
threading.Thread(target=_run, name=f"reclaim-{handle.slice_id}", daemon=True).start()

lib/iris/src/iris/cluster/controller/autoscaler/scaling_group.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,13 @@ def _db_clear_slices(self) -> None:
409409
with self._db.transaction() as cur:
410410
cur.execute(delete(slices_table).where(slices_table.c.scale_group == self.name))
411411

412+
def purge_persisted_slice_rows(self, slice_ids: Sequence[str]) -> None:
413+
"""Delete the named slice rows from the slices table in a single transaction."""
414+
if self._db is None or not slice_ids:
415+
return
416+
with self._db.transaction() as cur:
417+
cur.execute(delete(slices_table).where(slices_table.c.slice_id.in_(list(slice_ids))))
418+
412419
@property
413420
def platform(self) -> WorkerInfraProvider:
414421
"""Worker infrastructure provider for this scale group."""
@@ -1213,7 +1220,7 @@ class ScalingGroupRestoreResult:
12131220
"""Result of restoring a single scaling group from checkpoint metadata."""
12141221

12151222
slices: dict[str, SliceState] = field(default_factory=dict)
1216-
discarded_count: int = 0
1223+
discarded_slice_ids: list[str] = field(default_factory=list)
12171224
adopted_count: int = 0
12181225
last_scale_up: Timestamp = field(default_factory=lambda: Timestamp.from_ms(0))
12191226
last_scale_down: Timestamp = field(default_factory=lambda: Timestamp.from_ms(0))
@@ -1234,7 +1241,7 @@ def restore_scaling_group(
12341241
cloud_handle = cloud_by_id.get(slice_id)
12351242
if cloud_handle is None:
12361243
logger.info("Scaling group %s: discarding slice %s (missing from cloud)", group_snapshot.name, slice_id)
1237-
result.discarded_count += 1
1244+
result.discarded_slice_ids.append(slice_id)
12381245
continue
12391246

12401247
try:
@@ -1274,7 +1281,7 @@ def restore_scaling_group(
12741281
"Restored scaling group %s: %d slices (%d discarded, %d adopted)",
12751282
group_snapshot.name,
12761283
len(result.slices),
1277-
result.discarded_count,
1284+
len(result.discarded_slice_ids),
12781285
result.adopted_count,
12791286
)
12801287
return result

lib/iris/src/iris/cluster/providers/gcp/handles.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141

4242
logger = logging.getLogger(__name__)
4343

44-
# GCP TPU state mapping
44+
# GCP TPU state mapping. States not in this map collapse to UNKNOWN; the boot
45+
# reconciler treats anything outside the alive set (CREATING/READY/REPAIRING)
46+
# as a candidate for reclaim.
4547
_TPU_STATE_MAP: dict[str, CloudSliceState] = {
4648
"CREATING": CloudSliceState.CREATING,
4749
"READY": CloudSliceState.READY,
@@ -58,6 +60,19 @@
5860
}
5961

6062
_ACTIVE_VM_SLICE_STATES = frozenset({"PROVISIONING", "STAGING", "RUNNING"})
63+
64+
# Queued-resource (reserved TPU) state mapping. Non-live states must surface so
65+
# the boot reconciler can reclaim them; ACTIVE is transient (the matching TPU
66+
# VM is what list_all_slices normally returns) and only appears here briefly.
67+
_QR_STATE_MAP: dict[str, CloudSliceState] = {
68+
"QUEUED": CloudSliceState.CREATING,
69+
"WAITING_FOR_RESOURCES": CloudSliceState.CREATING,
70+
"PROVISIONING": CloudSliceState.CREATING,
71+
"ACTIVE": CloudSliceState.READY,
72+
"FAILED": CloudSliceState.FAILED,
73+
"SUSPENDED": CloudSliceState.FAILED,
74+
"DELETING": CloudSliceState.DELETING,
75+
}
6176
_GCE_NAME_MAX_LEN = 63
6277
_GCE_NAME_RE = re.compile(r"[^a-z0-9-]+")
6378
_GCE_NAME_EDGE_RE = re.compile(r"^-+|-+$")
@@ -222,7 +237,6 @@ def __init__(
222237
_gcp_service: GcpService,
223238
_ssh_config: config_pb2.SshConfig | None = None,
224239
_service_account: str | None = None,
225-
_state: str = "READY",
226240
_bootstrapping: bool = False,
227241
_is_queued_resource: bool = False,
228242
):
@@ -237,7 +251,6 @@ def __init__(
237251
self._accelerator_variant = _accelerator_variant
238252
self._ssh_config = _ssh_config
239253
self._service_account = _service_account
240-
self._state = _state
241254
self.is_queued_resource: bool = _is_queued_resource
242255
self._bootstrap_state: CloudSliceState | None = None if _bootstrapping else CloudSliceState.READY
243256
self._bootstrap_lock = threading.Lock()

lib/iris/src/iris/cluster/providers/gcp/workers.py

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
)
2828
from iris.cluster.providers.gcp.handles import (
2929
_ACTIVE_VM_SLICE_STATES,
30+
_QR_STATE_MAP,
31+
_TPU_STATE_MAP,
32+
_VM_STATE_MAP,
3033
CloudSliceState,
3134
GcpSliceHandle,
3235
GcpStandaloneWorkerHandle,
@@ -46,6 +49,7 @@
4649
from iris.cluster.providers.types import (
4750
InfraError,
4851
Labels,
52+
ListedSlice,
4953
SliceHandle,
5054
generate_slice_suffix,
5155
)
@@ -628,7 +632,6 @@ def list_slices(
628632
_gcp_service=self._gcp,
629633
_ssh_config=self._ssh_config,
630634
_service_account=tpu.service_account,
631-
_state=tpu.state,
632635
)
633636
)
634637

@@ -657,102 +660,99 @@ def list_slices(
657660

658661
return handles
659662

660-
def list_all_slices(self) -> list[GcpSliceHandle | GcpVmSliceHandle]:
661-
"""List all autoscaler-managed slices for this cluster.
663+
def list_all_slices(self) -> list[ListedSlice]:
664+
"""List every autoscaler-managed slice for this cluster, regardless of cloud state.
662665
663666
Uses project-wide queries (empty zones = all zones) via GcpService,
664667
filtered by iris-{prefix}-managed=true. Slices tagged
665-
iris-{prefix}-manual=true (operator-created via `iris cluster
666-
create-slice`) are excluded: the autoscaler and `cluster stop` must
667-
not see or terminate them.
668+
iris-{prefix}-manual=true are excluded — those are operator-created
669+
and never autoscaler-owned.
668670
"""
669671
managed_labels = {self._iris_labels.iris_managed: "true"}
670672
manual_label = self._iris_labels.iris_manual
671673

672674
if self._gcp.mode == ServiceMode.LOCAL:
673675
local_handles = self._gcp.get_local_slices(managed_labels)
674-
return [h for h in local_handles if h.labels.get(manual_label) != "true"] # type: ignore[return-value]
676+
return [
677+
ListedSlice(handle=h, state=CloudSliceState.READY)
678+
for h in local_handles
679+
if h.labels.get(manual_label) != "true"
680+
]
675681

676682
tpu_infos = self._gcp.tpu_list(zones=[], labels=managed_labels)
677683
vm_infos = self._gcp.vm_list(zones=[], labels=managed_labels)
678684

679-
handles: list[GcpSliceHandle | GcpVmSliceHandle] = []
685+
listed: list[ListedSlice] = []
680686

681687
for tpu in tpu_infos:
682-
if tpu.state not in ("READY", "CREATING"):
683-
continue
684688
if tpu.labels.get(manual_label) == "true":
685689
continue
686-
handles.append(
687-
GcpSliceHandle(
688-
_slice_id=tpu.name,
689-
_zone=tpu.zone,
690-
_project_id=self._project_id,
691-
_labels=tpu.labels,
692-
_created_at=tpu.created_at,
693-
_label_prefix=self._label_prefix,
694-
_accelerator_variant=tpu.accelerator_type,
695-
_gcp_service=self._gcp,
696-
_ssh_config=self._ssh_config,
697-
_service_account=tpu.service_account,
698-
_state=tpu.state,
699-
_is_queued_resource=tpu.labels.get(CAPACITY_TYPE_LABEL) == CAPACITY_TYPE_RESERVED_VALUE,
700-
)
690+
handle = GcpSliceHandle(
691+
_slice_id=tpu.name,
692+
_zone=tpu.zone,
693+
_project_id=self._project_id,
694+
_labels=tpu.labels,
695+
_created_at=tpu.created_at,
696+
_label_prefix=self._label_prefix,
697+
_accelerator_variant=tpu.accelerator_type,
698+
_gcp_service=self._gcp,
699+
_ssh_config=self._ssh_config,
700+
_service_account=tpu.service_account,
701+
_is_queued_resource=tpu.labels.get(CAPACITY_TYPE_LABEL) == CAPACITY_TYPE_RESERVED_VALUE,
701702
)
703+
listed.append(ListedSlice(handle=handle, state=_TPU_STATE_MAP.get(tpu.state, CloudSliceState.UNKNOWN)))
702704

703-
# Discover queued resources (reserved TPUs) not yet visible as TPU VMs.
704-
# These are in QUEUED/PROVISIONING/WAITING_FOR_RESOURCES and need handles
705-
# so the controller doesn't orphan them on restart.
706-
tpu_names = {h.slice_id for h in handles}
705+
# Discover queued resources (reserved TPUs) not already represented by a
706+
# TPU VM. We surface every state — including FAILED/SUSPENDED/DELETING —
707+
# so the boot reconciler can reclaim dead reservations instead of
708+
# orphaning them in GCP.
709+
tpu_names = {item.handle.slice_id for item in listed}
707710
qr_infos = self._gcp.queued_resource_list(zones=[], labels=managed_labels)
708711
for qr in qr_infos:
709712
if qr.name in tpu_names:
710713
continue
711-
if qr.state in ("FAILED", "SUSPENDED", "DELETING"):
712-
continue
713-
if qr.labels.get(manual_label) == "true":
714+
if qr.labels and qr.labels.get(manual_label) == "true":
714715
continue
715-
handles.append(
716-
GcpSliceHandle(
717-
_slice_id=qr.name,
718-
_zone=qr.zone,
719-
_project_id=self._project_id,
720-
_labels=qr.labels
721-
or {CAPACITY_TYPE_LABEL: CAPACITY_TYPE_RESERVED_VALUE, self._iris_labels.iris_managed: "true"},
722-
_created_at=Timestamp.now(),
723-
_label_prefix=self._label_prefix,
724-
_accelerator_variant="",
725-
_gcp_service=self._gcp,
726-
_ssh_config=self._ssh_config,
727-
_is_queued_resource=True,
728-
)
716+
handle = GcpSliceHandle(
717+
_slice_id=qr.name,
718+
_zone=qr.zone,
719+
_project_id=self._project_id,
720+
_labels=qr.labels
721+
or {CAPACITY_TYPE_LABEL: CAPACITY_TYPE_RESERVED_VALUE, self._iris_labels.iris_managed: "true"},
722+
_created_at=Timestamp.now(),
723+
_label_prefix=self._label_prefix,
724+
_accelerator_variant="",
725+
_gcp_service=self._gcp,
726+
_ssh_config=self._ssh_config,
727+
_is_queued_resource=True,
729728
)
729+
listed.append(ListedSlice(handle=handle, state=_QR_STATE_MAP.get(qr.state, CloudSliceState.UNKNOWN)))
730730

731+
# Surface every managed VM regardless of cloud state. Stopped/terminated
732+
# instances are exactly what the boot reconciler needs to reclaim; the
733+
# active-only filter belongs in list_slices(), used for live discovery.
731734
for vm in vm_infos:
732-
if vm.status not in _ACTIVE_VM_SLICE_STATES:
733-
continue
734735
slice_id = vm.labels.get(self._iris_labels.iris_slice_id, "")
735736
if not slice_id:
736737
continue
737738
if vm.labels.get(manual_label) == "true":
738739
continue
739-
handles.append(
740-
GcpVmSliceHandle(
741-
_slice_id=slice_id,
742-
_vm_name=vm.name,
743-
_zone=vm.zone,
744-
_project_id=self._project_id,
745-
_gcp_service=self._gcp,
746-
_labels=vm.labels,
747-
_created_at=vm.created_at,
748-
_label_prefix=self._label_prefix,
749-
_ssh_config=self._ssh_config,
750-
_service_account=vm.service_account,
751-
)
740+
handle = GcpVmSliceHandle(
741+
_slice_id=slice_id,
742+
_vm_name=vm.name,
743+
_zone=vm.zone,
744+
_project_id=self._project_id,
745+
_gcp_service=self._gcp,
746+
_labels=vm.labels,
747+
_created_at=vm.created_at,
748+
_label_prefix=self._label_prefix,
749+
_ssh_config=self._ssh_config,
750+
_service_account=vm.service_account,
752751
)
752+
listed.append(ListedSlice(handle=handle, state=_VM_STATE_MAP.get(vm.status, CloudSliceState.UNKNOWN)))
753753

754-
logger.info("list_all_slices: found %d managed slices", len(handles))
755-
return handles
754+
logger.info("list_all_slices: found %d managed slices", len(listed))
755+
return listed
756756

757757
def list_vms(
758758
self,

lib/iris/src/iris/cluster/providers/manual/provider.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
CloudWorkerState,
2828
InfraError,
2929
Labels,
30+
ListedSlice,
3031
SliceStatus,
3132
WorkerStatus,
3233
default_stop_all,
@@ -334,16 +335,23 @@ def list_slices(
334335
results = [s for s in results if all(s.labels.get(k) == v for k, v in labels.items())]
335336
return results
336337

337-
def list_all_slices(self) -> list[ManualSliceHandle]:
338-
"""List autoscaler-managed slices.
338+
def list_all_slices(self) -> list[ListedSlice]:
339+
"""List autoscaler-managed slices paired with cloud state.
339340
340341
Excludes slices tagged iris_manual=true (operator-created via
341-
`iris cluster create-slice`), which the autoscaler and
342-
`iris cluster stop` must not see or terminate.
342+
`iris cluster create-slice`). Manual slices have no real cloud
343+
lifecycle; non-terminated ones report READY.
343344
"""
344345
all_managed = self.list_slices(zones=[], labels={self._iris_labels.iris_managed: "true"})
345346
manual_label = self._iris_labels.iris_manual
346-
return [s for s in all_managed if s.labels.get(manual_label) != "true"]
347+
return [
348+
ListedSlice(
349+
handle=s,
350+
state=CloudSliceState.DELETING if s._terminated else CloudSliceState.READY,
351+
)
352+
for s in all_managed
353+
if s.labels.get(manual_label) != "true"
354+
]
347355

348356
def list_vms(
349357
self,

lib/iris/src/iris/cluster/providers/protocols.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from contextlib import AbstractContextManager
1717
from typing import Protocol
1818

19-
from iris.cluster.providers.types import SliceHandle, StandaloneWorkerHandle
19+
from iris.cluster.providers.types import ListedSlice, SliceHandle, StandaloneWorkerHandle
2020
from iris.rpc import config_pb2
2121

2222

@@ -123,8 +123,8 @@ def list_slices(
123123
"""List existing slices, filtered by zone and optionally by labels."""
124124
...
125125

126-
def list_all_slices(self) -> list[SliceHandle]:
127-
"""List all slices managed by this cluster across all zones."""
126+
def list_all_slices(self) -> list[ListedSlice]:
127+
"""List every iris-managed slice across all zones, paired with its cloud state."""
128128
...
129129

130130
def list_vms(

0 commit comments

Comments
 (0)