Skip to content

Commit 9931883

Browse files
rjpowerclaude
andcommitted
[rigging] Cleanup: drop dead JwtTokenManager wrappers, simplify resolver API
- Rename `vm_address(name, provider="gcp", ...)` → `gcp_vm_address(name, ...)`, drop the speculative provider arg. - Lazy-import `google.auth`/`httpx` inside their only call sites so importing `iris.client` no longer pulls them in eagerly. - Add `is_registered(scheme)` so callers (and tests) don't have to peek at `_HANDLERS`. - Replace the uvicorn-based `test_resolver_plugin.py` integration test with an in-process unit test that monkeypatches `ControllerServiceClientSync`. - Drop `JwtTokenManager.signing_key` and `JwtTokenManager.verifier` properties; surface the signing key on `ControllerAuth` instead so the log-server subprocess wiring goes through one explicit field. - Trim the `VerifiedIdentity` re-export comment in `iris/rpc/auth.py`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 91027b1 commit 9931883

7 files changed

Lines changed: 115 additions & 171 deletions

File tree

lib/iris/src/iris/client/resolver_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
"""Registers ``iris://<cluster>?endpoint=<name>`` with ``rigging.resolver``."""
55

6-
from rigging.resolver import ServiceURL, register_scheme, vm_address
6+
from rigging.resolver import ServiceURL, gcp_vm_address, register_scheme
77

88
from iris.rpc.controller_connect import ControllerServiceClientSync
99
from iris.rpc.controller_pb2 import Controller as _Controller
@@ -16,7 +16,7 @@ def _resolve_iris(url: ServiceURL) -> tuple[str, int]:
1616
name = url.query.get("endpoint")
1717
if not name:
1818
raise ValueError(f"iris:// URL requires ?endpoint=<name>: {url!r}")
19-
controller_host, controller_port = vm_address(f"iris-controller-{cluster}", provider="gcp")
19+
controller_host, controller_port = gcp_vm_address(f"iris-controller-{cluster}")
2020
with ControllerServiceClientSync(address=f"http://{controller_host}:{controller_port}") as client:
2121
response = client.list_endpoints(_Controller.ListEndpointsRequest(prefix=name, exact=True))
2222
if not response.endpoints:

lib/iris/src/iris/cluster/controller/auth.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -181,22 +181,12 @@ class JwtTokenManager:
181181
"""
182182

183183
def __init__(self, signing_key: str, db: ControllerDB | None = None):
184+
self._signing_key = signing_key
184185
self._verifier = JwtVerifier(signing_key)
185186
self._db = db
186187
# Tracks the last wall-clock time we wrote last_used_at per jti.
187188
self._last_touched: dict[str, float] = {}
188189

189-
@property
190-
def signing_key(self) -> str:
191-
"""HMAC secret used to sign and verify JWTs. Do not log or serialize."""
192-
return self._verifier.signing_key
193-
194-
@property
195-
def verifier(self) -> JwtVerifier:
196-
"""Underlying stateless verifier, suitable for handing to a log server
197-
or other process that should validate but not issue tokens."""
198-
return self._verifier
199-
200190
def create_token(
201191
self,
202192
user_id: str,
@@ -212,7 +202,7 @@ def create_token(
212202
"iat": int(now),
213203
"exp": int(now + ttl_seconds),
214204
}
215-
return jwt.encode(payload, self._verifier.signing_key, algorithm=JWT_ALGORITHM)
205+
return jwt.encode(payload, self._signing_key, algorithm=JWT_ALGORITHM)
216206

217207
def verify(self, token: str) -> VerifiedIdentity:
218208
"""Verify JWT signature and claims, check revocation.
@@ -276,6 +266,9 @@ class ControllerAuth:
276266
login_verifier: TokenVerifier | None = None
277267
gcp_project_id: str | None = None
278268
jwt_manager: JwtTokenManager | None = None
269+
# HMAC signing key — handed to the log-server subprocess via env var. Do
270+
# not log or serialize.
271+
signing_key: str | None = None
279272
optional: bool = False
280273

281274

