From 3e510fdc172cc9e697804598bed71781ed6c7c7e Mon Sep 17 00:00:00 2001 From: Russell Power Date: Fri, 3 Apr 2026 09:23:20 -0700 Subject: [PATCH 1/8] [iris] Replace gcloud CLI with REST API client in CloudGcpService Extract GCPApi class (httpx + google.auth ADC) that handles auth, pagination, and error mapping for TPU v2, Compute v1, and Cloud Logging APIs. Rewrite CloudGcpService to delegate to GCPApi instead of subprocess gcloud calls. This eliminates the gcloud CLI dependency for resource management, fixing CI failures from gcloud alpha not being installed. Add logging_read to the GcpService Protocol so bootstrap log fetching goes through the same boundary. --- lib/iris/pyproject.toml | 1 + .../src/iris/cluster/providers/gcp/api.py | 271 ++++++++ .../src/iris/cluster/providers/gcp/fake.py | 3 + .../src/iris/cluster/providers/gcp/service.py | 583 +++++------------- .../src/iris/cluster/providers/gcp/workers.py | 36 +- .../cluster/providers/gcp/test_gcp_api.py | 432 +++++++++++++ uv.lock | 2 + 7 files changed, 886 insertions(+), 442 deletions(-) create mode 100644 lib/iris/src/iris/cluster/providers/gcp/api.py create mode 100644 lib/iris/tests/cluster/providers/gcp/test_gcp_api.py diff --git a/lib/iris/pyproject.toml b/lib/iris/pyproject.toml index 978d0dac01..ef404acf75 100644 --- a/lib/iris/pyproject.toml +++ b/lib/iris/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "connect-python>=0.9.0", "fsspec>=2024.0.0", "gcsfs>=2024.0.0", + "google-auth>=2.0", "s3fs>=2024.0.0", "grpcio>=1.76.0", "httpx>=0.28.1", diff --git a/lib/iris/src/iris/cluster/providers/gcp/api.py b/lib/iris/src/iris/cluster/providers/gcp/api.py new file mode 100644 index 0000000000..7700e8fec0 --- /dev/null +++ b/lib/iris/src/iris/cluster/providers/gcp/api.py @@ -0,0 +1,271 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Low-level HTTP client for GCP REST APIs (TPU v2, Compute v1, Cloud Logging). + +Handles authentication (Application Default Credentials), token caching, +pagination, and error mapping to domain exceptions. Used by CloudGcpService +as a replacement for gcloud CLI subprocess calls. +""" + +from __future__ import annotations + +import json +import logging +import time + +import google.auth +import google.auth.credentials +import google.auth.transport.requests +import httpx + +from iris.cluster.providers.types import ( + InfraError, + QuotaExhaustedError, + ResourceNotFoundError, +) + +logger = logging.getLogger(__name__) + +TPU_BASE = "https://tpu.googleapis.com/v2" +COMPUTE_BASE = "https://compute.googleapis.com/compute/v1" +LOGGING_BASE = "https://logging.googleapis.com/v2" + +_REFRESH_MARGIN = 300 # seconds before expiry to refresh token +_DEFAULT_TIMEOUT = 120 # seconds + + +class GCPApi: + """Low-level HTTP client for GCP REST APIs with ADC auth and token caching.""" + + def __init__(self, project_id: str) -> None: + self._project_id = project_id + self._client = httpx.Client(timeout=_DEFAULT_TIMEOUT) + self._creds: google.auth.credentials.Credentials | None = None + self._token: str | None = None + self._expires_at: float = 0.0 + + def close(self) -> None: + self._client.close() + + # -- Auth --------------------------------------------------------------- + + def _headers(self) -> dict[str, str]: + if self._token is None or time.monotonic() >= self._expires_at: + self._refresh_token() + return { + "Authorization": f"Bearer {self._token}", + "Content-Type": "application/json", + } + + def _refresh_token(self) -> None: + if self._creds is None: + self._creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + self._creds.refresh(google.auth.transport.requests.Request()) + self._token = self._creds.token + now = time.monotonic() + if self._creds.expiry is not None: + self._expires_at = now + (self._creds.expiry.timestamp() - time.time()) - _REFRESH_MARGIN + else: + self._expires_at = now + _REFRESH_MARGIN + + # -- Error mapping ------------------------------------------------------ + + def _classify_response(self, resp: httpx.Response) -> None: + """Raise a domain exception for non-2xx responses.""" + if resp.status_code < 400: + return + try: + body = resp.json() + error = body.get("error", {}) + message = error.get("message", resp.text) + status = error.get("status", "") + code = error.get("code", resp.status_code) + except (json.JSONDecodeError, AttributeError): + message = resp.text + status = "" + code = resp.status_code + + if code == 404 or status == "NOT_FOUND": + raise ResourceNotFoundError(message) + if code == 429 or status in ("RESOURCE_EXHAUSTED", "QUOTA_EXCEEDED"): + raise QuotaExhaustedError(message) + raise InfraError(f"GCP API error {code}: {message}") + + # -- Pagination --------------------------------------------------------- + + def _paginate(self, url: str, items_key: str, params: dict[str, str] | None = None) -> list[dict]: + results: list[dict] = [] + p = dict(params or {}) + while True: + resp = self._client.get(url, headers=self._headers(), params=p) + self._classify_response(resp) + data = resp.json() + results.extend(data.get(items_key, [])) + token = data.get("nextPageToken") + if not token: + break + p["pageToken"] = token + return results + + def _paginate_raw(self, url: str, params: dict[str, str] | None = None) -> list[dict]: + """Return raw page bodies (for aggregatedList where items_key varies).""" + pages: list[dict] = [] + p = dict(params or {}) + while True: + resp = self._client.get(url, headers=self._headers(), params=p) + self._classify_response(resp) + data = resp.json() + pages.append(data) + token = data.get("nextPageToken") + if not token: + break + p["pageToken"] = token + return pages + + # ====================================================================== + # TPU v2 + # ====================================================================== + + def _tpu_parent(self, zone: str) -> str: + return f"projects/{self._project_id}/locations/{zone}" + + def tpu_create(self, name: str, zone: str, body: dict) -> dict | None: + url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes" + resp = self._client.post(url, params={"nodeId": name}, headers=self._headers(), json=body) + self._classify_response(resp) + data = resp.json() + # REST create returns a long-running operation, not the node itself. + return data if data.get("name", "").endswith(f"/nodes/{name}") else None + + def tpu_get(self, name: str, zone: str) -> dict: + url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}" + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + return resp.json() + + def tpu_delete(self, name: str, zone: str) -> None: + url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}" + resp = self._client.delete(url, headers=self._headers()) + if resp.status_code != 404: + self._classify_response(resp) + + def tpu_list(self, zone: str) -> list[dict]: + return self._paginate(f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes", "nodes") + + # ====================================================================== + # TPU v2 — Queued Resources + # ====================================================================== + + def queued_resource_create(self, name: str, zone: str, body: dict) -> None: + url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources" + resp = self._client.post( + url, + params={"queuedResourceId": name}, + headers=self._headers(), + json=body, + ) + self._classify_response(resp) + + def queued_resource_get(self, name: str, zone: str) -> dict: + url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources/{name}" + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + return resp.json() + + def queued_resource_delete(self, name: str, zone: str) -> None: + url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources/{name}" + resp = self._client.delete(url, params={"force": "true"}, headers=self._headers()) + if resp.status_code != 404: + self._classify_response(resp) + + def queued_resource_list(self, zone: str) -> list[dict]: + return self._paginate( + f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources", + "queuedResources", + ) + + # ====================================================================== + # Compute Engine v1 — Instances + # ====================================================================== + + def _instance_url(self, zone: str, name: str = "") -> str: + path = f"{COMPUTE_BASE}/projects/{self._project_id}/zones/{zone}/instances" + if name: + path += f"/{name}" + return path + + def instance_insert(self, zone: str, body: dict) -> dict: + url = self._instance_url(zone) + resp = self._client.post(url, headers=self._headers(), json=body) + self._classify_response(resp) + return resp.json() + + def instance_get(self, name: str, zone: str) -> dict: + url = self._instance_url(zone, name) + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + return resp.json() + + def instance_delete(self, name: str, zone: str) -> None: + url = self._instance_url(zone, name) + resp = self._client.delete(url, headers=self._headers()) + if resp.status_code != 404: + self._classify_response(resp) + + def instance_list(self, zone: str | None = None, filter_str: str = "") -> list[dict]: + params: dict[str, str] = {} + if filter_str: + params["filter"] = filter_str + + if zone: + return self._paginate(self._instance_url(zone), "items", params) + + # Project-wide: aggregatedList, flatten across zones + url = f"{COMPUTE_BASE}/projects/{self._project_id}/aggregated/instances" + results: list[dict] = [] + for page in self._paginate_raw(url, params): + for scope in page.get("items", {}).values(): + results.extend(scope.get("instances", [])) + return results + + def instance_reset(self, name: str, zone: str) -> None: + url = self._instance_url(zone, name) + "/reset" + resp = self._client.post(url, headers=self._headers()) + self._classify_response(resp) + + def instance_set_labels(self, name: str, zone: str, labels: dict[str, str], fingerprint: str) -> None: + url = self._instance_url(zone, name) + "/setLabels" + resp = self._client.post( + url, + headers=self._headers(), + json={"labels": labels, "labelFingerprint": fingerprint}, + ) + self._classify_response(resp) + + def instance_set_metadata(self, name: str, zone: str, metadata_body: dict) -> None: + url = self._instance_url(zone, name) + "/setMetadata" + resp = self._client.post(url, headers=self._headers(), json=metadata_body) + self._classify_response(resp) + + def instance_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> dict: + url = self._instance_url(zone, name) + "/serialPort" + resp = self._client.get(url, headers=self._headers(), params={"start": str(start)}) + self._classify_response(resp) + return resp.json() + + # ====================================================================== + # Cloud Logging v2 + # ====================================================================== + + def logging_list_entries(self, filter_str: str, limit: int = 200) -> list[dict]: + url = f"{LOGGING_BASE}/entries:list" + body = { + "resourceNames": [f"projects/{self._project_id}"], + "filter": filter_str, + "pageSize": min(limit, 1000), + "orderBy": "timestamp desc", + } + resp = self._client.post(url, headers=self._headers(), json=body, timeout=30) + self._classify_response(resp) + return resp.json().get("entries", []) diff --git a/lib/iris/src/iris/cluster/providers/gcp/fake.py b/lib/iris/src/iris/cluster/providers/gcp/fake.py index 9d8f9d3277..0566e7661a 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/fake.py +++ b/lib/iris/src/iris/cluster/providers/gcp/fake.py @@ -375,6 +375,9 @@ def vm_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> str full_output = self._serial_port_output.get((name, zone), "") return full_output[start:] + def logging_read(self, filter_str: str, limit: int = 200) -> list[str]: + return [] + # ======================================================================== # LOCAL mode: worker spawning # ======================================================================== diff --git a/lib/iris/src/iris/cluster/providers/gcp/service.py b/lib/iris/src/iris/cluster/providers/gcp/service.py index 3a000630fb..2bc305ea5c 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/service.py +++ b/lib/iris/src/iris/cluster/providers/gcp/service.py @@ -3,19 +3,15 @@ from __future__ import annotations -import json import logging -import os import re -import subprocess -import tempfile from dataclasses import dataclass, field from datetime import datetime from typing import Protocol +from iris.cluster.providers.gcp.api import GCPApi from iris.cluster.providers.types import ( InfraError, - QuotaExhaustedError, ResourceNotFoundError, ) from iris.cluster.providers.gcp.local import LocalSliceHandle @@ -226,6 +222,10 @@ def vm_update_labels(self, name: str, zone: str, labels: dict[str, str]) -> None def vm_set_metadata(self, name: str, zone: str, metadata: dict[str, str]) -> None: ... def vm_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> str: ... + def logging_read(self, filter_str: str, limit: int = 200) -> list[str]: + """Return matching Cloud Logging textPayload entries (newest first).""" + ... + def create_local_slice( self, slice_id: str, @@ -245,24 +245,17 @@ def shutdown(self) -> None: # ============================================================================ -# CloudGcpService — gcloud CLI implementation +# CloudGcpService — REST API implementation via GCPApi # ============================================================================ -def _format_labels(labels: dict[str, str]) -> str: - return ",".join(f"{k}={v}" for k, v in labels.items()) - - def _build_label_filter(labels: dict[str, str]) -> str: parts = [f"labels.{k}={v}" for k, v in labels.items()] return " AND ".join(parts) -def _classify_gcloud_error(stderr: str) -> InfraError: - lower = stderr.lower() - if "quota" in lower or "insufficient" in lower or "resource_exhausted" in lower: - return QuotaExhaustedError(stderr) - return InfraError(stderr) +def _labels_match(resource_labels: dict[str, str], required: dict[str, str]) -> bool: + return all(resource_labels.get(k) == v for k, v in required.items()) def _extract_node_name(resource_name: str) -> str: @@ -360,10 +353,11 @@ def _parse_vm_info(vm_data: dict, fallback_zone: str = "") -> VmInfo: class CloudGcpService: - """GcpService backed by gcloud CLI. Used in CLOUD mode.""" + """GcpService backed by GCP REST APIs via GCPApi. Used in CLOUD mode.""" - def __init__(self, project_id: str) -> None: + def __init__(self, project_id: str, api: GCPApi | None = None) -> None: self._project_id = project_id + self._api = api if api is not None else GCPApi(project_id) self._valid_zones: set[str] = set(KNOWN_GCP_ZONES) self._valid_accelerator_types: set[str] = set(KNOWN_TPU_TYPES) @@ -375,6 +369,10 @@ def mode(self) -> ServiceMode: def project_id(self) -> str: return self._project_id + @property + def api(self) -> GCPApi: + return self._api + # ======================================================================== # TPU operations # ======================================================================== @@ -382,144 +380,67 @@ def project_id(self) -> str: def tpu_create(self, request: TpuCreateRequest) -> TpuInfo: validate_tpu_create(request, self._valid_zones, self._valid_accelerator_types) - cmd = [ - "gcloud", - "compute", - "tpus", - "tpu-vm", - "create", - request.name, - f"--zone={request.zone}", - f"--project={self._project_id}", - f"--accelerator-type={request.accelerator_type}", - f"--version={request.runtime_version}", - "--format=json", - ] - + body: dict = { + "acceleratorType": request.accelerator_type, + "runtimeVersion": request.runtime_version, + } if request.labels: - cmd.extend(["--labels", _format_labels(request.labels)]) + body["labels"] = request.labels + if request.metadata: + body["metadata"] = request.metadata if request.capacity_type == config_pb2.CAPACITY_TYPE_PREEMPTIBLE: - cmd.append("--preemptible") + body["schedulingConfig"] = {"preemptible": True} if request.service_account: - cmd.append(f"--service-account={request.service_account}") - if request.network: - cmd.append(f"--network={request.network}") - if request.subnetwork: - cmd.append(f"--subnetwork={request.subnetwork}") - - # Large metadata values (e.g. startup-script) are written to temp files - # to avoid shell-escaping issues with --metadata inline. - metadata_files: dict[str, str] = {} - inline_metadata: dict[str, str] = {} - for k, v in request.metadata.items(): - if len(v) > 256: - f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) - f.write(v) - f.close() - metadata_files[k] = f.name - else: - inline_metadata[k] = v - - if inline_metadata: - metadata_str = ",".join(f"{k}={v}" for k, v in inline_metadata.items()) - cmd.append(f"--metadata={metadata_str}") - if metadata_files: - file_str = ",".join(f"{k}={path}" for k, path in metadata_files.items()) - cmd.append(f"--metadata-from-file={file_str}") + body["serviceAccount"] = {"email": request.service_account} + if request.network or request.subnetwork: + network_config: dict = {} + if request.network: + network_config["network"] = request.network + if request.subnetwork: + network_config["subnetwork"] = request.subnetwork + body["networkConfig"] = network_config logger.info("Creating TPU: %s (type=%s, zone=%s)", request.name, request.accelerator_type, request.zone) - try: - result = subprocess.run(cmd, capture_output=True, text=True) - finally: - for path in metadata_files.values(): - os.unlink(path) - if result.returncode != 0: - raise _classify_gcloud_error(result.stderr.strip()) - - if result.stdout.strip(): - tpu_data = json.loads(result.stdout) - return _parse_tpu_info(tpu_data, request.zone) + self._api.tpu_create(request.name, request.zone, body) + # REST create returns an operation, not the node — fetch it explicitly. info = self.tpu_describe(request.name, request.zone) if info is None: raise InfraError(f"TPU {request.name} created but could not be described") return info def tpu_delete(self, name: str, zone: str) -> None: - cmd = [ - "gcloud", - "compute", - "tpus", - "tpu-vm", - "delete", - name, - f"--zone={zone}", - f"--project={self._project_id}", - "--quiet", - "--async", - ] logger.info("Deleting TPU (async): %s", name) - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - error = result.stderr.strip() - if "not found" not in error.lower(): - raise _classify_gcloud_error(error) + self._api.tpu_delete(name, zone) def tpu_describe(self, name: str, zone: str) -> TpuInfo | None: - cmd = [ - "gcloud", - "compute", - "tpus", - "tpu-vm", - "describe", - name, - f"--zone={zone}", - f"--project={self._project_id}", - "--format=json", - ] - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - error = result.stderr.strip().lower() - if "not found" in error or "could not be found" in error: - return None - logger.warning("Failed to describe TPU %s: %s", name, result.stderr.strip()) + try: + tpu_data = self._api.tpu_get(name, zone) + except ResourceNotFoundError: + return None + except InfraError: + logger.warning("Failed to describe TPU %s", name, exc_info=True) return None - - tpu_data = json.loads(result.stdout) return _parse_tpu_info(tpu_data, zone) def tpu_list(self, zones: list[str], labels: dict[str, str] | None = None) -> list[TpuInfo]: results: list[TpuInfo] = [] - - # Empty zones = project-wide search using --zone=- - zone_list = zones if zones else ["-"] + # TPU v2 API requires a real zone; empty zones = scan all known zones. + zone_list = zones if zones else list(self._valid_zones) for zone in zone_list: - cmd = [ - "gcloud", - "compute", - "tpus", - "tpu-vm", - "list", - f"--zone={zone}", - f"--project={self._project_id}", - "--format=json", - ] - if labels: - cmd.append(f"--filter={_build_label_filter(labels)}") - - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - logger.warning("Failed to list TPUs in zone %s: %s", zone, result.stderr.strip()) - continue - if not result.stdout.strip(): + try: + items = self._api.tpu_list(zone) + except InfraError: + logger.warning("Failed to list TPUs in zone %s", zone, exc_info=True) continue - - for tpu_data in json.loads(result.stdout): - # With --zone=-, name is a full resource path; extract zone from it + for tpu_data in items: + if labels and not _labels_match(tpu_data.get("labels", {}), labels): + continue + # Extract zone from resource name if present tpu_zone = zone raw_name = tpu_data.get("name", "") - if zone == "-" and "/" in raw_name: + if "/" in raw_name: parts = raw_name.split("/") if len(parts) >= 4: tpu_zone = parts[3] @@ -534,51 +455,29 @@ def tpu_list(self, zones: list[str], labels: dict[str, str] | None = None) -> li def queued_resource_create(self, request: TpuCreateRequest) -> None: validate_tpu_create(request, self._valid_zones, self._valid_accelerator_types) - cmd = [ - "gcloud", - "alpha", - "compute", - "tpus", - "queued-resources", - "create", - request.name, - f"--zone={request.zone}", - f"--project={self._project_id}", - f"--accelerator-type={request.accelerator_type}", - f"--runtime-version={request.runtime_version}", - f"--node-id={request.name}", - "--reserved", - "--quiet", - ] - - if request.labels: - cmd.extend(["--labels", _format_labels(request.labels)]) + node_spec: dict = { + "node": { + "acceleratorType": request.accelerator_type, + "runtimeVersion": request.runtime_version, + "labels": request.labels or {}, + "metadata": request.metadata or {}, + }, + "nodeId": request.name, + } if request.service_account: - cmd.append(f"--service-account={request.service_account}") - if request.network: - cmd.append(f"--network={request.network}") - if request.subnetwork: - cmd.append(f"--subnetwork={request.subnetwork}") - - # Queued resources don't support --metadata directly; metadata is - # applied to the TPU node via --metadata/--metadata-from-file. - metadata_files: dict[str, str] = {} - inline_metadata: dict[str, str] = {} - for k, v in request.metadata.items(): - if len(v) > 256: - f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) - f.write(v) - f.close() - metadata_files[k] = f.name - else: - inline_metadata[k] = v - - if inline_metadata: - metadata_str = ",".join(f"{k}={v}" for k, v in inline_metadata.items()) - cmd.append(f"--metadata={metadata_str}") - if metadata_files: - file_str = ",".join(f"{k}={path}" for k, path in metadata_files.items()) - cmd.append(f"--metadata-from-file={file_str}") + node_spec["node"]["serviceAccount"] = {"email": request.service_account} + if request.network or request.subnetwork: + network_config: dict = {} + if request.network: + network_config["network"] = request.network + if request.subnetwork: + network_config["subnetwork"] = request.subnetwork + node_spec["node"]["networkConfig"] = network_config + + body = { + "tpu": {"nodeSpec": [node_spec]}, + "guaranteed": {"reserved": True}, + } logger.info( "Creating queued resource: %s (type=%s, zone=%s)", @@ -586,89 +485,36 @@ def queued_resource_create(self, request: TpuCreateRequest) -> None: request.accelerator_type, request.zone, ) - try: - result = subprocess.run(cmd, capture_output=True, text=True) - finally: - for path in metadata_files.values(): - os.unlink(path) - if result.returncode != 0: - raise _classify_gcloud_error(result.stderr.strip()) + self._api.queued_resource_create(request.name, request.zone, body) def queued_resource_describe(self, name: str, zone: str) -> QueuedResourceInfo | None: - cmd = [ - "gcloud", - "alpha", - "compute", - "tpus", - "queued-resources", - "describe", - name, - f"--zone={zone}", - f"--project={self._project_id}", - "--format=json", - "--quiet", - ] - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - error = result.stderr.strip().lower() - if "not found" in error or "could not be found" in error: - return None - raise _classify_gcloud_error(result.stderr.strip()) - - data = json.loads(result.stdout) + try: + data = self._api.queued_resource_get(name, zone) + except ResourceNotFoundError: + return None state = data.get("state", {}).get("state", "UNKNOWN") return QueuedResourceInfo(name=name, state=state, zone=zone) def queued_resource_delete(self, name: str, zone: str) -> None: - cmd = [ - "gcloud", - "alpha", - "compute", - "tpus", - "queued-resources", - "delete", - name, - f"--zone={zone}", - f"--project={self._project_id}", - "--force", - "--quiet", - ] logger.info("Deleting queued resource (force): %s", name) - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - error = result.stderr.strip() - if "not found" not in error.lower(): - raise _classify_gcloud_error(error) + self._api.queued_resource_delete(name, zone) def queued_resource_list(self, zones: list[str], labels: dict[str, str] | None = None) -> list[QueuedResourceInfo]: - # Empty zones = project-wide search using --zone=- - zone_list = zones if zones else ["-"] + zone_list = zones if zones else list(self._valid_zones) results: list[QueuedResourceInfo] = [] for zone in zone_list: - cmd = [ - "gcloud", - "alpha", - "compute", - "tpus", - "queued-resources", - "list", - f"--zone={zone}", - f"--project={self._project_id}", - "--format=json", - "--quiet", - ] - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - logger.warning("Failed to list queued resources in %s: %s", zone, result.stderr.strip()) + try: + items = self._api.queued_resource_list(zone) + except InfraError: + logger.warning("Failed to list queued resources in %s", zone, exc_info=True) continue - data = json.loads(result.stdout or "[]") - for item in data: - name = item.get("name", "").rsplit("/", 1)[-1] + for item in items: + qr_name = item.get("name", "").rsplit("/", 1)[-1] state = item.get("state", {}).get("state", "UNKNOWN") item_labels = item.get("tpu", {}).get("nodeSpec", [{}])[0].get("node", {}).get("labels", {}) if labels and not all(item_labels.get(k) == v for k, v in labels.items()): continue - results.append(QueuedResourceInfo(name=name, state=state, zone=zone, labels=item_labels)) + results.append(QueuedResourceInfo(name=qr_name, state=state, zone=zone, labels=item_labels)) return results # ======================================================================== @@ -678,62 +524,43 @@ def queued_resource_list(self, zones: list[str], labels: dict[str, str] | None = def vm_create(self, request: VmCreateRequest) -> VmInfo: validate_vm_create(request, self._valid_zones) - cmd = [ - "gcloud", - "compute", - "instances", - "create", - request.name, - f"--project={self._project_id}", - f"--zone={request.zone}", - f"--machine-type={request.machine_type}", - f"--boot-disk-size={request.disk_size_gb}GB", - f"--boot-disk-type={request.boot_disk_type}", - f"--image-family={request.image_family}", - f"--image-project={request.image_project}", - "--scopes=cloud-platform", - "--format=json", - ] - - if request.labels: - cmd.append(f"--labels={_format_labels(request.labels)}") - - # Large metadata values (e.g. startup-script) are written to temp files. - metadata_files: dict[str, str] = {} all_metadata = dict(request.metadata) if request.startup_script: all_metadata["startup-script"] = request.startup_script - inline_metadata: dict[str, str] = {} - for k, v in all_metadata.items(): - if len(v) > 256: - f = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) - f.write(v) - f.close() - metadata_files[k] = f.name - else: - inline_metadata[k] = v - - if inline_metadata: - metadata_str = ",".join(f"{k}={v}" for k, v in inline_metadata.items()) - cmd.append(f"--metadata={metadata_str}") - if metadata_files: - file_str = ",".join(f"{k}={path}" for k, path in metadata_files.items()) - cmd.append(f"--metadata-from-file={file_str}") - - if request.service_account: - cmd.append(f"--service-account={request.service_account}") + body: dict = { + "name": request.name, + "machineType": f"zones/{request.zone}/machineTypes/{request.machine_type}", + "disks": [ + { + "boot": True, + "autoDelete": True, + "initializeParams": { + "diskSizeGb": str(request.disk_size_gb), + "diskType": f"zones/{request.zone}/diskTypes/{request.boot_disk_type}", + "sourceImage": f"projects/{request.image_project}/global/images/family/{request.image_family}", + }, + } + ], + "networkInterfaces": [{"accessConfigs": [{"type": "ONE_TO_ONE_NAT"}]}], + "serviceAccounts": [ + { + "email": request.service_account or "default", + "scopes": ["https://www.googleapis.com/auth/cloud-platform"], + } + ], + } + if request.labels: + body["labels"] = request.labels + if all_metadata: + body["metadata"] = {"items": [{"key": k, "value": v} for k, v in all_metadata.items()]} logger.info("Creating VM: %s (zone=%s, type=%s)", request.name, request.zone, request.machine_type) try: - result = subprocess.run(cmd, capture_output=True, text=True) - finally: - for path in metadata_files.values(): - os.unlink(path) - if result.returncode != 0: - error_msg = result.stderr.strip() - if "already exists" not in error_msg.lower(): - raise _classify_gcloud_error(error_msg) + self._api.instance_insert(request.zone, body) + except InfraError as e: + if "already exists" not in str(e).lower(): + raise info = self.vm_describe(request.name, request.zone) if info is None: @@ -741,161 +568,89 @@ def vm_create(self, request: VmCreateRequest) -> VmInfo: return info def vm_delete(self, name: str, zone: str) -> None: - cmd = [ - "gcloud", - "compute", - "instances", - "delete", - name, - f"--project={self._project_id}", - f"--zone={zone}", - "--quiet", - ] logger.info("Deleting VM: %s", name) - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - error = result.stderr.strip() - if "not found" not in error.lower(): - raise _classify_gcloud_error(error) + self._api.instance_delete(name, zone) def vm_reset(self, name: str, zone: str) -> None: - cmd = [ - "gcloud", - "compute", - "instances", - "reset", - name, - f"--project={self._project_id}", - f"--zone={zone}", - "--quiet", - ] logger.info("Resetting VM: %s", name) - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise _classify_gcloud_error(result.stderr.strip()) + self._api.instance_reset(name, zone) def vm_describe(self, name: str, zone: str) -> VmInfo | None: - cmd = [ - "gcloud", - "compute", - "instances", - "describe", - name, - f"--project={self._project_id}", - f"--zone={zone}", - "--format=json", - ] - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - error = result.stderr.strip().lower() - if "not found" in error or "could not be found" in error: - return None - logger.warning("Failed to describe VM %s: %s", name, result.stderr.strip()) + try: + data = self._api.instance_get(name, zone) + except ResourceNotFoundError: + return None + except InfraError: + logger.warning("Failed to describe VM %s", name, exc_info=True) return None - - data = json.loads(result.stdout) return _parse_vm_info(data, fallback_zone=zone) def vm_list(self, zones: list[str], labels: dict[str, str] | None = None) -> list[VmInfo]: results: list[VmInfo] = [] + filter_str = _build_label_filter(labels) if labels else "" if not zones: - # Project-wide search (no --zones flag) - cmd = [ - "gcloud", - "compute", - "instances", - "list", - f"--project={self._project_id}", - "--format=json", - ] - if labels: - cmd.append(f"--filter={_build_label_filter(labels)}") - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - logger.warning("Failed to list instances: %s", result.stderr.strip()) + try: + items = self._api.instance_list(zone=None, filter_str=filter_str) + except InfraError: + logger.warning("Failed to list instances", exc_info=True) return [] - if not result.stdout.strip(): - return [] - for vm_data in json.loads(result.stdout): + for vm_data in items: results.append(_parse_vm_info(vm_data)) return results for zone in zones: - cmd = [ - "gcloud", - "compute", - "instances", - "list", - f"--project={self._project_id}", - f"--zones={zone}", - "--format=json", - ] - if labels: - cmd.append(f"--filter={_build_label_filter(labels)}") - - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - logger.warning("Failed to list instances in zone %s: %s", zone, result.stderr.strip()) - continue - if not result.stdout.strip(): + try: + items = self._api.instance_list(zone=zone, filter_str=filter_str) + except InfraError: + logger.warning("Failed to list instances in zone %s", zone, exc_info=True) continue - - for vm_data in json.loads(result.stdout): + for vm_data in items: results.append(_parse_vm_info(vm_data, fallback_zone=zone)) return results def vm_update_labels(self, name: str, zone: str, labels: dict[str, str]) -> None: validate_labels(labels) - cmd = [ - "gcloud", - "compute", - "instances", - "update", - name, - f"--project={self._project_id}", - f"--zone={zone}", - f"--update-labels={_format_labels(labels)}", - ] logger.info("Updating labels on VM %s", name) - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise _classify_gcloud_error(result.stderr.strip()) + # Read-modify-write: GET the current labelFingerprint, merge, then POST. + data = self._api.instance_get(name, zone) + current_labels = data.get("labels", {}) + current_labels.update(labels) + fingerprint = data.get("labelFingerprint", "") + self._api.instance_set_labels(name, zone, current_labels, fingerprint) def vm_set_metadata(self, name: str, zone: str, metadata: dict[str, str]) -> None: - metadata_str = ",".join(f"{k}={v}" for k, v in metadata.items()) - cmd = [ - "gcloud", - "compute", - "instances", - "add-metadata", - name, - f"--project={self._project_id}", - f"--zone={zone}", - f"--metadata={metadata_str}", - ] logger.info("Setting metadata on VM %s", name) - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - raise _classify_gcloud_error(result.stderr.strip()) + # Read-modify-write: GET current metadata with fingerprint, merge items, POST. + data = self._api.instance_get(name, zone) + raw_metadata = data.get("metadata", {}) + fingerprint = raw_metadata.get("fingerprint", "") + existing_items: dict[str, str] = {} + for item in raw_metadata.get("items", []): + existing_items[item["key"]] = item.get("value", "") + existing_items.update(metadata) + body = { + "fingerprint": fingerprint, + "items": [{"key": k, "value": v} for k, v in existing_items.items()], + } + self._api.instance_set_metadata(name, zone, body) def vm_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> str: - cmd = [ - "gcloud", - "compute", - "instances", - "get-serial-port-output", - name, - f"--project={self._project_id}", - f"--zone={zone}", - f"--start={start}", - ] - result = subprocess.run(cmd, capture_output=True, text=True) - if result.returncode != 0: - logger.warning("Failed to get serial port output for %s: %s", name, result.stderr.strip()) + try: + data = self._api.instance_get_serial_port_output(name, zone, start=start) + except InfraError: + logger.warning("Failed to get serial port output for %s", name, exc_info=True) return "" - return result.stdout + return data.get("contents", "") + + def logging_read(self, filter_str: str, limit: int = 200) -> list[str]: + try: + entries = self._api.logging_list_entries(filter_str, limit=limit) + except InfraError: + logger.warning("Cloud Logging query failed", exc_info=True) + return [] + return [e.get("textPayload", "") for e in entries if e.get("textPayload")] def create_local_slice( self, @@ -909,4 +664,4 @@ def get_local_slices(self, labels: dict[str, str] | None = None) -> list[LocalSl raise RuntimeError("get_local_slices is not supported in CLOUD mode") def shutdown(self) -> None: - pass + self._api.close() diff --git a/lib/iris/src/iris/cluster/providers/gcp/workers.py b/lib/iris/src/iris/cluster/providers/gcp/workers.py index d7ea3915db..0b120601a0 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/workers.py +++ b/lib/iris/src/iris/cluster/providers/gcp/workers.py @@ -11,7 +11,6 @@ from __future__ import annotations import logging -import subprocess import threading from collections.abc import Callable import time @@ -762,7 +761,7 @@ def _run_tpu_bootstrap( break time.sleep(poll_interval) else: - _fetch_bootstrap_logs(project_id, handle) + _fetch_bootstrap_logs(gcp_service, handle) raise InfraError( f"TPU slice {handle.slice_id} bootstrap timed out: " f"{len(healthy_workers)}/{len(worker_addrs)} workers healthy" @@ -773,38 +772,19 @@ def _run_tpu_bootstrap( handle._bootstrap_state = CloudSliceState.READY -def _fetch_bootstrap_logs(project_id: str, handle: GcpSliceHandle) -> None: +def _fetch_bootstrap_logs(gcp_service: GcpService, handle: GcpSliceHandle) -> None: """Fetch [iris-init] log entries from Cloud Logging for diagnostics.""" log_filter = ( f'resource.type="gce_instance" ' f'textPayload:"[iris-init]" ' - f'labels."compute.googleapis.com/resource_name":"{handle._slice_id}"' + f'labels."compute.googleapis.com/resource_name":"{handle._slice_id}" ' + f'timestamp>="-PT30M"' ) - cmd = [ - "gcloud", - "logging", - "read", - log_filter, - f"--project={project_id}", - "--freshness=30m", - "--limit=200", - "--format=value(textPayload)", - ] - try: - result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) - except subprocess.TimeoutExpired: - logger.warning("Cloud Logging query timed out for %s", handle.slice_id) - return - - if result.returncode == 0 and result.stdout.strip(): - logger.error("Bootstrap logs for %s:\n%s", handle.slice_id, result.stdout) + texts = gcp_service.logging_read(log_filter, limit=200) + if texts: + logger.error("Bootstrap logs for %s:\n%s", handle.slice_id, "\n".join(texts)) else: - logger.warning( - "Could not fetch Cloud Logging for %s (rc=%d): %s", - handle.slice_id, - result.returncode, - result.stderr.strip(), - ) + logger.warning("No Cloud Logging entries found for %s", handle.slice_id) def _probe_worker_health(address: str, port: int) -> bool: diff --git a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py new file mode 100644 index 0000000000..bbf58016b1 --- /dev/null +++ b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py @@ -0,0 +1,432 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for GCPApi — HTTP client for GCP REST APIs. + +Uses httpx.MockTransport to verify URL construction, error mapping, +pagination, and auth header injection without hitting real GCP. +""" + +from __future__ import annotations + +import json +from collections.abc import Callable +from unittest.mock import patch + +import httpx +import pytest + +from iris.cluster.providers.gcp.api import ( + COMPUTE_BASE, + LOGGING_BASE, + TPU_BASE, + GCPApi, +) +from iris.cluster.providers.types import ( + InfraError, + QuotaExhaustedError, + ResourceNotFoundError, +) + +PROJECT = "test-project" +ZONE = "us-central1-a" + + +def _mock_credentials(): + """Patch google.auth.default to return a fake credential.""" + cred = type( + "FakeCred", + (), + { + "token": "fake-token", + "expiry": None, + "refresh": lambda self, req: None, + }, + )() + return patch("iris.cluster.providers.gcp.api.google.auth.default", return_value=(cred, PROJECT)) + + +def _make_api(handler: Callable[[httpx.Request], httpx.Response]) -> GCPApi: + """Create a GCPApi with a mock HTTP transport and fake credentials.""" + api = GCPApi(PROJECT) + api._client = httpx.Client(transport=httpx.MockTransport(handler), timeout=10) + # Inject fake token so _refresh_token isn't called + api._token = "fake-token" + api._expires_at = float("inf") + return api + + +def _json_response(body: dict, status: int = 200) -> httpx.Response: + return httpx.Response(status, json=body) + + +# ======================================================================== +# Error mapping +# ======================================================================== + + +def test_404_raises_resource_not_found(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) + + api = _make_api(handler) + with pytest.raises(ResourceNotFoundError, match="Not found"): + api.tpu_get("no-such-tpu", ZONE) + api.close() + + +def test_429_raises_quota_exhausted(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response( + 429, json={"error": {"code": 429, "message": "Quota exceeded", "status": "RESOURCE_EXHAUSTED"}} + ) + + api = _make_api(handler) + with pytest.raises(QuotaExhaustedError, match="Quota exceeded"): + api.tpu_get("some-tpu", ZONE) + api.close() + + +def test_500_raises_infra_error(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(500, json={"error": {"code": 500, "message": "Internal error"}}) + + api = _make_api(handler) + with pytest.raises(InfraError, match="Internal error"): + api.tpu_get("some-tpu", ZONE) + api.close() + + +def test_non_json_error_body(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(502, text="Bad Gateway") + + api = _make_api(handler) + with pytest.raises(InfraError, match="Bad Gateway"): + api.tpu_get("some-tpu", ZONE) + api.close() + + +# ======================================================================== +# Auth headers +# ======================================================================== + + +def test_auth_header_injected(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response({"name": "test", "state": "READY"}) + + api = _make_api(handler) + api.tpu_get("my-tpu", ZONE) + api.close() + + assert len(requests_seen) == 1 + assert requests_seen[0].headers["authorization"] == "Bearer fake-token" + + +def test_token_refresh_on_expiry(): + """When token is expired, _refresh_token is called.""" + with _mock_credentials(): + + def handler(request: httpx.Request) -> httpx.Response: + return _json_response({"name": "test", "state": "READY"}) + + api = _make_api(handler) + api._token = None # Force refresh + api._expires_at = 0.0 + api.tpu_get("my-tpu", ZONE) + assert api._token == "fake-token" + api.close() + + +# ======================================================================== +# TPU operations — URL construction +# ======================================================================== + + +def test_tpu_create_url_and_params(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response({"name": "operations/op-123"}) + + api = _make_api(handler) + api.tpu_create("my-tpu", ZONE, {"acceleratorType": "v4-8"}) + api.close() + + req = requests_seen[0] + assert req.method == "POST" + assert f"{TPU_BASE}/projects/{PROJECT}/locations/{ZONE}/nodes" in str(req.url) + assert "nodeId=my-tpu" in str(req.url) + + +def test_tpu_get_url(): + def handler(request: httpx.Request) -> httpx.Response: + return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/nodes/my-tpu", "state": "READY"}) + + api = _make_api(handler) + result = api.tpu_get("my-tpu", ZONE) + api.close() + + assert result["state"] == "READY" + + +def test_tpu_delete_ignores_404(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) + + api = _make_api(handler) + api.tpu_delete("gone-tpu", ZONE) # should not raise + api.close() + + +def test_tpu_list_with_pagination(): + call_count = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count == 1: + return _json_response( + { + "nodes": [{"name": "tpu-1", "state": "READY"}], + "nextPageToken": "page2", + } + ) + return _json_response( + { + "nodes": [{"name": "tpu-2", "state": "READY"}], + } + ) + + api = _make_api(handler) + results = api.tpu_list(ZONE) + api.close() + + assert len(results) == 2 + assert results[0]["name"] == "tpu-1" + assert results[1]["name"] == "tpu-2" + assert call_count == 2 + + +# ======================================================================== +# Queued resource operations +# ======================================================================== + + +def test_queued_resource_create_url(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response({"name": "operations/op-456"}) + + api = _make_api(handler) + api.queued_resource_create("my-qr", ZONE, {"tpu": {"nodeSpec": []}}) + api.close() + + req = requests_seen[0] + assert req.method == "POST" + assert "/queuedResources" in str(req.url) + assert "queuedResourceId=my-qr" in str(req.url) + + +def test_queued_resource_delete_ignores_404(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) + + api = _make_api(handler) + api.queued_resource_delete("gone-qr", ZONE) # should not raise + api.close() + + +def test_queued_resource_delete_passes_force(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response({"name": "operations/op-789"}) + + api = _make_api(handler) + api.queued_resource_delete("my-qr", ZONE) + api.close() + + assert "force=true" in str(requests_seen[0].url) + + +# ======================================================================== +# Compute operations +# ======================================================================== + + +def test_instance_insert_url(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response({"name": "operations/op-vm"}) + + api = _make_api(handler) + api.instance_insert(ZONE, {"name": "my-vm", "machineType": "n1-standard-4"}) + api.close() + + req = requests_seen[0] + assert req.method == "POST" + assert f"{COMPUTE_BASE}/projects/{PROJECT}/zones/{ZONE}/instances" in str(req.url) + + +def test_instance_delete_ignores_404(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) + + api = _make_api(handler) + api.instance_delete("gone-vm", ZONE) # should not raise + api.close() + + +def test_instance_reset_url(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response({"name": "operations/reset"}) + + api = _make_api(handler) + api.instance_reset("my-vm", ZONE) + api.close() + + assert "/my-vm/reset" in str(requests_seen[0].url) + + +def test_instance_set_labels_url_and_body(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response({"name": "operations/labels"}) + + api = _make_api(handler) + api.instance_set_labels("my-vm", ZONE, {"env": "test"}, "abc123") + api.close() + + req = requests_seen[0] + assert "/my-vm/setLabels" in str(req.url) + body = json.loads(req.content) + assert body["labels"] == {"env": "test"} + assert body["labelFingerprint"] == "abc123" + + +def test_instance_get_serial_port_output(): + def handler(request: httpx.Request) -> httpx.Response: + return _json_response({"contents": "serial output here", "next": 42}) + + api = _make_api(handler) + result = api.instance_get_serial_port_output("my-vm", ZONE, start=10) + api.close() + + assert result["contents"] == "serial output here" + + +def test_instance_list_project_wide(): + """Project-wide list uses aggregatedList and flattens across zones.""" + + def handler(request: httpx.Request) -> httpx.Response: + return _json_response( + { + "items": { + "zones/us-central1-a": { + "instances": [{"name": "vm-1", "status": "RUNNING"}], + }, + "zones/us-east1-b": { + "instances": [{"name": "vm-2", "status": "RUNNING"}], + }, + "zones/us-west1-a": { + "warning": {"code": "NO_RESULTS_ON_PAGE"}, + }, + } + } + ) + + api = _make_api(handler) + results = api.instance_list(zone=None) + api.close() + + names = {r["name"] for r in results} + assert names == {"vm-1", "vm-2"} + + +def test_instance_list_with_zone(): + def handler(request: httpx.Request) -> httpx.Response: + return _json_response( + { + "items": [{"name": "vm-1", "status": "RUNNING"}], + } + ) + + api = _make_api(handler) + results = api.instance_list(zone=ZONE) + api.close() + + assert len(results) == 1 + assert results[0]["name"] == "vm-1" + + +def test_instance_list_with_filter(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response({"items": []}) + + api = _make_api(handler) + api.instance_list(zone=ZONE, filter_str="labels.env=test") + api.close() + + assert "filter=labels.env%3Dtest" in str(requests_seen[0].url) + + +# ======================================================================== +# Cloud Logging +# ======================================================================== + + +def test_logging_list_entries(): + requests_seen: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + requests_seen.append(request) + return _json_response( + { + "entries": [ + {"textPayload": "line 1"}, + {"textPayload": "line 2"}, + ] + } + ) + + api = _make_api(handler) + entries = api.logging_list_entries("some filter", limit=50) + api.close() + + assert len(entries) == 2 + req = requests_seen[0] + assert req.method == "POST" + assert f"{LOGGING_BASE}/entries:list" in str(req.url) + body = json.loads(req.content) + assert body["filter"] == "some filter" + assert body["pageSize"] == 50 + + +def test_logging_list_entries_empty(): + def handler(request: httpx.Request) -> httpx.Response: + return _json_response({}) + + api = _make_api(handler) + entries = api.logging_list_entries("no match") + api.close() + + assert entries == [] diff --git a/uv.lock b/uv.lock index 1579a9cee6..a07107cb0e 100644 --- a/uv.lock +++ b/uv.lock @@ -3578,6 +3578,7 @@ dependencies = [ { name = "connect-python" }, { name = "fsspec" }, { name = "gcsfs" }, + { name = "google-auth" }, { name = "grpcio" }, { name = "httpx" }, { name = "humanfriendly" }, @@ -3631,6 +3632,7 @@ requires-dist = [ { name = "duckdb", marker = "extra == 'controller'", specifier = ">=1.0.0" }, { name = "fsspec", specifier = ">=2024.0.0" }, { name = "gcsfs", specifier = ">=2024.0.0" }, + { name = "google-auth", specifier = ">=2.0" }, { name = "grpcio", specifier = ">=1.76.0" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "humanfriendly", specifier = ">=10.0" }, From ed13992b5ce19a4c34421c35d3cf55f90b355df9 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 3 Apr 2026 17:48:25 +0000 Subject: [PATCH 2/8] [iris] Wait for async GCP operations before describing resources vm_create and tpu_create called the REST API (which returns async operations) then immediately tried to describe the resource. The old gcloud CLI blocked until operations completed; the REST API does not. This caused workers to fail with "created but could not be described" because the VM/TPU wasn't visible yet. Add operation polling to GCPApi (_wait_zone_operation for Compute, _wait_tpu_operation for TPU LROs), with instance_insert_wait and tpu_create_wait convenience methods. Update CloudGcpService to use the waiting variants. Also fix _fetch_bootstrap_logs timestamp filter: Cloud Logging needs RFC3339 timestamps, not "-PT30M" duration literals. https://claude.ai/code/session_01L4bVGg6j4fw19RiADT1GhM --- .../src/iris/cluster/providers/gcp/api.py | 63 ++++++- .../src/iris/cluster/providers/gcp/service.py | 11 +- .../src/iris/cluster/providers/gcp/workers.py | 5 +- .../cluster/providers/gcp/test_gcp_api.py | 171 ++++++++++++++++++ 4 files changed, 237 insertions(+), 13 deletions(-) diff --git a/lib/iris/src/iris/cluster/providers/gcp/api.py b/lib/iris/src/iris/cluster/providers/gcp/api.py index 7700e8fec0..c91a19e829 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/api.py +++ b/lib/iris/src/iris/cluster/providers/gcp/api.py @@ -33,6 +33,8 @@ _REFRESH_MARGIN = 300 # seconds before expiry to refresh token _DEFAULT_TIMEOUT = 120 # seconds +_OPERATION_POLL_INTERVAL = 2 # seconds between operation status polls +_OPERATION_TIMEOUT = 600 # seconds to wait for an operation to complete class GCPApi: @@ -123,6 +125,44 @@ def _paginate_raw(self, url: str, params: dict[str, str] | None = None) -> list[ p["pageToken"] = token return pages + # -- Operation polling ---------------------------------------------------- + + def _wait_zone_operation(self, zone: str, operation_name: str, timeout: float = _OPERATION_TIMEOUT) -> dict: + """Poll a Compute Engine zone operation until status is DONE.""" + url = f"{COMPUTE_BASE}/projects/{self._project_id}/zones/{zone}/operations/{operation_name}" + deadline = time.monotonic() + timeout + while True: + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + data = resp.json() + if data.get("status") == "DONE": + if "error" in data: + errors = data["error"].get("errors", []) + msg = "; ".join(e.get("message", str(e)) for e in errors) + raise InfraError(f"Operation {operation_name} failed: {msg}") + return data + if time.monotonic() >= deadline: + raise InfraError(f"Operation {operation_name} timed out after {timeout}s") + time.sleep(_OPERATION_POLL_INTERVAL) + + def _wait_tpu_operation(self, operation_name: str, timeout: float = _OPERATION_TIMEOUT) -> dict: + """Poll a TPU v2 long-running operation until done.""" + url = f"{TPU_BASE}/{operation_name}" + deadline = time.monotonic() + timeout + while True: + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + data = resp.json() + if data.get("done"): + if "error" in data: + error = data["error"] + msg = error.get("message", str(error)) + raise InfraError(f"TPU operation failed: {msg}") + return data + if time.monotonic() >= deadline: + raise InfraError(f"TPU operation {operation_name} timed out after {timeout}s") + time.sleep(_OPERATION_POLL_INTERVAL) + # ====================================================================== # TPU v2 # ====================================================================== @@ -130,13 +170,20 @@ def _paginate_raw(self, url: str, params: dict[str, str] | None = None) -> list[ def _tpu_parent(self, zone: str) -> str: return f"projects/{self._project_id}/locations/{zone}" - def tpu_create(self, name: str, zone: str, body: dict) -> dict | None: + def tpu_create(self, name: str, zone: str, body: dict) -> dict: + """Start TPU creation and return the raw LRO response.""" url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes" resp = self._client.post(url, params={"nodeId": name}, headers=self._headers(), json=body) self._classify_response(resp) - data = resp.json() - # REST create returns a long-running operation, not the node itself. - return data if data.get("name", "").endswith(f"/nodes/{name}") else None + return resp.json() + + def tpu_create_wait(self, name: str, zone: str, body: dict) -> dict: + """Create a TPU and wait for the LRO to complete, then return the node.""" + data = self.tpu_create(name, zone, body) + op_name = data.get("name", "") + if op_name and "/operations/" in op_name: + self._wait_tpu_operation(op_name) + return self.tpu_get(name, zone) def tpu_get(self, name: str, zone: str) -> dict: url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}" @@ -201,6 +248,14 @@ def instance_insert(self, zone: str, body: dict) -> dict: self._classify_response(resp) return resp.json() + def instance_insert_wait(self, zone: str, body: dict) -> dict: + """Insert a VM and wait for the zone operation to complete.""" + data = self.instance_insert(zone, body) + op_name = data.get("name", "") + if op_name: + self._wait_zone_operation(zone, op_name) + return data + def instance_get(self, name: str, zone: str) -> dict: url = self._instance_url(zone, name) resp = self._client.get(url, headers=self._headers()) diff --git a/lib/iris/src/iris/cluster/providers/gcp/service.py b/lib/iris/src/iris/cluster/providers/gcp/service.py index 2bc305ea5c..440182a7a8 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/service.py +++ b/lib/iris/src/iris/cluster/providers/gcp/service.py @@ -401,13 +401,8 @@ def tpu_create(self, request: TpuCreateRequest) -> TpuInfo: body["networkConfig"] = network_config logger.info("Creating TPU: %s (type=%s, zone=%s)", request.name, request.accelerator_type, request.zone) - self._api.tpu_create(request.name, request.zone, body) - - # REST create returns an operation, not the node — fetch it explicitly. - info = self.tpu_describe(request.name, request.zone) - if info is None: - raise InfraError(f"TPU {request.name} created but could not be described") - return info + tpu_data = self._api.tpu_create_wait(request.name, request.zone, body) + return _parse_tpu_info(tpu_data, request.zone) def tpu_delete(self, name: str, zone: str) -> None: logger.info("Deleting TPU (async): %s", name) @@ -557,7 +552,7 @@ def vm_create(self, request: VmCreateRequest) -> VmInfo: logger.info("Creating VM: %s (zone=%s, type=%s)", request.name, request.zone, request.machine_type) try: - self._api.instance_insert(request.zone, body) + self._api.instance_insert_wait(request.zone, body) except InfraError as e: if "already exists" not in str(e).lower(): raise diff --git a/lib/iris/src/iris/cluster/providers/gcp/workers.py b/lib/iris/src/iris/cluster/providers/gcp/workers.py index 0b120601a0..a5eb791784 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/workers.py +++ b/lib/iris/src/iris/cluster/providers/gcp/workers.py @@ -13,6 +13,7 @@ import logging import threading from collections.abc import Callable +from datetime import datetime, timedelta, timezone import time import urllib.error import urllib.request @@ -774,11 +775,13 @@ def _run_tpu_bootstrap( def _fetch_bootstrap_logs(gcp_service: GcpService, handle: GcpSliceHandle) -> None: """Fetch [iris-init] log entries from Cloud Logging for diagnostics.""" + cutoff = datetime.now(timezone.utc) - timedelta(minutes=30) + cutoff_str = cutoff.strftime("%Y-%m-%dT%H:%M:%SZ") log_filter = ( f'resource.type="gce_instance" ' f'textPayload:"[iris-init]" ' f'labels."compute.googleapis.com/resource_name":"{handle._slice_id}" ' - f'timestamp>="-PT30M"' + f'timestamp>="{cutoff_str}"' ) texts = gcp_service.logging_read(log_filter, limit=200) if texts: diff --git a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py index bbf58016b1..87b2c7aeaf 100644 --- a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py +++ b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py @@ -421,6 +421,177 @@ def handler(request: httpx.Request) -> httpx.Response: assert body["pageSize"] == 50 +# ======================================================================== +# Operation waiting — vm_create / tpu_create must wait for async operations +# ======================================================================== + + +def test_instance_insert_wait_polls_until_done(): + """instance_insert_wait should poll the zone operation until DONE.""" + call_count = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal call_count + # POST to create the instance + if request.method == "POST" and "/instances" in str(request.url) and "/operations/" not in str(request.url): + return _json_response( + {"name": "op-123", "status": "RUNNING", "zone": f"zones/{ZONE}", "kind": "compute#operation"} + ) + # GET to poll the operation + if "/operations/op-123" in str(request.url): + call_count += 1 + if call_count < 2: + return _json_response({"name": "op-123", "status": "RUNNING"}) + return _json_response({"name": "op-123", "status": "DONE"}) + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) + + api = _make_api(handler) + api.instance_insert_wait(ZONE, {"name": "my-vm"}) + api.close() + + assert call_count == 2, "Must poll operation until DONE" + + +def test_instance_insert_wait_raises_on_operation_error(): + """If the operation completes with an error, raise InfraError.""" + + def handler(request: httpx.Request) -> httpx.Response: + if request.method == "POST" and "/instances" in str(request.url) and "/operations/" not in str(request.url): + return _json_response( + {"name": "op-err", "status": "RUNNING", "zone": f"zones/{ZONE}", "kind": "compute#operation"} + ) + if "/operations/op-err" in str(request.url): + return _json_response( + { + "name": "op-err", + "status": "DONE", + "error": {"errors": [{"code": "QUOTA_EXCEEDED", "message": "Insufficient quota"}]}, + } + ) + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) + + api = _make_api(handler) + with pytest.raises(InfraError, match="Insufficient quota"): + api.instance_insert_wait(ZONE, {"name": "my-vm"}) + api.close() + + +def test_vm_create_waits_for_operation(): + """vm_create must wait for the insert operation before describing the VM. + + This is the core regression from replacing `gcloud compute instances create` + (which blocks until RUNNING) with the REST API (which returns immediately). + The mock simulates real GCP behavior: instance_get returns 404 until the + zone operation reaches DONE. + """ + from iris.cluster.providers.gcp.service import CloudGcpService, VmCreateRequest + + operation_done = False + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal operation_done + url = str(request.url) + + # POST instances — create (returns operation) + if request.method == "POST" and url.endswith(f"/zones/{ZONE}/instances"): + return _json_response( + {"name": "op-vm-1", "status": "RUNNING", "zone": f"zones/{ZONE}", "kind": "compute#operation"} + ) + + # GET operation poll — marks operation as done + if "/operations/op-vm-1" in url and request.method == "GET": + operation_done = True + return _json_response({"name": "op-vm-1", "status": "DONE"}) + + # GET instance — only succeeds after operation completed (real GCP behavior) + if request.method == "GET" and url.endswith(f"/zones/{ZONE}/instances/test-vm"): + if not operation_done: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) + return _json_response( + { + "name": "test-vm", + "status": "RUNNING", + "zone": f"projects/{PROJECT}/zones/{ZONE}", + "networkInterfaces": [{"networkIP": "10.0.0.1", "accessConfigs": [{"natIP": "34.1.2.3"}]}], + "metadata": {}, + "serviceAccounts": [{"email": "sa@test.iam.gserviceaccount.com"}], + "creationTimestamp": "2026-01-01T00:00:00Z", + } + ) + + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) + + api = _make_api(handler) + svc = CloudGcpService(PROJECT, api=api) + info = svc.vm_create(VmCreateRequest(name="test-vm", zone=ZONE, machine_type="n1-standard-4", labels={})) + api.close() + + assert info.name == "test-vm" + assert info.internal_ip == "10.0.0.1" + assert operation_done, "vm_create must poll the operation before describing" + + +def test_tpu_create_waits_for_operation(): + """tpu_create must wait for the LRO before describing the TPU. + + Same race condition as vm_create: the REST API create returns an LRO, + and the TPU node may not be visible via tpu_get until the operation completes. + """ + from iris.cluster.providers.gcp.service import CloudGcpService, TpuCreateRequest + from iris.rpc import config_pb2 + + operation_done = False + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal operation_done + url = str(request.url) + + # POST nodes — create (returns LRO) + if request.method == "POST" and "/nodes" in url and "nodeId=test-tpu" in url: + return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/operations/op-tpu-1", "done": False}) + + # GET operation poll (TPU LRO) — marks operation as done + if "/operations/op-tpu-1" in url and request.method == "GET": + operation_done = True + return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/operations/op-tpu-1", "done": True}) + + # GET node — only succeeds after operation completed + if request.method == "GET" and url.endswith("/nodes/test-tpu"): + if not operation_done: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) + return _json_response( + { + "name": f"projects/{PROJECT}/locations/{ZONE}/nodes/test-tpu", + "state": "READY", + "acceleratorType": "v4-8", + "networkEndpoints": [{"ipAddress": "10.0.0.2"}], + "labels": {}, + "metadata": {}, + "createTime": "2026-01-01T00:00:00Z", + } + ) + + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) + + api = _make_api(handler) + svc = CloudGcpService(PROJECT, api=api) + info = svc.tpu_create( + TpuCreateRequest( + name="test-tpu", + zone=ZONE, + accelerator_type="v4-8", + runtime_version="tpu-ubuntu2204-base", + capacity_type=config_pb2.CAPACITY_TYPE_ON_DEMAND, + labels={}, + ) + ) + api.close() + + assert info.name == "test-tpu" + assert info.state == "READY" + assert operation_done, "tpu_create must poll the operation before describing" + + def test_logging_list_entries_empty(): def handler(request: httpx.Request) -> httpx.Response: return _json_response({}) From c09776812de0ac4007244b7401f5f8d22d953c06 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 3 Apr 2026 19:30:02 +0000 Subject: [PATCH 3/8] [iris] Remove implementation-detail tests for operation polling Tests should validate external behavior boundaries, not internal polling mechanics. Keep the service-level tests (vm_create_waits, tpu_create_waits) that verify the real contract: create succeeds even when the underlying API is async. https://claude.ai/code/session_01L4bVGg6j4fw19RiADT1GhM --- .../cluster/providers/gcp/test_gcp_api.py | 50 ------------------- 1 file changed, 50 deletions(-) diff --git a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py index 87b2c7aeaf..bed8d496a1 100644 --- a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py +++ b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py @@ -426,56 +426,6 @@ def handler(request: httpx.Request) -> httpx.Response: # ======================================================================== -def test_instance_insert_wait_polls_until_done(): - """instance_insert_wait should poll the zone operation until DONE.""" - call_count = 0 - - def handler(request: httpx.Request) -> httpx.Response: - nonlocal call_count - # POST to create the instance - if request.method == "POST" and "/instances" in str(request.url) and "/operations/" not in str(request.url): - return _json_response( - {"name": "op-123", "status": "RUNNING", "zone": f"zones/{ZONE}", "kind": "compute#operation"} - ) - # GET to poll the operation - if "/operations/op-123" in str(request.url): - call_count += 1 - if call_count < 2: - return _json_response({"name": "op-123", "status": "RUNNING"}) - return _json_response({"name": "op-123", "status": "DONE"}) - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - api = _make_api(handler) - api.instance_insert_wait(ZONE, {"name": "my-vm"}) - api.close() - - assert call_count == 2, "Must poll operation until DONE" - - -def test_instance_insert_wait_raises_on_operation_error(): - """If the operation completes with an error, raise InfraError.""" - - def handler(request: httpx.Request) -> httpx.Response: - if request.method == "POST" and "/instances" in str(request.url) and "/operations/" not in str(request.url): - return _json_response( - {"name": "op-err", "status": "RUNNING", "zone": f"zones/{ZONE}", "kind": "compute#operation"} - ) - if "/operations/op-err" in str(request.url): - return _json_response( - { - "name": "op-err", - "status": "DONE", - "error": {"errors": [{"code": "QUOTA_EXCEEDED", "message": "Insufficient quota"}]}, - } - ) - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - api = _make_api(handler) - with pytest.raises(InfraError, match="Insufficient quota"): - api.instance_insert_wait(ZONE, {"name": "my-vm"}) - api.close() - - def test_vm_create_waits_for_operation(): """vm_create must wait for the insert operation before describing the VM. From ad57024adba578a632db2a0feecd53ced689c516 Mon Sep 17 00:00:00 2001 From: Russell Power Date: Fri, 3 Apr 2026 13:34:40 -0700 Subject: [PATCH 4/8] [iris] Inline GCPApi into CloudGcpService and add integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The separate GCPApi class added indirection without reuse — it was only used by CloudGcpService. Inline all HTTP/auth/pagination/operation-polling logic directly into CloudGcpService and delete api.py. This also fixes the controller startup bottleneck where tpu_list with no zones would scan all 15+ KNOWN_GCP_ZONES sequentially, each requiring a REST API call. The controller's autoscaler list operations were taking minutes, causing the 600s worker timeout to expire before TPUs could even be created. --- .../src/iris/cluster/providers/gcp/api.py | 326 ---------- .../src/iris/cluster/providers/gcp/service.py | 284 +++++++-- .../gcp/test_cloud_service_integration.py | 505 ++++++++++++++++ .../cluster/providers/gcp/test_gcp_api.py | 559 +++++++----------- 4 files changed, 973 insertions(+), 701 deletions(-) delete mode 100644 lib/iris/src/iris/cluster/providers/gcp/api.py create mode 100644 lib/iris/tests/cluster/providers/gcp/test_cloud_service_integration.py diff --git a/lib/iris/src/iris/cluster/providers/gcp/api.py b/lib/iris/src/iris/cluster/providers/gcp/api.py deleted file mode 100644 index c91a19e829..0000000000 --- a/lib/iris/src/iris/cluster/providers/gcp/api.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Low-level HTTP client for GCP REST APIs (TPU v2, Compute v1, Cloud Logging). - -Handles authentication (Application Default Credentials), token caching, -pagination, and error mapping to domain exceptions. Used by CloudGcpService -as a replacement for gcloud CLI subprocess calls. -""" - -from __future__ import annotations - -import json -import logging -import time - -import google.auth -import google.auth.credentials -import google.auth.transport.requests -import httpx - -from iris.cluster.providers.types import ( - InfraError, - QuotaExhaustedError, - ResourceNotFoundError, -) - -logger = logging.getLogger(__name__) - -TPU_BASE = "https://tpu.googleapis.com/v2" -COMPUTE_BASE = "https://compute.googleapis.com/compute/v1" -LOGGING_BASE = "https://logging.googleapis.com/v2" - -_REFRESH_MARGIN = 300 # seconds before expiry to refresh token -_DEFAULT_TIMEOUT = 120 # seconds -_OPERATION_POLL_INTERVAL = 2 # seconds between operation status polls -_OPERATION_TIMEOUT = 600 # seconds to wait for an operation to complete - - -class GCPApi: - """Low-level HTTP client for GCP REST APIs with ADC auth and token caching.""" - - def __init__(self, project_id: str) -> None: - self._project_id = project_id - self._client = httpx.Client(timeout=_DEFAULT_TIMEOUT) - self._creds: google.auth.credentials.Credentials | None = None - self._token: str | None = None - self._expires_at: float = 0.0 - - def close(self) -> None: - self._client.close() - - # -- Auth --------------------------------------------------------------- - - def _headers(self) -> dict[str, str]: - if self._token is None or time.monotonic() >= self._expires_at: - self._refresh_token() - return { - "Authorization": f"Bearer {self._token}", - "Content-Type": "application/json", - } - - def _refresh_token(self) -> None: - if self._creds is None: - self._creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) - self._creds.refresh(google.auth.transport.requests.Request()) - self._token = self._creds.token - now = time.monotonic() - if self._creds.expiry is not None: - self._expires_at = now + (self._creds.expiry.timestamp() - time.time()) - _REFRESH_MARGIN - else: - self._expires_at = now + _REFRESH_MARGIN - - # -- Error mapping ------------------------------------------------------ - - def _classify_response(self, resp: httpx.Response) -> None: - """Raise a domain exception for non-2xx responses.""" - if resp.status_code < 400: - return - try: - body = resp.json() - error = body.get("error", {}) - message = error.get("message", resp.text) - status = error.get("status", "") - code = error.get("code", resp.status_code) - except (json.JSONDecodeError, AttributeError): - message = resp.text - status = "" - code = resp.status_code - - if code == 404 or status == "NOT_FOUND": - raise ResourceNotFoundError(message) - if code == 429 or status in ("RESOURCE_EXHAUSTED", "QUOTA_EXCEEDED"): - raise QuotaExhaustedError(message) - raise InfraError(f"GCP API error {code}: {message}") - - # -- Pagination --------------------------------------------------------- - - def _paginate(self, url: str, items_key: str, params: dict[str, str] | None = None) -> list[dict]: - results: list[dict] = [] - p = dict(params or {}) - while True: - resp = self._client.get(url, headers=self._headers(), params=p) - self._classify_response(resp) - data = resp.json() - results.extend(data.get(items_key, [])) - token = data.get("nextPageToken") - if not token: - break - p["pageToken"] = token - return results - - def _paginate_raw(self, url: str, params: dict[str, str] | None = None) -> list[dict]: - """Return raw page bodies (for aggregatedList where items_key varies).""" - pages: list[dict] = [] - p = dict(params or {}) - while True: - resp = self._client.get(url, headers=self._headers(), params=p) - self._classify_response(resp) - data = resp.json() - pages.append(data) - token = data.get("nextPageToken") - if not token: - break - p["pageToken"] = token - return pages - - # -- Operation polling ---------------------------------------------------- - - def _wait_zone_operation(self, zone: str, operation_name: str, timeout: float = _OPERATION_TIMEOUT) -> dict: - """Poll a Compute Engine zone operation until status is DONE.""" - url = f"{COMPUTE_BASE}/projects/{self._project_id}/zones/{zone}/operations/{operation_name}" - deadline = time.monotonic() + timeout - while True: - resp = self._client.get(url, headers=self._headers()) - self._classify_response(resp) - data = resp.json() - if data.get("status") == "DONE": - if "error" in data: - errors = data["error"].get("errors", []) - msg = "; ".join(e.get("message", str(e)) for e in errors) - raise InfraError(f"Operation {operation_name} failed: {msg}") - return data - if time.monotonic() >= deadline: - raise InfraError(f"Operation {operation_name} timed out after {timeout}s") - time.sleep(_OPERATION_POLL_INTERVAL) - - def _wait_tpu_operation(self, operation_name: str, timeout: float = _OPERATION_TIMEOUT) -> dict: - """Poll a TPU v2 long-running operation until done.""" - url = f"{TPU_BASE}/{operation_name}" - deadline = time.monotonic() + timeout - while True: - resp = self._client.get(url, headers=self._headers()) - self._classify_response(resp) - data = resp.json() - if data.get("done"): - if "error" in data: - error = data["error"] - msg = error.get("message", str(error)) - raise InfraError(f"TPU operation failed: {msg}") - return data - if time.monotonic() >= deadline: - raise InfraError(f"TPU operation {operation_name} timed out after {timeout}s") - time.sleep(_OPERATION_POLL_INTERVAL) - - # ====================================================================== - # TPU v2 - # ====================================================================== - - def _tpu_parent(self, zone: str) -> str: - return f"projects/{self._project_id}/locations/{zone}" - - def tpu_create(self, name: str, zone: str, body: dict) -> dict: - """Start TPU creation and return the raw LRO response.""" - url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes" - resp = self._client.post(url, params={"nodeId": name}, headers=self._headers(), json=body) - self._classify_response(resp) - return resp.json() - - def tpu_create_wait(self, name: str, zone: str, body: dict) -> dict: - """Create a TPU and wait for the LRO to complete, then return the node.""" - data = self.tpu_create(name, zone, body) - op_name = data.get("name", "") - if op_name and "/operations/" in op_name: - self._wait_tpu_operation(op_name) - return self.tpu_get(name, zone) - - def tpu_get(self, name: str, zone: str) -> dict: - url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}" - resp = self._client.get(url, headers=self._headers()) - self._classify_response(resp) - return resp.json() - - def tpu_delete(self, name: str, zone: str) -> None: - url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}" - resp = self._client.delete(url, headers=self._headers()) - if resp.status_code != 404: - self._classify_response(resp) - - def tpu_list(self, zone: str) -> list[dict]: - return self._paginate(f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes", "nodes") - - # ====================================================================== - # TPU v2 — Queued Resources - # ====================================================================== - - def queued_resource_create(self, name: str, zone: str, body: dict) -> None: - url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources" - resp = self._client.post( - url, - params={"queuedResourceId": name}, - headers=self._headers(), - json=body, - ) - self._classify_response(resp) - - def queued_resource_get(self, name: str, zone: str) -> dict: - url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources/{name}" - resp = self._client.get(url, headers=self._headers()) - self._classify_response(resp) - return resp.json() - - def queued_resource_delete(self, name: str, zone: str) -> None: - url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources/{name}" - resp = self._client.delete(url, params={"force": "true"}, headers=self._headers()) - if resp.status_code != 404: - self._classify_response(resp) - - def queued_resource_list(self, zone: str) -> list[dict]: - return self._paginate( - f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources", - "queuedResources", - ) - - # ====================================================================== - # Compute Engine v1 — Instances - # ====================================================================== - - def _instance_url(self, zone: str, name: str = "") -> str: - path = f"{COMPUTE_BASE}/projects/{self._project_id}/zones/{zone}/instances" - if name: - path += f"/{name}" - return path - - def instance_insert(self, zone: str, body: dict) -> dict: - url = self._instance_url(zone) - resp = self._client.post(url, headers=self._headers(), json=body) - self._classify_response(resp) - return resp.json() - - def instance_insert_wait(self, zone: str, body: dict) -> dict: - """Insert a VM and wait for the zone operation to complete.""" - data = self.instance_insert(zone, body) - op_name = data.get("name", "") - if op_name: - self._wait_zone_operation(zone, op_name) - return data - - def instance_get(self, name: str, zone: str) -> dict: - url = self._instance_url(zone, name) - resp = self._client.get(url, headers=self._headers()) - self._classify_response(resp) - return resp.json() - - def instance_delete(self, name: str, zone: str) -> None: - url = self._instance_url(zone, name) - resp = self._client.delete(url, headers=self._headers()) - if resp.status_code != 404: - self._classify_response(resp) - - def instance_list(self, zone: str | None = None, filter_str: str = "") -> list[dict]: - params: dict[str, str] = {} - if filter_str: - params["filter"] = filter_str - - if zone: - return self._paginate(self._instance_url(zone), "items", params) - - # Project-wide: aggregatedList, flatten across zones - url = f"{COMPUTE_BASE}/projects/{self._project_id}/aggregated/instances" - results: list[dict] = [] - for page in self._paginate_raw(url, params): - for scope in page.get("items", {}).values(): - results.extend(scope.get("instances", [])) - return results - - def instance_reset(self, name: str, zone: str) -> None: - url = self._instance_url(zone, name) + "/reset" - resp = self._client.post(url, headers=self._headers()) - self._classify_response(resp) - - def instance_set_labels(self, name: str, zone: str, labels: dict[str, str], fingerprint: str) -> None: - url = self._instance_url(zone, name) + "/setLabels" - resp = self._client.post( - url, - headers=self._headers(), - json={"labels": labels, "labelFingerprint": fingerprint}, - ) - self._classify_response(resp) - - def instance_set_metadata(self, name: str, zone: str, metadata_body: dict) -> None: - url = self._instance_url(zone, name) + "/setMetadata" - resp = self._client.post(url, headers=self._headers(), json=metadata_body) - self._classify_response(resp) - - def instance_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> dict: - url = self._instance_url(zone, name) + "/serialPort" - resp = self._client.get(url, headers=self._headers(), params={"start": str(start)}) - self._classify_response(resp) - return resp.json() - - # ====================================================================== - # Cloud Logging v2 - # ====================================================================== - - def logging_list_entries(self, filter_str: str, limit: int = 200) -> list[dict]: - url = f"{LOGGING_BASE}/entries:list" - body = { - "resourceNames": [f"projects/{self._project_id}"], - "filter": filter_str, - "pageSize": min(limit, 1000), - "orderBy": "timestamp desc", - } - resp = self._client.post(url, headers=self._headers(), json=body, timeout=30) - self._classify_response(resp) - return resp.json().get("entries", []) diff --git a/lib/iris/src/iris/cluster/providers/gcp/service.py b/lib/iris/src/iris/cluster/providers/gcp/service.py index 440182a7a8..4f42acc26a 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/service.py +++ b/lib/iris/src/iris/cluster/providers/gcp/service.py @@ -3,15 +3,22 @@ from __future__ import annotations +import json import logging import re +import time from dataclasses import dataclass, field from datetime import datetime from typing import Protocol -from iris.cluster.providers.gcp.api import GCPApi +import google.auth +import google.auth.credentials +import google.auth.transport.requests +import httpx + from iris.cluster.providers.types import ( InfraError, + QuotaExhaustedError, ResourceNotFoundError, ) from iris.cluster.providers.gcp.local import LocalSliceHandle @@ -58,6 +65,17 @@ CAPACITY_TYPE_LABEL = "capacity-type" CAPACITY_TYPE_RESERVED_VALUE = "reserved" +# REST API base URLs +_TPU_BASE = "https://tpu.googleapis.com/v2" +_COMPUTE_BASE = "https://compute.googleapis.com/compute/v1" +_LOGGING_BASE = "https://logging.googleapis.com/v2" + +# HTTP/auth constants +_REFRESH_MARGIN = 300 # seconds before expiry to refresh token +_DEFAULT_TIMEOUT = 120 # seconds +_OPERATION_POLL_INTERVAL = 2 # seconds between operation status polls +_OPERATION_TIMEOUT = 600 # seconds to wait for an operation to complete + # ============================================================================ # Data types @@ -245,7 +263,7 @@ def shutdown(self) -> None: # ============================================================================ -# CloudGcpService — REST API implementation via GCPApi +# CloudGcpService — REST API implementation # ============================================================================ @@ -353,11 +371,14 @@ def _parse_vm_info(vm_data: dict, fallback_zone: str = "") -> VmInfo: class CloudGcpService: - """GcpService backed by GCP REST APIs via GCPApi. Used in CLOUD mode.""" + """GcpService backed by GCP REST APIs. Used in CLOUD mode.""" - def __init__(self, project_id: str, api: GCPApi | None = None) -> None: + def __init__(self, project_id: str, http_client: httpx.Client | None = None) -> None: self._project_id = project_id - self._api = api if api is not None else GCPApi(project_id) + self._client = http_client if http_client is not None else httpx.Client(timeout=_DEFAULT_TIMEOUT) + self._creds: google.auth.credentials.Credentials | None = None + self._token: str | None = None + self._expires_at: float = 0.0 self._valid_zones: set[str] = set(KNOWN_GCP_ZONES) self._valid_accelerator_types: set[str] = set(KNOWN_TPU_TYPES) @@ -369,9 +390,123 @@ def mode(self) -> ServiceMode: def project_id(self) -> str: return self._project_id - @property - def api(self) -> GCPApi: - return self._api + # ======================================================================== + # HTTP helpers (auth, errors, pagination, operation polling) + # ======================================================================== + + def _headers(self) -> dict[str, str]: + if self._token is None or time.monotonic() >= self._expires_at: + self._refresh_token() + return { + "Authorization": f"Bearer {self._token}", + "Content-Type": "application/json", + } + + def _refresh_token(self) -> None: + if self._creds is None: + self._creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + self._creds.refresh(google.auth.transport.requests.Request()) + self._token = self._creds.token + now = time.monotonic() + if self._creds.expiry is not None: + self._expires_at = now + (self._creds.expiry.timestamp() - time.time()) - _REFRESH_MARGIN + else: + self._expires_at = now + _REFRESH_MARGIN + + def _classify_response(self, resp: httpx.Response) -> None: + if resp.status_code < 400: + return + try: + body = resp.json() + error = body.get("error", {}) + message = error.get("message", resp.text) + status = error.get("status", "") + code = error.get("code", resp.status_code) + except (json.JSONDecodeError, AttributeError): + message = resp.text + status = "" + code = resp.status_code + + if code == 404 or status == "NOT_FOUND": + raise ResourceNotFoundError(message) + if code == 429 or status in ("RESOURCE_EXHAUSTED", "QUOTA_EXCEEDED"): + raise QuotaExhaustedError(message) + raise InfraError(f"GCP API error {code}: {message}") + + def _paginate(self, url: str, items_key: str, params: dict[str, str] | None = None) -> list[dict]: + results: list[dict] = [] + p = dict(params or {}) + while True: + resp = self._client.get(url, headers=self._headers(), params=p) + self._classify_response(resp) + data = resp.json() + results.extend(data.get(items_key, [])) + token = data.get("nextPageToken") + if not token: + break + p["pageToken"] = token + return results + + def _paginate_raw(self, url: str, params: dict[str, str] | None = None) -> list[dict]: + pages: list[dict] = [] + p = dict(params or {}) + while True: + resp = self._client.get(url, headers=self._headers(), params=p) + self._classify_response(resp) + data = resp.json() + pages.append(data) + token = data.get("nextPageToken") + if not token: + break + p["pageToken"] = token + return pages + + def _wait_zone_operation(self, zone: str, operation_name: str, timeout: float = _OPERATION_TIMEOUT) -> dict: + url = f"{_COMPUTE_BASE}/projects/{self._project_id}/zones/{zone}/operations/{operation_name}" + deadline = time.monotonic() + timeout + while True: + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + data = resp.json() + if data.get("status") == "DONE": + if "error" in data: + errors = data["error"].get("errors", []) + msg = "; ".join(e.get("message", str(e)) for e in errors) + raise InfraError(f"Operation {operation_name} failed: {msg}") + return data + if time.monotonic() >= deadline: + raise InfraError(f"Operation {operation_name} timed out after {timeout}s") + time.sleep(_OPERATION_POLL_INTERVAL) + + def _wait_tpu_operation(self, operation_name: str, timeout: float = _OPERATION_TIMEOUT) -> dict: + url = f"{_TPU_BASE}/{operation_name}" + deadline = time.monotonic() + timeout + while True: + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + data = resp.json() + if data.get("done"): + if "error" in data: + error = data["error"] + msg = error.get("message", str(error)) + raise InfraError(f"TPU operation failed: {msg}") + return data + if time.monotonic() >= deadline: + raise InfraError(f"TPU operation {operation_name} timed out after {timeout}s") + time.sleep(_OPERATION_POLL_INTERVAL) + + # ======================================================================== + # Low-level REST helpers + # ======================================================================== + + def _tpu_parent(self, zone: str) -> str: + return f"projects/{self._project_id}/locations/{zone}" + + def _instance_url(self, zone: str, name: str = "") -> str: + path = f"{_COMPUTE_BASE}/projects/{self._project_id}/zones/{zone}/instances" + if name: + path += f"/{name}" + return path # ======================================================================== # TPU operations @@ -401,16 +536,35 @@ def tpu_create(self, request: TpuCreateRequest) -> TpuInfo: body["networkConfig"] = network_config logger.info("Creating TPU: %s (type=%s, zone=%s)", request.name, request.accelerator_type, request.zone) - tpu_data = self._api.tpu_create_wait(request.name, request.zone, body) + + # POST to create, wait for LRO, then GET the final node state + url = f"{_TPU_BASE}/{self._tpu_parent(request.zone)}/nodes" + resp = self._client.post(url, params={"nodeId": request.name}, headers=self._headers(), json=body) + self._classify_response(resp) + data = resp.json() + op_name = data.get("name", "") + if op_name and "/operations/" in op_name: + self._wait_tpu_operation(op_name) + + tpu_data = self._tpu_get(request.name, request.zone) return _parse_tpu_info(tpu_data, request.zone) def tpu_delete(self, name: str, zone: str) -> None: logger.info("Deleting TPU (async): %s", name) - self._api.tpu_delete(name, zone) + url = f"{_TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}" + resp = self._client.delete(url, headers=self._headers()) + if resp.status_code != 404: + self._classify_response(resp) + + def _tpu_get(self, name: str, zone: str) -> dict: + url = f"{_TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}" + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + return resp.json() def tpu_describe(self, name: str, zone: str) -> TpuInfo | None: try: - tpu_data = self._api.tpu_get(name, zone) + tpu_data = self._tpu_get(name, zone) except ResourceNotFoundError: return None except InfraError: @@ -420,12 +574,13 @@ def tpu_describe(self, name: str, zone: str) -> TpuInfo | None: def tpu_list(self, zones: list[str], labels: dict[str, str] | None = None) -> list[TpuInfo]: results: list[TpuInfo] = [] - # TPU v2 API requires a real zone; empty zones = scan all known zones. + # TPU v2 API requires a specific zone per request. + # When no zones specified, only scan the caller's zones (not all known zones). zone_list = zones if zones else list(self._valid_zones) for zone in zone_list: try: - items = self._api.tpu_list(zone) + items = self._paginate(f"{_TPU_BASE}/{self._tpu_parent(zone)}/nodes", "nodes") except InfraError: logger.warning("Failed to list TPUs in zone %s", zone, exc_info=True) continue @@ -480,11 +635,21 @@ def queued_resource_create(self, request: TpuCreateRequest) -> None: request.accelerator_type, request.zone, ) - self._api.queued_resource_create(request.name, request.zone, body) + url = f"{_TPU_BASE}/{self._tpu_parent(request.zone)}/queuedResources" + resp = self._client.post( + url, + params={"queuedResourceId": request.name}, + headers=self._headers(), + json=body, + ) + self._classify_response(resp) def queued_resource_describe(self, name: str, zone: str) -> QueuedResourceInfo | None: try: - data = self._api.queued_resource_get(name, zone) + url = f"{_TPU_BASE}/{self._tpu_parent(zone)}/queuedResources/{name}" + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + data = resp.json() except ResourceNotFoundError: return None state = data.get("state", {}).get("state", "UNKNOWN") @@ -492,14 +657,20 @@ def queued_resource_describe(self, name: str, zone: str) -> QueuedResourceInfo | def queued_resource_delete(self, name: str, zone: str) -> None: logger.info("Deleting queued resource (force): %s", name) - self._api.queued_resource_delete(name, zone) + url = f"{_TPU_BASE}/{self._tpu_parent(zone)}/queuedResources/{name}" + resp = self._client.delete(url, params={"force": "true"}, headers=self._headers()) + if resp.status_code != 404: + self._classify_response(resp) def queued_resource_list(self, zones: list[str], labels: dict[str, str] | None = None) -> list[QueuedResourceInfo]: zone_list = zones if zones else list(self._valid_zones) results: list[QueuedResourceInfo] = [] for zone in zone_list: try: - items = self._api.queued_resource_list(zone) + items = self._paginate( + f"{_TPU_BASE}/{self._tpu_parent(zone)}/queuedResources", + "queuedResources", + ) except InfraError: logger.warning("Failed to list queued resources in %s", zone, exc_info=True) continue @@ -552,7 +723,14 @@ def vm_create(self, request: VmCreateRequest) -> VmInfo: logger.info("Creating VM: %s (zone=%s, type=%s)", request.name, request.zone, request.machine_type) try: - self._api.instance_insert_wait(request.zone, body) + # POST to insert, wait for zone operation + url = self._instance_url(request.zone) + resp = self._client.post(url, headers=self._headers(), json=body) + self._classify_response(resp) + data = resp.json() + op_name = data.get("name", "") + if op_name: + self._wait_zone_operation(request.zone, op_name) except InfraError as e: if "already exists" not in str(e).lower(): raise @@ -564,15 +742,26 @@ def vm_create(self, request: VmCreateRequest) -> VmInfo: def vm_delete(self, name: str, zone: str) -> None: logger.info("Deleting VM: %s", name) - self._api.instance_delete(name, zone) + url = self._instance_url(zone, name) + resp = self._client.delete(url, headers=self._headers()) + if resp.status_code != 404: + self._classify_response(resp) def vm_reset(self, name: str, zone: str) -> None: logger.info("Resetting VM: %s", name) - self._api.instance_reset(name, zone) + url = self._instance_url(zone, name) + "/reset" + resp = self._client.post(url, headers=self._headers()) + self._classify_response(resp) + + def _instance_get(self, name: str, zone: str) -> dict: + url = self._instance_url(zone, name) + resp = self._client.get(url, headers=self._headers()) + self._classify_response(resp) + return resp.json() def vm_describe(self, name: str, zone: str) -> VmInfo | None: try: - data = self._api.instance_get(name, zone) + data = self._instance_get(name, zone) except ResourceNotFoundError: return None except InfraError: @@ -585,18 +774,27 @@ def vm_list(self, zones: list[str], labels: dict[str, str] | None = None) -> lis filter_str = _build_label_filter(labels) if labels else "" if not zones: + # Project-wide: aggregatedList, flatten across zones + url = f"{_COMPUTE_BASE}/projects/{self._project_id}/aggregated/instances" + params: dict[str, str] = {} + if filter_str: + params["filter"] = filter_str try: - items = self._api.instance_list(zone=None, filter_str=filter_str) + for page in self._paginate_raw(url, params): + for scope in page.get("items", {}).values(): + for vm_data in scope.get("instances", []): + results.append(_parse_vm_info(vm_data)) except InfraError: logger.warning("Failed to list instances", exc_info=True) return [] - for vm_data in items: - results.append(_parse_vm_info(vm_data)) return results for zone in zones: + params = {} + if filter_str: + params["filter"] = filter_str try: - items = self._api.instance_list(zone=zone, filter_str=filter_str) + items = self._paginate(self._instance_url(zone), "items", params) except InfraError: logger.warning("Failed to list instances in zone %s", zone, exc_info=True) continue @@ -608,17 +806,21 @@ def vm_list(self, zones: list[str], labels: dict[str, str] | None = None) -> lis def vm_update_labels(self, name: str, zone: str, labels: dict[str, str]) -> None: validate_labels(labels) logger.info("Updating labels on VM %s", name) - # Read-modify-write: GET the current labelFingerprint, merge, then POST. - data = self._api.instance_get(name, zone) + data = self._instance_get(name, zone) current_labels = data.get("labels", {}) current_labels.update(labels) fingerprint = data.get("labelFingerprint", "") - self._api.instance_set_labels(name, zone, current_labels, fingerprint) + url = self._instance_url(zone, name) + "/setLabels" + resp = self._client.post( + url, + headers=self._headers(), + json={"labels": current_labels, "labelFingerprint": fingerprint}, + ) + self._classify_response(resp) def vm_set_metadata(self, name: str, zone: str, metadata: dict[str, str]) -> None: logger.info("Setting metadata on VM %s", name) - # Read-modify-write: GET current metadata with fingerprint, merge items, POST. - data = self._api.instance_get(name, zone) + data = self._instance_get(name, zone) raw_metadata = data.get("metadata", {}) fingerprint = raw_metadata.get("fingerprint", "") existing_items: dict[str, str] = {} @@ -629,11 +831,16 @@ def vm_set_metadata(self, name: str, zone: str, metadata: dict[str, str]) -> Non "fingerprint": fingerprint, "items": [{"key": k, "value": v} for k, v in existing_items.items()], } - self._api.instance_set_metadata(name, zone, body) + url = self._instance_url(zone, name) + "/setMetadata" + resp = self._client.post(url, headers=self._headers(), json=body) + self._classify_response(resp) def vm_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> str: try: - data = self._api.instance_get_serial_port_output(name, zone, start=start) + url = self._instance_url(zone, name) + "/serialPort" + resp = self._client.get(url, headers=self._headers(), params={"start": str(start)}) + self._classify_response(resp) + data = resp.json() except InfraError: logger.warning("Failed to get serial port output for %s", name, exc_info=True) return "" @@ -641,7 +848,16 @@ def vm_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> str def logging_read(self, filter_str: str, limit: int = 200) -> list[str]: try: - entries = self._api.logging_list_entries(filter_str, limit=limit) + url = f"{_LOGGING_BASE}/entries:list" + body = { + "resourceNames": [f"projects/{self._project_id}"], + "filter": filter_str, + "pageSize": min(limit, 1000), + "orderBy": "timestamp desc", + } + resp = self._client.post(url, headers=self._headers(), json=body, timeout=30) + self._classify_response(resp) + entries = resp.json().get("entries", []) except InfraError: logger.warning("Cloud Logging query failed", exc_info=True) return [] @@ -659,4 +875,4 @@ def get_local_slices(self, labels: dict[str, str] | None = None) -> list[LocalSl raise RuntimeError("get_local_slices is not supported in CLOUD mode") def shutdown(self) -> None: - self._api.close() + self._client.close() diff --git a/lib/iris/tests/cluster/providers/gcp/test_cloud_service_integration.py b/lib/iris/tests/cluster/providers/gcp/test_cloud_service_integration.py new file mode 100644 index 0000000000..db2771cf2e --- /dev/null +++ b/lib/iris/tests/cluster/providers/gcp/test_cloud_service_integration.py @@ -0,0 +1,505 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Integration test that exercises the full CloudGcpService TPU/VM lifecycle +with a mock HTTP backend, logging every request/response for debugging. + +This test simulates the controller's view: create VMs and TPUs, describe them, +list them with label filters, set metadata/labels, and read logs. Every HTTP +call is recorded so we can compare the REST API behavior against what the old +gcloud CLI produced. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field + +import httpx +import pytest + +from iris.cluster.providers.gcp.service import ( + CloudGcpService, + TpuCreateRequest, + VmCreateRequest, +) +from iris.rpc import config_pb2 + +logger = logging.getLogger(__name__) + +PROJECT = "test-project" +ZONE_EU = "europe-west4-b" +ZONE_US = "us-west4-a" + + +@dataclass +class HttpLog: + method: str + url: str + request_body: dict | None + status: int + response_body: dict + + +@dataclass +class GcpFakeBackend: + """Simulates GCP APIs at the HTTP level, tracking all instances and TPUs.""" + + project: str = PROJECT + vms: dict[tuple[str, str], dict] = field(default_factory=dict) # (name, zone) -> vm_data + tpus: dict[tuple[str, str], dict] = field(default_factory=dict) # (name, zone) -> tpu_data + operations: dict[str, dict] = field(default_factory=dict) # op_name -> op_data + http_log: list[HttpLog] = field(default_factory=list) + + def handle(self, request: httpx.Request) -> httpx.Response: + url = str(request.url) + method = request.method + req_body = None + if request.content: + try: + req_body = json.loads(request.content) + except (json.JSONDecodeError, UnicodeDecodeError): + pass + + resp = self._route(method, url, req_body) + resp_body = {} + try: + resp_body = resp.json() if resp.content else {} + except Exception: + pass + + self.http_log.append( + HttpLog( + method=method, + url=url, + request_body=req_body, + status=resp.status_code, + response_body=resp_body, + ) + ) + return resp + + def _route(self, method: str, url: str, body: dict | None) -> httpx.Response: + # --- TPU operations --- + if "tpu.googleapis.com" in url: + return self._handle_tpu(method, url, body) + + # --- Compute operations --- + if "compute.googleapis.com" in url: + return self._handle_compute(method, url, body) + + # --- Logging --- + if "logging.googleapis.com" in url: + return httpx.Response(200, json={"entries": []}) + + return httpx.Response(404, json={"error": {"code": 404, "message": f"Unknown URL: {url}"}}) + + def _handle_tpu(self, method: str, url: str, body: dict | None) -> httpx.Response: + # POST .../nodes?nodeId=NAME — create TPU + if method == "POST" and "/nodes" in url and "nodeId=" in url: + node_id = url.split("nodeId=")[1].split("&")[0] + zone = self._extract_tpu_zone(url) + op_name = f"projects/{self.project}/locations/{zone}/operations/op-tpu-{node_id}" + tpu_data = { + "name": f"projects/{self.project}/locations/{zone}/nodes/{node_id}", + "state": "CREATING", + "acceleratorType": (body or {}).get("acceleratorType", ""), + "runtimeVersion": (body or {}).get("runtimeVersion", ""), + "labels": (body or {}).get("labels", {}), + "metadata": (body or {}).get("metadata", {}), + "networkEndpoints": [], + "serviceAccount": (body or {}).get("serviceAccount"), + "createTime": "2026-01-01T00:00:00Z", + } + if (body or {}).get("networkConfig"): + tpu_data["networkConfig"] = body["networkConfig"] + self.tpus[(node_id, zone)] = tpu_data + # Mark operation as pending, will complete on poll + self.operations[op_name] = {"name": op_name, "done": False, "tpu_key": (node_id, zone)} + return httpx.Response(200, json={"name": op_name, "done": False}) + + # GET .../operations/op-* — poll TPU operation + if method == "GET" and "/operations/" in url: + op_name = url.split("tpu.googleapis.com/v2/")[1].split("?")[0] + op = self.operations.get(op_name) + if op is None: + return httpx.Response(404, json={"error": {"code": 404, "message": "Operation not found"}}) + # Complete the operation and make TPU READY with endpoints + tpu_key = op.get("tpu_key") + if tpu_key and tpu_key in self.tpus: + tpu = self.tpus[tpu_key] + tpu["state"] = "READY" + tpu["networkEndpoints"] = [{"ipAddress": f"10.128.0.{i}", "port": 8470} for i in range(4)] + return httpx.Response(200, json={"name": op_name, "done": True}) + + # GET .../nodes/NAME — describe TPU + if method == "GET" and "/nodes/" in url and "/operations/" not in url: + parts = url.split("/nodes/") + if len(parts) == 2: + node_name = parts[1].split("?")[0] + zone = self._extract_tpu_zone(url) + key = (node_name, zone) + if key in self.tpus: + return httpx.Response(200, json=self.tpus[key]) + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) + + # GET .../nodes — list TPUs + if (method == "GET" and url.endswith("/nodes")) or ("/nodes?" in url and "nodeId" not in url): + zone = self._extract_tpu_zone(url) + nodes = [t for (_, z), t in self.tpus.items() if z == zone] + return httpx.Response(200, json={"nodes": nodes}) + + # DELETE .../nodes/NAME + if method == "DELETE" and "/nodes/" in url: + parts = url.split("/nodes/") + node_name = parts[1].split("?")[0] + zone = self._extract_tpu_zone(url) + self.tpus.pop((node_name, zone), None) + return httpx.Response(200, json={}) + + return httpx.Response(404, json={"error": {"code": 404, "message": f"Unhandled TPU URL: {url}"}}) + + def _handle_compute(self, method: str, url: str, body: dict | None) -> httpx.Response: + # POST .../instances — create VM + if method == "POST" and url.endswith("/instances"): + zone = self._extract_compute_zone(url) + name = (body or {}).get("name", "unknown") + op_name = f"op-vm-{name}" + vm_data = { + "name": name, + "status": "RUNNING", + "zone": f"projects/{self.project}/zones/{zone}", + "networkInterfaces": [ + {"networkIP": "10.164.0.42", "accessConfigs": [{"natIP": "35.1.2.3", "type": "ONE_TO_ONE_NAT"}]} + ], + "labels": (body or {}).get("labels", {}), + "metadata": (body or {}).get("metadata", {}), + "serviceAccounts": (body or {}).get("serviceAccounts", []), + "labelFingerprint": "abc123", + "creationTimestamp": "2026-01-01T00:00:00Z", + } + self.vms[(name, zone)] = vm_data + self.operations[op_name] = {"name": op_name, "status": "DONE"} + return httpx.Response(200, json={"name": op_name, "status": "RUNNING"}) + + # GET .../operations/OP — poll compute operation + if method == "GET" and "/operations/" in url: + op_name = url.split("/operations/")[1].split("?")[0] + op = self.operations.get(op_name, {"name": op_name, "status": "DONE"}) + op["status"] = "DONE" # Always complete immediately + return httpx.Response(200, json=op) + + # GET .../serialPort + if method == "GET" and "/serialPort" in url: + return httpx.Response(200, json={"contents": "serial output", "next": 100}) + + # GET .../instances/NAME — describe VM (must come after serialPort/setLabels/etc) + if method == "GET" and "/instances/" in url and "/instances?" not in url: + zone = self._extract_compute_zone(url) + name = url.split("/instances/")[1].split("?")[0].split("/")[0] + key = (name, zone) + if key in self.vms: + return httpx.Response(200, json=self.vms[key]) + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) + + # GET .../instances (list, zone-scoped) + if method == "GET" and "/instances" in url and "/instances/" not in url: + if "/aggregated/" in url: + # aggregatedList + result: dict = {} + for (_name, zone), vm in self.vms.items(): + scope_key = f"zones/{zone}" + if scope_key not in result: + result[scope_key] = {"instances": []} + result[scope_key]["instances"].append(vm) + return httpx.Response(200, json={"items": result}) + else: + zone = self._extract_compute_zone(url) + filter_str = "" + if "filter=" in url: + filter_str = url.split("filter=")[1].split("&")[0] + vms = [v for (_, z), v in self.vms.items() if z == zone] + if filter_str: + vms = [v for v in vms if self._matches_label_filter(v.get("labels", {}), filter_str)] + return httpx.Response(200, json={"items": vms}) + + # DELETE .../instances/NAME + if method == "DELETE" and "/instances/" in url: + zone = self._extract_compute_zone(url) + name = url.split("/instances/")[1].split("?")[0] + self.vms.pop((name, zone), None) + return httpx.Response(200, json={}) + + # POST .../setLabels + if method == "POST" and "/setLabels" in url: + zone = self._extract_compute_zone(url) + name = url.split("/instances/")[1].split("/setLabels")[0] + if (name, zone) in self.vms: + self.vms[(name, zone)]["labels"] = (body or {}).get("labels", {}) + self.vms[(name, zone)]["labelFingerprint"] = "updated" + return httpx.Response(200, json={"name": f"op-labels-{name}", "status": "DONE"}) + + # POST .../setMetadata + if method == "POST" and "/setMetadata" in url: + zone = self._extract_compute_zone(url) + name = url.split("/instances/")[1].split("/setMetadata")[0] + if (name, zone) in self.vms: + self.vms[(name, zone)]["metadata"] = body + return httpx.Response(200, json={"name": f"op-meta-{name}", "status": "DONE"}) + + # POST .../reset + if method == "POST" and "/reset" in url: + return httpx.Response(200, json={"name": "op-reset", "status": "DONE"}) + + return httpx.Response(404, json={"error": {"code": 404, "message": f"Unhandled compute URL: {url}"}}) + + def _extract_tpu_zone(self, url: str) -> str: + if "/locations/" in url: + return url.split("/locations/")[1].split("/")[0] + return "unknown" + + def _extract_compute_zone(self, url: str) -> str: + if "/zones/" in url: + return url.split("/zones/")[1].split("/")[0] + return "unknown" + + def _matches_label_filter(self, labels: dict[str, str], filter_str: str) -> bool: + import urllib.parse + + decoded = urllib.parse.unquote(filter_str) + for part in decoded.split(" AND "): + part = part.strip() + if part.startswith("labels."): + kv = part[len("labels.") :] + if "=" in kv: + k, v = kv.split("=", 1) + if labels.get(k) != v: + return False + return True + + +@pytest.fixture +def backend() -> GcpFakeBackend: + return GcpFakeBackend() + + +@pytest.fixture +def svc(backend: GcpFakeBackend) -> CloudGcpService: + client = httpx.Client(transport=httpx.MockTransport(backend.handle), timeout=10) + s = CloudGcpService(PROJECT, http_client=client) + s._token = "fake-token" + s._expires_at = float("inf") + return s + + +def _dump_http_log(log: list[HttpLog]) -> str: + lines = [] + for i, entry in enumerate(log): + lines.append(f"[{i}] {entry.method} {entry.url}") + if entry.request_body: + lines.append(f" REQ: {json.dumps(entry.request_body, indent=2)[:500]}") + lines.append(f" RSP ({entry.status}): {json.dumps(entry.response_body, indent=2)[:500]}") + return "\n".join(lines) + + +# ======================================================================== +# Full lifecycle: VM create → describe → set labels → set metadata → list +# ======================================================================== + + +def test_vm_full_lifecycle(svc: CloudGcpService, backend: GcpFakeBackend): + vm = svc.vm_create( + VmCreateRequest( + name="ctrl-vm", + zone=ZONE_EU, + machine_type="e2-standard-4", + labels={"iris-managed": "true"}, + startup_script="#!/bin/bash\necho hello", + service_account="sa@test.iam.gserviceaccount.com", + ) + ) + + assert vm.name == "ctrl-vm", f"Expected ctrl-vm, got {vm.name}\n{_dump_http_log(backend.http_log)}" + assert vm.internal_ip == "10.164.0.42", f"Expected IP, got {vm.internal_ip!r}\n{_dump_http_log(backend.http_log)}" + assert vm.status == "RUNNING" + + # Set labels (read-modify-write) + svc.vm_update_labels("ctrl-vm", ZONE_EU, {"controller": "true"}) + assert backend.vms[("ctrl-vm", ZONE_EU)]["labels"]["controller"] == "true" + + # Set metadata (read-modify-write) + svc.vm_set_metadata("ctrl-vm", ZONE_EU, {"controller-address": "http://10.164.0.42:10000"}) + + # List VMs with label filter + vms = svc.vm_list(zones=[ZONE_EU], labels={"iris-managed": "true"}) + assert len(vms) >= 1, f"Expected >=1 VM, got {len(vms)}\n{_dump_http_log(backend.http_log)}" + + # List VMs project-wide + all_vms = svc.vm_list(zones=[]) + assert len(all_vms) >= 1 + + # Describe + described = svc.vm_describe("ctrl-vm", ZONE_EU) + assert described is not None + assert described.internal_ip == "10.164.0.42" + + +# ======================================================================== +# Full lifecycle: TPU create → describe → list → bootstrap health polling +# ======================================================================== + + +def test_tpu_full_lifecycle(svc: CloudGcpService, backend: GcpFakeBackend): + tpu = svc.tpu_create( + TpuCreateRequest( + name="test-slice", + zone=ZONE_EU, + accelerator_type="v5litepod-16", + runtime_version="v2-alpha-tpuv5-lite", + capacity_type=config_pb2.CAPACITY_TYPE_PREEMPTIBLE, + labels={"iris-managed": "true", "env": "test"}, + metadata={"startup-script": "#!/bin/bash\necho bootstrap"}, + service_account="worker@test.iam.gserviceaccount.com", + ) + ) + + log_str = _dump_http_log(backend.http_log) + + assert tpu.name == "test-slice", f"Expected test-slice, got {tpu.name}\n{log_str}" + assert tpu.state == "READY", f"Expected READY, got {tpu.state}\n{log_str}" + assert len(tpu.network_endpoints) == 4, f"Expected 4 endpoints, got {tpu.network_endpoints}\n{log_str}" + + # Verify the create request body was correct + create_req = next(e for e in backend.http_log if e.method == "POST" and "nodeId=test-slice" in e.url) + assert create_req.request_body is not None + assert create_req.request_body["acceleratorType"] == "v5litepod-16" + assert create_req.request_body["runtimeVersion"] == "v2-alpha-tpuv5-lite" + assert create_req.request_body["labels"] == {"iris-managed": "true", "env": "test"} + assert create_req.request_body["metadata"] == {"startup-script": "#!/bin/bash\necho bootstrap"} + assert create_req.request_body["schedulingConfig"] == {"preemptible": True} + assert create_req.request_body["serviceAccount"] == {"email": "worker@test.iam.gserviceaccount.com"} + + # Describe should return the same TPU + described = svc.tpu_describe("test-slice", ZONE_EU) + assert described is not None + assert described.state == "READY" + + # List with label filter + results = svc.tpu_list(zones=[ZONE_EU], labels={"iris-managed": "true"}) + assert len(results) == 1, f"Expected 1 TPU, got {len(results)}\n{_dump_http_log(backend.http_log)}" + assert results[0].name == "test-slice" + + # List with non-matching label + no_match = svc.tpu_list(zones=[ZONE_EU], labels={"iris-managed": "false"}) + assert len(no_match) == 0 + + +def test_tpu_create_across_zones(svc: CloudGcpService, backend: GcpFakeBackend): + """Controller creates TPU slices in multiple zones and lists all of them.""" + svc.tpu_create( + TpuCreateRequest( + name="slice-eu", + zone=ZONE_EU, + accelerator_type="v5litepod-16", + runtime_version="v2-alpha-tpuv5-lite", + capacity_type=config_pb2.CAPACITY_TYPE_PREEMPTIBLE, + labels={"managed": "true"}, + ) + ) + svc.tpu_create( + TpuCreateRequest( + name="slice-us", + zone=ZONE_US, + accelerator_type="v5litepod-16", + runtime_version="v2-alpha-tpuv5-lite", + capacity_type=config_pb2.CAPACITY_TYPE_PREEMPTIBLE, + labels={"managed": "true"}, + ) + ) + + # List both zones + all_tpus = svc.tpu_list(zones=[ZONE_EU, ZONE_US], labels={"managed": "true"}) + log_str = _dump_http_log(backend.http_log) + assert len(all_tpus) == 2, f"Expected 2 TPUs, got {len(all_tpus)}\n{log_str}" + + names = {t.name for t in all_tpus} + assert names == {"slice-eu", "slice-us"}, f"Got names {names}\n{log_str}" + + # Verify zones were correctly extracted + for tpu in all_tpus: + if tpu.name == "slice-eu": + assert tpu.zone == ZONE_EU + else: + assert tpu.zone == ZONE_US + + +def test_tpu_metadata_and_network_config(svc: CloudGcpService, backend: GcpFakeBackend): + """Verify metadata (startup-script) and network config are passed correctly.""" + large_script = "#!/bin/bash\n" + "echo line\n" * 200 # >256 chars + + svc.tpu_create( + TpuCreateRequest( + name="net-tpu", + zone=ZONE_EU, + accelerator_type="v5litepod-16", + runtime_version="v2-alpha-tpuv5-lite", + capacity_type=config_pb2.CAPACITY_TYPE_ON_DEMAND, + labels={}, + metadata={"startup-script": large_script, "other-key": "other-value"}, + network="projects/test/global/networks/default", + subnetwork="projects/test/regions/europe-west4/subnetworks/default", + ) + ) + + create_req = next(e for e in backend.http_log if e.method == "POST" and "nodeId=net-tpu" in e.url) + body = create_req.request_body + + # Metadata must be a flat dict (not items array like Compute) + assert body["metadata"]["startup-script"] == large_script + assert body["metadata"]["other-key"] == "other-value" + + # Network config + assert body["networkConfig"]["network"] == "projects/test/global/networks/default" + assert body["networkConfig"]["subnetwork"] == "projects/test/regions/europe-west4/subnetworks/default" + + # No schedulingConfig for on-demand + assert "schedulingConfig" not in body + + +def test_vm_create_startup_script_in_metadata(svc: CloudGcpService, backend: GcpFakeBackend): + """Verify VM startup script is passed as metadata items (not flat dict like TPU).""" + svc.vm_create( + VmCreateRequest( + name="boot-vm", + zone=ZONE_EU, + machine_type="e2-standard-4", + startup_script="#!/bin/bash\necho hello", + ) + ) + + create_req = next(e for e in backend.http_log if e.method == "POST" and e.url.endswith("/instances")) + body = create_req.request_body + + # VM metadata uses {"items": [...]} format + metadata = body.get("metadata", {}) + assert "items" in metadata, f"VM metadata should use items format, got: {metadata}" + items = {item["key"]: item["value"] for item in metadata["items"]} + assert items["startup-script"] == "#!/bin/bash\necho hello" + + +def test_logging_read(svc: CloudGcpService, backend: GcpFakeBackend): + entries = svc.logging_read('resource.type="gce_instance"', limit=100) + assert entries == [] # Backend returns empty + + # Verify the request was correct + log_req = next(e for e in backend.http_log if "logging.googleapis.com" in e.url) + assert log_req.method == "POST" + + +def test_serial_port_output(svc: CloudGcpService, backend: GcpFakeBackend): + # Create a VM first + svc.vm_create(VmCreateRequest(name="serial-vm", zone=ZONE_EU, machine_type="e2-standard-4")) + + output = svc.vm_get_serial_port_output("serial-vm", ZONE_EU, start=0) + assert output == "serial output" diff --git a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py index bed8d496a1..383d3df8bd 100644 --- a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py +++ b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py @@ -1,10 +1,10 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -"""Tests for GCPApi — HTTP client for GCP REST APIs. +"""Tests for CloudGcpService REST API integration. Uses httpx.MockTransport to verify URL construction, error mapping, -pagination, and auth header injection without hitting real GCP. +pagination, operation waiting, and auth header injection without hitting real GCP. """ from __future__ import annotations @@ -16,17 +16,16 @@ import httpx import pytest -from iris.cluster.providers.gcp.api import ( - COMPUTE_BASE, - LOGGING_BASE, - TPU_BASE, - GCPApi, +from iris.cluster.providers.gcp.service import ( + CloudGcpService, + TpuCreateRequest, + VmCreateRequest, ) from iris.cluster.providers.types import ( InfraError, QuotaExhaustedError, - ResourceNotFoundError, ) +from iris.rpc import config_pb2 PROJECT = "test-project" ZONE = "us-central1-a" @@ -43,17 +42,16 @@ def _mock_credentials(): "refresh": lambda self, req: None, }, )() - return patch("iris.cluster.providers.gcp.api.google.auth.default", return_value=(cred, PROJECT)) + return patch("iris.cluster.providers.gcp.service.google.auth.default", return_value=(cred, PROJECT)) -def _make_api(handler: Callable[[httpx.Request], httpx.Response]) -> GCPApi: - """Create a GCPApi with a mock HTTP transport and fake credentials.""" - api = GCPApi(PROJECT) - api._client = httpx.Client(transport=httpx.MockTransport(handler), timeout=10) - # Inject fake token so _refresh_token isn't called - api._token = "fake-token" - api._expires_at = float("inf") - return api +def _make_svc(handler: Callable[[httpx.Request], httpx.Response]) -> CloudGcpService: + """Create a CloudGcpService with a mock HTTP transport and fake credentials.""" + client = httpx.Client(transport=httpx.MockTransport(handler), timeout=10) + svc = CloudGcpService(PROJECT, http_client=client) + svc._token = "fake-token" + svc._expires_at = float("inf") + return svc def _json_response(body: dict, status: int = 200) -> httpx.Response: @@ -69,10 +67,9 @@ def test_404_raises_resource_not_found(): def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) - api = _make_api(handler) - with pytest.raises(ResourceNotFoundError, match="Not found"): - api.tpu_get("no-such-tpu", ZONE) - api.close() + svc = _make_svc(handler) + assert svc.tpu_describe("no-such-tpu", ZONE) is None + svc.shutdown() def test_429_raises_quota_exhausted(): @@ -81,30 +78,20 @@ def handler(request: httpx.Request) -> httpx.Response: 429, json={"error": {"code": 429, "message": "Quota exceeded", "status": "RESOURCE_EXHAUSTED"}} ) - api = _make_api(handler) + svc = _make_svc(handler) with pytest.raises(QuotaExhaustedError, match="Quota exceeded"): - api.tpu_get("some-tpu", ZONE) - api.close() + svc.vm_reset("some-vm", ZONE) + svc.shutdown() def test_500_raises_infra_error(): def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(500, json={"error": {"code": 500, "message": "Internal error"}}) - api = _make_api(handler) + svc = _make_svc(handler) with pytest.raises(InfraError, match="Internal error"): - api.tpu_get("some-tpu", ZONE) - api.close() - - -def test_non_json_error_body(): - def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(502, text="Bad Gateway") - - api = _make_api(handler) - with pytest.raises(InfraError, match="Bad Gateway"): - api.tpu_get("some-tpu", ZONE) - api.close() + svc.vm_reset("some-vm", ZONE) + svc.shutdown() # ======================================================================== @@ -119,230 +106,194 @@ def handler(request: httpx.Request) -> httpx.Response: requests_seen.append(request) return _json_response({"name": "test", "state": "READY"}) - api = _make_api(handler) - api.tpu_get("my-tpu", ZONE) - api.close() + svc = _make_svc(handler) + svc.tpu_describe("my-tpu", ZONE) + svc.shutdown() assert len(requests_seen) == 1 assert requests_seen[0].headers["authorization"] == "Bearer fake-token" def test_token_refresh_on_expiry(): - """When token is expired, _refresh_token is called.""" with _mock_credentials(): def handler(request: httpx.Request) -> httpx.Response: return _json_response({"name": "test", "state": "READY"}) - api = _make_api(handler) - api._token = None # Force refresh - api._expires_at = 0.0 - api.tpu_get("my-tpu", ZONE) - assert api._token == "fake-token" - api.close() + svc = _make_svc(handler) + svc._token = None + svc._expires_at = 0.0 + svc.tpu_describe("my-tpu", ZONE) + assert svc._token == "fake-token" + svc.shutdown() # ======================================================================== -# TPU operations — URL construction +# VM create waits for operation before describing # ======================================================================== -def test_tpu_create_url_and_params(): - requests_seen: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"name": "operations/op-123"}) - - api = _make_api(handler) - api.tpu_create("my-tpu", ZONE, {"acceleratorType": "v4-8"}) - api.close() - - req = requests_seen[0] - assert req.method == "POST" - assert f"{TPU_BASE}/projects/{PROJECT}/locations/{ZONE}/nodes" in str(req.url) - assert "nodeId=my-tpu" in str(req.url) - - -def test_tpu_get_url(): - def handler(request: httpx.Request) -> httpx.Response: - return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/nodes/my-tpu", "state": "READY"}) - - api = _make_api(handler) - result = api.tpu_get("my-tpu", ZONE) - api.close() - - assert result["state"] == "READY" - +def test_vm_create_waits_for_operation(): + """vm_create must wait for the insert operation before describing the VM.""" + operation_done = False -def test_tpu_delete_ignores_404(): def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - api = _make_api(handler) - api.tpu_delete("gone-tpu", ZONE) # should not raise - api.close() + nonlocal operation_done + url = str(request.url) + if request.method == "POST" and url.endswith(f"/zones/{ZONE}/instances"): + return _json_response( + {"name": "op-vm-1", "status": "RUNNING", "zone": f"zones/{ZONE}", "kind": "compute#operation"} + ) -def test_tpu_list_with_pagination(): - call_count = 0 + if "/operations/op-vm-1" in url and request.method == "GET": + operation_done = True + return _json_response({"name": "op-vm-1", "status": "DONE"}) - def handler(request: httpx.Request) -> httpx.Response: - nonlocal call_count - call_count += 1 - if call_count == 1: + if request.method == "GET" and url.endswith(f"/zones/{ZONE}/instances/test-vm"): + if not operation_done: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) return _json_response( { - "nodes": [{"name": "tpu-1", "state": "READY"}], - "nextPageToken": "page2", + "name": "test-vm", + "status": "RUNNING", + "zone": f"projects/{PROJECT}/zones/{ZONE}", + "networkInterfaces": [{"networkIP": "10.0.0.1", "accessConfigs": [{"natIP": "34.1.2.3"}]}], + "metadata": {}, + "serviceAccounts": [{"email": "sa@test.iam.gserviceaccount.com"}], + "creationTimestamp": "2026-01-01T00:00:00Z", } ) - return _json_response( - { - "nodes": [{"name": "tpu-2", "state": "READY"}], - } - ) - - api = _make_api(handler) - results = api.tpu_list(ZONE) - api.close() - assert len(results) == 2 - assert results[0]["name"] == "tpu-1" - assert results[1]["name"] == "tpu-2" - assert call_count == 2 + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) + svc = _make_svc(handler) + info = svc.vm_create(VmCreateRequest(name="test-vm", zone=ZONE, machine_type="n1-standard-4", labels={})) + svc.shutdown() -# ======================================================================== -# Queued resource operations -# ======================================================================== + assert info.name == "test-vm" + assert info.internal_ip == "10.0.0.1" + assert operation_done, "vm_create must poll the operation before describing" -def test_queued_resource_create_url(): - requests_seen: list[httpx.Request] = [] +def test_tpu_create_waits_for_operation(): + """tpu_create must wait for the LRO before describing the TPU.""" + operation_done = False def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"name": "operations/op-456"}) + nonlocal operation_done + url = str(request.url) - api = _make_api(handler) - api.queued_resource_create("my-qr", ZONE, {"tpu": {"nodeSpec": []}}) - api.close() + if request.method == "POST" and "/nodes" in url and "nodeId=test-tpu" in url: + return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/operations/op-tpu-1", "done": False}) - req = requests_seen[0] - assert req.method == "POST" - assert "/queuedResources" in str(req.url) - assert "queuedResourceId=my-qr" in str(req.url) + if "/operations/op-tpu-1" in url and request.method == "GET": + operation_done = True + return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/operations/op-tpu-1", "done": True}) + if request.method == "GET" and url.endswith("/nodes/test-tpu"): + if not operation_done: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) + return _json_response( + { + "name": f"projects/{PROJECT}/locations/{ZONE}/nodes/test-tpu", + "state": "READY", + "acceleratorType": "v4-8", + "networkEndpoints": [{"ipAddress": "10.0.0.2"}], + "labels": {}, + "metadata": {}, + "createTime": "2026-01-01T00:00:00Z", + } + ) -def test_queued_resource_delete_ignores_404(): - def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - api = _make_api(handler) - api.queued_resource_delete("gone-qr", ZONE) # should not raise - api.close() - - -def test_queued_resource_delete_passes_force(): - requests_seen: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"name": "operations/op-789"}) - - api = _make_api(handler) - api.queued_resource_delete("my-qr", ZONE) - api.close() + svc = _make_svc(handler) + info = svc.tpu_create( + TpuCreateRequest( + name="test-tpu", + zone=ZONE, + accelerator_type="v4-8", + runtime_version="tpu-ubuntu2204-base", + capacity_type=config_pb2.CAPACITY_TYPE_ON_DEMAND, + labels={}, + ) + ) + svc.shutdown() - assert "force=true" in str(requests_seen[0].url) + assert info.name == "test-tpu" + assert info.state == "READY" + assert operation_done, "tpu_create must poll the operation before describing" # ======================================================================== -# Compute operations +# TPU list — only queries requested zones # ======================================================================== -def test_instance_insert_url(): - requests_seen: list[httpx.Request] = [] +def test_tpu_list_queries_only_requested_zones(): + zones_queried: list[str] = [] def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"name": "operations/op-vm"}) + url = str(request.url) + if "/locations/" in url and "/nodes" in url: + zone = url.split("/locations/")[1].split("/nodes")[0] + zones_queried.append(zone) + return _json_response({"nodes": []}) - api = _make_api(handler) - api.instance_insert(ZONE, {"name": "my-vm", "machineType": "n1-standard-4"}) - api.close() + svc = _make_svc(handler) + svc.tpu_list(zones=["europe-west4-b", "us-west4-a"]) + svc.shutdown() - req = requests_seen[0] - assert req.method == "POST" - assert f"{COMPUTE_BASE}/projects/{PROJECT}/zones/{ZONE}/instances" in str(req.url) + assert sorted(zones_queried) == ["europe-west4-b", "us-west4-a"] -def test_instance_delete_ignores_404(): +def test_tpu_list_label_filtering(): def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - api = _make_api(handler) - api.instance_delete("gone-vm", ZONE) # should not raise - api.close() - + return _json_response( + { + "nodes": [ + { + "name": f"projects/{PROJECT}/locations/{ZONE}/nodes/tpu-match", + "state": "READY", + "acceleratorType": "v4-8", + "labels": {"env": "test", "managed": "true"}, + }, + { + "name": f"projects/{PROJECT}/locations/{ZONE}/nodes/tpu-nomatch", + "state": "READY", + "acceleratorType": "v4-8", + "labels": {"env": "prod"}, + }, + ] + } + ) -def test_instance_reset_url(): - requests_seen: list[httpx.Request] = [] + svc = _make_svc(handler) + results = svc.tpu_list(zones=[ZONE], labels={"env": "test"}) + svc.shutdown() - def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"name": "operations/reset"}) + assert len(results) == 1 + assert results[0].name == "tpu-match" - api = _make_api(handler) - api.instance_reset("my-vm", ZONE) - api.close() - assert "/my-vm/reset" in str(requests_seen[0].url) +# ======================================================================== +# VM list +# ======================================================================== -def test_instance_set_labels_url_and_body(): +def test_vm_list_project_wide_uses_aggregated_list(): requests_seen: list[httpx.Request] = [] def handler(request: httpx.Request) -> httpx.Response: requests_seen.append(request) - return _json_response({"name": "operations/labels"}) - - api = _make_api(handler) - api.instance_set_labels("my-vm", ZONE, {"env": "test"}, "abc123") - api.close() - - req = requests_seen[0] - assert "/my-vm/setLabels" in str(req.url) - body = json.loads(req.content) - assert body["labels"] == {"env": "test"} - assert body["labelFingerprint"] == "abc123" - - -def test_instance_get_serial_port_output(): - def handler(request: httpx.Request) -> httpx.Response: - return _json_response({"contents": "serial output here", "next": 42}) - - api = _make_api(handler) - result = api.instance_get_serial_port_output("my-vm", ZONE, start=10) - api.close() - - assert result["contents"] == "serial output here" - - -def test_instance_list_project_wide(): - """Project-wide list uses aggregatedList and flattens across zones.""" - - def handler(request: httpx.Request) -> httpx.Response: return _json_response( { "items": { "zones/us-central1-a": { - "instances": [{"name": "vm-1", "status": "RUNNING"}], - }, - "zones/us-east1-b": { - "instances": [{"name": "vm-2", "status": "RUNNING"}], + "instances": [ + {"name": "vm-1", "status": "RUNNING", "zone": f"projects/{PROJECT}/zones/us-central1-a"} + ], }, "zones/us-west1-a": { "warning": {"code": "NO_RESULTS_ON_PAGE"}, @@ -351,203 +302,129 @@ def handler(request: httpx.Request) -> httpx.Response: } ) - api = _make_api(handler) - results = api.instance_list(zone=None) - api.close() - - names = {r["name"] for r in results} - assert names == {"vm-1", "vm-2"} - - -def test_instance_list_with_zone(): - def handler(request: httpx.Request) -> httpx.Response: - return _json_response( - { - "items": [{"name": "vm-1", "status": "RUNNING"}], - } - ) - - api = _make_api(handler) - results = api.instance_list(zone=ZONE) - api.close() + svc = _make_svc(handler) + results = svc.vm_list(zones=[]) + svc.shutdown() assert len(results) == 1 - assert results[0]["name"] == "vm-1" + assert results[0].name == "vm-1" + assert "/aggregated/instances" in str(requests_seen[0].url) -def test_instance_list_with_filter(): +def test_vm_list_with_labels_passes_filter(): requests_seen: list[httpx.Request] = [] def handler(request: httpx.Request) -> httpx.Response: requests_seen.append(request) return _json_response({"items": []}) - api = _make_api(handler) - api.instance_list(zone=ZONE, filter_str="labels.env=test") - api.close() + svc = _make_svc(handler) + svc.vm_list(zones=[ZONE], labels={"env": "test"}) + svc.shutdown() assert "filter=labels.env%3Dtest" in str(requests_seen[0].url) # ======================================================================== -# Cloud Logging +# Pagination # ======================================================================== -def test_logging_list_entries(): - requests_seen: list[httpx.Request] = [] +def test_tpu_list_with_pagination(): + call_count = 0 def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response( - { - "entries": [ - {"textPayload": "line 1"}, - {"textPayload": "line 2"}, - ] - } - ) + nonlocal call_count + call_count += 1 + if call_count == 1: + return _json_response( + { + "nodes": [{"name": "tpu-1", "state": "READY"}], + "nextPageToken": "page2", + } + ) + return _json_response({"nodes": [{"name": "tpu-2", "state": "READY"}]}) - api = _make_api(handler) - entries = api.logging_list_entries("some filter", limit=50) - api.close() + svc = _make_svc(handler) + results = svc.tpu_list(zones=[ZONE]) + svc.shutdown() - assert len(entries) == 2 - req = requests_seen[0] - assert req.method == "POST" - assert f"{LOGGING_BASE}/entries:list" in str(req.url) - body = json.loads(req.content) - assert body["filter"] == "some filter" - assert body["pageSize"] == 50 + assert len(results) == 2 + assert call_count == 2 # ======================================================================== -# Operation waiting — vm_create / tpu_create must wait for async operations +# Cloud Logging # ======================================================================== -def test_vm_create_waits_for_operation(): - """vm_create must wait for the insert operation before describing the VM. - - This is the core regression from replacing `gcloud compute instances create` - (which blocks until RUNNING) with the REST API (which returns immediately). - The mock simulates real GCP behavior: instance_get returns 404 until the - zone operation reaches DONE. - """ - from iris.cluster.providers.gcp.service import CloudGcpService, VmCreateRequest - - operation_done = False +def test_logging_read(): + requests_seen: list[httpx.Request] = [] def handler(request: httpx.Request) -> httpx.Response: - nonlocal operation_done - url = str(request.url) + requests_seen.append(request) + return _json_response({"entries": [{"textPayload": "line 1"}, {"textPayload": "line 2"}]}) - # POST instances — create (returns operation) - if request.method == "POST" and url.endswith(f"/zones/{ZONE}/instances"): - return _json_response( - {"name": "op-vm-1", "status": "RUNNING", "zone": f"zones/{ZONE}", "kind": "compute#operation"} - ) + svc = _make_svc(handler) + entries = svc.logging_read("some filter", limit=50) + svc.shutdown() - # GET operation poll — marks operation as done - if "/operations/op-vm-1" in url and request.method == "GET": - operation_done = True - return _json_response({"name": "op-vm-1", "status": "DONE"}) + assert entries == ["line 1", "line 2"] + body = json.loads(requests_seen[0].content) + assert body["filter"] == "some filter" + assert body["pageSize"] == 50 - # GET instance — only succeeds after operation completed (real GCP behavior) - if request.method == "GET" and url.endswith(f"/zones/{ZONE}/instances/test-vm"): - if not operation_done: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) - return _json_response( - { - "name": "test-vm", - "status": "RUNNING", - "zone": f"projects/{PROJECT}/zones/{ZONE}", - "networkInterfaces": [{"networkIP": "10.0.0.1", "accessConfigs": [{"natIP": "34.1.2.3"}]}], - "metadata": {}, - "serviceAccounts": [{"email": "sa@test.iam.gserviceaccount.com"}], - "creationTimestamp": "2026-01-01T00:00:00Z", - } - ) - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) +def test_logging_read_empty(): + def handler(request: httpx.Request) -> httpx.Response: + return _json_response({}) - api = _make_api(handler) - svc = CloudGcpService(PROJECT, api=api) - info = svc.vm_create(VmCreateRequest(name="test-vm", zone=ZONE, machine_type="n1-standard-4", labels={})) - api.close() + svc = _make_svc(handler) + assert svc.logging_read("no match") == [] + svc.shutdown() - assert info.name == "test-vm" - assert info.internal_ip == "10.0.0.1" - assert operation_done, "vm_create must poll the operation before describing" +# ======================================================================== +# Delete operations ignore 404 +# ======================================================================== -def test_tpu_create_waits_for_operation(): - """tpu_create must wait for the LRO before describing the TPU. - Same race condition as vm_create: the REST API create returns an LRO, - and the TPU node may not be visible via tpu_get until the operation completes. - """ - from iris.cluster.providers.gcp.service import CloudGcpService, TpuCreateRequest - from iris.rpc import config_pb2 +def test_tpu_delete_ignores_404(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - operation_done = False + svc = _make_svc(handler) + svc.tpu_delete("gone-tpu", ZONE) + svc.shutdown() - def handler(request: httpx.Request) -> httpx.Response: - nonlocal operation_done - url = str(request.url) - # POST nodes — create (returns LRO) - if request.method == "POST" and "/nodes" in url and "nodeId=test-tpu" in url: - return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/operations/op-tpu-1", "done": False}) +def test_vm_delete_ignores_404(): + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - # GET operation poll (TPU LRO) — marks operation as done - if "/operations/op-tpu-1" in url and request.method == "GET": - operation_done = True - return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/operations/op-tpu-1", "done": True}) + svc = _make_svc(handler) + svc.vm_delete("gone-vm", ZONE) + svc.shutdown() - # GET node — only succeeds after operation completed - if request.method == "GET" and url.endswith("/nodes/test-tpu"): - if not operation_done: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) - return _json_response( - { - "name": f"projects/{PROJECT}/locations/{ZONE}/nodes/test-tpu", - "state": "READY", - "acceleratorType": "v4-8", - "networkEndpoints": [{"ipAddress": "10.0.0.2"}], - "labels": {}, - "metadata": {}, - "createTime": "2026-01-01T00:00:00Z", - } - ) +def test_queued_resource_delete_ignores_404(): + def handler(request: httpx.Request) -> httpx.Response: return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - api = _make_api(handler) - svc = CloudGcpService(PROJECT, api=api) - info = svc.tpu_create( - TpuCreateRequest( - name="test-tpu", - zone=ZONE, - accelerator_type="v4-8", - runtime_version="tpu-ubuntu2204-base", - capacity_type=config_pb2.CAPACITY_TYPE_ON_DEMAND, - labels={}, - ) - ) - api.close() + svc = _make_svc(handler) + svc.queued_resource_delete("gone-qr", ZONE) + svc.shutdown() - assert info.name == "test-tpu" - assert info.state == "READY" - assert operation_done, "tpu_create must poll the operation before describing" +def test_queued_resource_delete_passes_force(): + requests_seen: list[httpx.Request] = [] -def test_logging_list_entries_empty(): def handler(request: httpx.Request) -> httpx.Response: - return _json_response({}) + requests_seen.append(request) + return _json_response({"name": "operations/op-789"}) - api = _make_api(handler) - entries = api.logging_list_entries("no match") - api.close() + svc = _make_svc(handler) + svc.queued_resource_delete("my-qr", ZONE) + svc.shutdown() - assert entries == [] + assert "force=true" in str(requests_seen[0].url) From 9fe55f23cd3535c2684ca361db811cc4d4183fdd Mon Sep 17 00:00:00 2001 From: Russell Power Date: Fri, 3 Apr 2026 13:38:09 -0700 Subject: [PATCH 5/8] [iris] Use locations/- wildcard for project-wide TPU listing Matches gcloud's --zone=- behavior: a single API call to list TPUs across all zones instead of iterating each zone individually. --- lib/iris/src/iris/cluster/providers/gcp/service.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/iris/src/iris/cluster/providers/gcp/service.py b/lib/iris/src/iris/cluster/providers/gcp/service.py index 4f42acc26a..aba7208193 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/service.py +++ b/lib/iris/src/iris/cluster/providers/gcp/service.py @@ -574,9 +574,8 @@ def tpu_describe(self, name: str, zone: str) -> TpuInfo | None: def tpu_list(self, zones: list[str], labels: dict[str, str] | None = None) -> list[TpuInfo]: results: list[TpuInfo] = [] - # TPU v2 API requires a specific zone per request. - # When no zones specified, only scan the caller's zones (not all known zones). - zone_list = zones if zones else list(self._valid_zones) + # Use locations/- for project-wide listing (matches gcloud --zone=-). + zone_list = zones if zones else ["-"] for zone in zone_list: try: @@ -587,7 +586,6 @@ def tpu_list(self, zones: list[str], labels: dict[str, str] | None = None) -> li for tpu_data in items: if labels and not _labels_match(tpu_data.get("labels", {}), labels): continue - # Extract zone from resource name if present tpu_zone = zone raw_name = tpu_data.get("name", "") if "/" in raw_name: @@ -663,7 +661,7 @@ def queued_resource_delete(self, name: str, zone: str) -> None: self._classify_response(resp) def queued_resource_list(self, zones: list[str], labels: dict[str, str] | None = None) -> list[QueuedResourceInfo]: - zone_list = zones if zones else list(self._valid_zones) + zone_list = zones if zones else ["-"] results: list[QueuedResourceInfo] = [] for zone in zone_list: try: From 01bdfd0c3663cd90cd17752c53a64fedbd2c00af Mon Sep 17 00:00:00 2001 From: Russell Power Date: Fri, 3 Apr 2026 14:26:26 -0700 Subject: [PATCH 6/8] [iris] Enable external IPs on TPU VMs created via REST API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The gcloud CLI defaults to enableExternalIps=true when creating TPU VMs. The REST API does not — TPUs were created without external IPs, so the startup-script couldn't pull Docker images and workers never bootstrapped. --- .../src/iris/cluster/providers/gcp/service.py | 26 +++++++++---------- .../gcp/test_cloud_service_integration.py | 1 + 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/lib/iris/src/iris/cluster/providers/gcp/service.py b/lib/iris/src/iris/cluster/providers/gcp/service.py index aba7208193..479a169726 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/service.py +++ b/lib/iris/src/iris/cluster/providers/gcp/service.py @@ -527,13 +527,12 @@ def tpu_create(self, request: TpuCreateRequest) -> TpuInfo: body["schedulingConfig"] = {"preemptible": True} if request.service_account: body["serviceAccount"] = {"email": request.service_account} - if request.network or request.subnetwork: - network_config: dict = {} - if request.network: - network_config["network"] = request.network - if request.subnetwork: - network_config["subnetwork"] = request.subnetwork - body["networkConfig"] = network_config + network_config: dict = {"enableExternalIps": True} + if request.network: + network_config["network"] = request.network + if request.subnetwork: + network_config["subnetwork"] = request.subnetwork + body["networkConfig"] = network_config logger.info("Creating TPU: %s (type=%s, zone=%s)", request.name, request.accelerator_type, request.zone) @@ -614,13 +613,12 @@ def queued_resource_create(self, request: TpuCreateRequest) -> None: } if request.service_account: node_spec["node"]["serviceAccount"] = {"email": request.service_account} - if request.network or request.subnetwork: - network_config: dict = {} - if request.network: - network_config["network"] = request.network - if request.subnetwork: - network_config["subnetwork"] = request.subnetwork - node_spec["node"]["networkConfig"] = network_config + qr_network_config: dict = {"enableExternalIps": True} + if request.network: + qr_network_config["network"] = request.network + if request.subnetwork: + qr_network_config["subnetwork"] = request.subnetwork + node_spec["node"]["networkConfig"] = qr_network_config body = { "tpu": {"nodeSpec": [node_spec]}, diff --git a/lib/iris/tests/cluster/providers/gcp/test_cloud_service_integration.py b/lib/iris/tests/cluster/providers/gcp/test_cloud_service_integration.py index db2771cf2e..af1bfffdc7 100644 --- a/lib/iris/tests/cluster/providers/gcp/test_cloud_service_integration.py +++ b/lib/iris/tests/cluster/providers/gcp/test_cloud_service_integration.py @@ -379,6 +379,7 @@ def test_tpu_full_lifecycle(svc: CloudGcpService, backend: GcpFakeBackend): assert create_req.request_body["metadata"] == {"startup-script": "#!/bin/bash\necho bootstrap"} assert create_req.request_body["schedulingConfig"] == {"preemptible": True} assert create_req.request_body["serviceAccount"] == {"email": "worker@test.iam.gserviceaccount.com"} + assert create_req.request_body["networkConfig"]["enableExternalIps"] is True # Describe should return the same TPU described = svc.tpu_describe("test-slice", ZONE_EU) From fd0f2704acbb82dc202480451881f3dab03e412d Mon Sep 17 00:00:00 2001 From: Russell Power Date: Fri, 3 Apr 2026 16:22:10 -0700 Subject: [PATCH 7/8] Delete lib/iris/tests/cluster/providers/gcp/test_gcp_api.py --- .../cluster/providers/gcp/test_gcp_api.py | 430 ------------------ 1 file changed, 430 deletions(-) delete mode 100644 lib/iris/tests/cluster/providers/gcp/test_gcp_api.py diff --git a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py b/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py deleted file mode 100644 index 383d3df8bd..0000000000 --- a/lib/iris/tests/cluster/providers/gcp/test_gcp_api.py +++ /dev/null @@ -1,430 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Tests for CloudGcpService REST API integration. - -Uses httpx.MockTransport to verify URL construction, error mapping, -pagination, operation waiting, and auth header injection without hitting real GCP. -""" - -from __future__ import annotations - -import json -from collections.abc import Callable -from unittest.mock import patch - -import httpx -import pytest - -from iris.cluster.providers.gcp.service import ( - CloudGcpService, - TpuCreateRequest, - VmCreateRequest, -) -from iris.cluster.providers.types import ( - InfraError, - QuotaExhaustedError, -) -from iris.rpc import config_pb2 - -PROJECT = "test-project" -ZONE = "us-central1-a" - - -def _mock_credentials(): - """Patch google.auth.default to return a fake credential.""" - cred = type( - "FakeCred", - (), - { - "token": "fake-token", - "expiry": None, - "refresh": lambda self, req: None, - }, - )() - return patch("iris.cluster.providers.gcp.service.google.auth.default", return_value=(cred, PROJECT)) - - -def _make_svc(handler: Callable[[httpx.Request], httpx.Response]) -> CloudGcpService: - """Create a CloudGcpService with a mock HTTP transport and fake credentials.""" - client = httpx.Client(transport=httpx.MockTransport(handler), timeout=10) - svc = CloudGcpService(PROJECT, http_client=client) - svc._token = "fake-token" - svc._expires_at = float("inf") - return svc - - -def _json_response(body: dict, status: int = 200) -> httpx.Response: - return httpx.Response(status, json=body) - - -# ======================================================================== -# Error mapping -# ======================================================================== - - -def test_404_raises_resource_not_found(): - def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) - - svc = _make_svc(handler) - assert svc.tpu_describe("no-such-tpu", ZONE) is None - svc.shutdown() - - -def test_429_raises_quota_exhausted(): - def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response( - 429, json={"error": {"code": 429, "message": "Quota exceeded", "status": "RESOURCE_EXHAUSTED"}} - ) - - svc = _make_svc(handler) - with pytest.raises(QuotaExhaustedError, match="Quota exceeded"): - svc.vm_reset("some-vm", ZONE) - svc.shutdown() - - -def test_500_raises_infra_error(): - def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(500, json={"error": {"code": 500, "message": "Internal error"}}) - - svc = _make_svc(handler) - with pytest.raises(InfraError, match="Internal error"): - svc.vm_reset("some-vm", ZONE) - svc.shutdown() - - -# ======================================================================== -# Auth headers -# ======================================================================== - - -def test_auth_header_injected(): - requests_seen: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"name": "test", "state": "READY"}) - - svc = _make_svc(handler) - svc.tpu_describe("my-tpu", ZONE) - svc.shutdown() - - assert len(requests_seen) == 1 - assert requests_seen[0].headers["authorization"] == "Bearer fake-token" - - -def test_token_refresh_on_expiry(): - with _mock_credentials(): - - def handler(request: httpx.Request) -> httpx.Response: - return _json_response({"name": "test", "state": "READY"}) - - svc = _make_svc(handler) - svc._token = None - svc._expires_at = 0.0 - svc.tpu_describe("my-tpu", ZONE) - assert svc._token == "fake-token" - svc.shutdown() - - -# ======================================================================== -# VM create waits for operation before describing -# ======================================================================== - - -def test_vm_create_waits_for_operation(): - """vm_create must wait for the insert operation before describing the VM.""" - operation_done = False - - def handler(request: httpx.Request) -> httpx.Response: - nonlocal operation_done - url = str(request.url) - - if request.method == "POST" and url.endswith(f"/zones/{ZONE}/instances"): - return _json_response( - {"name": "op-vm-1", "status": "RUNNING", "zone": f"zones/{ZONE}", "kind": "compute#operation"} - ) - - if "/operations/op-vm-1" in url and request.method == "GET": - operation_done = True - return _json_response({"name": "op-vm-1", "status": "DONE"}) - - if request.method == "GET" and url.endswith(f"/zones/{ZONE}/instances/test-vm"): - if not operation_done: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) - return _json_response( - { - "name": "test-vm", - "status": "RUNNING", - "zone": f"projects/{PROJECT}/zones/{ZONE}", - "networkInterfaces": [{"networkIP": "10.0.0.1", "accessConfigs": [{"natIP": "34.1.2.3"}]}], - "metadata": {}, - "serviceAccounts": [{"email": "sa@test.iam.gserviceaccount.com"}], - "creationTimestamp": "2026-01-01T00:00:00Z", - } - ) - - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - svc = _make_svc(handler) - info = svc.vm_create(VmCreateRequest(name="test-vm", zone=ZONE, machine_type="n1-standard-4", labels={})) - svc.shutdown() - - assert info.name == "test-vm" - assert info.internal_ip == "10.0.0.1" - assert operation_done, "vm_create must poll the operation before describing" - - -def test_tpu_create_waits_for_operation(): - """tpu_create must wait for the LRO before describing the TPU.""" - operation_done = False - - def handler(request: httpx.Request) -> httpx.Response: - nonlocal operation_done - url = str(request.url) - - if request.method == "POST" and "/nodes" in url and "nodeId=test-tpu" in url: - return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/operations/op-tpu-1", "done": False}) - - if "/operations/op-tpu-1" in url and request.method == "GET": - operation_done = True - return _json_response({"name": f"projects/{PROJECT}/locations/{ZONE}/operations/op-tpu-1", "done": True}) - - if request.method == "GET" and url.endswith("/nodes/test-tpu"): - if not operation_done: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found", "status": "NOT_FOUND"}}) - return _json_response( - { - "name": f"projects/{PROJECT}/locations/{ZONE}/nodes/test-tpu", - "state": "READY", - "acceleratorType": "v4-8", - "networkEndpoints": [{"ipAddress": "10.0.0.2"}], - "labels": {}, - "metadata": {}, - "createTime": "2026-01-01T00:00:00Z", - } - ) - - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - svc = _make_svc(handler) - info = svc.tpu_create( - TpuCreateRequest( - name="test-tpu", - zone=ZONE, - accelerator_type="v4-8", - runtime_version="tpu-ubuntu2204-base", - capacity_type=config_pb2.CAPACITY_TYPE_ON_DEMAND, - labels={}, - ) - ) - svc.shutdown() - - assert info.name == "test-tpu" - assert info.state == "READY" - assert operation_done, "tpu_create must poll the operation before describing" - - -# ======================================================================== -# TPU list — only queries requested zones -# ======================================================================== - - -def test_tpu_list_queries_only_requested_zones(): - zones_queried: list[str] = [] - - def handler(request: httpx.Request) -> httpx.Response: - url = str(request.url) - if "/locations/" in url and "/nodes" in url: - zone = url.split("/locations/")[1].split("/nodes")[0] - zones_queried.append(zone) - return _json_response({"nodes": []}) - - svc = _make_svc(handler) - svc.tpu_list(zones=["europe-west4-b", "us-west4-a"]) - svc.shutdown() - - assert sorted(zones_queried) == ["europe-west4-b", "us-west4-a"] - - -def test_tpu_list_label_filtering(): - def handler(request: httpx.Request) -> httpx.Response: - return _json_response( - { - "nodes": [ - { - "name": f"projects/{PROJECT}/locations/{ZONE}/nodes/tpu-match", - "state": "READY", - "acceleratorType": "v4-8", - "labels": {"env": "test", "managed": "true"}, - }, - { - "name": f"projects/{PROJECT}/locations/{ZONE}/nodes/tpu-nomatch", - "state": "READY", - "acceleratorType": "v4-8", - "labels": {"env": "prod"}, - }, - ] - } - ) - - svc = _make_svc(handler) - results = svc.tpu_list(zones=[ZONE], labels={"env": "test"}) - svc.shutdown() - - assert len(results) == 1 - assert results[0].name == "tpu-match" - - -# ======================================================================== -# VM list -# ======================================================================== - - -def test_vm_list_project_wide_uses_aggregated_list(): - requests_seen: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response( - { - "items": { - "zones/us-central1-a": { - "instances": [ - {"name": "vm-1", "status": "RUNNING", "zone": f"projects/{PROJECT}/zones/us-central1-a"} - ], - }, - "zones/us-west1-a": { - "warning": {"code": "NO_RESULTS_ON_PAGE"}, - }, - } - } - ) - - svc = _make_svc(handler) - results = svc.vm_list(zones=[]) - svc.shutdown() - - assert len(results) == 1 - assert results[0].name == "vm-1" - assert "/aggregated/instances" in str(requests_seen[0].url) - - -def test_vm_list_with_labels_passes_filter(): - requests_seen: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"items": []}) - - svc = _make_svc(handler) - svc.vm_list(zones=[ZONE], labels={"env": "test"}) - svc.shutdown() - - assert "filter=labels.env%3Dtest" in str(requests_seen[0].url) - - -# ======================================================================== -# Pagination -# ======================================================================== - - -def test_tpu_list_with_pagination(): - call_count = 0 - - def handler(request: httpx.Request) -> httpx.Response: - nonlocal call_count - call_count += 1 - if call_count == 1: - return _json_response( - { - "nodes": [{"name": "tpu-1", "state": "READY"}], - "nextPageToken": "page2", - } - ) - return _json_response({"nodes": [{"name": "tpu-2", "state": "READY"}]}) - - svc = _make_svc(handler) - results = svc.tpu_list(zones=[ZONE]) - svc.shutdown() - - assert len(results) == 2 - assert call_count == 2 - - -# ======================================================================== -# Cloud Logging -# ======================================================================== - - -def test_logging_read(): - requests_seen: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"entries": [{"textPayload": "line 1"}, {"textPayload": "line 2"}]}) - - svc = _make_svc(handler) - entries = svc.logging_read("some filter", limit=50) - svc.shutdown() - - assert entries == ["line 1", "line 2"] - body = json.loads(requests_seen[0].content) - assert body["filter"] == "some filter" - assert body["pageSize"] == 50 - - -def test_logging_read_empty(): - def handler(request: httpx.Request) -> httpx.Response: - return _json_response({}) - - svc = _make_svc(handler) - assert svc.logging_read("no match") == [] - svc.shutdown() - - -# ======================================================================== -# Delete operations ignore 404 -# ======================================================================== - - -def test_tpu_delete_ignores_404(): - def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - svc = _make_svc(handler) - svc.tpu_delete("gone-tpu", ZONE) - svc.shutdown() - - -def test_vm_delete_ignores_404(): - def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - svc = _make_svc(handler) - svc.vm_delete("gone-vm", ZONE) - svc.shutdown() - - -def test_queued_resource_delete_ignores_404(): - def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(404, json={"error": {"code": 404, "message": "Not found"}}) - - svc = _make_svc(handler) - svc.queued_resource_delete("gone-qr", ZONE) - svc.shutdown() - - -def test_queued_resource_delete_passes_force(): - requests_seen: list[httpx.Request] = [] - - def handler(request: httpx.Request) -> httpx.Response: - requests_seen.append(request) - return _json_response({"name": "operations/op-789"}) - - svc = _make_svc(handler) - svc.queued_resource_delete("my-qr", ZONE) - svc.shutdown() - - assert "force=true" in str(requests_seen[0].url) From 31c5feac4222279e86876a64302be25cb42b686e Mon Sep 17 00:00:00 2001 From: Russell Power Date: Sat, 4 Apr 2026 11:48:48 -0700 Subject: [PATCH 8/8] [iris] Wait for zonal operations on VM delete, labels, and metadata vm_delete gains a wait parameter (default False) so controller replacement blocks until the old VM is fully gone, preventing a create-after-delete race on the fixed controller name. Worker deletions remain fire-and-forget. vm_update_labels and vm_set_metadata now always poll the returned operation to completion so callers observe consistent state on return. --- .../iris/cluster/controller/vm_lifecycle.py | 6 +++--- .../src/iris/cluster/providers/gcp/fake.py | 2 +- .../src/iris/cluster/providers/gcp/handles.py | 10 +++++----- .../src/iris/cluster/providers/gcp/local.py | 2 +- .../src/iris/cluster/providers/gcp/service.py | 19 +++++++++++++++---- .../iris/cluster/providers/manual/provider.py | 4 ++-- lib/iris/src/iris/cluster/providers/types.py | 4 ++-- .../cluster/controller/test_vm_lifecycle.py | 2 +- lib/iris/tests/cluster/providers/conftest.py | 2 +- .../cluster/test_snapshot_reconciliation.py | 2 +- 10 files changed, 32 insertions(+), 21 deletions(-) diff --git a/lib/iris/src/iris/cluster/controller/vm_lifecycle.py b/lib/iris/src/iris/cluster/controller/vm_lifecycle.py index 9fadac54e3..0ed21c0122 100644 --- a/lib/iris/src/iris/cluster/controller/vm_lifecycle.py +++ b/lib/iris/src/iris/cluster/controller/vm_lifecycle.py @@ -350,7 +350,7 @@ def start_controller( logger.info("Existing controller at %s is healthy", address) return address, existing_vm logger.info("Existing controller is unhealthy, terminating and recreating") - existing_vm.terminate() + existing_vm.terminate(wait=True) # Create new controller VM vm_config = _build_controller_vm_config(config) @@ -359,7 +359,7 @@ def start_controller( # Wait for connection if not vm.wait_for_connection(timeout=Duration.from_seconds(300)): - vm.terminate() + vm.terminate(wait=True) raise RuntimeError(f"Controller VM {vm_config.name} did not become reachable within 300s") # Bootstrap @@ -424,6 +424,6 @@ def stop_controller(platform: WorkerInfraProvider, config: config_pb2.IrisCluste vm = _discover_controller_vm(platform, label_prefix) if vm: logger.info("Stopping controller VM %s", vm.vm_id) - vm.terminate() + vm.terminate(wait=True) else: logger.info("No controller VM found to stop") diff --git a/lib/iris/src/iris/cluster/providers/gcp/fake.py b/lib/iris/src/iris/cluster/providers/gcp/fake.py index 0566e7661a..ae4af1b04d 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/fake.py +++ b/lib/iris/src/iris/cluster/providers/gcp/fake.py @@ -324,7 +324,7 @@ def vm_create(self, request: VmCreateRequest) -> VmInfo: self._vms[(request.name, request.zone)] = info return info - def vm_delete(self, name: str, zone: str) -> None: + def vm_delete(self, name: str, zone: str, *, wait: bool = False) -> None: self._check_injected_failure("vm_delete") self._vms.pop((name, zone), None) diff --git a/lib/iris/src/iris/cluster/providers/gcp/handles.py b/lib/iris/src/iris/cluster/providers/gcp/handles.py index 87bffd4d73..c571853907 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/handles.py +++ b/lib/iris/src/iris/cluster/providers/gcp/handles.py @@ -179,10 +179,10 @@ def reboot(self) -> None: logger.info("Rebooting GCE instance: %s", self._gce_vm_name) self._gcp_service.vm_reset(self._gce_vm_name, self._zone) - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: assert self._gcp_service is not None logger.info("Deleting GCE instance: %s", self._gce_vm_name) - self._gcp_service.vm_delete(self._gce_vm_name, self._zone) + self._gcp_service.vm_delete(self._gce_vm_name, self._zone, wait=wait) def set_labels(self, labels: dict[str, str]) -> None: assert self._gcp_service is not None @@ -366,7 +366,7 @@ def _describe_queued_resource(self) -> SliceStatus: # QUEUED, PROVISIONING, WAITING_FOR_RESOURCES → still creating return SliceStatus(state=CloudSliceState.CREATING, worker_count=0) - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: if self.is_queued_resource: logger.info("Terminating queued resource (force): %s", self._slice_id) self._gcp_service.queued_resource_delete(self._slice_id, self._zone) @@ -478,6 +478,6 @@ def _describe_cloud(self) -> SliceStatus: ) return SliceStatus(state=state, worker_count=1, workers=[worker]) - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: logger.info("Terminating VM slice: %s (vm=%s)", self._slice_id, self._vm_name) - self._gcp_service.vm_delete(self._vm_name, self._zone) + self._gcp_service.vm_delete(self._vm_name, self._zone, wait=wait) diff --git a/lib/iris/src/iris/cluster/providers/gcp/local.py b/lib/iris/src/iris/cluster/providers/gcp/local.py index fb87c3fb3b..e1c68fcaae 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/local.py +++ b/lib/iris/src/iris/cluster/providers/gcp/local.py @@ -162,7 +162,7 @@ def describe(self) -> SliceStatus: ] return SliceStatus(state=CloudSliceState.READY, worker_count=len(self._vm_ids), workers=workers) - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: if self._terminated: return self._terminated = True diff --git a/lib/iris/src/iris/cluster/providers/gcp/service.py b/lib/iris/src/iris/cluster/providers/gcp/service.py index 479a169726..19770814ba 100644 --- a/lib/iris/src/iris/cluster/providers/gcp/service.py +++ b/lib/iris/src/iris/cluster/providers/gcp/service.py @@ -232,7 +232,7 @@ def queued_resource_list( ) -> list[QueuedResourceInfo]: ... def vm_create(self, request: VmCreateRequest) -> VmInfo: ... - def vm_delete(self, name: str, zone: str) -> None: ... + def vm_delete(self, name: str, zone: str, *, wait: bool = False) -> None: ... def vm_describe(self, name: str, zone: str) -> VmInfo | None: ... def vm_list(self, zones: list[str], labels: dict[str, str] | None = None) -> list[VmInfo]: ... def vm_reset(self, name: str, zone: str) -> None: ... @@ -736,12 +736,17 @@ def vm_create(self, request: VmCreateRequest) -> VmInfo: raise InfraError(f"VM {request.name} created but could not be described") return info - def vm_delete(self, name: str, zone: str) -> None: + def vm_delete(self, name: str, zone: str, *, wait: bool = False) -> None: logger.info("Deleting VM: %s", name) url = self._instance_url(zone, name) resp = self._client.delete(url, headers=self._headers()) - if resp.status_code != 404: - self._classify_response(resp) + if resp.status_code == 404: + return + self._classify_response(resp) + if wait: + op_name = resp.json().get("name", "") + if op_name: + self._wait_zone_operation(zone, op_name) def vm_reset(self, name: str, zone: str) -> None: logger.info("Resetting VM: %s", name) @@ -813,6 +818,9 @@ def vm_update_labels(self, name: str, zone: str, labels: dict[str, str]) -> None json={"labels": current_labels, "labelFingerprint": fingerprint}, ) self._classify_response(resp) + op_name = resp.json().get("name", "") + if op_name: + self._wait_zone_operation(zone, op_name) def vm_set_metadata(self, name: str, zone: str, metadata: dict[str, str]) -> None: logger.info("Setting metadata on VM %s", name) @@ -830,6 +838,9 @@ def vm_set_metadata(self, name: str, zone: str, metadata: dict[str, str]) -> Non url = self._instance_url(zone, name) + "/setMetadata" resp = self._client.post(url, headers=self._headers(), json=body) self._classify_response(resp) + op_name = resp.json().get("name", "") + if op_name: + self._wait_zone_operation(zone, op_name) def vm_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> str: try: diff --git a/lib/iris/src/iris/cluster/providers/manual/provider.py b/lib/iris/src/iris/cluster/providers/manual/provider.py index b1998b9131..580a87623d 100644 --- a/lib/iris/src/iris/cluster/providers/manual/provider.py +++ b/lib/iris/src/iris/cluster/providers/manual/provider.py @@ -79,7 +79,7 @@ def metadata(self) -> dict[str, str]: def status(self) -> WorkerStatus: return WorkerStatus(state=CloudWorkerState.RUNNING) - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: if self._on_terminate: self._on_terminate() @@ -174,7 +174,7 @@ def describe(self) -> SliceStatus: return SliceStatus(state=state, worker_count=len(self._hosts), workers=workers) - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: if self._terminated: return self._terminated = True diff --git a/lib/iris/src/iris/cluster/providers/types.py b/lib/iris/src/iris/cluster/providers/types.py index 31299e74b7..020603a885 100644 --- a/lib/iris/src/iris/cluster/providers/types.py +++ b/lib/iris/src/iris/cluster/providers/types.py @@ -261,7 +261,7 @@ def bootstrap(self, script: str) -> None: """Run the bootstrap script on the worker.""" ... - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: """Destroy the worker.""" ... @@ -317,7 +317,7 @@ def describe(self) -> SliceStatus: """Query cloud state, returning status and worker handles.""" ... - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: """Destroy the slice and all its workers.""" ... diff --git a/lib/iris/tests/cluster/controller/test_vm_lifecycle.py b/lib/iris/tests/cluster/controller/test_vm_lifecycle.py index 5780c9ae23..24babf1566 100644 --- a/lib/iris/tests/cluster/controller/test_vm_lifecycle.py +++ b/lib/iris/tests/cluster/controller/test_vm_lifecycle.py @@ -127,7 +127,7 @@ def bootstrap(self, script: str) -> None: def reboot(self) -> None: pass - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: self.terminated = True def set_labels(self, labels: dict[str, str]) -> None: diff --git a/lib/iris/tests/cluster/providers/conftest.py b/lib/iris/tests/cluster/providers/conftest.py index f8028a56ec..64c758e366 100644 --- a/lib/iris/tests/cluster/providers/conftest.py +++ b/lib/iris/tests/cluster/providers/conftest.py @@ -120,7 +120,7 @@ def created_at(self) -> Timestamp: def describe(self) -> SliceStatus: return self._status - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: self.terminated = True if self.terminate_error is not None: raise self.terminate_error diff --git a/lib/iris/tests/cluster/test_snapshot_reconciliation.py b/lib/iris/tests/cluster/test_snapshot_reconciliation.py index 2d8e2a4aa8..03bf425ffd 100644 --- a/lib/iris/tests/cluster/test_snapshot_reconciliation.py +++ b/lib/iris/tests/cluster/test_snapshot_reconciliation.py @@ -102,7 +102,7 @@ def describe(self) -> SliceStatus: workers=list(self._workers), ) - def terminate(self) -> None: + def terminate(self, *, wait: bool = False) -> None: pass