Skip to content

Commit 8e48926

Browse files
committed
[iris] Replace gcloud CLI with REST API client in CloudGcpService
Extract GCPApi class (httpx + google.auth ADC) that handles auth, pagination, and error mapping for TPU v2, Compute v1, and Cloud Logging APIs. Rewrite CloudGcpService to delegate to GCPApi instead of subprocess gcloud calls. This eliminates the gcloud CLI dependency for resource management, fixing CI failures from gcloud alpha not being installed. Add logging_read to the GcpService Protocol so bootstrap log fetching goes through the same boundary.
1 parent d3b7c64 commit 8e48926

7 files changed

Lines changed: 886 additions & 442 deletions

File tree

lib/iris/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ dependencies = [
1313
"connect-python>=0.9.0",
1414
"fsspec>=2024.0.0",
1515
"gcsfs>=2024.0.0",
16+
"google-auth>=2.0",
1617
"s3fs>=2024.0.0",
1718
"grpcio>=1.76.0",
1819
"httpx>=0.28.1",
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Low-level HTTP client for GCP REST APIs (TPU v2, Compute v1, Cloud Logging).
5+
6+
Handles authentication (Application Default Credentials), token caching,
7+
pagination, and error mapping to domain exceptions. Used by CloudGcpService
8+
as a replacement for gcloud CLI subprocess calls.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import json
14+
import logging
15+
import time
16+
17+
import google.auth
18+
import google.auth.credentials
19+
import google.auth.transport.requests
20+
import httpx
21+
22+
from iris.cluster.providers.types import (
23+
InfraError,
24+
QuotaExhaustedError,
25+
ResourceNotFoundError,
26+
)
27+
28+
logger = logging.getLogger(__name__)
29+
30+
TPU_BASE = "https://tpu.googleapis.com/v2"
31+
COMPUTE_BASE = "https://compute.googleapis.com/compute/v1"
32+
LOGGING_BASE = "https://logging.googleapis.com/v2"
33+
34+
_REFRESH_MARGIN = 300 # seconds before expiry to refresh token
35+
_DEFAULT_TIMEOUT = 120 # seconds
36+
37+
38+
class GCPApi:
39+
"""Low-level HTTP client for GCP REST APIs with ADC auth and token caching."""
40+
41+
def __init__(self, project_id: str) -> None:
42+
self._project_id = project_id
43+
self._client = httpx.Client(timeout=_DEFAULT_TIMEOUT)
44+
self._creds: google.auth.credentials.Credentials | None = None
45+
self._token: str | None = None
46+
self._expires_at: float = 0.0
47+
48+
def close(self) -> None:
49+
self._client.close()
50+
51+
# -- Auth ---------------------------------------------------------------
52+
53+
def _headers(self) -> dict[str, str]:
54+
if self._token is None or time.monotonic() >= self._expires_at:
55+
self._refresh_token()
56+
return {
57+
"Authorization": f"Bearer {self._token}",
58+
"Content-Type": "application/json",
59+
}
60+
61+
def _refresh_token(self) -> None:
62+
if self._creds is None:
63+
self._creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
64+
self._creds.refresh(google.auth.transport.requests.Request())
65+
self._token = self._creds.token
66+
now = time.monotonic()
67+
if self._creds.expiry is not None:
68+
self._expires_at = now + (self._creds.expiry.timestamp() - time.time()) - _REFRESH_MARGIN
69+
else:
70+
self._expires_at = now + _REFRESH_MARGIN
71+
72+
# -- Error mapping ------------------------------------------------------
73+
74+
def _classify_response(self, resp: httpx.Response) -> None:
75+
"""Raise a domain exception for non-2xx responses."""
76+
if resp.status_code < 400:
77+
return
78+
try:
79+
body = resp.json()
80+
error = body.get("error", {})
81+
message = error.get("message", resp.text)
82+
status = error.get("status", "")
83+
code = error.get("code", resp.status_code)
84+
except (json.JSONDecodeError, AttributeError):
85+
message = resp.text
86+
status = ""
87+
code = resp.status_code
88+
89+
if code == 404 or status == "NOT_FOUND":
90+
raise ResourceNotFoundError(message)
91+
if code == 429 or status in ("RESOURCE_EXHAUSTED", "QUOTA_EXCEEDED"):
92+
raise QuotaExhaustedError(message)
93+
raise InfraError(f"GCP API error {code}: {message}")
94+
95+
# -- Pagination ---------------------------------------------------------
96+
97+
def _paginate(self, url: str, items_key: str, params: dict[str, str] | None = None) -> list[dict]:
98+
results: list[dict] = []
99+
p = dict(params or {})
100+
while True:
101+
resp = self._client.get(url, headers=self._headers(), params=p)
102+
self._classify_response(resp)
103+
data = resp.json()
104+
results.extend(data.get(items_key, []))
105+
token = data.get("nextPageToken")
106+
if not token:
107+
break
108+
p["pageToken"] = token
109+
return results
110+
111+
def _paginate_raw(self, url: str, params: dict[str, str] | None = None) -> list[dict]:
112+
"""Return raw page bodies (for aggregatedList where items_key varies)."""
113+
pages: list[dict] = []
114+
p = dict(params or {})
115+
while True:
116+
resp = self._client.get(url, headers=self._headers(), params=p)
117+
self._classify_response(resp)
118+
data = resp.json()
119+
pages.append(data)
120+
token = data.get("nextPageToken")
121+
if not token:
122+
break
123+
p["pageToken"] = token
124+
return pages
125+
126+
# ======================================================================
127+
# TPU v2
128+
# ======================================================================
129+
130+
def _tpu_parent(self, zone: str) -> str:
131+
return f"projects/{self._project_id}/locations/{zone}"
132+
133+
def tpu_create(self, name: str, zone: str, body: dict) -> dict | None:
134+
url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes"
135+
resp = self._client.post(url, params={"nodeId": name}, headers=self._headers(), json=body)
136+
self._classify_response(resp)
137+
data = resp.json()
138+
# REST create returns a long-running operation, not the node itself.
139+
return data if data.get("name", "").endswith(f"/nodes/{name}") else None
140+
141+
def tpu_get(self, name: str, zone: str) -> dict:
142+
url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}"
143+
resp = self._client.get(url, headers=self._headers())
144+
self._classify_response(resp)
145+
return resp.json()
146+
147+
def tpu_delete(self, name: str, zone: str) -> None:
148+
url = f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes/{name}"
149+
resp = self._client.delete(url, headers=self._headers())
150+
if resp.status_code != 404:
151+
self._classify_response(resp)
152+
153+
def tpu_list(self, zone: str) -> list[dict]:
154+
return self._paginate(f"{TPU_BASE}/{self._tpu_parent(zone)}/nodes", "nodes")
155+
156+
# ======================================================================
157+
# TPU v2 — Queued Resources
158+
# ======================================================================
159+
160+
def queued_resource_create(self, name: str, zone: str, body: dict) -> None:
161+
url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources"
162+
resp = self._client.post(
163+
url,
164+
params={"queuedResourceId": name},
165+
headers=self._headers(),
166+
json=body,
167+
)
168+
self._classify_response(resp)
169+
170+
def queued_resource_get(self, name: str, zone: str) -> dict:
171+
url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources/{name}"
172+
resp = self._client.get(url, headers=self._headers())
173+
self._classify_response(resp)
174+
return resp.json()
175+
176+
def queued_resource_delete(self, name: str, zone: str) -> None:
177+
url = f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources/{name}"
178+
resp = self._client.delete(url, params={"force": "true"}, headers=self._headers())
179+
if resp.status_code != 404:
180+
self._classify_response(resp)
181+
182+
def queued_resource_list(self, zone: str) -> list[dict]:
183+
return self._paginate(
184+
f"{TPU_BASE}/{self._tpu_parent(zone)}/queuedResources",
185+
"queuedResources",
186+
)
187+
188+
# ======================================================================
189+
# Compute Engine v1 — Instances
190+
# ======================================================================
191+
192+
def _instance_url(self, zone: str, name: str = "") -> str:
193+
path = f"{COMPUTE_BASE}/projects/{self._project_id}/zones/{zone}/instances"
194+
if name:
195+
path += f"/{name}"
196+
return path
197+
198+
def instance_insert(self, zone: str, body: dict) -> dict:
199+
url = self._instance_url(zone)
200+
resp = self._client.post(url, headers=self._headers(), json=body)
201+
self._classify_response(resp)
202+
return resp.json()
203+
204+
def instance_get(self, name: str, zone: str) -> dict:
205+
url = self._instance_url(zone, name)
206+
resp = self._client.get(url, headers=self._headers())
207+
self._classify_response(resp)
208+
return resp.json()
209+
210+
def instance_delete(self, name: str, zone: str) -> None:
211+
url = self._instance_url(zone, name)
212+
resp = self._client.delete(url, headers=self._headers())
213+
if resp.status_code != 404:
214+
self._classify_response(resp)
215+
216+
def instance_list(self, zone: str | None = None, filter_str: str = "") -> list[dict]:
217+
params: dict[str, str] = {}
218+
if filter_str:
219+
params["filter"] = filter_str
220+
221+
if zone:
222+
return self._paginate(self._instance_url(zone), "items", params)
223+
224+
# Project-wide: aggregatedList, flatten across zones
225+
url = f"{COMPUTE_BASE}/projects/{self._project_id}/aggregated/instances"
226+
results: list[dict] = []
227+
for page in self._paginate_raw(url, params):
228+
for scope in page.get("items", {}).values():
229+
results.extend(scope.get("instances", []))
230+
return results
231+
232+
def instance_reset(self, name: str, zone: str) -> None:
233+
url = self._instance_url(zone, name) + "/reset"
234+
resp = self._client.post(url, headers=self._headers())
235+
self._classify_response(resp)
236+
237+
def instance_set_labels(self, name: str, zone: str, labels: dict[str, str], fingerprint: str) -> None:
238+
url = self._instance_url(zone, name) + "/setLabels"
239+
resp = self._client.post(
240+
url,
241+
headers=self._headers(),
242+
json={"labels": labels, "labelFingerprint": fingerprint},
243+
)
244+
self._classify_response(resp)
245+
246+
def instance_set_metadata(self, name: str, zone: str, metadata_body: dict) -> None:
247+
url = self._instance_url(zone, name) + "/setMetadata"
248+
resp = self._client.post(url, headers=self._headers(), json=metadata_body)
249+
self._classify_response(resp)
250+
251+
def instance_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> dict:
252+
url = self._instance_url(zone, name) + "/serialPort"
253+
resp = self._client.get(url, headers=self._headers(), params={"start": str(start)})
254+
self._classify_response(resp)
255+
return resp.json()
256+
257+
# ======================================================================
258+
# Cloud Logging v2
259+
# ======================================================================
260+
261+
def logging_list_entries(self, filter_str: str, limit: int = 200) -> list[dict]:
262+
url = f"{LOGGING_BASE}/entries:list"
263+
body = {
264+
"resourceNames": [f"projects/{self._project_id}"],
265+
"filter": filter_str,
266+
"pageSize": min(limit, 1000),
267+
"orderBy": "timestamp desc",
268+
}
269+
resp = self._client.post(url, headers=self._headers(), json=body, timeout=30)
270+
self._classify_response(resp)
271+
return resp.json().get("entries", [])

lib/iris/src/iris/cluster/providers/gcp/fake.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ def vm_get_serial_port_output(self, name: str, zone: str, start: int = 0) -> str
375375
full_output = self._serial_port_output.get((name, zone), "")
376376
return full_output[start:]
377377

378+
def logging_read(self, filter_str: str, limit: int = 200) -> list[str]:
379+
return []
380+
378381
# ========================================================================
379382
# LOCAL mode: worker spawning
380383
# ========================================================================

0 commit comments

Comments
 (0)