@@ -301,7 +294,12 @@ def create_controller_auth(
301294

302295
worker_token = _create_worker_jwt(db, jwt_mgr, now)
303296
logger.info("Authentication disabled — null-auth mode (workers use JWT)")
304-
return ControllerAuth(verifier=jwt_mgr, worker_token=worker_token, jwt_manager=jwt_mgr)
297+
return ControllerAuth(
298+
verifier=jwt_mgr,
299+
worker_token=worker_token,
300+
jwt_manager=jwt_mgr,
301+
signing_key=signing_key,
302+
)
305303
logger.info("Authentication disabled — null-auth mode, no DB")
306304
return ControllerAuth()
307305

@@ -327,8 +325,8 @@ def create_controller_auth(
327325

328326
verifier: TokenVerifier | None = jwt_mgr
329327
else:
330-
ephemeral_key = secrets.token_hex(32)
331-
jwt_mgr = JwtTokenManager(ephemeral_key)
328+
signing_key = secrets.token_hex(32)
329+
jwt_mgr = JwtTokenManager(signing_key)
332330
worker_token = jwt_mgr.create_token(WORKER_USER, "worker", f"iris_k_worker_{secrets.token_hex(8)}")
333331
verifier = None
334332

@@ -361,6 +359,7 @@ def create_controller_auth(
361359
login_verifier=login_verifier,
362360
gcp_project_id=gcp_project_id,
363361
jwt_manager=jwt_mgr,
362+
signing_key=signing_key,
364363
optional=optional,
365364
)
366365

lib/iris/src/iris/cluster/controller/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def run_controller_serve(
231231
port=log_port,
232232
log_dir=log_dir,
233233
remote_log_dir=remote_log_dir,
234-
signing_key=auth.jwt_manager.signing_key if auth.jwt_manager else None,
234+
signing_key=auth.signing_key,
235235
strict_auth=auth.provider is not None,
236236
)
237237
try:

lib/iris/src/iris/rpc/auth.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@
3030
SESSION_COOKIE = "iris_session"
3131

3232

33-
# `VerifiedIdentity` is defined in `rigging.auth` so the stateless
34-
# `JwtVerifier` (also in rigging) and the iris-side issuer share the same
35-
# type. Re-exported here for the many iris call sites that import
36-
# `from iris.rpc.auth import VerifiedIdentity`.
37-
38-
3933
def _extract_cookie(cookie_header: str, name: str) -> str | None:
4034
"""Extract a named cookie value from a raw Cookie header."""
4135
if not cookie_header:
Lines changed: 57 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,146 +1,89 @@
11
# Copyright The Marin Authors
22
# SPDX-License-Identifier: Apache-2.0
33

4-
"""End-to-end test for the iris resolver plugin."""
5-
6-
import socket
7-
import threading
8-
from collections.abc import Iterator
9-
from typing import Any
4+
"""Unit tests for the iris:// resolver plugin."""
105

116
import pytest
12-
import uvicorn
13-
from starlette.applications import Starlette
14-
from starlette.middleware.wsgi import WSGIMiddleware
15-
from starlette.routing import Mount
167

178
import iris.client # noqa: F401 -- side-effect import: registers iris:// scheme
9+
from iris.client import resolver_plugin
1810
from iris.rpc import controller_pb2
19-
from iris.rpc.controller_connect import ControllerServiceSync, ControllerServiceWSGIApplication
20-
from rigging import resolver as resolver_module
21-
from rigging.resolver import resolve
22-
from rigging.timing import Duration, ExponentialBackoff
11+
from rigging.resolver import is_registered, resolve
12+
2313

14+
class _FakeControllerClient:
15+
"""Stubs ControllerServiceClientSync + its context manager."""
2416

25-
class _StubControllerService(ControllerServiceSync):
26-
"""Minimal ``ControllerServiceSync`` that only implements ``list_endpoints``.
17+
def __init__(self, endpoints: dict[str, str]):
18+
self._endpoints = endpoints
19+
self.last_request: controller_pb2.Controller.ListEndpointsRequest | None = None
2720

28-
Inherits the Protocol base class so all other RPCs default to
29-
``UNIMPLEMENTED`` errors, exactly what we want for an isolated test.
30-
"""
21+
def __enter__(self) -> "_FakeControllerClient":
22+
return self
3123

32-
def __init__(self) -> None:
33-
self.endpoints: dict[str, str] = {}
24+
def __exit__(self, *_exc) -> None:
25+
return None
3426

3527
def list_endpoints(
3628
self,
3729
request: controller_pb2.Controller.ListEndpointsRequest,
38-
ctx: Any,
3930
) -> controller_pb2.Controller.ListEndpointsResponse:
40-
results: list[controller_pb2.Controller.Endpoint] = []
41-
for name, address in self.endpoints.items():
42-
if request.exact:
43-
if name == request.prefix:
44-
results.append(controller_pb2.Controller.Endpoint(name=name, address=address))
45-
else:
46-
if name.startswith(request.prefix):
47-
results.append(controller_pb2.Controller.Endpoint(name=name, address=address))
48-
return controller_pb2.Controller.ListEndpointsResponse(endpoints=results)
49-
50-
51-
def _free_port() -> int:
52-
with socket.socket() as s:
53-
s.bind(("127.0.0.1", 0))
54-
return s.getsockname()[1]
55-
56-
57-
def _build_app(service: _StubControllerService) -> Starlette:
58-
wsgi = ControllerServiceWSGIApplication(service=service)
59-
return Starlette(routes=[Mount(wsgi.path, app=WSGIMiddleware(wsgi))])
60-
61-
62-
class _BackgroundServer:
63-
def __init__(self, app: Starlette, port: int) -> None:
64-
config = uvicorn.Config(
65-
app,
66-
host="127.0.0.1",
67-
port=port,
68-
log_level="error",
69-
log_config=None,
70-
timeout_keep_alive=5,
71-
)
72-
self.server = uvicorn.Server(config)
73-
self.port = port
74-
self._thread = threading.Thread(
75-
target=self.server.run,
76-
name=f"resolver-plugin-test-{port}",
77-
daemon=True,
78-
)
79-
80-
def start(self) -> None:
81-
self._thread.start()
82-
started = ExponentialBackoff(initial=0.01, maximum=0.2).wait_until(
83-
lambda: self.server.started,
84-
timeout=Duration.from_seconds(5.0),
85-
)
86-
if not started:
87-
raise RuntimeError(f"uvicorn did not start within 5s on port {self.port}")
88-
89-
def stop(self) -> None:
90-
self.server.should_exit = True
91-
self._thread.join(timeout=5.0)
31+
self.last_request = request
32+
matches = [
33+
controller_pb2.Controller.Endpoint(name=n, address=a)
34+
for n, a in self._endpoints.items()
35+
if n == request.prefix
36+
]
37+
return controller_pb2.Controller.ListEndpointsResponse(endpoints=matches)
9238

9339

9440
@pytest.fixture
95-
def stub_controller() -> Iterator[tuple[_StubControllerService, int]]:
96-
svc = _StubControllerService()
97-
port = _free_port()
98-
bg = _BackgroundServer(_build_app(svc), port)
99-
bg.start()
100-
try:
101-
yield svc, port
102-
finally:
103-
bg.stop()
104-
105-
106-
def test_resolve_iris_round_trips(monkeypatch, stub_controller):
107-
svc, controller_port = stub_controller
108-
svc.endpoints["/system/x"] = "host.example.com:1234"
109-
110-
captured: list[tuple] = []
41+
def patch_resolver(monkeypatch):
42+
"""Replace gcp_vm_address + controller client with in-process stubs."""
43+
vm_calls: list[str] = []
44+
45+
def _install(endpoints: dict[str, str]) -> _FakeControllerClient:
46+
fake = _FakeControllerClient(endpoints)
47+
48+
def _fake_vm_address(name: str, *, port: int = 10002) -> tuple[str, int]:
49+
vm_calls.append(name)
50+
return ("127.0.0.1", 65000)
51+
52+
monkeypatch.setattr(resolver_plugin, "gcp_vm_address", _fake_vm_address)
53+
monkeypatch.setattr(
54+
resolver_plugin,
55+
"ControllerServiceClientSync",
56+
lambda address: fake,
57+
)
58+
return fake
11159

112-
def _fake_vm_address(name: str, provider: str) -> tuple[str, int]:
113-
captured.append((name, provider))
114-
# Direct test traffic at the in-process stub rather than GCP.
115-
return ("127.0.0.1", controller_port)
60+
_install.vm_calls = vm_calls # type: ignore[attr-defined]
61+
return _install
11662

117-
# The plugin binds vm_address as a module-global at import time; patch
118-
# it where it's looked up.
119-
from iris.client import resolver_plugin
12063

121-
monkeypatch.setattr(resolver_plugin, "vm_address", _fake_vm_address)
64+
def test_resolve_iris_returns_endpoint_address(patch_resolver):
65+
patch_resolver({"/system/x": "host.example.com:1234"})
66+
assert resolve("iris://marin?endpoint=/system/x") == ("host.example.com", 1234)
67+
assert patch_resolver.vm_calls == ["iris-controller-marin"]
12268

123-
host, port = resolve("iris://marin?endpoint=/system/x")
124-
assert (host, port) == ("host.example.com", 1234)
125-
assert captured == [("iris-controller-marin", "gcp")]
12669

70+
def test_resolve_iris_not_found_raises(patch_resolver):
71+
patch_resolver({})
72+
with pytest.raises(KeyError, match="iris endpoint not found"):
73+
resolve("iris://marin?endpoint=/system/missing")
12774

128-
def test_resolve_iris_not_found(monkeypatch, stub_controller):
129-
_svc, controller_port = stub_controller
13075

131-
from iris.client import resolver_plugin
76+
def test_resolve_iris_requires_endpoint_query(patch_resolver):
77+
patch_resolver({})
78+
with pytest.raises(ValueError, match="requires \\?endpoint="):
79+
resolve("iris://marin")
13280

133-
monkeypatch.setattr(
134-
resolver_plugin,
135-
"vm_address",
136-
lambda name, provider: ("127.0.0.1", controller_port),
137-
)
13881

139-
with pytest.raises(KeyError, match="iris endpoint not found"):
140-
resolve("iris://marin?endpoint=/system/missing")
82+
def test_resolve_iris_rejects_port(patch_resolver):
83+
patch_resolver({})
84+
with pytest.raises(ValueError, match="cannot have a port"):
85+
resolve("iris://marin:9000?endpoint=/x")
14186

14287

14388
def test_iris_scheme_registered_after_iris_client_import():
144-
# Sanity check: importing iris.client (done at the top of this module)
145-
# installs the iris:// handler in rigging.resolver's registry.
146-
assert "iris" in resolver_module._HANDLERS
89+
assert is_registered("iris")

lib/rigging/src/rigging/resolver.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
from dataclasses import dataclass
1313
from urllib.parse import parse_qs, urlsplit
1414

15-
import google.auth
16-
import google.auth.transport.requests
17-
import httpx
18-
1915
_COMPUTE_BASE = "https://compute.googleapis.com/compute/v1"
2016
_OAUTH_SCOPE = "https://www.googleapis.com/auth/cloud-platform"
2117
_TIMEOUT_SECONDS = 10.0
@@ -52,6 +48,10 @@ def register_scheme(scheme: str, handler: SchemeHandler) -> None:
5248
_HANDLERS[scheme] = handler
5349

5450

51+
def is_registered(scheme: str) -> bool:
52+
return scheme in _HANDLERS
53+
54+
5555
def resolve(ref: str) -> tuple[str, int]:
5656
if "://" not in ref:
5757
host, port = ref.rsplit(":", 1)
@@ -63,9 +63,7 @@ def resolve(ref: str) -> tuple[str, int]:
6363
return handler(url)
6464

6565

66-
def vm_address(name: str, provider: str, *, port: int = _DEFAULT_GCP_PORT) -> tuple[str, int]:
67-
if provider != "gcp":
68-
raise ValueError(f"unsupported provider: {provider}")
66+
def gcp_vm_address(name: str, *, port: int = _DEFAULT_GCP_PORT) -> tuple[str, int]:
6967
return _gcp_internal_ip(name), port
7068

7169

@@ -84,6 +82,9 @@ def _gcp_internal_ip(name: str) -> str:
8482

8583

8684
def _gcp_credentials() -> tuple[str, str]:
85+
import google.auth
86+
import google.auth.transport.requests
87+
8788
creds, project_id = google.auth.default(scopes=[_OAUTH_SCOPE])
8889
if not project_id:
8990
raise LookupError("google.auth.default() returned no project_id; set GOOGLE_CLOUD_PROJECT")
@@ -92,6 +93,8 @@ def _gcp_credentials() -> tuple[str, str]:
9293

9394

9495
def _fetch_vm_aggregated(project_id: str, token: str, name: str) -> dict | None:
96+
import httpx
97+
9598
url = f"{_COMPUTE_BASE}/projects/{project_id}/aggregated/instances"
9699
headers = {"Authorization": f"Bearer {token}"}
97100
params: dict[str, str] = {"filter": f"name eq {name}"}
@@ -110,4 +113,4 @@ def _fetch_vm_aggregated(project_id: str, token: str, name: str) -> dict | None:
110113
params["pageToken"] = page_token
111114

112115

113-
register_scheme("gcp", lambda url: vm_address(url.host, "gcp", port=url.port or _DEFAULT_GCP_PORT))
116+
register_scheme("gcp", lambda url: gcp_vm_address(url.host, port=url.port or _DEFAULT_GCP_PORT))

0 commit comments

Comments
 (0)