|
1 | 1 | # Copyright The Marin Authors |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | | -"""End-to-end test for the iris resolver plugin.""" |
5 | | - |
6 | | -import socket |
7 | | -import threading |
8 | | -from collections.abc import Iterator |
9 | | -from typing import Any |
| 4 | +"""Unit tests for the iris:// resolver plugin.""" |
10 | 5 |
|
11 | 6 | import pytest |
12 | | -import uvicorn |
13 | | -from starlette.applications import Starlette |
14 | | -from starlette.middleware.wsgi import WSGIMiddleware |
15 | | -from starlette.routing import Mount |
16 | 7 |
|
17 | 8 | import iris.client # noqa: F401 -- side-effect import: registers iris:// scheme |
| 9 | +from iris.client import resolver_plugin |
18 | 10 | from iris.rpc import controller_pb2 |
19 | | -from iris.rpc.controller_connect import ControllerServiceSync, ControllerServiceWSGIApplication |
20 | | -from rigging import resolver as resolver_module |
21 | | -from rigging.resolver import resolve |
22 | | -from rigging.timing import Duration, ExponentialBackoff |
| 11 | +from rigging.resolver import is_registered, resolve |
| 12 | + |
23 | 13 |
|
| 14 | +class _FakeControllerClient: |
| 15 | + """Stubs ControllerServiceClientSync + its context manager.""" |
24 | 16 |
|
25 | | -class _StubControllerService(ControllerServiceSync): |
26 | | - """Minimal ``ControllerServiceSync`` that only implements ``list_endpoints``. |
| 17 | + def __init__(self, endpoints: dict[str, str]): |
| 18 | + self._endpoints = endpoints |
| 19 | + self.last_request: controller_pb2.Controller.ListEndpointsRequest | None = None |
27 | 20 |
|
28 | | - Inherits the Protocol base class so all other RPCs default to |
29 | | - ``UNIMPLEMENTED`` errors, exactly what we want for an isolated test. |
30 | | - """ |
| 21 | + def __enter__(self) -> "_FakeControllerClient": |
| 22 | + return self |
31 | 23 |
|
32 | | - def __init__(self) -> None: |
33 | | - self.endpoints: dict[str, str] = {} |
| 24 | + def __exit__(self, *_exc) -> None: |
| 25 | + return None |
34 | 26 |
|
35 | 27 | def list_endpoints( |
36 | 28 | self, |
37 | 29 | request: controller_pb2.Controller.ListEndpointsRequest, |
38 | | - ctx: Any, |
39 | 30 | ) -> controller_pb2.Controller.ListEndpointsResponse: |
40 | | - results: list[controller_pb2.Controller.Endpoint] = [] |
41 | | - for name, address in self.endpoints.items(): |
42 | | - if request.exact: |
43 | | - if name == request.prefix: |
44 | | - results.append(controller_pb2.Controller.Endpoint(name=name, address=address)) |
45 | | - else: |
46 | | - if name.startswith(request.prefix): |
47 | | - results.append(controller_pb2.Controller.Endpoint(name=name, address=address)) |
48 | | - return controller_pb2.Controller.ListEndpointsResponse(endpoints=results) |
49 | | - |
50 | | - |
51 | | -def _free_port() -> int: |
52 | | - with socket.socket() as s: |
53 | | - s.bind(("127.0.0.1", 0)) |
54 | | - return s.getsockname()[1] |
55 | | - |
56 | | - |
57 | | -def _build_app(service: _StubControllerService) -> Starlette: |
58 | | - wsgi = ControllerServiceWSGIApplication(service=service) |
59 | | - return Starlette(routes=[Mount(wsgi.path, app=WSGIMiddleware(wsgi))]) |
60 | | - |
61 | | - |
62 | | -class _BackgroundServer: |
63 | | - def __init__(self, app: Starlette, port: int) -> None: |
64 | | - config = uvicorn.Config( |
65 | | - app, |
66 | | - host="127.0.0.1", |
67 | | - port=port, |
68 | | - log_level="error", |
69 | | - log_config=None, |
70 | | - timeout_keep_alive=5, |
71 | | - ) |
72 | | - self.server = uvicorn.Server(config) |
73 | | - self.port = port |
74 | | - self._thread = threading.Thread( |
75 | | - target=self.server.run, |
76 | | - name=f"resolver-plugin-test-{port}", |
77 | | - daemon=True, |
78 | | - ) |
79 | | - |
80 | | - def start(self) -> None: |
81 | | - self._thread.start() |
82 | | - started = ExponentialBackoff(initial=0.01, maximum=0.2).wait_until( |
83 | | - lambda: self.server.started, |
84 | | - timeout=Duration.from_seconds(5.0), |
85 | | - ) |
86 | | - if not started: |
87 | | - raise RuntimeError(f"uvicorn did not start within 5s on port {self.port}") |
88 | | - |
89 | | - def stop(self) -> None: |
90 | | - self.server.should_exit = True |
91 | | - self._thread.join(timeout=5.0) |
| 31 | + self.last_request = request |
| 32 | + matches = [ |
| 33 | + controller_pb2.Controller.Endpoint(name=n, address=a) |
| 34 | + for n, a in self._endpoints.items() |
| 35 | + if n == request.prefix |
| 36 | + ] |
| 37 | + return controller_pb2.Controller.ListEndpointsResponse(endpoints=matches) |
92 | 38 |
|
93 | 39 |
|
94 | 40 | @pytest.fixture |
95 | | -def stub_controller() -> Iterator[tuple[_StubControllerService, int]]: |
96 | | - svc = _StubControllerService() |
97 | | - port = _free_port() |
98 | | - bg = _BackgroundServer(_build_app(svc), port) |
99 | | - bg.start() |
100 | | - try: |
101 | | - yield svc, port |
102 | | - finally: |
103 | | - bg.stop() |
104 | | - |
105 | | - |
106 | | -def test_resolve_iris_round_trips(monkeypatch, stub_controller): |
107 | | - svc, controller_port = stub_controller |
108 | | - svc.endpoints["/system/x"] = "host.example.com:1234" |
109 | | - |
110 | | - captured: list[tuple] = [] |
| 41 | +def patch_resolver(monkeypatch): |
| 42 | + """Replace gcp_vm_address + controller client with in-process stubs.""" |
| 43 | + vm_calls: list[str] = [] |
| 44 | + |
| 45 | + def _install(endpoints: dict[str, str]) -> _FakeControllerClient: |
| 46 | + fake = _FakeControllerClient(endpoints) |
| 47 | + |
| 48 | + def _fake_vm_address(name: str, *, port: int = 10002) -> tuple[str, int]: |
| 49 | + vm_calls.append(name) |
| 50 | + return ("127.0.0.1", 65000) |
| 51 | + |
| 52 | + monkeypatch.setattr(resolver_plugin, "gcp_vm_address", _fake_vm_address) |
| 53 | + monkeypatch.setattr( |
| 54 | + resolver_plugin, |
| 55 | + "ControllerServiceClientSync", |
| 56 | + lambda address: fake, |
| 57 | + ) |
| 58 | + return fake |
111 | 59 |
|
112 | | - def _fake_vm_address(name: str, provider: str) -> tuple[str, int]: |
113 | | - captured.append((name, provider)) |
114 | | - # Direct test traffic at the in-process stub rather than GCP. |
115 | | - return ("127.0.0.1", controller_port) |
| 60 | + _install.vm_calls = vm_calls # type: ignore[attr-defined] |
| 61 | + return _install |
116 | 62 |
|
117 | | - # The plugin binds vm_address as a module-global at import time; patch |
118 | | - # it where it's looked up. |
119 | | - from iris.client import resolver_plugin |
120 | 63 |
|
121 | | - monkeypatch.setattr(resolver_plugin, "vm_address", _fake_vm_address) |
| 64 | +def test_resolve_iris_returns_endpoint_address(patch_resolver): |
| 65 | + patch_resolver({"/system/x": "host.example.com:1234"}) |
| 66 | + assert resolve("iris://marin?endpoint=/system/x") == ("host.example.com", 1234) |
| 67 | + assert patch_resolver.vm_calls == ["iris-controller-marin"] |
122 | 68 |
|
123 | | - host, port = resolve("iris://marin?endpoint=/system/x") |
124 | | - assert (host, port) == ("host.example.com", 1234) |
125 | | - assert captured == [("iris-controller-marin", "gcp")] |
126 | 69 |
|
| 70 | +def test_resolve_iris_not_found_raises(patch_resolver): |
| 71 | + patch_resolver({}) |
| 72 | + with pytest.raises(KeyError, match="iris endpoint not found"): |
| 73 | + resolve("iris://marin?endpoint=/system/missing") |
127 | 74 |
|
128 | | -def test_resolve_iris_not_found(monkeypatch, stub_controller): |
129 | | - _svc, controller_port = stub_controller |
130 | 75 |
|
131 | | - from iris.client import resolver_plugin |
| 76 | +def test_resolve_iris_requires_endpoint_query(patch_resolver): |
| 77 | + patch_resolver({}) |
| 78 | + with pytest.raises(ValueError, match="requires \\?endpoint="): |
| 79 | + resolve("iris://marin") |
132 | 80 |
|
133 | | - monkeypatch.setattr( |
134 | | - resolver_plugin, |
135 | | - "vm_address", |
136 | | - lambda name, provider: ("127.0.0.1", controller_port), |
137 | | - ) |
138 | 81 |
|
139 | | - with pytest.raises(KeyError, match="iris endpoint not found"): |
140 | | - resolve("iris://marin?endpoint=/system/missing") |
| 82 | +def test_resolve_iris_rejects_port(patch_resolver): |
| 83 | + patch_resolver({}) |
| 84 | + with pytest.raises(ValueError, match="cannot have a port"): |
| 85 | + resolve("iris://marin:9000?endpoint=/x") |
141 | 86 |
|
142 | 87 |
|
143 | 88 | def test_iris_scheme_registered_after_iris_client_import(): |
144 | | - # Sanity check: importing iris.client (done at the top of this module) |
145 | | - # installs the iris:// handler in rigging.resolver's registry. |
146 | | - assert "iris" in resolver_module._HANDLERS |
| 89 | + assert is_registered("iris") |
0 commit comments