diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bf4603fe4..b934742a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,9 +44,11 @@ jobs: - name: Install tiled run: uv sync --all-extras - - name: Start LDAP service in container. - shell: bash -l {0} - run: source continuous_integration/scripts/start_LDAP.sh + # TODO Find a new image to use. + # https://github.com/bluesky/tiled/issues/1109 + # - name: Start LDAP service in container. + # shell: bash -l {0} + # run: source continuous_integration/scripts/start_LDAP.sh - name: Download SQLite example data. shell: bash -l {0} @@ -56,6 +58,10 @@ jobs: shell: bash -l {0} run: source continuous_integration/scripts/start_postgres.sh + - name: Start Redis service in container. + shell: bash -l {0} + run: source continuous_integration/scripts/start_redis.sh + - name: Ensure example data is migrated to current catalog database schema. # The example data is expected to be kept up to date to the latest Tiled @@ -75,10 +81,13 @@ jobs: uv run coverage report -m uv run coverage xml -o cov.xml env: - # Provide test suite with a PostgreSQL database to use. + # Provide test suite with PostgreSQL and Redis databases to use. TILED_TEST_POSTGRESQL_URI: postgresql://postgres:secret@localhost:5432 - # Opt in to LDAPAuthenticator tests. - TILED_TEST_LDAP: 1 + TILED_TEST_REIDS_URI: redis://localhost:6379 + # TODO Reinstate after finding a new image to use + # https://github.com/bluesky/tiled/issues/1109 + # # Opt in to LDAPAuthenticator tests. + # TILED_TEST_LDAP: 1 - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 diff --git a/.gitignore b/.gitignore index bd1a455ea..bb11e147c 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ pixi.lock # uv environments uv.lock .venv +# pixi environments +.pixi/* +!.pixi/config.toml diff --git a/continuous_integration/scripts/start_redis.sh b/continuous_integration/scripts/start_redis.sh new file mode 100644 index 000000000..b8f830952 --- /dev/null +++ b/continuous_integration/scripts/start_redis.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set -e + +docker run -d --rm --name tiled-test-redis -p 6379:6379 docker.io/redis:7-alpine +docker ps diff --git a/locust/README.md b/locust/README.md index 36b707ee0..3f4350944 100644 --- a/locust/README.md +++ b/locust/README.md @@ -1,35 +1,107 @@ # Tiled Load Testing with Locust -Simple load testing for Tiled using the `reader.py` file. +Load testing for Tiled using Locust. Two test files are available: +- `reader.py` - Tests HTTP read operations and search endpoints +- `streaming.py` - Tests streaming data writes and WebSocket delivery latency ## Quick Start ```bash -# Install dependencies (dev environment includes locust) -pixi install -e dev +# Install dependencies (locust should already be available in the environment) +# If not installed, add it to your requirements or install with: +# uv add locust ``` +## Starting Test Server + +Before running locust tests, start a Tiled server: + +```bash +# Basic server (works for most tests) +uv run tiled serve catalog \ + --host 0.0.0.0 \ + --port 8000 \ + --api-key secret \ + --temp \ + --init +``` + +For streaming tests with Redis cache (optional): +```bash +# Start Redis first +redis-server + +# Start Tiled server with Redis cache +uv run tiled serve catalog \ + --host 0.0.0.0 \ + --port 8000 \ + --api-key secret \ + --cache "redis://localhost:6379" \ + --cache-ttl 60 \ + --temp \ + --init +``` + +This creates a temporary catalog with: +- API key authentication (key: "secret") +- Temporary writable storage (automatically cleaned up) +- Optional Redis cache for enhanced streaming performance +- Server running on http://localhost:8000 + +## Reading Performance Tests (`reader.py`) + +Tests various HTTP endpoints for reading data, metadata, and search operations. + ### Examples Run with default localhost server (uses default API key 'secret'): ```bash -pixi run -e dev locust -f reader.py --host http://localhost:8000 +uv run locust -f reader.py --headless -u 100 -r 10 -t 60s --host http://localhost:8000 ``` Run with custom API key: ```bash -pixi run -e dev locust -f reader.py --host http://localhost:8000 --api-key your-api-key +uv run locust -f reader.py --headless -u 100 -r 10 -t 60s --host http://localhost:8000 --api-key your-api-key ``` Run with custom container name (defaults to locust_testing): ```bash -pixi run -e dev locust -f reader.py --host http://localhost:8000 --container-name my_test_container +uv run locust -f reader.py --headless -u 100 -r 10 -t 60s --host http://localhost:8000 --container-name my_test_container +``` + +## Streaming Performance Tests (`streaming.py`) + +Tests streaming data writes and WebSocket delivery with end-to-end latency measurement. + +**Note:** The `--node-name` parameter is required for streaming tests to avoid conflicts when multiple test runs create nodes with the same name. + +### Examples +Run with required node name: +```bash +uv run locust -f streaming.py --headless -u 10 -r 2 -t 120s --host http://localhost:8000 --node-name my_test_stream ``` -## Headless Mode -Run without the web interface: +Run with custom API key: ```bash -pixi run -e dev locust -f reader.py --headless -u 100 -r 10 -t 60s +uv run locust -f streaming.py --headless -u 10 -r 2 -t 120s --host http://localhost:8000 --api-key your-api-key --node-name my_test_stream ``` -- `-u 100`: 100 concurrent users -- `-r 10`: Spawn 10 users per second -- `-t 60s`: Run for 60 seconds + +Control user types with environment variables: +```bash +# 2 writers for every 1 streaming reader +WRITER_WEIGHT=2 STREAMING_WEIGHT=1 uv run locust -f streaming.py --headless -u 10 -r 2 -t 120s --host http://localhost:8000 --node-name my_test_stream +``` + +### Streaming Test Components +- **WriterUser**: Writes timestamped array data to streaming nodes +- **StreamingUser**: Connects via WebSocket to measure write-to-delivery latency + +## Parameters +- `-u N`: N concurrent users +- `-r N`: Spawn N users per second +- `-t Ns`: Run for N seconds +- `--headless`: Run without web interface (required for automation) + +## Notes +- All examples use `--headless` mode for reliable automation +- For streaming tests, `--node-name` is required to avoid conflicts +- Use environment variables `WRITER_WEIGHT` and `STREAMING_WEIGHT` to control user distribution diff --git a/locust/streaming.py b/locust/streaming.py new file mode 100644 index 000000000..7c3be8237 --- /dev/null +++ b/locust/streaming.py @@ -0,0 +1,238 @@ +import json +import logging +import os +import threading +import time +from urllib.parse import urlparse + +import msgpack +import numpy as np +import websocket + +from locust import HttpUser, between, events, task + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@events.init_command_line_parser.add_listener +def _(parser): + parser.add_argument( + "--api-key", + type=str, + default="secret", + help="API key for Tiled authentication (default: secret)", + ) + parser.add_argument( + "--node-name", + type=str, + required=True, + help="Node name for streaming test (required)", + ) + + +@events.init.add_listener +def on_locust_init(environment, **kwargs): + if environment.host is None: + raise ValueError( + "Host must be specified with --host argument, or through the web-ui." + ) + + # Create the streaming node once for all users + create_streaming_node( + environment.host, + environment.parsed_options.api_key, + environment.parsed_options.node_name, + ) + + +def create_streaming_node(host, api_key, node_name): + """Create a streaming array node using Tiled client""" + from tiled.client import from_uri + + # Connect to Tiled server using client + client = from_uri(host, api_key=api_key) + + # Create initial streaming array + arr = np.full(5, 0.0, dtype=np.float64) # Initial array with zeros + client.write_array(arr, key=node_name) + + logger.info(f"Created streaming node: {node_name}") + client.logout() + + +class WriterUser(HttpUser): + """User that writes streaming data to a Tiled node""" + + wait_time = between(0.1, 0.2) # Wait 0.1-0.2 seconds between writes + weight = int(os.getenv("WRITER_WEIGHT", 1)) + + def on_start(self): + """Initialize user state""" + self.node_name = self.environment.parsed_options.node_name + self.message_count = 0 + self.api_key = self.environment.parsed_options.api_key + + # Set authentication header + self.client.headers.update({"Authorization": f"Apikey {self.api_key}"}) + + @task(10) # Run 10x as often as cleanup + def write_data(self): + """Write streaming data to the node""" + # Create data with current timestamp as all values + current_time = time.time() + data = np.full(5, current_time, dtype=np.float64) + binary_data = data.tobytes() + + # Post data to the streaming endpoint + response = self.client.put( + f"/api/v1/array/full/{self.node_name}", + data=binary_data, + headers={"Content-Type": "application/octet-stream"}, + ) + + # Log status + if response.status_code == 200: + logger.debug(f"Wrote message {self.message_count} to node {self.node_name}") + self.message_count += 1 + else: + logger.error( + f"Failed to write message {self.message_count}: {response.status_code} - {response.text}" + ) + + @task(1) + def cleanup(self): + """Periodically cleanup the stream""" + if self.message_count > 50: + # Close the stream + response = self.client.delete(f"/api/v1/stream/close/{self.node_name}") + if response.status_code == 200: + logger.info(f"Closed stream for node {self.node_name}") + + # Reset message count (node persists for other users) + self.message_count = 0 + + +class StreamingUser(HttpUser): + """User that connects to websocket stream and measures latency""" + + wait_time = between(1, 2) + weight = int(os.getenv("STREAMING_WEIGHT", 1)) + + def on_start(self): + """Connect to the streaming endpoint""" + self.node_name = self.environment.parsed_options.node_name + self.api_key = self.environment.parsed_options.api_key + self.envelope_format = "msgpack" # Use msgpack for efficiency + self.ws = None + self.connected = False + + # Set up authentication for HTTP requests + self.client.headers.update({"Authorization": f"Apikey {self.api_key}"}) + + self._connect_websocket() + + def _connect_websocket(self): + """Connect to the websocket stream""" + try: + # Parse host to get websocket URL + parsed = urlparse(self.host) + ws_scheme = "wss" if parsed.scheme == "https" else "ws" + host = f"{ws_scheme}://{parsed.netloc}" + + ws_url = f"{host}/api/v1/stream/single/{self.node_name}?envelope_format={self.envelope_format}&start=0" + + # Create websocket connection + self.ws = websocket.WebSocketApp( + ws_url, + header=[ + f"Authorization: Apikey {self.api_key}" + ], # Proper Apikey format for websockets + on_open=self._on_open, + on_message=self._on_message, + on_error=self._on_error, + on_close=self._on_close, + ) + + # Start websocket in background thread + self.ws_thread = threading.Thread(target=self.ws.run_forever) + self.ws_thread.daemon = True + self.ws_thread.start() + + # Wait a bit for connection + time.sleep(0.5) + + except Exception as e: + logger.error(f"Failed to connect websocket: {e}") + + def _on_open(self, ws): + """Websocket connection opened""" + self.connected = True + logger.info(f"WebSocket connected to {self.node_name}") + + def _on_message(self, ws, message): + """Process websocket messages and measure latency""" + try: + received_time = time.time() + + if isinstance(message, bytes): + data = msgpack.unpackb(message) + else: + data = json.loads(message) + + # Extract timestamp from the payload (first element of the array) + payload = data.get("payload") + if payload and len(payload) > 0: + # Convert bytes back to numpy array to get the timestamp + payload_array = np.frombuffer(payload, dtype=np.float64) + if len(payload_array) > 0: + write_time = payload_array[0] + latency_ms = (received_time - write_time) * 1000 + + logger.debug( + f"WS latency (sequence {data.get('sequence', 'N/A')}): {latency_ms:.1f}ms" + ) + + # Report to Locust + events.request.fire( + request_type="WS", + name="write_to_websocket_delivery", + response_time=latency_ms, + response_length=len(message), + exception=None, + ) + + except Exception as e: + logger.error(f"Error processing message: {e}") + events.request.fire( + request_type="WS", + name="write_to_websocket_delivery", + response_time=0, + response_length=0, + exception=e, + ) + + def _on_error(self, ws, error): + """Websocket error occurred""" + logger.error(f"WebSocket error: {error}") + self.connected = False + + def _on_close(self, ws, close_status_code, close_msg): + """Websocket connection closed""" + logger.info(f"WebSocket closed: {close_status_code} - {close_msg}") + self.connected = False + + @task + def keep_alive(self): + """Dummy task to keep the user active while listening for messages""" + if not self.connected and self.ws: + # Try to reconnect if disconnected + logger.info("Attempting to reconnect WebSocket...") + self._connect_websocket() + + def on_stop(self): + """Clean up websocket connection""" + if self.ws: + self.ws.close() + logger.info("WebSocket connection closed") diff --git a/pixi.toml b/pixi.toml index 678544e1f..fceeff68d 100644 --- a/pixi.toml +++ b/pixi.toml @@ -35,6 +35,7 @@ entrypoints = "*" rich = "*" stamina = "*" watchfiles = "*" +websockets = "*" [feature.compression.dependencies] # python-blosc2 = "*" # not available on Python < 3.11 @@ -98,6 +99,7 @@ prometheus_client = "*" python-dateutil = "*" python-jose = "*" python-multipart = "*" +redis-py = "*" sqlalchemy = ">=2" starlette = ">=0.38.0" uvicorn = "*" diff --git a/pyproject.toml b/pyproject.toml index dc4df63b6..83f839ebc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ all = [ "python-dateutil", "python-jose[cryptography]", "python-multipart", + "redis", "rich", "sparse >=0.15.5", "sqlalchemy[asyncio] >=2", @@ -126,6 +127,7 @@ client = [ "rich", "sparse >=0.15.5", "stamina", + "websockets", "watchfiles", "xarray", "zstandard", @@ -183,6 +185,7 @@ minimal-client = [ "rich", "stamina", "watchfiles", + "websockets", ] # These are the requirements needed for basic server functionality. minimal-server = [ @@ -203,6 +206,7 @@ minimal-server = [ "python-dateutil", "python-jose[cryptography]", "python-multipart", + "redis", "sqlalchemy[asyncio] >=2", "starlette >=0.38.0", "uvicorn[standard]", @@ -248,10 +252,12 @@ server = [ "python-jose[cryptography]", "python-multipart", "sparse >=0.15.5", + "redis", "sqlalchemy[asyncio] >=2", "starlette >=0.38.0", "tifffile", "uvicorn[standard]", + "websockets", "xarray", "zarr", "zstandard", diff --git a/tiled/_tests/conftest.py b/tiled/_tests/conftest.py index 6d206f7f2..0d3a5b4cb 100644 --- a/tiled/_tests/conftest.py +++ b/tiled/_tests/conftest.py @@ -293,3 +293,17 @@ def url_limit(request: pytest.FixtureRequest): yield # Then restore the original value. BaseClient.URL_CHARACTER_LIMIT = PREVIOUS_LIMIT + + +@pytest.fixture +def redis_uri(): + if uri := os.getenv("TILED_TEST_REDIS"): + import redis + + client = redis.from_url(uri, socket_timeout=10, socket_connect_timeout=30) + # Delete all keys from the current database before and after test. + client.flushdb() + yield uri + client.flushdb() + else: + raise pytest.skip("No TILED_TEST_REDIS configured") diff --git a/tiled/_tests/test_catalog.py b/tiled/_tests/test_catalog.py index 7e8d829c3..200c7749c 100644 --- a/tiled/_tests/test_catalog.py +++ b/tiled/_tests/test_catalog.py @@ -277,6 +277,8 @@ async def test_write_dataframe_external_direct(a, tmpdir): @pytest.mark.asyncio async def test_write_array_internal_direct(a, tmpdir): + from ..media_type_registration import default_deserialization_registry + arr = numpy.ones((5, 3)) ad = ArrayAdapter.from_array(arr) structure = ad.structure() @@ -293,7 +295,12 @@ async def test_write_array_internal_direct(a, tmpdir): ], ) x = await a.lookup_adapter(["x"]) - await x.write(arr) + + media_type = "application/octet-stream" + body = arr.tobytes() + deserializer = default_deserialization_registry.dispatch("array", media_type) + await x.write(media_type, deserializer, x, body) + val = await x.read() assert numpy.array_equal(val, arr) diff --git a/tiled/_tests/test_subscription.py b/tiled/_tests/test_subscription.py new file mode 100644 index 000000000..efaa2da94 --- /dev/null +++ b/tiled/_tests/test_subscription.py @@ -0,0 +1,249 @@ +import sys +import threading + +import numpy as np +import pytest +from starlette.testclient import WebSocketDenialResponse + +from ..catalog import from_uri +from ..client import Context, from_context +from ..client.stream import Subscription +from ..server.app import build_app + +pytestmark = pytest.mark.skipif( + sys.platform == "win32", reason="Requires Redis service" +) + + +@pytest.fixture(scope="function") +def tiled_websocket_context(tmpdir, redis_uri): + """Fixture that provides a Tiled context with websocket support.""" + + tree = from_uri( + "sqlite:///:memory:", + writable_storage=[ + f"file://localhost{str(tmpdir / 'data')}", + f"duckdb:///{tmpdir / 'data.duckdb'}", + ], + readable_storage=None, + init_if_not_exists=True, + cache_settings={ + "uri": redis_uri, + "ttl": 60, + "socket_timeout": 10.0, + "socket_connect_timeout": 10.0, + }, + ) + + app = build_app( + tree, + authentication={"single_user_api_key": "secret"}, + ) + + with Context.from_app(app) as context: + yield context + + +def test_subscribe_immediately_after_creation_websockets(tiled_websocket_context): + context = tiled_websocket_context + client = from_context(context) + + # Create streaming array node using Tiled client + arr = np.arange(10) + streaming_node = client.write_array(arr, key="test_stream_immediate") + + # Set up subscription using the Subscription class + received = [] + received_event = threading.Event() + + def callback(subscription, data): + """Callback to collect received messages.""" + received.append(data) + if len(received) >= 3: + received_event.set() + + # Create subscription for the streaming node + subscription = Subscription( + context=context, + segments=["test_stream_immediate"], + ) + subscription.add_callback(callback) + + # Start the subscription + subscription.start() + + # Write updates using Tiled client + for i in range(1, 4): + new_arr = np.arange(10) + i + streaming_node.write(new_arr) + + # Wait for all messages to be received + assert received_event.wait(timeout=5.0), "Timeout waiting for messages" + + # Verify all updates received in order + assert len(received) == 3 + + # Check that we received messages with the expected data + for i, msg in enumerate(received): + assert "timestamp" in msg + assert "payload" in msg + assert msg["shape"] == [10] + + # Verify payload contains the expected array data + payload_array = np.frombuffer(msg["payload"], dtype=np.int64) + expected_array = np.arange(10) + (i + 1) + np.testing.assert_array_equal(payload_array, expected_array) + + # Clean up the subscription + subscription.stop() + + +def test_websocket_connection_to_non_existent_node_subscription( + tiled_websocket_context, +): + """Test subscription to non-existent node raises appropriate error.""" + context = tiled_websocket_context + + non_existent_node_id = "definitely_non_existent_websocket_node_99999999" + + # Create subscription for non-existent node + subscription = Subscription( + context=context, + segments=[non_existent_node_id], + ) + + # Attempting to start should raise WebSocketDenialResponse + with pytest.raises(WebSocketDenialResponse): + subscription.start() + + +def test_subscribe_after_first_update_subscription(tiled_websocket_context): + """Client that subscribes after first update sees only subsequent updates.""" + context = tiled_websocket_context + client = from_context(context) + + # Create streaming array node using Tiled client + arr = np.arange(10) + streaming_node = client.write_array(arr, key="test_stream_after_update") + + # Write first update before subscribing + first_update = np.arange(10) + 1 + streaming_node.write(first_update) + + # Set up subscription using the Subscription class (no start parameter = only new updates) + received = [] + received_event = threading.Event() + + def callback(subscription, data): + """Callback to collect received messages.""" + received.append(data) + if len(received) >= 2: + received_event.set() + + # Create subscription for the streaming node + subscription = Subscription( + context=context, + segments=["test_stream_after_update"], + ) + subscription.add_callback(callback) + + # Start the subscription + subscription.start() + + # Write more updates + for i in range(2, 4): + new_arr = np.arange(10) + i + streaming_node.write(new_arr) + + # Wait for messages to be received + assert received_event.wait(timeout=5.0), "Timeout waiting for messages" + + # Should only receive the 2 new updates (not the first one) + assert len(received) == 2 + + # Check that we received messages with the expected data + for i, msg in enumerate(received): + assert "timestamp" in msg + assert "payload" in msg + assert msg["shape"] == [10] + + # Verify payload contains the expected array data + payload_array = np.frombuffer(msg["payload"], dtype=np.int64) + expected_array = np.arange(10) + (i + 2) # i+2 because we start from update 2 + np.testing.assert_array_equal(payload_array, expected_array) + + # Clean up the subscription + subscription.stop() + + +def test_subscribe_after_first_update_from_beginning_subscription( + tiled_websocket_context, +): + """Client that subscribes after first update but requests from start=0 sees all updates.""" + context = tiled_websocket_context + client = from_context(context) + + # Use unique key to avoid interference from other tests + import uuid + + unique_key = f"test_stream_from_beginning_{uuid.uuid4().hex[:8]}" + + # Create streaming array node using Tiled client + arr = np.arange(10) + streaming_node = client.write_array(arr, key=unique_key) + + # Write first update before subscribing + first_update = np.arange(10) + 1 + streaming_node.write(first_update) + + # Set up subscription using the Subscription class with start=0 + received = [] + received_event = threading.Event() + + def callback(subscription, data): + """Callback to collect received messages.""" + received.append(data) + if len(received) >= 4: # initial + first update + 2 new updates + received_event.set() + + # Create subscription for the streaming node with start=0 + subscription = Subscription(context=context, segments=[unique_key], start=0) + subscription.add_callback(callback) + + # Start the subscription + subscription.start() + + # Write more updates + for i in range(2, 4): + new_arr = np.arange(10) + i + streaming_node.write(new_arr) + + # Wait for all messages to be received + assert received_event.wait(timeout=5.0), "Timeout waiting for messages" + + # Should receive: initial array + first update + 2 new updates = 4 total + assert len(received) == 4 + + # Check the messages in order + # First message: initial array creation + msg = received[0] + assert "timestamp" in msg + assert "payload" in msg + assert msg["shape"] == [10] + payload_array = np.frombuffer(msg["payload"], dtype=np.int64) + expected_array = np.arange(10) # Initial array + np.testing.assert_array_equal(payload_array, expected_array) + + # Remaining messages: updates 1, 2, 3 + for i, msg in enumerate(received[1:], 1): + assert "timestamp" in msg + assert "payload" in msg + assert msg["shape"] == [10] + + # Verify payload contains the expected array data + payload_array = np.frombuffer(msg["payload"], dtype=np.int64) + expected_array = np.arange(10) + i + np.testing.assert_array_equal(payload_array, expected_array) + + # Clean up the subscription + subscription.stop() diff --git a/tiled/_tests/test_websockets.py b/tiled/_tests/test_websockets.py new file mode 100644 index 000000000..073f1e544 --- /dev/null +++ b/tiled/_tests/test_websockets.py @@ -0,0 +1,301 @@ +import sys + +import msgpack +import numpy as np +import pytest + +from ..catalog import from_uri +from ..client import Context, from_context +from ..server.app import build_app + +pytestmark = pytest.mark.skipif( + sys.platform == "win32", reason="Requires Redis service" +) + + +@pytest.fixture +def tiled_websocket_context(tmpdir, redis_uri): + """Fixture that provides a Tiled context with websocket support.""" + tree = from_uri( + "sqlite:///:memory:", + writable_storage=[ + f"file://localhost{str(tmpdir / 'data')}", + f"duckdb:///{tmpdir / 'data.duckdb'}", + ], + readable_storage=None, + init_if_not_exists=True, + cache_settings={ + "uri": redis_uri, + "ttl": 60, + "socket_timeout": 10.0, + "socket_connect_timeout": 10.0, + }, + ) + + app = build_app( + tree, + authentication={"single_user_api_key": "secret"}, + ) + + with Context.from_app(app) as context: + yield context + + +def test_subscribe_immediately_after_creation_websockets(tiled_websocket_context): + context = tiled_websocket_context + client = from_context(context) + test_client = context.http_client + + # Create streaming array node using Tiled client + arr = np.arange(10) + streaming_node = client.write_array(arr, key="test_stream_immediate") + + # Connect WebSocket using TestClient with msgpack format and authorization + with test_client.websocket_connect( + "/api/v1/stream/single/test_stream_immediate?envelope_format=msgpack", + headers={"Authorization": "Apikey secret"}, + ) as websocket: + # Write updates using Tiled client + for i in range(1, 4): + new_arr = np.arange(10) + i + streaming_node.write(new_arr) + + # Receive all updates + received = [] + for _ in range(3): + msg_bytes = websocket.receive_bytes() + msg = msgpack.unpackb(msg_bytes) + received.append(msg) + + # Verify all updates received in order + assert len(received) == 3 + + # Check that we received messages with the expected data + for i, msg in enumerate(received): + assert "timestamp" in msg + assert "payload" in msg + assert msg["shape"] == [10] + + # Verify payload contains the expected array data + payload_array = np.frombuffer(msg["payload"], dtype=np.int64) + expected_array = np.arange(10) + (i + 1) + np.testing.assert_array_equal(payload_array, expected_array) + + +def test_websocket_connection_to_non_existent_node(tiled_websocket_context): + """Test websocket connection to non-existent node returns 404.""" + context = tiled_websocket_context + test_client = context.http_client + + non_existent_node_id = "definitely_non_existent_websocket_node_99999999" + + # Try to connect to websocket for non-existent node + # This should result in an HTTP 404 response during the handshake + response = test_client.get( + f"/api/v1/stream/single/{non_existent_node_id}", + headers={"Authorization": "Apikey secret"}, + ) + assert response.status_code == 404 + + +def test_subscribe_after_first_update_websockets(tiled_websocket_context): + """Client that subscribes after first update sees only subsequent updates.""" + context = tiled_websocket_context + client = from_context(context) + test_client = context.http_client + + # Create streaming array node using Tiled client + arr = np.arange(10) + streaming_node = client.write_array(arr, key="test_stream_after_update") + + # Write first update before subscribing + first_update = np.arange(10) + 1 + streaming_node.write(first_update) + + # Connect WebSocket after first update + with test_client.websocket_connect( + "/api/v1/stream/single/test_stream_after_update?envelope_format=msgpack", + headers={"Authorization": "Apikey secret"}, + ) as websocket: + # Write more updates + for i in range(2, 4): + new_arr = np.arange(10) + i + streaming_node.write(new_arr) + + # Should only receive the 2 new updates + received = [] + for _ in range(2): + msg_bytes = websocket.receive_bytes() + msg = msgpack.unpackb(msg_bytes) + received.append(msg) + + # Verify only new updates received + assert len(received) == 2 + + # Check that we received messages with the expected data + for i, msg in enumerate(received): + assert "timestamp" in msg + assert "payload" in msg + assert msg["shape"] == [10] + + # Verify payload contains the expected array data + payload_array = np.frombuffer(msg["payload"], dtype=np.int64) + expected_array = np.arange(10) + ( + i + 2 + ) # i+2 because we start from update 2 + np.testing.assert_array_equal(payload_array, expected_array) + + +def test_subscribe_after_first_update_from_beginning_websockets( + tiled_websocket_context, +): + """Client that subscribes after first update but requests from seq_num=0 sees all updates. + + Note: seq_num starts at 1 for the first data point. seq_num=0 means "start as far back + as you have" (similar to Bluesky social) + """ + context = tiled_websocket_context + client = from_context(context) + test_client = context.http_client + + # Create streaming array node using Tiled client + arr = np.arange(10) + streaming_node = client.write_array(arr, key="test_stream_from_beginning") + + # Write first update before subscribing + first_update = np.arange(10) + 1 + streaming_node.write(first_update) + + # Connect WebSocket requesting from beginning + with test_client.websocket_connect( + "/api/v1/stream/single/test_stream_from_beginning?envelope_format=msgpack&start=0", + headers={"Authorization": "Apikey secret"}, + ) as websocket: + # First, should receive the initial array creation + historical_msg_bytes = websocket.receive_bytes() + historical_msg = msgpack.unpackb(historical_msg_bytes) + assert "timestamp" in historical_msg + assert "payload" in historical_msg + assert historical_msg["shape"] == [10] + + # Verify historical payload (initial array creation - sequence 0) + historical_payload = np.frombuffer(historical_msg["payload"], dtype=np.int64) + expected_historical = np.arange(10) # Initial array + np.testing.assert_array_equal(historical_payload, expected_historical) + + # Next, should receive the first update (sequence 1) + first_update_bytes = websocket.receive_bytes() + first_update_msg = msgpack.unpackb(first_update_bytes) + first_update_payload = np.frombuffer( + first_update_msg["payload"], dtype=np.int64 + ) + expected_first_update = np.arange(10) + 1 + np.testing.assert_array_equal(first_update_payload, expected_first_update) + + # Write more updates + for i in range(2, 4): + new_arr = np.arange(10) + i + streaming_node.write(new_arr) + + # Receive the new updates + for i in range(2, 4): + msg_bytes = websocket.receive_bytes() + msg = msgpack.unpackb(msg_bytes) + assert "timestamp" in msg + assert "payload" in msg + assert msg["shape"] == [10] + + # Verify payload contains the expected array data + payload_array = np.frombuffer(msg["payload"], dtype=np.int64) + expected_array = np.arange(10) + i + np.testing.assert_array_equal(payload_array, expected_array) + + +def test_close_stream_success(tiled_websocket_context): + """Test successful close of an existing stream.""" + context = tiled_websocket_context + client = from_context(context) + test_client = context.http_client + + # Create a streaming array node + arr = np.arange(10) + streaming_node = client.write_array(arr, key="test_close_stream") + + # Upload some data + streaming_node.write(np.arange(10) + 1) + + # Add a small delay to ensure the stream is fully established + import time + + time.sleep(0.5) + + # Now close the stream + response = test_client.delete( + "/api/v1/stream/close/test_close_stream", + headers={"Authorization": "Apikey secret"}, + ) + assert response.status_code == 200 + + # Now close the stream again + response = test_client.delete( + "/api/v1/stream/close/test_close_stream", + headers={"Authorization": "Apikey secret"}, + ) + + # TODO: I think the test is correct and the server should be updated. + assert response.status_code == 404 + + +def test_close_stream_not_found(tiled_websocket_context): + """Test close endpoint returns 404 for non-existent node.""" + context = tiled_websocket_context + test_client = context.http_client + + non_existent_node_id = "definitely_non_existent_node_99999999" + + response = test_client.delete( + f"/api/v1/stream/close/{non_existent_node_id}", + headers={"Authorization": "Apikey secret"}, + ) + assert response.status_code == 404 + + +def test_websocket_connection_wrong_api_key(tiled_websocket_context): + """Test websocket connection with wrong API key fails with 401.""" + from starlette.testclient import WebSocketDenialResponse + + context = tiled_websocket_context + client = from_context(context) + test_client = context.http_client + + # Create streaming array node using correct key + arr = np.arange(10) + client.write_array(arr, key="test_auth_websocket") + + # Try to connect to websocket with wrong API key + with pytest.raises(WebSocketDenialResponse) as exc_info: + with test_client.websocket_connect( + "/api/v1/stream/single/test_auth_websocket?envelope_format=msgpack", + headers={"Authorization": "Apikey wrong_key"}, + ): + pass + + assert exc_info.value.status_code == 401 + + +def test_close_stream_wrong_api_key(tiled_websocket_context): + """Test close endpoint returns 403 with wrong API key.""" + context = tiled_websocket_context + client = from_context(context) + test_client = context.http_client + + # Create streaming array node using correct key + arr = np.arange(10) + client.write_array(arr, key="test_auth_close") + + # Try to close stream with wrong API key + response = test_client.delete( + "/api/v1/stream/close/test_auth_close", + headers={"Authorization": "Apikey wrong_key"}, + ) + assert response.status_code == 401 diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index e6d685778..d29495cc5 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -1,3 +1,4 @@ +import asyncio import collections import copy import dataclasses @@ -11,13 +12,15 @@ import sys import uuid from contextlib import closing +from datetime import datetime from functools import partial, reduce from pathlib import Path from typing import Callable, Dict, List, Optional, Union, cast from urllib.parse import urlparse import anyio -from fastapi import HTTPException +import orjson +from fastapi import HTTPException, WebSocketDisconnect from sqlalchemy import ( and_, delete, @@ -169,6 +172,7 @@ def __init__( writable_storage=None, readable_storage=None, adapters_by_mimetype=None, + cache_settings=None, key_maker=lambda: str(uuid.uuid4()), storage_pool_size=None, storage_max_overflow=None, @@ -214,6 +218,7 @@ def __init__( adapters_by_mimetype, DEFAULT_ADAPTERS_BY_MIMETYPE ) self.adapters_by_mimetype = merged_adapters_by_mimetype + self.cache_settings = cache_settings def session(self): "Convenience method for constructing an AsyncSession context" @@ -226,6 +231,36 @@ async def execute(self, statement, explain=None): await db.commit() return result + async def startup(self): + if (self.engine.dialect.name == "sqlite") and ( + self.engine.url.database == ":memory:" + or self.engine.url.query.get("mode") == "memory" + ): + # Special-case for in-memory SQLite: Because it is transient we can + # skip over anything related to migrations. + await initialize_database(self.engine) + else: + await check_catalog_database(self.engine) + + cache_client = None + cache_ttl = 0 + if self.cache_settings: + if self.cache_settings["uri"].startswith("redis"): + from redis import asyncio as redis + + socket_timeout = self.cache_settings.get("socket_timeout", 10.0) + socket_connect_timeout = self.cache_settings.get( + "socket_connect_timeout", 10.0 + ) + cache_client = redis.from_url( + self.cache_settings["uri"], + socket_timeout=socket_timeout, + socket_connect_timeout=socket_connect_timeout, + ) + cache_ttl = self.cache_settings.get("ttl", 3600) + self.cache_client = cache_client + self.cache_ttl = cache_ttl + class CatalogNodeAdapter: query_registry = QueryTranslationRegistry() @@ -275,15 +310,7 @@ def metadata(self): return self.node.metadata_ async def startup(self): - if (self.context.engine.dialect.name == "sqlite") and ( - self.context.engine.url.database == ":memory:" - or self.context.engine.url.query.get("mode") == "memory" - ): - # Special-case for in-memory SQLite: Because it is transient we can - # skip over anything related to migrations. - await initialize_database(self.context.engine) - else: - await check_catalog_database(self.context.engine) + await self.context.startup() async def create_mount(self, mount_path: list[str]): statement = node_from_segments(mount_path).with_only_columns(orm.Node.id) @@ -762,6 +789,38 @@ async def create_node( ) ) ).scalar() + if self.context.cache_client: + # Allocate a counter for the new node. + await self.context.cache_client.setnx(f"sequence:{node.id}", 0) + # Notify subscribers of the *parent* node about the new child. + sequence = await self.context.cache_client.incr( + f"sequence:{self.node.id}" + ) + metadata = { + "sequence": sequence, + "timestamp": datetime.now().isoformat(), + "key": key, + "structure_family": structure_family, + "specs": [spec.model_dump() for spec in (specs or [])], + "metadata": metadata, + "data_sources": [d.model_dump() for d in data_sources], + } + + # Cache data in Redis with a TTL, and publish + # a notification about it. + pipeline = self.context.cache_client.pipeline() + pipeline.hset( + f"data:{self.node.id}:{sequence}", + mapping={ + "sequence": sequence, + "metadata": safe_json_dump(metadata), + }, + ) + pipeline.expire( + f"data:{self.node.id}:{sequence}", self.context.cache_ttl + ) + pipeline.publish(f"notify:{self.node.id}", sequence) + await pipeline.execute() return type(self)(self.context, refreshed_node) async def _put_asset(self, db: AsyncSession, asset): @@ -780,7 +839,7 @@ async def _put_asset(self, db: AsyncSession, asset): return asset_id - async def put_data_source(self, data_source): + async def put_data_source(self, data_source, patch): # Obtain and hash the canonical (RFC 8785) representation of # the JSON structure. structure = _prepare_structure( @@ -830,6 +889,28 @@ async def put_data_source(self, data_source): db.add(assoc_orm) await db.commit() + if self.context.cache_client: + sequence = await self.context.cache_client.incr(f"sequence:{self.node.id}") + metadata = { + "sequence": sequence, + "timestamp": datetime.now().isoformat(), + "data_source": data_source.dict(), + "patch": patch.dict() if patch else None, + } + + # Cache data in Redis with a TTL, and publish + # a notification about it. + pipeline = self.context.cache_client.pipeline() + pipeline.hset( + f"data:{self.node.id}:{sequence}", + mapping={ + "sequence": sequence, + "metadata": orjson.dumps(metadata), + }, + ) + pipeline.expire(f"data:{self.node.id}:{sequence}", self.context.cache_ttl) + pipeline.publish(f"notify:{self.node.id}", sequence) + await pipeline.execute() async def revisions(self, offset, limit): async with self.context.session() as db: @@ -1012,6 +1093,117 @@ async def replace_metadata( ) await db.commit() + async def close_stream(self): + # Check the node status. + # ttl returns -2 if the key does not exist. + # ttl returns -1 if the key exists but has no associated expire. + # ttl greater than 0 means that it is marked to expire. + node_ttl = await self.context.cache_client.ttl(f"sequence:{self.node.id}") + if node_ttl > 0: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"Stream for node {self.node.id} is already closed.", + ) + if node_ttl == -2: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"Node {self.node.id} not found.", + ) + + metadata = { + "timestamp": datetime.now().isoformat(), + "end_of_stream": True, + } + # Increment the counter for this node. + sequence = await self.context.cache_client.incr(f"sequence:{self.node.id}") + + # Cache data in Redis with a TTL, and publish + # a notification about it. + pipeline = self.context.cache_client.pipeline() + pipeline.hset( + f"data:{self.node.id}:{sequence}", + mapping={ + "sequence": sequence, + "metadata": orjson.dumps(metadata), + }, + ) + pipeline.expire(f"data:{self.node.id}:{sequence}", self.context.cache_ttl) + pipeline.expire(f"sequence:{self.node.id}", self.context.cache_ttl) + pipeline.publish(f"notify:{self.node.id}", sequence) + await pipeline.execute() + + def make_ws_handler(self, websocket, formatter, uri): + async def handler(sequence: Optional[int] = None): + await websocket.accept() + end_stream = asyncio.Event() + cache_client = self.context.cache_client + + async def stream_data(sequence): + key = f"data:{self.node.id}:{sequence}" + payload_bytes, metadata_bytes = await cache_client.hmget( + key, "payload", "metadata" + ) + if metadata_bytes is None: + # This means that redis ttl has expired for this sequence + return + metadata = orjson.loads(metadata_bytes) + if metadata.get("end_of_stream"): + # This means that the stream is closed by the producer + end_stream.set() + return + metadata["uri"] = uri + if metadata.get("patch"): + s = ",".join( + f"{offset}:{offset+shape}" + for offset, shape in zip( + metadata["patch"]["offset"], metadata["patch"]["shape"] + ) + ) + metadata["uri"] = f"{uri}?slice={s}" + await formatter(websocket, metadata, payload_bytes) + + # Setup buffer + stream_buffer = asyncio.Queue() + + async def buffer_live_events(): + pubsub = cache_client.pubsub() + await pubsub.subscribe(f"notify:{self.node.id}") + try: + async for message in pubsub.listen(): + if message.get("type") == "message": + try: + live_seq = int(message["data"]) + await stream_buffer.put(live_seq) + except Exception as e: + print(f"Error parsing live message: {e}") + except Exception as e: + print(f"Live subscription error: {e}") + finally: + await pubsub.unsubscribe(f"notify:{self.node.id}") + await pubsub.aclose() + + live_task = asyncio.create_task(buffer_live_events()) + + if sequence is not None: + current_seq = await cache_client.get(f"sequence:{self.node.id}") + current_seq = int(current_seq) if current_seq is not None else 0 + print("Replaying old data...") + for s in range(sequence, current_seq + 1): + await stream_data(s) + # New data + try: + while not end_stream.is_set(): + live_seq = await stream_buffer.get() + await stream_data(live_seq) + else: + await websocket.close(code=1000, reason="Producer ended stream") + except WebSocketDisconnect: + print(f"Client disconnected from node {self.node.id}") + finally: + live_task.cancel() + + return handler + class CatalogContainerAdapter(CatalogNodeAdapter): async def keys_range(self, offset, limit): @@ -1088,19 +1280,70 @@ async def read_block(self, *args, **kwargs): (await self.get_adapter()).read_block, *args, **kwargs ) - async def write(self, *args, **kwargs): - return await ensure_awaitable((await self.get_adapter()).write, *args, **kwargs) + async def _stream(self, media_type, entry, body, shape, block=None, offset=None): + sequence = await self.context.cache_client.incr(f"sequence:{self.node.id}") + metadata = { + "timestamp": datetime.now().isoformat(), + "content-type": media_type, + "shape": shape, + "offset": offset, + "block": block, + } + + pipeline = self.context.cache_client.pipeline() + pipeline.hset( + f"data:{self.node.id}:{sequence}", + mapping={ + "sequence": sequence, + "metadata": orjson.dumps(metadata), + "payload": body, # raw user input + }, + ) + pipeline.expire(f"data:{self.node.id}:{sequence}", self.context.cache_ttl) + pipeline.publish(f"notify:{self.node.id}", sequence) + await pipeline.execute() + + async def write(self, media_type, deserializer, entry, body): + shape = entry.structure().shape + if self.context.cache_client: + await self._stream(media_type, entry, body, shape) + if entry.structure_family == "array": + dtype = entry.structure().data_type.to_numpy_dtype() + data = await ensure_awaitable(deserializer, body, dtype, shape) + elif entry.structure_family == "sparse": + data = await ensure_awaitable(deserializer, body) + else: + raise NotImplementedError(entry.structure_family) + return await ensure_awaitable((await self.get_adapter()).write, data) - async def write_block(self, *args, **kwargs): + async def write_block(self, block, media_type, deserializer, entry, body): + from tiled.adapters.array import slice_and_shape_from_block_and_chunks + + _, shape = slice_and_shape_from_block_and_chunks( + block, entry.structure().chunks + ) + if self.context.cache_client: + await self._stream(media_type, entry, body, shape, block=block) + if entry.structure_family == "array": + dtype = entry.structure().data_type.to_numpy_dtype() + data = await ensure_awaitable(deserializer, body, dtype, shape) + elif entry.structure_family == "sparse": + data = await ensure_awaitable(deserializer, body) + else: + raise NotImplementedError(entry.structure_family) return await ensure_awaitable( - (await self.get_adapter()).write_block, *args, **kwargs + (await self.get_adapter()).write_block, data, block ) - async def patch(self, *args, **kwargs): + async def patch(self, shape, offset, extend, media_type, deserializer, entry, body): + if self.context.cache_client: + await self._stream(media_type, entry, body, shape, offset=offset) + dtype = entry.structure().data_type.to_numpy_dtype() + data = await ensure_awaitable(deserializer, body, dtype, shape) # assumes a single DataSource (currently only supporting zarr) async with self.context.session() as db: new_shape_and_chunks = await ensure_awaitable( - (await self.get_adapter()).patch, *args, **kwargs + (await self.get_adapter()).patch, data, offset, extend ) node = await db.get(orm.Node, self.node.id) if len(node.data_sources) != 1: @@ -1490,6 +1733,7 @@ def in_memory( echo=DEFAULT_ECHO, adapters_by_mimetype=None, top_level_access_blob=None, + cache_settings=None, ): if not named_memory: uri = "sqlite:///:memory:" @@ -1507,6 +1751,7 @@ def in_memory( echo=echo, adapters_by_mimetype=adapters_by_mimetype, top_level_access_blob=top_level_access_blob, + cache_settings=cache_settings, ) @@ -1522,6 +1767,7 @@ def from_uri( adapters_by_mimetype=None, top_level_access_blob=None, mount_node: Optional[Union[str, List[str]]] = None, + cache_settings=None, catalog_pool_size=None, storage_pool_size=None, catalog_max_overflow=None, @@ -1592,10 +1838,12 @@ def from_uri( writable_storage, readable_storage, adapters_by_mimetype, + cache_settings, storage_pool_size=storage_pool_size, storage_max_overflow=storage_max_overflow, ) adapter = CatalogContainerAdapter(context, node, mount_path=mount_path) + return adapter diff --git a/tiled/client/base.py b/tiled/client/base.py index 9ff349ff7..8214f4202 100644 --- a/tiled/client/base.py +++ b/tiled/client/base.py @@ -881,5 +881,12 @@ def delete(self, recursive: bool = False, external_only: bool = True) -> None: ) ) + def close_stream(self): + "Declare the end of a stream of writes to this node." + endpoint = self.uri.replace("/metadata/", "/stream/close/", 1) + for attempt in retry_context(): + with attempt: + handle_error(self.context.http_client.delete(endpoint)) + def __dask_tokenize__(self): return (type(self), self.uri) diff --git a/tiled/client/stream.py b/tiled/client/stream.py new file mode 100644 index 000000000..2b7f2001b --- /dev/null +++ b/tiled/client/stream.py @@ -0,0 +1,201 @@ +import inspect +import threading +import weakref +from typing import Callable, List + +import anyio +import httpx +import msgpack +from starlette.testclient import TestClient +from websockets.sync.client import connect + +from tiled.client.context import Context + +Callback = Callable[["Subscription", dict], None] +"A Callback will be called with the Subscription calling it and a dict with the update." + + +class TestClientWebsocketWrapper: + """Wrapper for TestClient websockets.""" + + def __init__(self, http_client, uri: httpx.URL): + self._http_client = http_client + self._uri = uri + self._websocket = None + + def connect(self, api_key: str): + """Connect to the websocket.""" + query_string = self._uri.query.decode() if self._uri.query else "" + path = self._uri.path + ("?" + query_string if query_string else "") + self._websocket = self._http_client.websocket_connect( + path, headers={"Authorization": f"Apikey {api_key}"} + ) + self._websocket.__enter__() + + def recv(self, timeout=None): + """Receive data from websocket with consistent interface.""" + return self._websocket.receive_bytes() + + def close(self): + """Close websocket connection.""" + self._websocket.__exit__(None, None, None) + + +class RegularWebsocketWrapper: + """Wrapper for regular websockets.""" + + def __init__(self, http_client, uri: httpx.URL): + self._http_client = http_client + self._uri = uri + self._websocket = None + + def connect(self, api_key: str): + """Connect to the websocket.""" + self._websocket = connect( + str(self._uri), + additional_headers={"Authorization": f"Apikey {api_key}"}, + ) + + def recv(self, timeout=None): + """Receive data from websocket with consistent interface.""" + return self._websocket.recv(timeout=timeout) + + def close(self): + """Close websocket connection.""" + self._websocket.close() + + +class Subscription: + """ + Subscribe to streaming updates from a node. + + Parameters + ---------- + context : tiled.client.Context + Provides connection to Tiled server + segments : list[str] + Path to node of interest, given as a list of path segments + start : int, optional + By default, the stream begins from the most recent update. Use this parameter + to replay from some earlier update. Use 1 to start from the first item, 0 + to start from as far back as available (which may be later than the first item), + or any positive integer to start from a specific point in the sequence. + """ + + def __init__(self, context: Context, segments: List[str] = None, start: int = None): + segments = segments or ["/"] + self._context = context + self._segments = segments + params = {"envelope_format": "msgpack"} + if start is not None: + params["start"] = start + scheme = "wss" if context.api_uri.scheme == "https" else "ws" + path = "stream/single" + "/".join(f"/{segment}" for segment in segments) + self._uri = httpx.URL( + str(context.api_uri.copy_with(scheme=scheme)) + path, + params=params, + ) + name = f"tiled-subscription-{self._uri}" + self._thread = threading.Thread(target=self._receive, daemon=True, name=name) + self._callbacks = set() + self._close_event = threading.Event() + if isinstance(context.http_client, TestClient): + self._websocket = TestClientWebsocketWrapper(context.http_client, self._uri) + else: + self._websocket = RegularWebsocketWrapper(context.http_client, self._uri) + + @property + def context(self) -> Context: + return self._context + + @property + def segments(self) -> List[str]: + return self._segments + + def add_callback(self, callback: Callback) -> None: + """ + Register a callback to be run when the Subscription receives an update. + + The callback registry only holds a weak reference to the callback. If + no hard references are held elsewhere in the program, the callback will + be silently removed. + + Examples + -------- + + Simply subscribe the print function. + + >>> sub.add_callback(print) + + Subscribe a custom function. + + >>> def f(sub, data): + ... + + >>> sub.add_callback(f) + """ + + def cleanup(ref: weakref.ref) -> None: + # When an object is garbage collected, remove its entry + # from the set of callbacks. + self._callbacks.remove(ref) + + if inspect.ismethod(callback): + # This holds the reference to the method until the object it is + # bound to is garbage collected. + ref = weakref.WeakMethod(callback, cleanup) + else: + ref = weakref.ref(callback, cleanup) + self._callbacks.add(ref) + + def remove_callback(self, callback: Callback) -> None: + """ + Unregister a callback. + """ + self._callbacks.remove(callback) + + def _receive(self) -> None: + "This method is executed on self._thread." + TIMEOUT = 0.1 # seconds + while not self._close_event.is_set(): + try: + data_bytes = self._websocket.recv(timeout=TIMEOUT) + except (TimeoutError, anyio.EndOfStream): + continue + data = msgpack.unpackb(data_bytes) + for ref in self._callbacks: + callback = ref() + if callback is not None: + callback(self, data) + + def start(self) -> None: + "Connect to the websocket and launch a thread to receive and process updates." + if self._close_event.is_set(): + raise RuntimeError("Cannot be restarted once stopped.") + API_KEY_LIFETIME = 30 # seconds + needs_api_key = self.context.server_info.authentication.providers + if needs_api_key: + # Request a short-lived API key to use for authenticating the WS connection. + key_info = self.context.create_api_key( + expires_in=API_KEY_LIFETIME, note="websocket" + ) + api_key = key_info["secret"] + else: + # Use single-user API key. + api_key = self.context.api_key + + # Connect using the websocket wrapper + self._websocket.connect(api_key) + + if needs_api_key: + # The connection is made, so we no longer need the API key. + # TODO: Implement single-use API keys so that revoking is not + # necessary. + self.context.revoke_api_key(key_info["first_eight"]) + self._thread.start() + + def stop(self) -> None: + "Close the websocket connection." + self._close_event.set() + self._websocket.close() + self._thread.join() diff --git a/tiled/commandline/_serve.py b/tiled/commandline/_serve.py index 21f0b9768..825d4b051 100644 --- a/tiled/commandline/_serve.py +++ b/tiled/commandline/_serve.py @@ -329,6 +329,12 @@ def serve_catalog( "By default, a random key is generated at startup and printed." ), ), + cache_uri: Optional[str] = typer.Option( + None, "--cache", help=("Provide cache URI") + ), + cache_ttl: Optional[int] = typer.Option( + None, "--cache-ttl", help=("Provide cache ttl") + ), host: str = typer.Option( "127.0.0.1", help=( @@ -460,11 +466,18 @@ def serve_catalog( err=True, ) + cache_settings = {} + if cache_uri: + cache_settings["uri"] = cache_uri + if cache_ttl: + cache_settings["ttl"] = cache_ttl + tree = from_uri( database, writable_storage=write, readable_storage=read, init_if_not_exists=init, + cache_settings=cache_settings, ) web_app = build_app( tree, diff --git a/tiled/config.py b/tiled/config.py index 5a61da346..d68825238 100644 --- a/tiled/config.py +++ b/tiled/config.py @@ -16,7 +16,7 @@ import jsonschema from .adapters.mapping import MapAdapter -from .catalog import from_uri +from .catalog import from_uri, in_memory from .media_type_registration import ( default_compression_registry, default_deserialization_registry, @@ -147,6 +147,8 @@ def construct_build_app_kwargs( } } args.update(from_server_settings) + if (obj is from_uri) or (obj is in_memory): + args.update({"cache_settings": config.get("streaming_cache")}) tree = obj(**args) else: # Interpret obj as a tree *instance*. @@ -237,6 +239,7 @@ def merge(configs: dict[Path, dict[str, Any]]) -> dict[str, Any]: media_types = defaultdict(dict) specs = [] reject_undeclared_specs_source = None + streaming_cache_source = None file_extensions = {} paths = {} # map each item's path to config file that specified it @@ -314,6 +317,15 @@ def merge(configs: dict[Path, dict[str, Any]]) -> dict[str, Any]: ) reject_undeclared_specs_source = filepath merged["reject_undeclared_specs"] = config["reject_undeclared_specs"] + if "streaming_cache" in config: + if "streaming_cache" in merged: + raise ConfigError( + "'streaming_cache' can only be specified in one file. " + f"It was found in both {streaming_cache_source} and " + f"{filepath}" + ) + streaming_cache_source = filepath + merged["streaming_cache"] = config["streaming_cache"] for item in config.get("trees", []): if item["path"] in paths: msg = "A given path may be only be specified once." diff --git a/tiled/config_schemas/service_configuration.yml b/tiled/config_schemas/service_configuration.yml index 99d9ddd72..5246dbb7a 100644 --- a/tiled/config_schemas/service_configuration.yml +++ b/tiled/config_schemas/service_configuration.yml @@ -73,6 +73,22 @@ properties: args: uri: "catalog.db" ``` + streaming_cache: + description: | + Streaming cache database + type: object + additionalProperties: false + properties: + uri: + type: string + ttl: + type: number + socket_timeout: + type: number + socket_connect_timeout: + type: number + required: + - uri media_types: type: object additionalProperties: true diff --git a/tiled/server/app.py b/tiled/server/app.py index 0a4099688..e5f63db74 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -17,7 +17,7 @@ import packaging.version import yaml from asgi_correlation_id import CorrelationIdMiddleware, correlation_id -from fastapi import Depends, FastAPI, HTTPException, Request, Response +from fastapi import FastAPI, HTTPException, Request, Response from fastapi.exception_handlers import http_exception_handler from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi @@ -37,7 +37,6 @@ ) from tiled.query_registration import QueryRegistry, default_query_registry -from tiled.server.authentication import move_api_key from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator from ..catalog.adapter import WouldDeleteData @@ -234,7 +233,7 @@ async def lifespan(app: FastAPI): yield await shutdown_event() - app = FastAPI(lifespan=lifespan, dependencies=[Depends(move_api_key)]) + app = FastAPI(lifespan=lifespan) # Healthcheck for deployment to containerized systems, needs to preempt other responses. # Standardized for Kubernetes, but also used by other systems. diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index 696de2a95..3e8b0e734 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -4,17 +4,19 @@ import warnings from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Optional, Sequence, Set +from typing import Annotated, Any, Optional, Sequence, Set from fastapi import ( APIRouter, Depends, Form, + Header, HTTPException, Query, Request, Response, Security, + WebSocket, ) from fastapi.security import ( OAuth2PasswordBearer, @@ -264,7 +266,45 @@ async def get_current_access_tags( return None +def get_api_key_websocket( + authorization: Annotated[Optional[str], Header()] = None, +): + if authorization is None: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="An API key must be passed in the Authorization header", + ) + scheme, api_key = get_authorization_scheme_param(authorization) + if scheme.lower() != "apikey": + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Authorization header must be formatted like 'Apikey SECRET'", + ) + return api_key + + +async def get_current_access_tags_websocket( + websocket: WebSocket, + api_key: Optional[str] = Depends(get_api_key_websocket), + db: Optional[AsyncSession] = Depends(get_database_session), +) -> Optional[Set[str]]: + if api_key is not None: + return await get_access_tags_from_api_key( + api_key, websocket.app.state.authenticated, db + ) + else: + # Limits on access tags only available via API key auth + return None + + async def move_api_key(request: Request, api_key: Optional[str] = Depends(get_api_key)): + """ + Move API key from query parameter to cookie. + + When a URL with an API key in the query parameter is opened in a browser, + the API key is set as a cookie so that subsequent requests from the browser + are authenticated. (This approach was inspired by Jupyter notebook.) + """ if ("api_key" in request.query_params) and ( request.cookies.get(API_KEY_COOKIE_NAME) != api_key ): @@ -327,6 +367,20 @@ async def get_current_scopes( return PUBLIC_SCOPES if settings.allow_anonymous_access else NO_SCOPES +async def get_current_scopes_websocket( + websocket: WebSocket, + api_key: Optional[str] = Depends(get_api_key_websocket), + settings: Settings = Depends(get_settings), + db: Optional[AsyncSession] = Depends(get_database_session), +) -> set[str]: + if api_key is not None: + return await get_scopes_from_api_key( + api_key, settings, websocket.app.state.authenticated, db + ) + else: + return PUBLIC_SCOPES if settings.allow_anonymous_access else NO_SCOPES + + async def check_scopes( request: Request, security_scopes: SecurityScopes, @@ -344,6 +398,67 @@ async def check_scopes( ) +async def get_current_principal_from_api_key( + api_key: str, authenticated: bool, db: AsyncSession, settings: Settings +): + if authenticated: + # Tiled is in a multi-user configuration with authentication providers. + # We store the hashed value of the API key secret. + # By comparing hashes we protect against timing attacks. + # By storing only the hash of the (high-entropy) secret + # we reduce the value of that an attacker can extracted from a + # stolen database backup. + try: + secret = bytes.fromhex(api_key) + except Exception: + # Not valid hex, therefore not a valid API key + return None + + api_key_orm = await lookup_valid_api_key(db, secret) + if api_key_orm is not None: + principal = api_key_orm.principal + principal_scopes = set().union(*[role.scopes for role in principal.roles]) + # This intersection addresses the case where the Principal has + # lost a scope that they had when this key was created. + scopes = set(api_key_orm.scopes).intersection( + principal_scopes | {"inherit"} + ) + if "inherit" in scopes: + # The scope "inherit" is a metascope that confers all the + # scopes for the Principal associated with this API, + # resolved at access time. + scopes.update(principal_scopes) + api_key_orm.latest_activity = utcnow() + await db.commit() + return principal + else: + return None + else: + # Tiled is in a "single user" mode with only one API key. + if secrets.compare_digest(api_key, settings.single_user_api_key): + # Valid single-user API key - return None to indicate valid single user + return None + else: + # Invalid single-user API key - raise exception directly + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key" + ) + + +async def get_current_principal_websocket( + websocket: WebSocket, + api_key: str = Depends(get_api_key_websocket), + settings: Settings = Depends(get_settings), + db: Optional[AsyncSession] = Depends(get_database_session), +): + principal = await get_current_principal_from_api_key( + api_key, websocket.app.state.authenticated, db, settings + ) + if principal is None and websocket.app.state.authenticated: + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return principal + + async def get_current_principal( request: Request, security_scopes: SecurityScopes, @@ -351,8 +466,7 @@ async def get_current_principal( api_key: str = Depends(get_api_key), settings: Settings = Depends(get_settings), db: Optional[AsyncSession] = Depends(get_database_session), - # TODO: https://github.com/bluesky/tiled/issues/923 - # Remove non-Principal return types + _=Depends(move_api_key), ) -> Optional[schemas.Principal]: """ Get current Principal from: @@ -368,55 +482,18 @@ async def get_current_principal( """ if api_key is not None: - if request.app.state.authenticated: - # Tiled is in a multi-user configuration with authentication providers. - # We store the hashed value of the API key secret. - # By comparing hashes we protect against timing attacks. - # By storing only the hash of the (high-entropy) secret - # we reduce the value of that an attacker can extracted from a - # stolen database backup. - try: - secret = bytes.fromhex(api_key) - except Exception: - # Not valid hex, therefore not a valid API key - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - headers=headers_for_401(request, security_scopes), - ) - api_key_orm = await lookup_valid_api_key(db, secret) - if api_key_orm is not None: - principal = api_key_orm.principal - principal_scopes = set().union( - *[role.scopes for role in principal.roles] - ) - # This intersection addresses the case where the Principal has - # lost a scope that they had when this key was created. - scopes = set(api_key_orm.scopes).intersection( - principal_scopes | {"inherit"} - ) - if "inherit" in scopes: - # The scope "inherit" is a metascope that confers all the - # scopes for the Principal associated with this API, - # resolved at access time. - scopes.update(principal_scopes) - api_key_orm.latest_activity = utcnow() - await db.commit() - else: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - headers=headers_for_401(request, security_scopes), - ) - else: - # Tiled is in a "single user" mode with only one API key. - principal = None - if not secrets.compare_digest(api_key, settings.single_user_api_key): - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid API key", - headers=headers_for_401(request, security_scopes), - ) + principal = await get_current_principal_from_api_key( + api_key, + request.app.state.authenticated, + db, + settings, + ) + if principal is None and request.app.state.authenticated: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers=headers_for_401(request, security_scopes), + ) elif decoded_access_token is not None: principal = schemas.Principal( uuid=uuid_module.UUID(hex=decoded_access_token["sub"]), @@ -594,9 +671,9 @@ async def auth_code_route( db: Optional[AsyncSession] = Depends(get_database_session), ): request.state.endpoint = "auth" - user_session_state: UserSessionState | None = await authenticator.authenticate( - request - ) + user_session_state: Optional[ + UserSessionState + ] = await authenticator.authenticate(request) if not user_session_state: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Authentication failure" @@ -689,9 +766,9 @@ async def device_code_user_code_submit_route( }, status_code=HTTP_401_UNAUTHORIZED, ) - user_session_state: UserSessionState | None = await authenticator.authenticate( - request - ) + user_session_state: Optional[ + UserSessionState + ] = await authenticator.authenticate(request) if not user_session_state: return templates.TemplateResponse( request, @@ -773,7 +850,9 @@ async def handle_credentials_route( db: Optional[AsyncSession] = Depends(get_database_session), ): request.state.endpoint = "auth" - user_session_state: UserSessionState | None = await authenticator.authenticate( + user_session_state: Optional[ + UserSessionState + ] = await authenticator.authenticate( username=form_data.username, password=form_data.password ) if not user_session_state or not user_session_state.user_name: diff --git a/tiled/server/core.py b/tiled/server/core.py index eb0680fea..52486545c 100644 --- a/tiled/server/core.py +++ b/tiled/server/core.py @@ -10,13 +10,13 @@ from collections import defaultdict from datetime import datetime, timedelta, timezone from hashlib import md5 -from typing import Any +from typing import Any, Optional import anyio import dateutil.tz import jmespath import msgpack -from fastapi import HTTPException, Response +from fastapi import HTTPException, Response, WebSocket from starlette.responses import JSONResponse, StreamingResponse from starlette.status import HTTP_200_OK, HTTP_304_NOT_MODIFIED, HTTP_400_BAD_REQUEST @@ -744,6 +744,57 @@ def json_or_msgpack( ) +def get_websocket_envelope_formatter( + envelope_format: schemas.EnvelopeFormat, entry, deserialization_registry +): + if envelope_format == "msgpack": + + async def stream_msgpack( + websocket: WebSocket, + metadata: dict, + payload_bytes: Optional[bytes], + ): + if payload_bytes is not None: + metadata["payload"] = payload_bytes + data = msgpack.packb(metadata) + await websocket.send_bytes(data) + + return stream_msgpack + + elif envelope_format == "json": + + async def stream_json( + websocket: WebSocket, + metadata: dict, + payload_bytes: Optional[bytes], + ): + if payload_bytes is not None: + media_type = metadata.get("content-type", "application/octet-stream") + if media_type == "application/json": + # nothing to do, the payload is already JSON + payload_decoded = payload_bytes + else: + # Transcode to payload to JSON. + metadata["content-type"] = "application/json" + structure_family = ( + StructureFamily.array + ) # TODO: generalize beyond array + structure = entry.structure() + deserializer = deserialization_registry.dispatch( + structure_family, media_type + ) + payload_decoded = deserializer( + payload_bytes, + structure.data_type.to_numpy_dtype(), + metadata.get("shape"), + ) + metadata["payload"] = payload_decoded + data = safe_json_dump(metadata) + await websocket.send_text(data) + + return stream_json + + class UnsupportedMediaTypes(Exception): pass diff --git a/tiled/server/router.py b/tiled/server/router.py index f76c85156..65a2c66b2 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -1,3 +1,4 @@ +import collections import dataclasses import inspect import os @@ -11,10 +12,21 @@ import anyio import packaging -from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Security +import pydantic_settings +from fastapi import ( + APIRouter, + Body, + Depends, + HTTPException, + Query, + Request, + Security, + WebSocket, +) from jmespath.exceptions import JMESPathError from json_merge_patch import merge as apply_merge_patch from jsonpatch import apply_patch as apply_json_patch +from starlette.requests import URL from starlette.status import ( HTTP_200_OK, HTTP_206_PARTIAL_CONTENT, @@ -45,8 +57,11 @@ from .authentication import ( check_scopes, get_current_access_tags, + get_current_access_tags_websocket, get_current_principal, + get_current_principal_websocket, get_current_scopes, + get_current_scopes_websocket, get_session_state, ) from .core import ( @@ -61,6 +76,7 @@ construct_entries_response, construct_resource, construct_revisions_response, + get_websocket_envelope_formatter, json_or_msgpack, resolve_media_type, ) @@ -75,7 +91,12 @@ from .file_response_with_range import FileResponseWithRange from .links import links_for_node from .settings import Settings, get_settings -from .utils import filter_for_access, get_base_url, record_timing +from .utils import ( + filter_for_access, + get_base_url, + get_base_url_websocket, + record_timing, +) T = TypeVar("T") @@ -631,6 +652,76 @@ async def array_full( except UnsupportedMediaTypes as err: raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + @router.delete("/stream/close/{path:path}") + async def close_stream( + request: Request, + path: str, + principal: Optional[schemas.Principal] = Depends(get_current_principal), + root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), + session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), + authn_scopes: Scopes = Depends(get_current_scopes), + _=Security(check_scopes, scopes=["write:data"]), + ): + entry = await get_entry( + path, + ["write:data"], + principal, + authn_access_tags, + authn_scopes, + root_tree, + session_state, + request.state.metrics, + {StructureFamily.array, StructureFamily.sparse}, + getattr(request.app.state, "access_policy", None), + ) + await entry.close_stream() + + @router.websocket("/stream/single/{path:path}") + async def websocket_endpoint( + websocket: WebSocket, + path: str, + envelope_format: schemas.EnvelopeFormat = schemas.EnvelopeFormat.json, + start: Optional[int] = None, + principal: Optional[schemas.Principal] = Depends( + get_current_principal_websocket + ), + authn_access_tags: Optional[Set[str]] = Depends( + get_current_access_tags_websocket + ), + authn_scopes: Scopes = Depends(get_current_scopes_websocket), + ): + root_tree = websocket.app.state.root_tree + websocket.state.metrics = collections.defaultdict( + lambda: collections.defaultdict(lambda: 0) + ) + entry = await get_entry( + path, + ["read:data", "read:metadata"], + principal, + authn_access_tags, + authn_scopes, + root_tree, + {}, # session_state, + websocket.state.metrics, + { + StructureFamily.array, + StructureFamily.container, + StructureFamily.sparse, + }, + getattr(websocket.app.state, "access_policy", None), + ) + formatter = get_websocket_envelope_formatter( + envelope_format, entry, deserialization_registry + ) + base_websocket_url = URL(get_base_url_websocket(websocket)) + scheme = "https" if base_websocket_url.scheme == "wss" else "http" + path_parts = [segment for segment in path.split("/") if segment] + path_str = "/".join(path_parts) + uri = f"{base_websocket_url.replace(scheme=scheme)}/array/full/{path_str}" + handler = entry.make_ws_handler(websocket, formatter, uri) + await handler(start) + @router.get( "/table/partition/{path:path}", response_model=schemas.Response, @@ -1535,7 +1626,7 @@ async def put_data_source( None, getattr(request.app.state, "access_policy", None), ) - await entry.put_data_source(data_source=body.data_source) + await entry.put_data_source(data_source=body.data_source, patch=body.patch) @router.delete("/metadata/{path:path}") async def delete( @@ -1610,16 +1701,12 @@ async def put_array_full( ) media_type = request.headers["content-type"] if entry.structure_family == "array": - dtype = entry.structure().data_type.to_numpy_dtype() - shape = entry.structure().shape deserializer = deserialization_registry.dispatch("array", media_type) - data = await ensure_awaitable(deserializer, body, dtype, shape) elif entry.structure_family == "sparse": deserializer = deserialization_registry.dispatch("sparse", media_type) - data = await ensure_awaitable(deserializer, body) else: raise NotImplementedError(entry.structure_family) - await ensure_awaitable(entry.write, data) + await ensure_awaitable(entry.write, media_type, deserializer, entry, body) return json_or_msgpack(request, None) @router.put("/array/block/{path:path}") @@ -1651,23 +1738,15 @@ async def put_array_block( status_code=HTTP_405_METHOD_NOT_ALLOWED, detail="This node cannot accept array data.", ) - from tiled.adapters.array import slice_and_shape_from_block_and_chunks body = await request.body() media_type = request.headers["content-type"] - if entry.structure_family == "array": - dtype = entry.structure().data_type.to_numpy_dtype() - _, shape = slice_and_shape_from_block_and_chunks( - block, entry.structure().chunks - ) - deserializer = deserialization_registry.dispatch("array", media_type) - data = await ensure_awaitable(deserializer, body, dtype, shape) - elif entry.structure_family == "sparse": - deserializer = deserialization_registry.dispatch("sparse", media_type) - data = await ensure_awaitable(deserializer, body) - else: - raise NotImplementedError(entry.structure_family) - await ensure_awaitable(entry.write_block, data, block) + deserializer = deserialization_registry.dispatch( + entry.structure_family, media_type + ) + await ensure_awaitable( + entry.write_block, block, media_type, deserializer, entry, body + ) return json_or_msgpack(request, None) @router.patch("/array/full/{path:path}") @@ -1702,12 +1781,12 @@ async def patch_array_full( detail="This node cannot accept array data.", ) - dtype = entry.structure().data_type.to_numpy_dtype() body = await request.body() media_type = request.headers["content-type"] deserializer = deserialization_registry.dispatch("array", media_type) - data = await ensure_awaitable(deserializer, body, dtype, shape) - structure = await ensure_awaitable(entry.patch, data, offset, extend) + structure = await ensure_awaitable( + entry.patch, shape, offset, extend, media_type, deserializer, entry, body + ) return json_or_msgpack(request, structure) @router.put("/table/full/{path:path}") diff --git a/tiled/server/schemas.py b/tiled/server/schemas.py index e2c6e196c..e66a1881e 100644 --- a/tiled/server/schemas.py +++ b/tiled/server/schemas.py @@ -3,7 +3,17 @@ import enum import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, +) import pydantic.generics from pydantic import ConfigDict, Field, StringConstraints @@ -143,6 +153,12 @@ def from_orm(cls, orm: tiled.catalog.orm.Revision) -> Revision: ) +class Patch(pydantic.BaseModel): + offset: Tuple[int, ...] + shape: Tuple[int, ...] + extend: bool + + class DataSource(pydantic.BaseModel, Generic[StructureT]): id: Optional[int] = None structure_family: StructureFamily @@ -448,6 +464,7 @@ def narrow_structure_type(self): class PutDataSourceRequest(pydantic.BaseModel): data_source: DataSource + patch: Optional[Patch] class PostMetadataResponse(pydantic.BaseModel, Generic[ResourceLinksT]): @@ -571,4 +588,10 @@ class PatchMetadataResponse(pydantic.BaseModel, Generic[ResourceLinksT]): List[Resource[NodeAttributes, Dict, Dict]], PaginationLinks, Dict ] + +class EnvelopeFormat(str, enum.Enum): + json = "json" + msgpack = "msgpack" + + NodeStructure.model_rebuild() diff --git a/tiled/server/utils.py b/tiled/server/utils.py index 536616e24..d415e0f80 100644 --- a/tiled/server/utils.py +++ b/tiled/server/utils.py @@ -3,7 +3,7 @@ from collections.abc import Generator from typing import Any, Literal, Mapping -from fastapi import Request +from fastapi import Request, WebSocket from starlette.types import Scope from ..access_control.access_policies import NO_ACCESS @@ -32,6 +32,14 @@ def get_root_url(request: Request) -> str: return f"{get_root_url_low_level(request.headers, request.scope)}" +def get_root_url_websocket(websocket: WebSocket) -> str: + return f"{get_root_url_low_level(websocket.headers, websocket.scope)}" + + +def get_base_url_websocket(websocket: WebSocket) -> str: + return f"{get_root_url_websocket(websocket)}/api/v1" + + def get_base_url(request: Request) -> str: """ Base URL for the API