From cfa1ac728f07ef90949500941c6dbfac32258412 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Tue, 11 Mar 2025 11:32:32 +0000 Subject: [PATCH 01/18] Refactor API key handling --- tiled/server/app.py | 9 ++-- tiled/server/authentication.py | 86 ++++++++++++++-------------------- tiled/server/metrics.py | 6 +-- tiled/server/router.py | 5 +- 4 files changed, 40 insertions(+), 66 deletions(-) diff --git a/tiled/server/app.py b/tiled/server/app.py index 54429b2c5..e076f40c9 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -16,7 +16,7 @@ import packaging.version import yaml from asgi_correlation_id import CorrelationIdMiddleware, correlation_id -from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, Security +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response from fastapi.exception_handlers import http_exception_handler from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi @@ -34,6 +34,7 @@ HTTP_500_INTERNAL_SERVER_ERROR, ) +from tiled.server.authentication import move_api_key from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator from ..config import construct_build_app_kwargs @@ -43,7 +44,6 @@ from ..utils import SHARE_TILED_PATH, Conflicts, SpecialUsers, UnsupportedQueryType from ..validation_registration import validation_registry as default_validation_registry from . import schemas -from .authentication import get_current_principal from .compression import CompressionMiddleware from .dependencies import ( get_query_registry, @@ -215,7 +215,7 @@ async def lifespan(app: FastAPI): yield await shutdown_event() - app = FastAPI(lifespan=lifespan) + app = FastAPI(lifespan=lifespan, dependencies=[Depends(move_api_key)]) # Healthcheck for deployment to containerized systems, needs to preempt other responses. # Standardized for Kubernetes, but also used by other systems. @@ -265,9 +265,6 @@ async def lookup_file(path, try_app=True): @app.get("/", response_class=HTMLResponse) async def index( request: Request, - # This dependency is here because it runs the code that moves - # API key from the query parameter to a cookie (if it is valid). - principal=Security(get_current_principal, scopes=[]), ): return templates.TemplateResponse( request, diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index dc974fe04..f987a055c 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -17,13 +17,12 @@ Response, Security, ) -from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.security import ( OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes, ) -from fastapi.security.api_key import APIKeyBase, APIKeyCookie, APIKeyQuery +from fastapi.security.api_key import APIKeyCookie, APIKeyHeader, APIKeyQuery from fastapi.security.utils import get_authorization_scheme_param from fastapi.templating import Jinja2Templates from sqlalchemy.future import select @@ -93,33 +92,14 @@ class TokenData(BaseModel): username: Optional[str] = None -class APIKeyAuthorizationHeader(APIKeyBase): - """ - Expect a header like - - Authorization: Apikey SECRET - - where Apikey is case-insensitive. - """ - - def __init__( - self, - *, - name: str, - scheme_name: Optional[str] = None, - description: Optional[str] = None, - ): - self.model: APIKey = APIKey( - **{"in": APIKeyIn.header}, name=name, description=description - ) - self.scheme_name = scheme_name or self.__class__.__name__ - +# TODO: remove custom subclass https://github.com/bluesky/tiled/issues/921 +class StrictAPIKeyHeader(APIKeyHeader): async def __call__(self, request: Request) -> Optional[str]: - authorization: str = request.headers.get("Authorization") - scheme, param = get_authorization_scheme_param(authorization) - if not authorization or scheme.lower() == "bearer": - return None - if scheme.lower() != "apikey": + api_key: Optional[str] = request.headers.get(self.model.name) + scheme, param = get_authorization_scheme_param(api_key) + if not scheme or scheme.lower() == "bearer": + return self.check_api_key(None, self.auto_error) + if scheme.lower() != self.scheme_name.lower(): raise HTTPException( status_code=HTTP_400_BAD_REQUEST, detail=( @@ -128,17 +108,11 @@ async def __call__(self, request: Request) -> Optional[str]: "'Bearer SECRET' or 'Apikey SECRET'. " ), ) - return param + return self.check_api_key(param, self.auto_error) # The tokenUrl below is patched at app startup when we know it. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="PLACEHOLDER", auto_error=False) -api_key_query = APIKeyQuery(name="api_key", auto_error=False) -api_key_header = APIKeyAuthorizationHeader( - name="Authorization", - description="Prefix value with 'Apikey ' as in, 'Apikey SECRET'", -) -api_key_cookie = APIKeyCookie(name=API_KEY_COOKIE_NAME, auto_error=False) def create_access_token(data, secret_key, expires_delta): @@ -185,10 +159,21 @@ def decode_token(token, secret_keys): async def get_api_key( - api_key_query: str = Security(api_key_query), - api_key_header: str = Security(api_key_header), - api_key_cookie: str = Security(api_key_cookie), -): + api_key_query: Optional[str] = Depends( + APIKeyQuery(name="api_key", auto_error=False) + ), + api_key_header: Optional[str] = Depends( + StrictAPIKeyHeader( + name="Authorization", + description="Prefix value with 'Apikey ' as in, 'Apikey SECRET'", + scheme_name="Apikey", + auto_error=False, + ) + ), + api_key_cookie: Optional[str] = Depends( + APIKeyCookie(name=API_KEY_COOKIE_NAME, auto_error=False) + ), +) -> str | None: for api_key in [api_key_query, api_key_header, api_key_cookie]: if api_key is not None: return api_key @@ -233,6 +218,15 @@ async def get_session_state(decoded_access_token=Depends(get_decoded_access_toke return decoded_access_token.get("state") +async def move_api_key(request: Request, api_key: Optional[str] = Depends(get_api_key)): + if ("api_key" in request.query_params) and ( + request.cookies.get(API_KEY_COOKIE_NAME) != api_key + ): + request.state.cookies_to_set.append( + {"key": API_KEY_COOKIE_NAME, "value": api_key} + ) + + async def get_current_principal( request: Request, security_scopes: SecurityScopes, @@ -314,15 +308,6 @@ async def get_current_principal( detail="Invalid API key", headers=headers_for_401(request, security_scopes), ) - # If we made it to this point, we have a valid API key. - # If the API key was given in query param, move to cookie. - # This is convenient for browser-based access. - if ("api_key" in request.query_params) and ( - request.cookies.get(API_KEY_COOKIE_NAME) != api_key - ): - request.state.cookies_to_set.append( - {"key": API_KEY_COOKIE_NAME, "value": api_key} - ) elif decoded_access_token is not None: principal = schemas.Principal( uuid=uuid_module.UUID(hex=decoded_access_token["sub"]), @@ -875,7 +860,7 @@ async def create_service_principal( async def principal( request: Request, uuid: uuid_module.UUID, - principal=Security(get_current_principal, scopes=["read:principals"]), + _=Security(lambda: None, scopes=["read:principals"]), db=Depends(get_database_session), ): "Get information about one Principal (user or service)." @@ -911,7 +896,7 @@ async def revoke_apikey_for_principal( request: Request, uuid: uuid_module.UUID, first_eight: str, - principal=Security(get_current_principal, scopes=["admin:apikeys"]), + _=Security(lambda: None, scopes=["admin:apikeys"]), db=Depends(get_database_session), ): "Allow Tiled Admins to delete any user's apikeys e.g." @@ -1205,7 +1190,6 @@ async def whoami( async def logout( request: Request, response: Response, - principal=Security(get_current_principal, scopes=[]), ): "Deprecated. See revoke_session: POST /session/revoke." request.state.endpoint = "auth" diff --git a/tiled/server/metrics.py b/tiled/server/metrics.py index d6e0d79ee..043065b65 100644 --- a/tiled/server/metrics.py +++ b/tiled/server/metrics.py @@ -10,8 +10,6 @@ from fastapi import APIRouter, Request, Response, Security from prometheus_client import CONTENT_TYPE_LATEST, Histogram, generate_latest -from .authentication import get_current_principal - router = APIRouter() REQUEST_DURATION = Histogram( @@ -157,9 +155,7 @@ def prometheus_registry(): @router.get("/metrics") -async def metrics( - request: Request, principal=Security(get_current_principal, scopes=["metrics"]) -): +async def metrics(request: Request, _=Security(lambda: None, scopes=["metrics"])): """ Prometheus metrics """ diff --git a/tiled/server/router.py b/tiled/server/router.py index d5f613920..2f5bbe005 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -10,7 +10,7 @@ import anyio import packaging -from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Security +from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request from jmespath.exceptions import JMESPathError from json_merge_patch import merge as apply_merge_patch from jsonpatch import apply_patch as apply_json_patch @@ -77,9 +77,6 @@ async def about( authenticators=Depends(get_authenticators), serialization_registry=Depends(get_serialization_registry), query_registry=Depends(get_query_registry), - # This dependency is here because it runs the code that moves - # API key from the query parameter to a cookie (if it is valid). - principal=Security(get_current_principal, scopes=[]), ): # TODO The lazy import of entry modules and serializers means that the # lists of formats are not populated until they are first used. Not very From e52b48853168aab57d5c7be8189eff938d7a0919 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Tue, 11 Mar 2025 11:46:47 +0000 Subject: [PATCH 02/18] Clarify registry usage --- tiled/client/container.py | 4 ++-- tiled/config.py | 26 +++++++++++++-------- tiled/media_type_registration.py | 16 ++++++------- tiled/query_registration.py | 4 ++-- tiled/serialization/array.py | 39 ++++++++++++++++++++------------ tiled/serialization/awkward.py | 17 +++++++++----- tiled/serialization/container.py | 6 ++--- tiled/serialization/sparse.py | 25 ++++++++++++-------- tiled/serialization/table.py | 39 ++++++++++++++++++++------------ tiled/serialization/xarray.py | 18 +++++++-------- tiled/server/app.py | 18 +++++++++------ tiled/server/dependencies.py | 10 ++++---- tiled/validation_registration.py | 2 +- 13 files changed, 132 insertions(+), 92 deletions(-) diff --git a/tiled/client/container.py b/tiled/client/container.py index 2b39b8c8e..1b4f26001 100644 --- a/tiled/client/container.py +++ b/tiled/client/container.py @@ -15,7 +15,7 @@ from ..adapters.utils import IndexersMixin from ..iterviews import ItemsView, KeysView, ValuesView from ..queries import KeyLookup -from ..query_registration import query_registry +from ..query_registration import default_query_registry from ..structures.core import Spec, StructureFamily from ..structures.data_source import DataSource from ..utils import UNCHANGED, OneShotCachedMap, Sentinel, node_repr, safe_json_dump @@ -1055,7 +1055,7 @@ def _queries_to_params(*queries): "Compute GET params from the queries." params = collections.defaultdict(list) for query in queries: - name = query_registry.query_type_to_name[type(query)] + name = default_query_registry.query_type_to_name[type(query)] for field, value in query.encode().items(): if value is not None: params[f"filter[{name}][condition][{field}]"].append(value) diff --git a/tiled/config.py b/tiled/config.py index 991d2bf48..f2cfe5228 100644 --- a/tiled/config.py +++ b/tiled/config.py @@ -10,19 +10,21 @@ from datetime import timedelta from functools import cache from pathlib import Path +from typing import Optional import jsonschema from .adapters.mapping import MapAdapter from .media_type_registration import ( - compression_registry as default_compression_registry, + CompressionRegistry, + SerializationRegistry, + default_compression_registry, + default_deserialization_registry, + default_serialization_registry, ) -from .media_type_registration import ( - serialization_registry as default_serialization_registry, -) -from .query_registration import query_registry as default_query_registry +from .query_registration import QueryRegistry, default_query_registry from .utils import import_object, parse, prepend_to_sys_path -from .validation_registration import validation_registry as default_validation_registry +from .validation_registration import ValidationRegistry, default_validation_registry @cache @@ -40,10 +42,11 @@ def construct_build_app_kwargs( config, *, source_filepath=None, - query_registry=None, - compression_registry=None, - serialization_registry=None, - validation_registry=None, + query_registry: Optional[QueryRegistry] = None, + compression_registry: Optional[CompressionRegistry] = None, + serialization_registry: Optional[SerializationRegistry] = None, + deserialization_registry: Optional[SerializationRegistry] = None, + validation_registry: Optional[ValidationRegistry] = None, ): """ Given parsed configuration, construct arguments for build_app(...). @@ -61,6 +64,8 @@ def construct_build_app_kwargs( query_registry = default_query_registry if serialization_registry is None: serialization_registry = default_serialization_registry + if deserialization_registry is None: + deserialization_registry = default_deserialization_registry if compression_registry is None: compression_registry = default_compression_registry if validation_registry is None: @@ -220,6 +225,7 @@ def construct_build_app_kwargs( "server_settings": server_settings, "query_registry": query_registry, "serialization_registry": serialization_registry, + "deserialization_registry": deserialization_registry, "compression_registry": compression_registry, "validation_registry": validation_registry, "tasks": { diff --git a/tiled/media_type_registration.py b/tiled/media_type_registration.py index 854f20699..7b41c54c6 100644 --- a/tiled/media_type_registration.py +++ b/tiled/media_type_registration.py @@ -197,13 +197,13 @@ def __call__(self, media_type, encoder, *args, **kwargs): return self.dispatch(media_type, encoder)(*args, **kwargs) -serialization_registry = SerializationRegistry() +default_serialization_registry = SerializationRegistry() "Global serialization registry. See Registry for usage examples." -deserialization_registry = SerializationRegistry() +default_deserialization_registry = SerializationRegistry() "Global deserialization registry. See Registry for usage examples." -compression_registry = CompressionRegistry() +default_compression_registry = CompressionRegistry() "Global compression registry. See Registry for usage examples." @@ -211,7 +211,7 @@ def __call__(self, media_type, encoder, *args, **kwargs): "application/json", "application/x-msgpack", ]: - compression_registry.register( + default_compression_registry.register( media_type, "gzip", lambda buffer: gzip.GzipFile(mode="wb", fileobj=buffer, compresslevel=9), @@ -225,7 +225,7 @@ def __call__(self, media_type, encoder, *args, **kwargs): "text/plain", "text/html", ]: - compression_registry.register( + default_compression_registry.register( media_type, "gzip", # Use a lower compression level. High compression is extremely slow @@ -270,7 +270,7 @@ def close(self): "text/html", "text/plain", ]: - compression_registry.register(media_type, "zstd", ZstdBuffer) + default_compression_registry.register(media_type, "zstd", ZstdBuffer) if modules_available("lz4"): import lz4 @@ -326,7 +326,7 @@ def close(self): "text/html", "text/plain", ]: - compression_registry.register(media_type, "lz4", LZ4Buffer) + default_compression_registry.register(media_type, "lz4", LZ4Buffer) if modules_available("blosc2"): import blosc2 @@ -355,4 +355,4 @@ def close(self): pass for media_type in ["application/octet-stream", APACHE_ARROW_FILE_MIME_TYPE]: - compression_registry.register(media_type, "blosc2", BloscBuffer) + default_compression_registry.register(media_type, "blosc2", BloscBuffer) diff --git a/tiled/query_registration.py b/tiled/query_registration.py index ec4bca9b7..a42a54236 100644 --- a/tiled/query_registration.py +++ b/tiled/query_registration.py @@ -79,8 +79,8 @@ def inner(cls): # Make a global registry. -query_registry = QueryRegistry() -register = query_registry.register +default_query_registry = QueryRegistry() +register = default_query_registry.register """Register a new type of query.""" diff --git a/tiled/serialization/array.py b/tiled/serialization/array.py index 12d6df055..01abf74f4 100644 --- a/tiled/serialization/array.py +++ b/tiled/serialization/array.py @@ -3,7 +3,10 @@ import numpy -from ..media_type_registration import deserialization_registry, serialization_registry +from ..media_type_registration import ( + default_deserialization_registry, + default_serialization_registry, +) from ..utils import ( SerializationError, UnsupportedShape, @@ -22,13 +25,13 @@ def as_buffer(array, metadata): return numpy.asarray(array).tobytes() -serialization_registry.register( +default_serialization_registry.register( "array", "application/octet-stream", as_buffer, ) if modules_available("orjson"): - serialization_registry.register( + default_serialization_registry.register( "array", "application/json", lambda array, metadata: safe_json_dump(array), @@ -43,10 +46,12 @@ def serialize_csv(array, metadata): return file.getvalue().encode() -serialization_registry.register("array", "text/csv", serialize_csv) -serialization_registry.register("array", "text/x-comma-separated-values", serialize_csv) -serialization_registry.register("array", "text/plain", serialize_csv) -deserialization_registry.register( +default_serialization_registry.register("array", "text/csv", serialize_csv) +default_serialization_registry.register( + "array", "text/x-comma-separated-values", serialize_csv +) +default_serialization_registry.register("array", "text/plain", serialize_csv) +default_deserialization_registry.register( "array", "application/octet-stream", lambda buffer, dtype, shape: numpy.frombuffer(buffer, dtype=dtype).reshape(shape), @@ -90,10 +95,10 @@ def array_from_buffer_PIL(buffer, format, dtype, shape): image = Image.open(file, format=format) return numpy.asarray(image).asdtype(dtype).reshape(shape) - serialization_registry.register( + default_serialization_registry.register( "array", "image/png", lambda array, metadata: save_to_buffer_PIL(array, "png") ) - deserialization_registry.register( + default_deserialization_registry.register( "array", "image/png", lambda buffer, dtype, shape: array_from_buffer_PIL(buffer, "png", dtype, shape), @@ -120,18 +125,24 @@ def save_to_buffer_tifffile(array, metadata): imwrite(file, normalized_array) return file.getbuffer() - serialization_registry.register("array", "image/tiff", save_to_buffer_tifffile) - deserialization_registry.register("array", "image/tiff", array_from_buffer_tifffile) + default_serialization_registry.register( + "array", "image/tiff", save_to_buffer_tifffile + ) + default_deserialization_registry.register( + "array", "image/tiff", array_from_buffer_tifffile + ) def serialize_html(array, metadata): "Try to display as image. Fall back to CSV." try: - png_data = serialization_registry.dispatch("array", "image/png")( + png_data = default_serialization_registry.dispatch("array", "image/png")( array, metadata ) except Exception: - csv_data = serialization_registry.dispatch("array", "text/csv")(array, metadata) + csv_data = default_serialization_registry.dispatch("array", "text/csv")( + array, metadata + ) return "" "" f"{csv_data.decode()!s}" "" "" else: return ( @@ -145,4 +156,4 @@ def serialize_html(array, metadata): ) -serialization_registry.register("array", "text/html", serialize_html) +default_serialization_registry.register("array", "text/html", serialize_html) diff --git a/tiled/serialization/awkward.py b/tiled/serialization/awkward.py index dfb842504..f7a67a0eb 100644 --- a/tiled/serialization/awkward.py +++ b/tiled/serialization/awkward.py @@ -3,12 +3,15 @@ import awkward -from ..media_type_registration import deserialization_registry, serialization_registry +from ..media_type_registration import ( + default_deserialization_registry, + default_serialization_registry, +) from ..structures.core import StructureFamily from ..utils import APACHE_ARROW_FILE_MIME_TYPE, modules_available -@serialization_registry.register(StructureFamily.awkward, "application/zip") +@default_serialization_registry.register(StructureFamily.awkward, "application/zip") def to_zipped_buffers(components, metadata): (form, length, container) = components file = io.BytesIO() @@ -22,7 +25,7 @@ def to_zipped_buffers(components, metadata): return file.getbuffer() -@deserialization_registry.register(StructureFamily.awkward, "application/zip") +@default_deserialization_registry.register(StructureFamily.awkward, "application/zip") def from_zipped_buffers(buffer, form, length): file = io.BytesIO(buffer) with zipfile.ZipFile(file, "r") as zip: @@ -33,7 +36,7 @@ def from_zipped_buffers(buffer, form, length): return container -@serialization_registry.register(StructureFamily.awkward, "application/json") +@default_serialization_registry.register(StructureFamily.awkward, "application/json") def to_json(components, metadata): (form, length, container) = components file = io.StringIO() @@ -44,7 +47,7 @@ def to_json(components, metadata): if modules_available("pyarrow"): - @serialization_registry.register( + @default_serialization_registry.register( StructureFamily.awkward, APACHE_ARROW_FILE_MIME_TYPE ) def to_arrow(components, metadata): @@ -60,7 +63,9 @@ def to_arrow(components, metadata): # There seems to be no official Parquet MIME type. # https://issues.apache.org/jira/browse/PARQUET-1889 - @serialization_registry.register(StructureFamily.awkward, "application/x-parquet") + @default_serialization_registry.register( + StructureFamily.awkward, "application/x-parquet" + ) def to_parquet(components, metadata): import pyarrow.parquet diff --git a/tiled/serialization/container.py b/tiled/serialization/container.py index 5b82074c7..e475724b4 100644 --- a/tiled/serialization/container.py +++ b/tiled/serialization/container.py @@ -1,6 +1,6 @@ import io -from ..media_type_registration import serialization_registry +from ..media_type_registration import default_serialization_registry from ..structures.core import StructureFamily from ..utils import ( SerializationError, @@ -78,7 +78,7 @@ async def serialize_hdf5(node, metadata, filter_for_access): dataset.attrs.create(k, v) return buffer.getbuffer() - serialization_registry.register( + default_serialization_registry.register( StructureFamily.container, "application/x-hdf5", serialize_hdf5 ) @@ -101,6 +101,6 @@ async def serialize_json(node, metadata, filter_for_access): d = d[key]["contents"] return safe_json_dump(to_serialize) - serialization_registry.register( + default_serialization_registry.register( StructureFamily.container, "application/json", serialize_json ) diff --git a/tiled/serialization/sparse.py b/tiled/serialization/sparse.py index d44839b1a..aeba83a8e 100644 --- a/tiled/serialization/sparse.py +++ b/tiled/serialization/sparse.py @@ -1,6 +1,9 @@ import io -from ..media_type_registration import deserialization_registry, serialization_registry +from ..media_type_registration import ( + default_deserialization_registry, + default_serialization_registry, +) from ..utils import modules_available if modules_available("h5py"): @@ -19,7 +22,9 @@ def serialize_hdf5(sparse_arr, metadata): file.attrs.create(k, v) return buffer.getbuffer() - serialization_registry.register("sparse", "application/x-hdf5", serialize_hdf5) + default_serialization_registry.register( + "sparse", "application/x-hdf5", serialize_hdf5 + ) if modules_available("pandas", "pyarrow"): import pandas @@ -37,7 +42,7 @@ def serialize_hdf5(sparse_arr, metadata): if modules_available("openpyxl"): from .table import serialize_excel - serialization_registry.register( + default_serialization_registry.register( "sparse", XLSX_MIME_TYPE, lambda sparse_arr, metadata: serialize_excel( @@ -52,38 +57,38 @@ def to_dataframe(sparse_arr): d["data"] = sparse_arr.data return pandas.DataFrame(d) - deserialization_registry.register( + default_deserialization_registry.register( "sparse", APACHE_ARROW_FILE_MIME_TYPE, deserialize_arrow ) - serialization_registry.register( + default_serialization_registry.register( "sparse", APACHE_ARROW_FILE_MIME_TYPE, lambda sparse_arr, metadata: serialize_arrow( to_dataframe(sparse_arr), metadata, preserve_index=False ), ) - serialization_registry.register( + default_serialization_registry.register( "sparse", "application/x-parquet", lambda sparse_arr, metadata: serialize_parquet( to_dataframe(sparse_arr), metadata, preserve_index=False ), ) - serialization_registry.register( + default_serialization_registry.register( "sparse", "text/csv", lambda sparse_arr, metadata: serialize_csv( to_dataframe(sparse_arr), metadata, preserve_index=False ), ) - serialization_registry.register( + default_serialization_registry.register( "sparse", "text/plain", lambda sparse_arr, metadata: serialize_csv( to_dataframe(sparse_arr), metadata, preserve_index=False ), ) - serialization_registry.register( + default_serialization_registry.register( "sparse", "text/html", lambda sparse_arr, metadata: serialize_html( @@ -99,7 +104,7 @@ def serialize_json(sparse_arr, metadata): {column: df[column].tolist() for column in df}, ) - serialization_registry.register( + default_serialization_registry.register( "sparse", "application/json", serialize_json, diff --git a/tiled/serialization/table.py b/tiled/serialization/table.py index 214c5522a..cf1352dae 100644 --- a/tiled/serialization/table.py +++ b/tiled/serialization/table.py @@ -1,12 +1,17 @@ import io import mimetypes -from ..media_type_registration import deserialization_registry, serialization_registry +from ..media_type_registration import ( + default_deserialization_registry, + default_serialization_registry, +) from ..structures.core import StructureFamily from ..utils import APACHE_ARROW_FILE_MIME_TYPE, XLSX_MIME_TYPE, modules_available -@serialization_registry.register(StructureFamily.table, APACHE_ARROW_FILE_MIME_TYPE) +@default_serialization_registry.register( + StructureFamily.table, APACHE_ARROW_FILE_MIME_TYPE +) def serialize_arrow(df, metadata, preserve_index=True): import pyarrow @@ -22,7 +27,9 @@ def serialize_arrow(df, metadata, preserve_index=True): return memoryview(sink.getvalue()) -@deserialization_registry.register(StructureFamily.table, APACHE_ARROW_FILE_MIME_TYPE) +@default_deserialization_registry.register( + StructureFamily.table, APACHE_ARROW_FILE_MIME_TYPE +) def deserialize_arrow(buffer): import pyarrow @@ -31,7 +38,7 @@ def deserialize_arrow(buffer): # There seems to be no official Parquet MIME type. # https://issues.apache.org/jira/browse/PARQUET-1889 -@serialization_registry.register(StructureFamily.table, "application/x-parquet") +@default_serialization_registry.register(StructureFamily.table, "application/x-parquet") def serialize_parquet(df, metadata, preserve_index=True): import pyarrow.parquet @@ -48,21 +55,25 @@ def serialize_csv(df, metadata, preserve_index=False): return file.getvalue().encode() -@deserialization_registry.register(StructureFamily.table, "text/csv") +@default_deserialization_registry.register(StructureFamily.table, "text/csv") def deserialize_csv(buffer): import pandas return pandas.read_csv(io.BytesIO(buffer), header=None) -serialization_registry.register(StructureFamily.table, "text/csv", serialize_csv) -serialization_registry.register( +default_serialization_registry.register( + StructureFamily.table, "text/csv", serialize_csv +) +default_serialization_registry.register( StructureFamily.table, "text/x-comma-separated-values", serialize_csv ) -serialization_registry.register(StructureFamily.table, "text/plain", serialize_csv) +default_serialization_registry.register( + StructureFamily.table, "text/plain", serialize_csv +) -@serialization_registry.register(StructureFamily.table, "text/html") +@default_serialization_registry.register(StructureFamily.table, "text/html") def serialize_html(df, metadata, preserve_index=False): file = io.StringIO() df.to_html(file, index=preserve_index) @@ -73,20 +84,20 @@ def serialize_html(df, metadata, preserve_index=False): # The optional pandas dependency openpyxel is required for Excel read/write. import pandas - @serialization_registry.register(StructureFamily.table, XLSX_MIME_TYPE) + @default_serialization_registry.register(StructureFamily.table, XLSX_MIME_TYPE) def serialize_excel(df, metadata, preserve_index=False): file = io.BytesIO() df.to_excel(file, index=preserve_index) return file.getbuffer() - deserialization_registry.register( + default_deserialization_registry.register( StructureFamily.table, XLSX_MIME_TYPE, pandas.read_excel ) mimetypes.types_map.setdefault(".xlsx", XLSX_MIME_TYPE) if modules_available("orjson"): import orjson - serialization_registry.register( + default_serialization_registry.register( StructureFamily.table, "application/json", lambda df, metadata: orjson.dumps( @@ -103,7 +114,7 @@ def serialize_excel(df, metadata, preserve_index=False): # {'a': 1, 'b': 4} # {'a': 2, 'b': 5} # {'a': 3, 'b': 6} - @serialization_registry.register( + @default_serialization_registry.register( StructureFamily.table, "application/json-seq", # official mimetype for newline-delimited JSON ) @@ -126,6 +137,6 @@ def json_sequence(df, metadata): if modules_available("h5py"): from .container import serialize_hdf5 - serialization_registry.register( + default_serialization_registry.register( StructureFamily.table, "application/x-hdf5", serialize_hdf5 ) diff --git a/tiled/serialization/xarray.py b/tiled/serialization/xarray.py index 6bc4fc251..3677b53df 100644 --- a/tiled/serialization/xarray.py +++ b/tiled/serialization/xarray.py @@ -1,6 +1,6 @@ import io -from ..media_type_registration import serialization_registry +from ..media_type_registration import default_serialization_registry from ..utils import ensure_awaitable, modules_available from .container import walk from .table import ( @@ -49,7 +49,7 @@ def close(self): if modules_available("scipy"): # Both application/netcdf and application/x-netcdf are used. # https://en.wikipedia.org/wiki/NetCDF - @serialization_registry.register( + @default_serialization_registry.register( "xarray_dataset", ["application/netcdf", "application/x-netcdf"] ) async def serialize_netcdf(node, metadata, filter_for_access): @@ -65,29 +65,29 @@ async def serialize_netcdf(node, metadata, filter_for_access): # 1-dimensional variables it is useful. -@serialization_registry.register("xarray_dataset", APACHE_ARROW_FILE_MIME_TYPE) +@default_serialization_registry.register("xarray_dataset", APACHE_ARROW_FILE_MIME_TYPE) async def serialize_dataset_arrow(node, metadata, filter_for_access): return serialize_arrow((await as_dataset(node)).to_dataframe(), metadata) -@serialization_registry.register("xarray_dataset", "application/x-parquet") +@default_serialization_registry.register("xarray_dataset", "application/x-parquet") async def serialize_dataset_parquet(node, metadata, filter_for_access): return serialize_parquet((await as_dataset(node)).to_dataframe(), metadata) -@serialization_registry.register( +@default_serialization_registry.register( "xarray_dataset", ["text/csv", "text/comma-separated-values", "text/plain"] ) async def serialize_dataset_csv(node, metadata, filter_for_access): return serialize_csv((await as_dataset(node)).to_dataframe(), metadata) -@serialization_registry.register("xarray_dataset", "text/html") +@default_serialization_registry.register("xarray_dataset", "text/html") async def serialize_dataset_html(node, metadata, filter_for_access): return serialize_html((await as_dataset(node)).to_dataframe(), metadata) -@serialization_registry.register("xarray_dataset", XLSX_MIME_TYPE) +@default_serialization_registry.register("xarray_dataset", XLSX_MIME_TYPE) async def serialize_dataset_excel(node, metadata, filter_for_access): return serialize_excel((await as_dataset(node)).to_dataframe(), metadata) @@ -95,7 +95,7 @@ async def serialize_dataset_excel(node, metadata, filter_for_access): if modules_available("orjson"): import orjson - @serialization_registry.register("xarray_dataset", "application/json") + @default_serialization_registry.register("xarray_dataset", "application/json") async def serialize_json(node, metadata, filter_for_access): df = (await as_dataset(node)).to_dataframe() return orjson.dumps( @@ -105,7 +105,7 @@ async def serialize_json(node, metadata, filter_for_access): if modules_available("h5py"): - @serialization_registry.register("xarray_dataset", "application/x-hdf5") + @default_serialization_registry.register("xarray_dataset", "application/x-hdf5") async def serialize_hdf5(node, metadata, filter_for_access): """ Like for node, but encode everything under 'attrs' in attrs. diff --git a/tiled/server/app.py b/tiled/server/app.py index e076f40c9..61fed4e5b 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -10,7 +10,7 @@ from contextlib import asynccontextmanager from functools import cache, partial from pathlib import Path -from typing import List +from typing import List, Optional import anyio import packaging.version @@ -34,15 +34,18 @@ HTTP_500_INTERNAL_SERVER_ERROR, ) +from tiled.query_registration import QueryRegistry from tiled.server.authentication import move_api_key from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator from ..config import construct_build_app_kwargs from ..media_type_registration import ( - compression_registry as default_compression_registry, + CompressionRegistry, + SerializationRegistry, + default_compression_registry, ) from ..utils import SHARE_TILED_PATH, Conflicts, SpecialUsers, UnsupportedQueryType -from ..validation_registration import validation_registry as default_validation_registry +from ..validation_registration import ValidationRegistry, default_validation_registry from . import schemas from .compression import CompressionMiddleware from .dependencies import ( @@ -113,10 +116,11 @@ def build_app( tree, authentication=None, server_settings=None, - query_registry=None, - serialization_registry=None, - compression_registry=None, - validation_registry=None, + query_registry: Optional[QueryRegistry] = None, + serialization_registry: Optional[SerializationRegistry] = None, + deserialization_registry: Optional[SerializationRegistry] = None, + compression_registry: Optional[CompressionRegistry] = None, + validation_registry: Optional[ValidationRegistry] = None, tasks=None, scalable=False, ): diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index 09b8d8233..6a4ffa284 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -6,13 +6,11 @@ from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND from ..media_type_registration import ( - deserialization_registry as default_deserialization_registry, + default_deserialization_registry, + default_serialization_registry, ) -from ..media_type_registration import ( - serialization_registry as default_serialization_registry, -) -from ..query_registration import query_registry as default_query_registry -from ..validation_registration import validation_registry as default_validation_registry +from ..query_registration import default_query_registry +from ..validation_registration import default_validation_registry from .authentication import get_current_principal, get_session_state from .core import NoEntry from .utils import filter_for_access, record_timing diff --git a/tiled/validation_registration.py b/tiled/validation_registration.py index 133c3811a..a8712c177 100644 --- a/tiled/validation_registration.py +++ b/tiled/validation_registration.py @@ -23,7 +23,7 @@ def __contains__(self, spec): return spec in self._lookup -validation_registry = ValidationRegistry() +default_validation_registry = ValidationRegistry() "Global validation registry" From 8ad8ea38634feca354b8dbcb208235c87ea4fe04 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Tue, 11 Mar 2025 12:02:05 +0000 Subject: [PATCH 03/18] Create routes after registries known --- tiled/server/app.py | 63 +- tiled/server/router.py | 3115 ++++++++++++++++++++-------------------- 2 files changed, 1556 insertions(+), 1622 deletions(-) diff --git a/tiled/server/app.py b/tiled/server/app.py index 61fed4e5b..c2565f059 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -10,7 +10,7 @@ from contextlib import asynccontextmanager from functools import cache, partial from pathlib import Path -from typing import List, Optional +from typing import Optional import anyio import packaging.version @@ -34,7 +34,7 @@ HTTP_500_INTERNAL_SERVER_ERROR, ) -from tiled.query_registration import QueryRegistry +from tiled.query_registration import QueryRegistry, default_query_registry from tiled.server.authentication import move_api_key from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator @@ -46,15 +46,9 @@ ) from ..utils import SHARE_TILED_PATH, Conflicts, SpecialUsers, UnsupportedQueryType from ..validation_registration import ValidationRegistry, default_validation_registry -from . import schemas from .compression import CompressionMiddleware -from .dependencies import ( - get_query_registry, - get_root_tree, - get_serialization_registry, - get_validation_registry, -) -from .router import distinct, patch_route_signature, router, search +from .dependencies import get_root_tree +from .router import get_router from .settings import get_settings from .utils import ( API_KEY_COOKIE_NAME, @@ -143,7 +137,7 @@ def build_app( for spec in authentication.get("providers", []) } server_settings = server_settings or {} - query_registry = query_registry or get_query_registry() + query_registry = query_registry or default_query_registry compression_registry = compression_registry or default_compression_registry validation_registry = validation_registry or default_validation_registry tasks = tasks or {} @@ -349,6 +343,12 @@ async def unhandled_exception_handler( ), ) + router = get_router( + query_registry, + serialization_registry, + deserialization_registry, + validation_registry, + ) app.include_router(router, prefix="/api/v1") # The Tree and Authenticator have the opportunity to add custom routes to @@ -423,21 +423,6 @@ async def unhandled_exception_handler( # And add this authentication_router itself to the app. app.include_router(authentication_router, prefix="/api/v1/auth") - # The /search route is defined after import time so that the user has the - # opporunity to register custom query types before startup. - app.get( - "/api/v1/search/{path:path}", - response_model=schemas.Response[ - List[schemas.Resource[schemas.NodeAttributes, dict, dict]], - schemas.PaginationLinks, - dict, - ], - )(patch_route_signature(search, query_registry)) - app.get( - "/api/v1/distinct/{path:path}", - response_model=schemas.GetDistinctResponse, - )(patch_route_signature(distinct, query_registry)) - @cache def override_get_authenticators(): return authenticators @@ -770,32 +755,6 @@ async def set_cookies(request: Request, call_next): app.dependency_overrides[get_authenticators] = override_get_authenticators app.dependency_overrides[get_root_tree] = override_get_root_tree app.dependency_overrides[get_settings] = override_get_settings - if query_registry is not None: - - @cache - def override_get_query_registry(): - return query_registry - - app.dependency_overrides[get_query_registry] = override_get_query_registry - if serialization_registry is not None: - - @cache - def override_get_serialization_registry(): - return serialization_registry - - app.dependency_overrides[ - get_serialization_registry - ] = override_get_serialization_registry - - if validation_registry is not None: - - @cache - def override_get_validation_registry(): - return validation_registry - - app.dependency_overrides[ - get_validation_registry - ] = override_get_validation_registry @app.middleware("http") async def capture_metrics(request: Request, call_next): diff --git a/tiled/server/router.py b/tiled/server/router.py index 2f5bbe005..0275a1125 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -6,7 +6,7 @@ from datetime import datetime, timedelta, timezone from functools import partial from pathlib import Path -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional, TypeVar import anyio import packaging @@ -26,13 +26,15 @@ HTTP_422_UNPROCESSABLE_ENTITY, ) +from tiled.media_type_registration import SerializationRegistry +from tiled.query_registration import QueryRegistry from tiled.schemas import About from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator from .. import __version__ from ..structures.core import Spec, StructureFamily from ..utils import ensure_awaitable, patch_mimetypes, path_from_uri -from ..validation_registration import ValidationError +from ..validation_registration import ValidationError, ValidationRegistry from . import schemas from .authentication import get_authenticators, get_current_principal from .core import ( @@ -54,10 +56,6 @@ SecureEntry, block, expected_shape, - get_deserialization_registry, - get_query_registry, - get_serialization_registry, - get_validation_registry, offset_param, shape_param, slice_, @@ -67,354 +65,437 @@ from .settings import Settings, get_settings from .utils import filter_for_access, get_base_url, record_timing -router = APIRouter() - - -@router.get("/", response_model=About) -async def about( - request: Request, - settings: Settings = Depends(get_settings), - authenticators=Depends(get_authenticators), - serialization_registry=Depends(get_serialization_registry), - query_registry=Depends(get_query_registry), -): - # TODO The lazy import of entry modules and serializers means that the - # lists of formats are not populated until they are first used. Not very - # helpful for discovery! The registration can be made non-lazy, while the - # imports of the underlying I/O libraries themselves (openpyxl, pillow, - # etc.) can remain lazy. - request.state.endpoint = "about" - base_url = get_base_url(request) - authentication = { - "required": not settings.allow_anonymous_access, - } - provider_specs = [] - user_agent = request.headers.get("user-agent", "") - # The name of the "internal" mode used to be "password". - # This ensures back-compat with older Python clients. - internal_mode_name = "internal" - MINIMUM_INTERNAL_PYTHON_CLIENT_VERSION = packaging.version.parse("0.1.0b17") - if user_agent.startswith("python-tiled/"): - agent, _, raw_version = user_agent.partition("/") - try: - parsed_version = packaging.version.parse(raw_version) - except Exception: - pass - else: - if parsed_version < MINIMUM_INTERNAL_PYTHON_CLIENT_VERSION: - internal_mode_name = "password" - for provider, authenticator in authenticators.items(): - if isinstance(authenticator, InternalAuthenticator): - spec = { - "provider": provider, - "mode": internal_mode_name, - "links": { - "auth_endpoint": f"{base_url}/auth/provider/{provider}/token" - }, - "confirmation_message": getattr( - authenticator, "confirmation_message", None - ), - } - elif isinstance(authenticator, ExternalAuthenticator): - spec = { - "provider": provider, - "mode": "external", - "links": { - "auth_endpoint": f"{base_url}/auth/provider/{provider}/authorize" - }, - "confirmation_message": getattr( - authenticator, "confirmation_message", None - ), - } - else: - # It should be impossible to reach here. - assert False - provider_specs.append(spec) - if provider_specs: - # If there are *any* authenticaiton providers, these - # endpoints will be added. - authentication["links"] = { - "whoami": f"{base_url}/auth/whoami", - "apikey": f"{base_url}/auth/apikey", - "refresh_session": f"{base_url}/auth/session/refresh", - "revoke_session": f"{base_url}/auth/session/revoke/{{session_id}}", - "logout": f"{base_url}/auth/logout", - } - authentication["providers"] = provider_specs - - return json_or_msgpack( - request, - About( - library_version=__version__, - api_version=0, - formats={ - structure_family: list( - serialization_registry.media_types(structure_family) - ) - for structure_family in serialization_registry.structure_families - }, - aliases={ - structure_family: serialization_registry.aliases(structure_family) - for structure_family in serialization_registry.structure_families - }, - queries=list(query_registry.name_to_query_type), - authentication=authentication, - links={ - "self": base_url, - "documentation": f"{base_url}/docs", - }, - meta={"root_path": request.scope.get("root_path") or "" + "/api"}, - ).model_dump(), - expires=datetime.now(timezone.utc) + timedelta(seconds=600), - ) - -async def search( - request: Request, - path: str, - fields: Optional[List[schemas.EntryFields]] = Query(list(schemas.EntryFields)), - select_metadata: Optional[str] = Query(None), - offset: Optional[int] = Query(0, alias="page[offset]", ge=0), - limit: Optional[int] = Query( - DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE - ), - sort: Optional[str] = Query(None), - max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), - omit_links: bool = Query(False), - include_data_sources: bool = Query(False), - entry: Any = SecureEntry(scopes=["read:metadata"]), - query_registry=Depends(get_query_registry), - principal: str = Depends(get_current_principal), - **filters, -): - request.state.endpoint = "search" - if entry.structure_family != StructureFamily.container: - raise WrongTypeForRoute("This is not a Node; it cannot be searched or listed.") - try: - resource, metadata_stale_at, must_revalidate = await construct_entries_response( - query_registry, - entry, - "/search", - path, - offset, - limit, - fields, - select_metadata, - omit_links, - include_data_sources, - filters, - sort, - get_base_url(request), - resolve_media_type(request), - max_depth=max_depth, - ) - # We only get one Expires header, so if different parts - # of this response become stale at different times, we - # cite the earliest one. - entries_stale_at = getattr(entry, "entries_stale_at", None) - headers = {} - if (metadata_stale_at is None) or (entries_stale_at is None): - expires = None - else: - expires = min(metadata_stale_at, entries_stale_at) - if must_revalidate: - headers["Cache-Control"] = "must-revalidate" - return json_or_msgpack( - request, - resource.model_dump(), - expires=expires, - headers=headers, - ) - except NoEntry: - raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="No such entry.") - except WrongTypeForRoute as err: - raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=err.args[0]) - except JMESPathError as err: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"Malformed 'select_metadata' parameter raised JMESPathError: {err}", - ) +def get_router( + query_registry: QueryRegistry, + serialization_registry: SerializationRegistry, + deserialization_registry: SerializationRegistry, + validation_registry: ValidationRegistry, +) -> APIRouter: + router = APIRouter() + + T = TypeVar("T") + + def patch_route_signature(route: Callable[..., T]) -> Callable[..., T]: + """ + This is done dynamically at router startup. + + We check the registry of known search query types, which is user + configurable, and use that to define the allowed HTTP query parameters for + this route. + + Take a route that accept unspecified search queries as **filters. + Return a wrapped version of the route that has the supported + search queries explicitly spelled out in the function signature. + + This has no change in the actual behavior of the function, + but it enables FastAPI to generate good OpenAPI documentation + showing the supported search queries. + + """ + + # Build a wrapper so that we can modify the signature + # without mutating the wrapped original. + + async def route_with_sig(*args, **kwargs): + return await route(*args, **kwargs) + + # Black magic here! FastAPI bases its validation and auto-generated swagger + # documentation on the signature of the route function. We do not know what + # that signature should be at compile-time. We only know it once we have a + # chance to check the user-configurable registry of query types. Therefore, + # we modify the signature here, at runtime, just before handing it to + # FastAPI in the usual way. + + # When FastAPI calls the function with these added parameters, they will be + # accepted via **filters. + + # Make a copy of the original parameters. + signature = inspect.signature(route) + parameters = list(signature.parameters.values()) + # Drop the **filters parameter from the signature. + del parameters[-1] + # Add a parameter for each field in each type of query. + for name, query in query_registry.name_to_query_type.items(): + for field in dataclasses.fields(query): + # The structured "alias" here is based on + # https://mglaman.dev/blog/using-json-router-query-your-search-router-indexes + if getattr(field.type, "__origin__", None) is list: + field_type = str + else: + field_type = field.type + injected_parameter = inspect.Parameter( + name=f"filter___{name}___{field.name}", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=Query( + None, alias=f"filter[{name}][condition][{field.name}]" + ), + annotation=Optional[List[field_type]], + ) + parameters.append(injected_parameter) + route_with_sig.__signature__ = signature.replace(parameters=parameters) + # End black magic + return route_with_sig -async def distinct( - request: Request, - structure_families: bool = False, - specs: bool = False, - metadata: Optional[List[str]] = Query(default=[]), - counts: bool = False, - entry: Any = SecureEntry(scopes=["read:metadata"]), - query_registry=Depends(get_query_registry), - **filters, -): - if hasattr(entry, "get_distinct"): - filtered = await apply_search(entry, filters, query_registry) - distinct = await ensure_awaitable( - filtered.get_distinct, metadata, structure_families, specs, counts - ) + @router.get("/", response_model=About) + async def about( + request: Request, + settings: Settings = Depends(get_settings), + authenticators=Depends(get_authenticators), + ): + # TODO The lazy import of entry modules and serializers means that the + # lists of formats are not populated until they are first used. Not very + # helpful for discovery! The registration can be made non-lazy, while the + # imports of the underlying I/O libraries themselves (openpyxl, pillow, + # etc.) can remain lazy. + request.state.endpoint = "about" + base_url = get_base_url(request) + authentication = { + "required": not settings.allow_anonymous_access, + } + provider_specs = [] + user_agent = request.headers.get("user-agent", "") + # The name of the "internal" mode used to be "password". + # This ensures back-compat with older Python clients. + internal_mode_name = "internal" + MINIMUM_INTERNAL_PYTHON_CLIENT_VERSION = packaging.version.parse("0.1.0b17") + if user_agent.startswith("python-tiled/"): + agent, _, raw_version = user_agent.partition("/") + try: + parsed_version = packaging.version.parse(raw_version) + except Exception: + pass + else: + if parsed_version < MINIMUM_INTERNAL_PYTHON_CLIENT_VERSION: + internal_mode_name = "password" + for provider, authenticator in authenticators.items(): + if isinstance(authenticator, InternalAuthenticator): + spec = { + "provider": provider, + "mode": internal_mode_name, + "links": { + "auth_endpoint": f"{base_url}/auth/provider/{provider}/token" + }, + "confirmation_message": getattr( + authenticator, "confirmation_message", None + ), + } + elif isinstance(authenticator, ExternalAuthenticator): + spec = { + "provider": provider, + "mode": "external", + "links": { + "auth_endpoint": f"{base_url}/auth/provider/{provider}/authorize" + }, + "confirmation_message": getattr( + authenticator, "confirmation_message", None + ), + } + else: + # It should be impossible to reach here. + assert False + provider_specs.append(spec) + if provider_specs: + # If there are *any* authenticaiton providers, these + # endpoints will be added. + authentication["links"] = { + "whoami": f"{base_url}/auth/whoami", + "apikey": f"{base_url}/auth/apikey", + "refresh_session": f"{base_url}/auth/session/refresh", + "revoke_session": f"{base_url}/auth/session/revoke/{{session_id}}", + "logout": f"{base_url}/auth/logout", + } + authentication["providers"] = provider_specs return json_or_msgpack( - request, schemas.GetDistinctResponse.model_validate(distinct).model_dump() - ) - else: - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support distinct.", + request, + About( + library_version=__version__, + api_version=0, + formats={ + structure_family: list( + serialization_registry.media_types(structure_family) + ) + for structure_family in serialization_registry.structure_families + }, + aliases={ + structure_family: serialization_registry.aliases(structure_family) + for structure_family in serialization_registry.structure_families + }, + queries=list(query_registry.name_to_query_type), + authentication=authentication, + links={ + "self": base_url, + "documentation": f"{base_url}/docs", + }, + meta={"root_path": request.scope.get("root_path") or "" + "/api"}, + ).model_dump(), + expires=datetime.now(timezone.utc) + timedelta(seconds=600), ) - -def patch_route_signature(route, query_registry): - """ - This is done dynamically at router startup. - - We check the registry of known search query types, which is user - configurable, and use that to define the allowed HTTP query parameters for - this route. - - Take a route that accept unspecified search queries as **filters. - Return a wrapped version of the route that has the supported - search queries explicitly spelled out in the function signature. - - This has no change in the actual behavior of the function, - but it enables FastAPI to generate good OpenAPI documentation - showing the supported search queries. - - """ - - # Build a wrapper so that we can modify the signature - # without mutating the wrapped original. - - async def route_with_sig(*args, **kwargs): - return await route(*args, **kwargs) - - # Black magic here! FastAPI bases its validation and auto-generated swagger - # documentation on the signature of the route function. We do not know what - # that signature should be at compile-time. We only know it once we have a - # chance to check the user-configurable registry of query types. Therefore, - # we modify the signature here, at runtime, just before handing it to - # FastAPI in the usual way. - - # When FastAPI calls the function with these added parameters, they will be - # accepted via **filters. - - # Make a copy of the original parameters. - signature = inspect.signature(route) - parameters = list(signature.parameters.values()) - # Drop the **filters parameter from the signature. - del parameters[-1] - # Add a parameter for each field in each type of query. - for name, query in query_registry.name_to_query_type.items(): - for field in dataclasses.fields(query): - # The structured "alias" here is based on - # https://mglaman.dev/blog/using-json-router-query-your-search-router-indexes - if getattr(field.type, "__origin__", None) is list: - field_type = str + @router.get( + "/api/v1/search/{path:path}", + response_model=schemas.Response[ + List[schemas.Resource[schemas.NodeAttributes, dict, dict]], + schemas.PaginationLinks, + dict, + ], + ) + @patch_route_signature + async def search( + request: Request, + path: str, + fields: Optional[List[schemas.EntryFields]] = Query(list(schemas.EntryFields)), + select_metadata: Optional[str] = Query(None), + offset: Optional[int] = Query(0, alias="page[offset]", ge=0), + limit: Optional[int] = Query( + DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE + ), + sort: Optional[str] = Query(None), + max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), + omit_links: bool = Query(False), + include_data_sources: bool = Query(False), + entry: Any = SecureEntry(scopes=["read:metadata"]), + principal: str = Depends(get_current_principal), + **filters, + ): + request.state.endpoint = "search" + if entry.structure_family != StructureFamily.container: + raise WrongTypeForRoute( + "This is not a Node; it cannot be searched or listed." + ) + try: + ( + resource, + metadata_stale_at, + must_revalidate, + ) = await construct_entries_response( + query_registry, + entry, + "/search", + path, + offset, + limit, + fields, + select_metadata, + omit_links, + include_data_sources, + filters, + sort, + get_base_url(request), + resolve_media_type(request), + max_depth=max_depth, + ) + # We only get one Expires header, so if different parts + # of this response become stale at different times, we + # cite the earliest one. + entries_stale_at = getattr(entry, "entries_stale_at", None) + headers = {} + if (metadata_stale_at is None) or (entries_stale_at is None): + expires = None else: - field_type = field.type - injected_parameter = inspect.Parameter( - name=f"filter___{name}___{field.name}", - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=Query(None, alias=f"filter[{name}][condition][{field.name}]"), - annotation=Optional[List[field_type]], + expires = min(metadata_stale_at, entries_stale_at) + if must_revalidate: + headers["Cache-Control"] = "must-revalidate" + return json_or_msgpack( + request, + resource.model_dump(), + expires=expires, + headers=headers, + ) + except NoEntry: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="No such entry.") + except WrongTypeForRoute as err: + raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail=err.args[0]) + except JMESPathError as err: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Malformed 'select_metadata' parameter raised JMESPathError: {err}", ) - parameters.append(injected_parameter) - route_with_sig.__signature__ = signature.replace(parameters=parameters) - # End black magic - - return route_with_sig + @router.get( + "/api/v1/distinct/{path:path}", + response_model=schemas.GetDistinctResponse, + ) + @patch_route_signature + async def distinct( + request: Request, + structure_families: bool = False, + specs: bool = False, + metadata: Optional[List[str]] = Query(default=[]), + counts: bool = False, + entry: Any = SecureEntry(scopes=["read:metadata"]), + **filters, + ): + if hasattr(entry, "get_distinct"): + filtered = await apply_search(entry, filters, query_registry) + distinct = await ensure_awaitable( + filtered.get_distinct, metadata, structure_families, specs, counts + ) -@router.get( - "/metadata/{path:path}", - response_model=schemas.Response[ - schemas.Resource[schemas.NodeAttributes, dict, dict], dict, dict - ], -) -async def metadata( - request: Request, - path: str, - fields: Optional[List[schemas.EntryFields]] = Query(list(schemas.EntryFields)), - select_metadata: Optional[str] = Query(None), - max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), - omit_links: bool = Query(False), - include_data_sources: bool = Query(False), - entry: Any = SecureEntry(scopes=["read:metadata"]), - root_path: bool = Query(False), -): - """Fetch the metadata and structure information for one entry""" - - request.state.endpoint = "metadata" - base_url = get_base_url(request) - path_parts = [segment for segment in path.split("/") if segment] - try: - resource = await construct_resource( - base_url, - path_parts, - entry, - fields, - select_metadata, - omit_links, - include_data_sources, - resolve_media_type(request), - max_depth=max_depth, - ) - except JMESPathError as err: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"Malformed 'select_metadata' parameter raised JMESPathError: {err}", - ) - meta = {"root_path": request.scope.get("root_path") or "/"} if root_path else {} + return json_or_msgpack( + request, + schemas.GetDistinctResponse.model_validate(distinct).model_dump(), + ) + else: + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support distinct.", + ) - return json_or_msgpack( - request, - schemas.Response(data=resource, meta=meta).model_dump(), - expires=getattr(entry, "metadata_stale_at", None), + @router.get( + "/metadata/{path:path}", + response_model=schemas.Response[ + schemas.Resource[schemas.NodeAttributes, dict, dict], dict, dict + ], ) + async def metadata( + request: Request, + path: str, + fields: Optional[List[schemas.EntryFields]] = Query(list(schemas.EntryFields)), + select_metadata: Optional[str] = Query(None), + max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), + omit_links: bool = Query(False), + include_data_sources: bool = Query(False), + entry: Any = SecureEntry(scopes=["read:metadata"]), + root_path: bool = Query(False), + ): + """Fetch the metadata and structure information for one entry""" + request.state.endpoint = "metadata" + base_url = get_base_url(request) + path_parts = [segment for segment in path.split("/") if segment] + try: + resource = await construct_resource( + base_url, + path_parts, + entry, + fields, + select_metadata, + omit_links, + include_data_sources, + resolve_media_type(request), + max_depth=max_depth, + ) + except JMESPathError as err: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Malformed 'select_metadata' parameter raised JMESPathError: {err}", + ) + meta = {"root_path": request.scope.get("root_path") or "/"} if root_path else {} -@router.get( - "/array/block/{path:path}", response_model=schemas.Response, name="array block" -) -async def array_block( - request: Request, - entry=SecureEntry( - scopes=["read:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, - ), - block=Depends(block), - slice=Depends(slice_), - expected_shape=Depends(expected_shape), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a chunk of array-like data. - """ - shape = entry.structure().shape - # Check that block dimensionality matches array dimensionality. - ndim = len(shape) - if len(block) != ndim: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Block parameter must have {ndim} comma-separated parameters, " - f"corresponding to the dimensions of this {ndim}-dimensional array." - ), + return json_or_msgpack( + request, + schemas.Response(data=resource, meta=meta).model_dump(), + expires=getattr(entry, "metadata_stale_at", None), ) - if block == (): - # Handle special case of numpy scalar. - if shape != (): + + @router.get( + "/array/block/{path:path}", response_model=schemas.Response, name="array block" + ) + async def array_block( + request: Request, + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), + block=Depends(block), + slice=Depends(slice_), + expected_shape=Depends(expected_shape), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a chunk of array-like data. + """ + shape = entry.structure().shape + # Check that block dimensionality matches array dimensionality. + ndim = len(shape) + if len(block) != ndim: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, - detail=f"Requested scalar but shape is {entry.structure().shape}", + detail=( + f"Block parameter must have {ndim} comma-separated parameters, " + f"corresponding to the dimensions of this {ndim}-dimensional array." + ), + ) + if block == (): + # Handle special case of numpy scalar. + if shape != (): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Requested scalar but shape is {entry.structure().shape}", + ) + with record_timing(request.state.metrics, "read"): + array = await ensure_awaitable(entry.read) + else: + try: + with record_timing(request.state.metrics, "read"): + array = await ensure_awaitable(entry.read_block, block, slice) + except IndexError: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail="Block index out of range" + ) + if (expected_shape is not None) and (expected_shape != array.shape): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"The expected_shape {expected_shape} does not match the actual shape {array.shape}", + ) + if array.nbytes > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), ) - with record_timing(request.state.metrics, "read"): - array = await ensure_awaitable(entry.read) - else: + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + entry.structure_family, + serialization_registry, + array, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + # raise HTTPException(status_code=406, detail=", ".join(err.supported)) + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + + @router.get( + "/array/full/{path:path}", response_model=schemas.Response, name="full array" + ) + async def array_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), + slice=Depends(slice_), + expected_shape=Depends(expected_shape), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a slice of array-like data. + """ + structure_family = entry.structure_family + # Deferred import because this is not a required dependency of the server + # for some use cases. + import numpy + try: with record_timing(request.state.metrics, "read"): - array = await ensure_awaitable(entry.read_block, block, slice) + array = await ensure_awaitable(entry.read, slice) + if structure_family == StructureFamily.array: + array = numpy.asarray(array) # Force dask or PIMS or ... to do I/O. except IndexError: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, detail="Block index out of range" @@ -424,1330 +505,1224 @@ async def array_block( status_code=HTTP_400_BAD_REQUEST, detail=f"The expected_shape {expected_shape} does not match the actual shape {array.shape}", ) - if array.nbytes > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Use slicing ('?slice=...') to request smaller chunks." - ), - ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - entry.structure_family, - serialization_registry, - array, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, + if array.nbytes > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), ) - except UnsupportedMediaTypes as err: - # raise HTTPException(status_code=406, detail=", ".join(err.supported)) - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + structure_family, + serialization_registry, + array, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) -@router.get( - "/array/full/{path:path}", response_model=schemas.Response, name="full array" -) -async def array_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, - ), - slice=Depends(slice_), - expected_shape=Depends(expected_shape), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a slice of array-like data. - """ - structure_family = entry.structure_family - # Deferred import because this is not a required dependency of the server - # for some use cases. - import numpy - - try: - with record_timing(request.state.metrics, "read"): - array = await ensure_awaitable(entry.read, slice) - if structure_family == StructureFamily.array: - array = numpy.asarray(array) # Force dask or PIMS or ... to do I/O. - except IndexError: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail="Block index out of range" - ) - if (expected_shape is not None) and (expected_shape != array.shape): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"The expected_shape {expected_shape} does not match the actual shape {array.shape}", - ) - if array.nbytes > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Use slicing ('?slice=...') to request smaller chunks." - ), - ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - structure_family, - serialization_registry, - array, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, + @router.get( + "/table/partition/{path:path}", + response_model=schemas.Response, + name="table partition", + ) + async def get_table_partition( + request: Request, + partition: int, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.table} + ), + column: Optional[List[str]] = Query(None, min_length=1), + field: Optional[List[str]] = Query(None, min_length=1, deprecated=True), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a partition (continuous block of rows) from a DataFrame [GET route]. + """ + if (field is not None) and (column is not None): + redundant_field_and_column = " ".join( + ( + "Cannot accept both 'column' and 'field' query parameters", + "in the same /table/partition request.", + "Include these query values using only the 'column' parameter.", + ) ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - - -@router.get( - "/table/partition/{path:path}", - response_model=schemas.Response, - name="table partition", -) -async def get_table_partition( - request: Request, - partition: int, - entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), - column: Optional[List[str]] = Query(None, min_length=1), - field: Optional[List[str]] = Query(None, min_length=1, deprecated=True), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a partition (continuous block of rows) from a DataFrame [GET route]. - """ - if (field is not None) and (column is not None): - redundant_field_and_column = " ".join( - ( - "Cannot accept both 'column' and 'field' query parameters", - "in the same /table/partition request.", - "Include these query values using only the 'column' parameter.", + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=redundant_field_and_column ) - ) - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=redundant_field_and_column - ) - elif field is not None: - field_is_deprecated = " ".join( - ( - "Query parameter 'field' is deprecated for the /table/partition route.", - "Instead use the query parameter 'column'.", + elif field is not None: + field_is_deprecated = " ".join( + ( + "Query parameter 'field' is deprecated for the /table/partition route.", + "Instead use the query parameter 'column'.", + ) ) + warnings.warn(field_is_deprecated, DeprecationWarning) + return await table_partition( + request=request, + partition=partition, + entry=entry, + column=(column or field), + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, ) - warnings.warn(field_is_deprecated, DeprecationWarning) - return await table_partition( - request=request, - partition=partition, - entry=entry, - column=(column or field), - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, - ) - -@router.post( - "/table/partition/{path:path}", - response_model=schemas.Response, - name="table partition", -) -async def post_table_partition( - request: Request, - partition: int, - entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), - column: Optional[List[str]] = Body(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a partition (continuous block of rows) from a DataFrame [POST route]. - """ - return await table_partition( - request=request, - partition=partition, - entry=entry, - column=column, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.post( + "/table/partition/{path:path}", + response_model=schemas.Response, + name="table partition", ) - - -async def table_partition( - request: Request, - partition: int, - entry, - column: Optional[List[str]], - format: Optional[str], - filename: Optional[str], - serialization_registry, - settings: Settings, -): - """ - Fetch a partition (continuous block of rows) from a DataFrame. - """ - try: - # The singular/plural mismatch here of "fields" and "field" is - # due to the ?field=A&field=B&field=C... encodes in a URL. - with record_timing(request.state.metrics, "read"): - df = await ensure_awaitable(entry.read_partition, partition, column) - except IndexError: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail="Partition out of range" - ) - except KeyError as err: - (key,) = err.args - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." - ) - if df.memory_usage().sum() > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Select a subset of the columns ('?field=...') to " - "request a smaller chunks." - ), + async def post_table_partition( + request: Request, + partition: int, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.table} + ), + column: Optional[List[str]] = Body(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a partition (continuous block of rows) from a DataFrame [POST route]. + """ + return await table_partition( + request=request, + partition=partition, + entry=entry, + column=column, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - StructureFamily.table, - serialization_registry, - df, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, - ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + async def table_partition( + request: Request, + partition: int, + entry, + column: Optional[List[str]], + format: Optional[str], + filename: Optional[str], + serialization_registry, + settings: Settings, + ): + """ + Fetch a partition (continuous block of rows) from a DataFrame. + """ + try: + # The singular/plural mismatch here of "fields" and "field" is + # due to the ?field=A&field=B&field=C... encodes in a URL. + with record_timing(request.state.metrics, "read"): + df = await ensure_awaitable(entry.read_partition, partition, column) + except IndexError: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail="Partition out of range" + ) + except KeyError as err: + (key,) = err.args + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." + ) + if df.memory_usage().sum() > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Select a subset of the columns ('?field=...') to " + "request a smaller chunks." + ), + ) + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + StructureFamily.table, + serialization_registry, + df, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) -@router.get( - "/table/full/{path:path}", - response_model=schemas.Response, - name="full 'table' data", -) -async def get_table_full( - request: Request, - entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), - column: Optional[List[str]] = Query(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch the data for the given table [GET route]. - """ - return await table_full( - request=request, - entry=entry, - column=column, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.get( + "/table/full/{path:path}", + response_model=schemas.Response, + name="full 'table' data", ) + async def get_table_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.table} + ), + column: Optional[List[str]] = Query(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch the data for the given table [GET route]. + """ + return await table_full( + request=request, + entry=entry, + column=column, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, + ) - -@router.post( - "/table/full/{path:path}", - response_model=schemas.Response, - name="full 'table' data", -) -async def post_table_full( - request: Request, - entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), - column: Optional[List[str]] = Body(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch the data for the given table [POST route]. - """ - return await table_full( - request=request, - entry=entry, - column=column, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.post( + "/table/full/{path:path}", + response_model=schemas.Response, + name="full 'table' data", ) - - -async def table_full( - request: Request, - entry, - column: Optional[List[str]], - format: Optional[str], - filename: Optional[str], - serialization_registry, - settings: Settings, -): - """ - Fetch the data for the given table. - """ - try: - with record_timing(request.state.metrics, "read"): - data = await ensure_awaitable(entry.read, column) - except KeyError as err: - (key,) = err.args - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." - ) - if data.memory_usage().sum() > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Select a subset of the columns to " - "request a smaller chunks." - ), + async def post_table_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.table} + ), + column: Optional[List[str]] = Body(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch the data for the given table [POST route]. + """ + return await table_full( + request=request, + entry=entry, + column=column, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - entry.structure_family, - serialization_registry, - data, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, - filter_for_access=None, - ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + async def table_full( + request: Request, + entry, + column: Optional[List[str]], + format: Optional[str], + filename: Optional[str], + serialization_registry, + settings: Settings, + ): + """ + Fetch the data for the given table. + """ + try: + with record_timing(request.state.metrics, "read"): + data = await ensure_awaitable(entry.read, column) + except KeyError as err: + (key,) = err.args + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." + ) + if data.memory_usage().sum() > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Select a subset of the columns to " + "request a smaller chunks." + ), + ) + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + entry.structure_family, + serialization_registry, + data, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + filter_for_access=None, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) -@router.get( - "/container/full/{path:path}", - response_model=schemas.Response, - name="full 'container' metadata and data", -) -async def get_container_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.container} - ), - principal: str = Depends(get_current_principal), - field: Optional[List[str]] = Query(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), -): - """ - Fetch the data for the given container via a GET request. - """ - return await container_full( - request=request, - entry=entry, - principal=principal, - field=field, - format=format, - filename=filename, - serialization_registry=serialization_registry, + @router.get( + "/container/full/{path:path}", + response_model=schemas.Response, + name="full 'container' metadata and data", ) + async def get_container_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), + principal: str = Depends(get_current_principal), + field: Optional[List[str]] = Query(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + ): + """ + Fetch the data for the given container via a GET request. + """ + return await container_full( + request=request, + entry=entry, + principal=principal, + field=field, + format=format, + filename=filename, + serialization_registry=serialization_registry, + ) - -@router.post( - "/container/full/{path:path}", - response_model=schemas.Response, - name="full 'container' metadata and data", -) -async def post_container_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.container} - ), - principal: str = Depends(get_current_principal), - field: Optional[List[str]] = Body(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), -): - """ - Fetch the data for the given container via a POST request. - """ - return await container_full( - request=request, - entry=entry, - principal=principal, - field=field, - format=format, - filename=filename, - serialization_registry=serialization_registry, + @router.post( + "/container/full/{path:path}", + response_model=schemas.Response, + name="full 'container' metadata and data", ) - - -async def container_full( - request: Request, - entry, - principal: str, - field: Optional[List[str]], - format: Optional[str], - filename: Optional[str], - serialization_registry, -): - """ - Fetch the data for the given container. - """ - try: - with record_timing(request.state.metrics, "read"): - data = await ensure_awaitable(entry.read, fields=field) - except KeyError as err: - (key,) = err.args - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." + async def post_container_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), + principal: str = Depends(get_current_principal), + field: Optional[List[str]] = Body(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + ): + """ + Fetch the data for the given container via a POST request. + """ + return await container_full( + request=request, + entry=entry, + principal=principal, + field=field, + format=format, + filename=filename, + serialization_registry=serialization_registry, ) - curried_filter = partial( - filter_for_access, - principal=principal, - scopes=["read:data"], - metrics=request.state.metrics, - ) - # TODO Walk node to determine size before handing off to serializer. - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - entry.structure_family, - serialization_registry, - data, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, - filter_for_access=curried_filter, - ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - -@router.get( - "/node/full/{path:path}", - response_model=schemas.Response, - name="full 'container' or 'table'", - deprecated=True, -) -async def node_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], - structure_families={StructureFamily.table, StructureFamily.container}, - ), - principal: str = Depends(get_current_principal), - field: Optional[List[str]] = Query(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch the data below the given node. - """ - try: - with record_timing(request.state.metrics, "read"): - data = await ensure_awaitable(entry.read, field) - except KeyError as err: - (key,) = err.args - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." - ) - if (entry.structure_family == StructureFamily.table) and ( - data.memory_usage().sum() > settings.response_bytesize_limit + async def container_full( + request: Request, + entry, + principal: str, + field: Optional[List[str]], + format: Optional[str], + filename: Optional[str], + serialization_registry, ): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Select a subset of the columns ('?field=...') to " - "request a smaller chunks." - ), - ) - if entry.structure_family == StructureFamily.container: + """ + Fetch the data for the given container. + """ + try: + with record_timing(request.state.metrics, "read"): + data = await ensure_awaitable(entry.read, fields=field) + except KeyError as err: + (key,) = err.args + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." + ) curried_filter = partial( filter_for_access, principal=principal, scopes=["read:data"], metrics=request.state.metrics, ) - else: - curried_filter = None # TODO Walk node to determine size before handing off to serializer. - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - entry.structure_family, - serialization_registry, - data, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, - filter_for_access=curried_filter, + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + entry.structure_family, + serialization_registry, + data, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + filter_for_access=curried_filter, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + + @router.get( + "/node/full/{path:path}", + response_model=schemas.Response, + name="full 'container' or 'table'", + deprecated=True, + ) + async def node_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.table, StructureFamily.container}, + ), + principal: str = Depends(get_current_principal), + field: Optional[List[str]] = Query(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch the data below the given node. + """ + try: + with record_timing(request.state.metrics, "read"): + data = await ensure_awaitable(entry.read, field) + except KeyError as err: + (key,) = err.args + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, detail=f"No such field {key}." ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - + if (entry.structure_family == StructureFamily.table) and ( + data.memory_usage().sum() > settings.response_bytesize_limit + ): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Select a subset of the columns ('?field=...') to " + "request a smaller chunks." + ), + ) + if entry.structure_family == StructureFamily.container: + curried_filter = partial( + filter_for_access, + principal=principal, + scopes=["read:data"], + metrics=request.state.metrics, + ) + else: + curried_filter = None + # TODO Walk node to determine size before handing off to serializer. + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + entry.structure_family, + serialization_registry, + data, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + filter_for_access=curried_filter, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) -@router.get( - "/awkward/buffers/{path:path}", - response_model=schemas.Response, - name="AwkwardArray buffers", -) -async def get_awkward_buffers( - request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} - ), - form_key: Optional[List[str]] = Query(None, min_length=1), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a slice of AwkwardArray data. - - Note that there is a POST route on this same path with equivalent functionality. - HTTP caches tends to engage with GET but not POST, so that GET route may be - preferred for that reason. However, HTTP clients, servers, and proxies - typically impose a length limit on URLs. (The HTTP spec does not specify - one, but this is a pragmatic measure.) For requests with large numbers of - form_key parameters, POST may be the only option. - """ - return await _awkward_buffers( - request=request, - entry=entry, - form_key=form_key, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.get( + "/awkward/buffers/{path:path}", + response_model=schemas.Response, + name="AwkwardArray buffers", ) + async def get_awkward_buffers( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), + form_key: Optional[List[str]] = Query(None, min_length=1), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a slice of AwkwardArray data. + + Note that there is a POST route on this same path with equivalent functionality. + HTTP caches tends to engage with GET but not POST, so that GET route may be + preferred for that reason. However, HTTP clients, servers, and proxies + typically impose a length limit on URLs. (The HTTP spec does not specify + one, but this is a pragmatic measure.) For requests with large numbers of + form_key parameters, POST may be the only option. + """ + return await _awkward_buffers( + request=request, + entry=entry, + form_key=form_key, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, + ) - -@router.post( - "/awkward/buffers/{path:path}", - response_model=schemas.Response, - name="AwkwardArray buffers", -) -async def post_awkward_buffers( - request: Request, - body: List[str], - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} - ), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a slice of AwkwardArray data. - - Note that there is a GET route on this same path with equivalent functionality. - HTTP caches tends to engage with GET but not POST, so that GET route may be - preferred for that reason. However, HTTP clients, servers, and proxies - typically impose a length limit on URLs. (The HTTP spec does not specify - one, but this is a pragmatic measure.) For requests with large numbers of - form_key parameters, POST may be the only option. - """ - return await _awkward_buffers( - request=request, - entry=entry, - form_key=body, - format=format, - filename=filename, - serialization_registry=serialization_registry, - settings=settings, + @router.post( + "/awkward/buffers/{path:path}", + response_model=schemas.Response, + name="AwkwardArray buffers", ) - - -async def _awkward_buffers( - request: Request, - entry, - form_key: Optional[List[str]], - format: Optional[str], - filename: Optional[str], - serialization_registry, - settings: Settings, -): - structure_family = entry.structure_family - structure = entry.structure() - with record_timing(request.state.metrics, "read"): - # The plural vs. singular mismatch is due to the way query parameters - # are given as ?form_key=A&form_key=B&form_key=C. - container = await ensure_awaitable(entry.read_buffers, form_key) - if ( - sum(len(buffer) for buffer in container.values()) - > settings.response_bytesize_limit + async def post_awkward_buffers( + request: Request, + body: List[str], + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), ): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Use slicing ('?slice=...') to request smaller chunks." - ), + """ + Fetch a slice of AwkwardArray data. + + Note that there is a GET route on this same path with equivalent functionality. + HTTP caches tends to engage with GET but not POST, so that GET route may be + preferred for that reason. However, HTTP clients, servers, and proxies + typically impose a length limit on URLs. (The HTTP spec does not specify + one, but this is a pragmatic measure.) For requests with large numbers of + form_key parameters, POST may be the only option. + """ + return await _awkward_buffers( + request=request, + entry=entry, + form_key=body, + format=format, + filename=filename, + serialization_registry=serialization_registry, + settings=settings, ) - components = (structure.form, structure.length, container) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - structure_family, - serialization_registry, - components, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, + + async def _awkward_buffers( + request: Request, + entry, + form_key: Optional[List[str]], + format: Optional[str], + filename: Optional[str], + serialization_registry, + settings: Settings, + ): + structure_family = entry.structure_family + structure = entry.structure() + with record_timing(request.state.metrics, "read"): + # The plural vs. singular mismatch is due to the way query parameters + # are given as ?form_key=A&form_key=B&form_key=C. + container = await ensure_awaitable(entry.read_buffers, form_key) + if ( + sum(len(buffer) for buffer in container.values()) + > settings.response_bytesize_limit + ): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + components = (structure.form, structure.length, container) + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + structure_family, + serialization_registry, + components, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + @router.get( + "/awkward/full/{path:path}", + response_model=schemas.Response, + name="Full AwkwardArray", + ) + async def awkward_full( + request: Request, + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), + # slice=Depends(slice_), + format: Optional[str] = None, + filename: Optional[str] = None, + settings: Settings = Depends(get_settings), + ): + """ + Fetch a slice of AwkwardArray data. + """ + structure_family = entry.structure_family + # Deferred import because this is not a required dependency of the server + # for some use cases. + import awkward -@router.get( - "/awkward/full/{path:path}", - response_model=schemas.Response, - name="Full AwkwardArray", -) -async def awkward_full( - request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} - ), - # slice=Depends(slice_), - format: Optional[str] = None, - filename: Optional[str] = None, - serialization_registry=Depends(get_serialization_registry), - settings: Settings = Depends(get_settings), -): - """ - Fetch a slice of AwkwardArray data. - """ - structure_family = entry.structure_family - # Deferred import because this is not a required dependency of the server - # for some use cases. - import awkward - - with record_timing(request.state.metrics, "read"): - container = await ensure_awaitable(entry.read) - structure = entry.structure() - components = (structure.form, structure.length, container) - array = awkward.from_buffers(*components) - if array.nbytes > settings.response_bytesize_limit: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - f"Response would exceed {settings.response_bytesize_limit}. " - "Use slicing ('?slice=...') to request smaller chunks." - ), - ) - try: - with record_timing(request.state.metrics, "pack"): - return await construct_data_response( - structure_family, - serialization_registry, - components, - entry.metadata(), - request, - format, - specs=getattr(entry, "specs", []), - expires=getattr(entry, "content_stale_at", None), - filename=filename, + with record_timing(request.state.metrics, "read"): + container = await ensure_awaitable(entry.read) + structure = entry.structure() + components = (structure.form, structure.length, container) + array = awkward.from_buffers(*components) + if array.nbytes > settings.response_bytesize_limit: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + f"Response would exceed {settings.response_bytesize_limit}. " + "Use slicing ('?slice=...') to request smaller chunks." + ), ) - except UnsupportedMediaTypes as err: - raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) - - -@router.post("/metadata/{path:path}", response_model=schemas.PostMetadataResponse) -async def post_metadata( - request: Request, - path: str, - body: schemas.PostMetadataRequest, - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "create"]), -): - for data_source in body.data_sources: - if data_source.assets: + try: + with record_timing(request.state.metrics, "pack"): + return await construct_data_response( + structure_family, + serialization_registry, + components, + entry.metadata(), + request, + format, + specs=getattr(entry, "specs", []), + expires=getattr(entry, "content_stale_at", None), + filename=filename, + ) + except UnsupportedMediaTypes as err: + raise HTTPException(status_code=HTTP_406_NOT_ACCEPTABLE, detail=err.args[0]) + + @router.post("/metadata/{path:path}", response_model=schemas.PostMetadataResponse) + async def post_metadata( + request: Request, + path: str, + body: schemas.PostMetadataRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata", "create"]), + ): + for data_source in body.data_sources: + if data_source.assets: + raise HTTPException( + "Externally-managed assets cannot be registered " + "using POST /metadata/{path} Use POST /register/{path} instead." + ) + if body.data_sources and not getattr(entry, "writable", False): raise HTTPException( - "Externally-managed assets cannot be registered " - "using POST /metadata/{path} Use POST /register/{path} instead." + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail=f"Data cannot be written at the path {path}", ) - if body.data_sources and not getattr(entry, "writable", False): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail=f"Data cannot be written at the path {path}", + return await _create_node( + request=request, + path=path, + body=body, + validation_registry=validation_registry, + settings=settings, + entry=entry, ) - return await _create_node( - request=request, - path=path, - body=body, - validation_registry=validation_registry, - settings=settings, - entry=entry, - ) - - -@router.post("/register/{path:path}", response_model=schemas.PostMetadataResponse) -async def post_register( - request: Request, - path: str, - body: schemas.PostMetadataRequest, - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "create", "register"]), -): - return await _create_node( - request=request, - path=path, - body=body, - validation_registry=validation_registry, - settings=settings, - entry=entry, - ) - - -async def _create_node( - request: Request, - path: str, - body: schemas.PostMetadataRequest, - validation_registry, - settings: Settings, - entry, -): - metadata, structure_family, specs = ( - body.metadata, - body.structure_family, - body.specs, - ) - if structure_family == StructureFamily.container: - structure = None - else: - if len(body.data_sources) != 1: - raise NotImplementedError - structure = body.data_sources[0].structure - - metadata_modified, metadata = await validate_metadata( - metadata=metadata, - structure_family=structure_family, - structure=structure, - specs=specs, - validation_registry=validation_registry, - settings=settings, - ) - - key, node = await entry.create_node( - metadata=body.metadata, - structure_family=body.structure_family, - key=body.id, - specs=body.specs, - data_sources=body.data_sources, - ) - links = links_for_node( - structure_family, structure, get_base_url(request), path + f"/{key}" - ) - response_data = { - "id": key, - "links": links, - "data_sources": [ds.model_dump() for ds in node.data_sources], - } - if metadata_modified: - response_data["metadata"] = metadata - - return json_or_msgpack(request, response_data) - - -@router.put("/data_source/{path:path}") -async def put_data_source( - request: Request, - path: str, - data_source: int, - body: schemas.PutDataSourceRequest, - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "register"]), -): - await entry.put_data_source( - data_source=body.data_source, - ) + @router.post("/register/{path:path}", response_model=schemas.PostMetadataResponse) + async def post_register( + request: Request, + path: str, + body: schemas.PostMetadataRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata", "create", "register"]), + ): + return await _create_node( + request=request, + path=path, + body=body, + validation_registry=validation_registry, + settings=settings, + entry=entry, + ) -@router.delete("/metadata/{path:path}") -async def delete( - request: Request, - entry=SecureEntry(scopes=["write:data", "write:metadata"]), -): - if hasattr(entry, "delete"): - await entry.delete() - else: - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support deletion.", + async def _create_node( + request: Request, + path: str, + body: schemas.PostMetadataRequest, + validation_registry, + settings: Settings, + entry, + ): + metadata, structure_family, specs = ( + body.metadata, + body.structure_family, + body.specs, ) - return json_or_msgpack(request, None) - - -@router.delete("/nodes/{path:path}") -async def bulk_delete( - request: Request, - entry=SecureEntry(scopes=["write:data", "write:metadata"]), -): - if hasattr(entry, "delete_tree"): - await entry.delete_tree() - else: - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support bulk deletion.", + if structure_family == StructureFamily.container: + structure = None + else: + if len(body.data_sources) != 1: + raise NotImplementedError + structure = body.data_sources[0].structure + + metadata_modified, metadata = await validate_metadata( + metadata=metadata, + structure_family=structure_family, + structure=structure, + specs=specs, + settings=settings, ) - return json_or_msgpack(request, None) - - -@router.put("/array/full/{path:path}") -async def put_array_full( - request: Request, - entry=SecureEntry( - scopes=["write:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, - ), - deserialization_registry=Depends(get_deserialization_registry), -): - body = await request.body() - if not hasattr(entry, "write"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node cannot accept array data.", + + key, node = await entry.create_node( + metadata=body.metadata, + structure_family=body.structure_family, + key=body.id, + specs=body.specs, + data_sources=body.data_sources, ) - media_type = request.headers["content-type"] - if entry.structure_family == "array": - dtype = entry.structure().data_type.to_numpy_dtype() - shape = entry.structure().shape - deserializer = deserialization_registry.dispatch("array", media_type) - data = await ensure_awaitable(deserializer, body, dtype, shape) - elif entry.structure_family == "sparse": - deserializer = deserialization_registry.dispatch("sparse", media_type) - data = await ensure_awaitable(deserializer, body) - else: - raise NotImplementedError(entry.structure_family) - await ensure_awaitable(entry.write, data) - return json_or_msgpack(request, None) - - -@router.put("/array/block/{path:path}") -async def put_array_block( - request: Request, - entry=SecureEntry( - scopes=["write:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, - ), - deserialization_registry=Depends(get_deserialization_registry), - block=Depends(block), -): - if not hasattr(entry, "write_block"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node cannot accept array data.", + links = links_for_node( + structure_family, structure, get_base_url(request), path + f"/{key}" + ) + response_data = { + "id": key, + "links": links, + "data_sources": [ds.model_dump() for ds in node.data_sources], + } + if metadata_modified: + response_data["metadata"] = metadata + + return json_or_msgpack(request, response_data) + + @router.put("/data_source/{path:path}") + async def put_data_source( + request: Request, + path: str, + data_source: int, + body: schemas.PutDataSourceRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata", "register"]), + ): + await entry.put_data_source( + data_source=body.data_source, ) - from tiled.adapters.array import slice_and_shape_from_block_and_chunks - body = await request.body() - media_type = request.headers["content-type"] - if entry.structure_family == "array": + @router.delete("/metadata/{path:path}") + async def delete( + request: Request, + entry=SecureEntry(scopes=["write:data", "write:metadata"]), + ): + if hasattr(entry, "delete"): + await entry.delete() + else: + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support deletion.", + ) + return json_or_msgpack(request, None) + + @router.delete("/nodes/{path:path}") + async def bulk_delete( + request: Request, + entry=SecureEntry(scopes=["write:data", "write:metadata"]), + ): + if hasattr(entry, "delete_tree"): + await entry.delete_tree() + else: + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support bulk deletion.", + ) + return json_or_msgpack(request, None) + + @router.put("/array/full/{path:path}") + async def put_array_full( + request: Request, + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), + ): + body = await request.body() + if not hasattr(entry, "write"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) + media_type = request.headers["content-type"] + if entry.structure_family == "array": + dtype = entry.structure().data_type.to_numpy_dtype() + shape = entry.structure().shape + deserializer = deserialization_registry.dispatch("array", media_type) + data = await ensure_awaitable(deserializer, body, dtype, shape) + elif entry.structure_family == "sparse": + deserializer = deserialization_registry.dispatch("sparse", media_type) + data = await ensure_awaitable(deserializer, body) + else: + raise NotImplementedError(entry.structure_family) + await ensure_awaitable(entry.write, data) + return json_or_msgpack(request, None) + + @router.put("/array/block/{path:path}") + async def put_array_block( + request: Request, + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), + block=Depends(block), + ): + if not hasattr(entry, "write_block"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) + from tiled.adapters.array import slice_and_shape_from_block_and_chunks + + body = await request.body() + media_type = request.headers["content-type"] + if entry.structure_family == "array": + dtype = entry.structure().data_type.to_numpy_dtype() + _, shape = slice_and_shape_from_block_and_chunks( + block, entry.structure().chunks + ) + deserializer = deserialization_registry.dispatch("array", media_type) + data = await ensure_awaitable(deserializer, body, dtype, shape) + elif entry.structure_family == "sparse": + deserializer = deserialization_registry.dispatch("sparse", media_type) + data = await ensure_awaitable(deserializer, body) + else: + raise NotImplementedError(entry.structure_family) + await ensure_awaitable(entry.write_block, data, block) + return json_or_msgpack(request, None) + + @router.patch("/array/full/{path:path}") + async def patch_array_full( + request: Request, + offset=Depends(offset_param), + shape=Depends(shape_param), + extend: bool = False, + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array}, + ), + ): + if not hasattr(entry, "patch"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot accept array data.", + ) + dtype = entry.structure().data_type.to_numpy_dtype() - _, shape = slice_and_shape_from_block_and_chunks( - block, entry.structure().chunks - ) + body = await request.body() + media_type = request.headers["content-type"] deserializer = deserialization_registry.dispatch("array", media_type) data = await ensure_awaitable(deserializer, body, dtype, shape) - elif entry.structure_family == "sparse": - deserializer = deserialization_registry.dispatch("sparse", media_type) - data = await ensure_awaitable(deserializer, body) - else: - raise NotImplementedError(entry.structure_family) - await ensure_awaitable(entry.write_block, data, block) - return json_or_msgpack(request, None) - - -@router.patch("/array/full/{path:path}") -async def patch_array_full( - request: Request, - offset=Depends(offset_param), - shape=Depends(shape_param), - extend: bool = False, - entry=SecureEntry( - scopes=["write:data"], - structure_families={StructureFamily.array}, - ), - deserialization_registry=Depends(get_deserialization_registry), -): - if not hasattr(entry, "patch"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node cannot accept array data.", - ) - - dtype = entry.structure().data_type.to_numpy_dtype() - body = await request.body() - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch("array", media_type) - data = await ensure_awaitable(deserializer, body, dtype, shape) - structure = await ensure_awaitable(entry.patch, data, offset, extend) - return json_or_msgpack(request, structure) - - -@router.put("/table/full/{path:path}") -@router.put("/node/full/{path:path}", deprecated=True) -async def put_node_full( - request: Request, - entry=SecureEntry( - scopes=["write:data"], structure_families={StructureFamily.table} - ), - deserialization_registry=Depends(get_deserialization_registry), -): - if not hasattr(entry, "write"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support writing.", - ) - body = await request.body() - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch(StructureFamily.table, media_type) - data = await ensure_awaitable(deserializer, body) - await ensure_awaitable(entry.write, data) - return json_or_msgpack(request, None) - - -@router.put("/table/partition/{path:path}") -async def put_table_partition( - partition: int, - request: Request, - entry=SecureEntry(scopes=["write:data"]), - deserialization_registry=Depends(get_deserialization_registry), -): - if not hasattr(entry, "write_partition"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not supporting writing a partition.", - ) - body = await request.body() - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch(StructureFamily.table, media_type) - data = await ensure_awaitable(deserializer, body) - await ensure_awaitable(entry.write_partition, data, partition) - return json_or_msgpack(request, None) - - -@router.patch("/table/partition/{path:path}") -async def patch_table_partition( - partition: int, - request: Request, - entry=SecureEntry(scopes=["write:data"]), - deserialization_registry=Depends(get_deserialization_registry), -): - if not hasattr(entry, "write_partition"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not supporting writing a partition.", - ) - body = await request.body() - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch(StructureFamily.table, media_type) - data = await ensure_awaitable(deserializer, body) - await ensure_awaitable(entry.append_partition, data, partition) - return json_or_msgpack(request, None) - - -@router.put("/awkward/full/{path:path}") -async def put_awkward_full( - request: Request, - entry=SecureEntry( - scopes=["write:data"], structure_families={StructureFamily.awkward} - ), - deserialization_registry=Depends(get_deserialization_registry), -): - body = await request.body() - if not hasattr(entry, "write"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node cannot be written to.", + structure = await ensure_awaitable(entry.patch, data, offset, extend) + return json_or_msgpack(request, structure) + + @router.put("/table/full/{path:path}") + @router.put("/node/full/{path:path}", deprecated=True) + async def put_node_full( + request: Request, + entry=SecureEntry( + scopes=["write:data"], structure_families={StructureFamily.table} + ), + ): + if not hasattr(entry, "write"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support writing.", + ) + body = await request.body() + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + StructureFamily.table, media_type ) - media_type = request.headers["content-type"] - deserializer = deserialization_registry.dispatch( - StructureFamily.awkward, media_type - ) - structure = entry.structure() - data = await ensure_awaitable(deserializer, body, structure.form, structure.length) - await ensure_awaitable(entry.write, data) - return json_or_msgpack(request, None) - - -@router.patch("/metadata/{path:path}", response_model=schemas.PatchMetadataResponse) -async def patch_metadata( - request: Request, - body: schemas.PatchMetadataRequest, - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata"]), -): - if not hasattr(entry, "replace_metadata"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support update of metadata.", + data = await ensure_awaitable(deserializer, body) + await ensure_awaitable(entry.write, data) + return json_or_msgpack(request, None) + + @router.put("/table/partition/{path:path}") + async def put_table_partition( + partition: int, + request: Request, + entry=SecureEntry(scopes=["write:data"]), + ): + if not hasattr(entry, "write_partition"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not supporting writing a partition.", + ) + body = await request.body() + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + StructureFamily.table, media_type ) - if body.content_type == patch_mimetypes.JSON_PATCH: - metadata = apply_json_patch(entry.metadata(), (body.metadata or [])) - specs = apply_json_patch((entry.specs or []), (body.specs or [])) - elif body.content_type == patch_mimetypes.MERGE_PATCH: - metadata = apply_merge_patch(entry.metadata(), (body.metadata or {})) - # body.specs = [] clears specs, as per json merge patch specification - # but we treat body.specs = None as "no modifications" - current_specs = entry.specs or [] - target_specs = current_specs if body.specs is None else body.specs - specs = apply_merge_patch(current_specs, target_specs) - else: - raise HTTPException( - status_code=HTTP_406_NOT_ACCEPTABLE, - detail=f"valid content types: {', '.join(patch_mimetypes)}", + data = await ensure_awaitable(deserializer, body) + await ensure_awaitable(entry.write_partition, data, partition) + return json_or_msgpack(request, None) + + @router.patch("/table/partition/{path:path}") + async def patch_table_partition( + partition: int, + request: Request, + entry=SecureEntry(scopes=["write:data"]), + ): + if not hasattr(entry, "write_partition"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not supporting writing a partition.", + ) + body = await request.body() + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + StructureFamily.table, media_type ) - - # Manually validate limits that bypass pydantic validation via patch - if len(specs) > schemas.MAX_ALLOWED_SPECS: - raise HTTPException( - status_code=HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Update cannot result in more than {schemas.MAX_ALLOWED_SPECS} specs", + data = await ensure_awaitable(deserializer, body) + await ensure_awaitable(entry.append_partition, data, partition) + return json_or_msgpack(request, None) + + @router.put("/awkward/full/{path:path}") + async def put_awkward_full( + request: Request, + entry=SecureEntry( + scopes=["write:data"], structure_families={StructureFamily.awkward} + ), + ): + body = await request.body() + if not hasattr(entry, "write"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node cannot be written to.", + ) + media_type = request.headers["content-type"] + deserializer = deserialization_registry.dispatch( + StructureFamily.awkward, media_type ) - if len(specs) != len(set(specs)): - raise HTTPException( - status_code=HTTP_422_UNPROCESSABLE_ENTITY, - detail="Update cannot result in non-unique specs", + structure = entry.structure() + data = await ensure_awaitable( + deserializer, body, structure.form, structure.length ) + await ensure_awaitable(entry.write, data) + return json_or_msgpack(request, None) + + @router.patch("/metadata/{path:path}", response_model=schemas.PatchMetadataResponse) + async def patch_metadata( + request: Request, + body: schemas.PatchMetadataRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata"]), + ): + if not hasattr(entry, "replace_metadata"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support update of metadata.", + ) + if body.content_type == patch_mimetypes.JSON_PATCH: + metadata = apply_json_patch(entry.metadata(), (body.metadata or [])) + specs = apply_json_patch((entry.specs or []), (body.specs or [])) + elif body.content_type == patch_mimetypes.MERGE_PATCH: + metadata = apply_merge_patch(entry.metadata(), (body.metadata or {})) + # body.specs = [] clears specs, as per json merge patch specification + # but we treat body.specs = None as "no modifications" + current_specs = entry.specs or [] + target_specs = current_specs if body.specs is None else body.specs + specs = apply_merge_patch(current_specs, target_specs) + else: + raise HTTPException( + status_code=HTTP_406_NOT_ACCEPTABLE, + detail=f"valid content types: {', '.join(patch_mimetypes)}", + ) - structure_family, structure = ( - entry.structure_family, - entry.structure(), - ) + # Manually validate limits that bypass pydantic validation via patch + if len(specs) > schemas.MAX_ALLOWED_SPECS: + raise HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Update cannot result in more than {schemas.MAX_ALLOWED_SPECS} specs", + ) + if len(specs) != len(set(specs)): + raise HTTPException( + status_code=HTTP_422_UNPROCESSABLE_ENTITY, + detail="Update cannot result in non-unique specs", + ) - metadata_modified, metadata = await validate_metadata( - metadata=metadata, - structure_family=structure_family, - structure=structure, - specs=[Spec(x) for x in specs], - validation_registry=validation_registry, - settings=settings, - ) + structure_family, structure = ( + entry.structure_family, + entry.structure(), + ) - await entry.replace_metadata(metadata=metadata, specs=specs) - - response_data = {"id": entry.key} - if metadata_modified: - response_data["metadata"] = metadata - return json_or_msgpack(request, response_data) - - -@router.put("/metadata/{path:path}", response_model=schemas.PutMetadataResponse) -async def put_metadata( - request: Request, - body: schemas.PutMetadataRequest, - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata"]), -): - if not hasattr(entry, "replace_metadata"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support update of metadata.", + metadata_modified, metadata = await validate_metadata( + metadata=metadata, + structure_family=structure_family, + structure=structure, + specs=[Spec(x) for x in specs], + settings=settings, ) - metadata, structure_family, structure, specs = ( - body.metadata if body.metadata is not None else entry.metadata(), - entry.structure_family, - entry.structure(), - body.specs if body.specs is not None else entry.specs, - ) + await entry.replace_metadata(metadata=metadata, specs=specs) - metadata_modified, metadata = await validate_metadata( - metadata=metadata, - structure_family=structure_family, - structure=structure, - specs=specs, - validation_registry=validation_registry, - settings=settings, - ) + response_data = {"id": entry.key} + if metadata_modified: + response_data["metadata"] = metadata + return json_or_msgpack(request, response_data) - await entry.replace_metadata(metadata=metadata, specs=specs) - - response_data = {"id": entry.key} - if metadata_modified: - response_data["metadata"] = metadata - return json_or_msgpack(request, response_data) - - -@router.get("/revisions/{path:path}") -async def get_revisions( - request: Request, - path: str, - offset: Optional[int] = Query(0, alias="page[offset]", ge=0), - limit: Optional[int] = Query( - DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE - ), - entry=SecureEntry(scopes=["read:metadata"]), -): - if not hasattr(entry, "revisions"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support revisions.", - ) + @router.put("/metadata/{path:path}", response_model=schemas.PutMetadataResponse) + async def put_metadata( + request: Request, + body: schemas.PutMetadataRequest, + settings: Settings = Depends(get_settings), + entry=SecureEntry(scopes=["write:metadata"]), + ): + if not hasattr(entry, "replace_metadata"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support update of metadata.", + ) - base_url = get_base_url(request) - resource = await construct_revisions_response( - entry, - base_url, - "/revisions", - path, - offset, - limit, - resolve_media_type(request), - ) - return json_or_msgpack(request, resource.model_dump()) - - -@router.delete("/revisions/{path:path}") -async def delete_revision( - request: Request, - number: int, - entry=SecureEntry(scopes=["write:metadata"]), -): - if not hasattr(entry, "revisions"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support a del request for revisions.", + metadata, structure_family, structure, specs = ( + body.metadata if body.metadata is not None else entry.metadata(), + entry.structure_family, + entry.structure(), + body.specs if body.specs is not None else entry.specs, ) - await entry.delete_revision(number) - return json_or_msgpack(request, None) - - -# For simplicity of implementation, we support a restricted subset of the full -# Range spec. This could be extended if the need arises. -# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range -RANGE_HEADER_PATTERN = re.compile(r"^bytes=(\d+)-(\d+)$") - - -@router.get("/asset/bytes/{path:path}") -async def get_asset( - request: Request, - id: int, - relative_path: Optional[Path] = None, - entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? - settings: Settings = Depends(get_settings), -): - if not settings.expose_raw_assets: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail=( - "This Tiled server is configured not to allow " - "downloading raw assets." - ), - ) - if not hasattr(entry, "asset_by_id"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support downloading assets.", + metadata_modified, metadata = await validate_metadata( + metadata=metadata, + structure_family=structure_family, + structure=structure, + specs=specs, + settings=settings, ) - asset = await entry.asset_by_id(id) - if asset is None: - raise HTTPException( - status_code=HTTP_404_NOT_FOUND, - detail=f"This node exists but it does not have an Asset with id {id}", - ) - if asset.is_directory: - if relative_path is None: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=( - "This asset is a directory. Must specify relative path, " - f"from manifest provided by /asset/manifest/...?id={id}" - ), - ) - if relative_path.is_absolute(): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="relative_path query parameter must be a *relative* path", - ) - else: - if relative_path is not None: + await entry.replace_metadata(metadata=metadata, specs=specs) + + response_data = {"id": entry.key} + if metadata_modified: + response_data["metadata"] = metadata + return json_or_msgpack(request, response_data) + + @router.get("/revisions/{path:path}") + async def get_revisions( + request: Request, + path: str, + offset: Optional[int] = Query(0, alias="page[offset]", ge=0), + limit: Optional[int] = Query( + DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE + ), + entry=SecureEntry(scopes=["read:metadata"]), + ): + if not hasattr(entry, "revisions"): raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="This asset is not a directory. The relative_path query parameter must not be set.", + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support revisions.", ) - if not asset.data_uri.startswith("file:"): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="Only download assets stored as file:// is currently supported.", + + base_url = get_base_url(request) + resource = await construct_revisions_response( + entry, + base_url, + "/revisions", + path, + offset, + limit, + resolve_media_type(request), ) - path = path_from_uri(asset.data_uri) - if relative_path is not None: - # Be doubly sure that this is under the Asset's data_uri directory - # and not sneakily escaping it. - if not os.path.commonpath([path, path / relative_path]) != path: - # This should not be possible. - raise RuntimeError( - f"Refusing to serve {path / relative_path} because it is outside " - "of the Asset's directory" + return json_or_msgpack(request, resource.model_dump()) + + @router.delete("/revisions/{path:path}") + async def delete_revision( + request: Request, + number: int, + entry=SecureEntry(scopes=["write:metadata"]), + ): + if not hasattr(entry, "revisions"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support a del request for revisions.", ) - full_path = path / relative_path - else: - full_path = path - stat_result = await anyio.to_thread.run_sync(os.stat, full_path) - filename = full_path.name - if "range" in request.headers: - range_header = request.headers["range"] - match = RANGE_HEADER_PATTERN.match(range_header) - if match is None: + + await entry.delete_revision(number) + return json_or_msgpack(request, None) + + # For simplicity of implementation, we support a restricted subset of the full + # Range spec. This could be extended if the need arises. + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range + RANGE_HEADER_PATTERN = re.compile(r"^bytes=(\d+)-(\d+)$") + + @router.get("/asset/bytes/{path:path}") + async def get_asset( + request: Request, + id: int, + relative_path: Optional[Path] = None, + entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? + settings: Settings = Depends(get_settings), + ): + if not settings.expose_raw_assets: raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, + status_code=HTTP_403_FORBIDDEN, detail=( - "Only a Range headers of the form 'bytes=start-end' are supported. " - f"Could not parse Range header: {range_header}", + "This Tiled server is configured not to allow " + "downloading raw assets." ), ) - range = start, _ = (int(match.group(1)), int(match.group(2))) - if start > stat_result.st_size: + if not hasattr(entry, "asset_by_id"): raise HTTPException( - status_code=HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, - headers={"content-range": f"bytes */{stat_result.st_size}"}, + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support downloading assets.", ) - status_code = HTTP_206_PARTIAL_CONTENT - else: - range = None - status_code = HTTP_200_OK - return FileResponseWithRange( - full_path, - stat_result=stat_result, - status_code=status_code, - headers={"Content-Disposition": f'attachment; filename="{filename}"'}, - range=range, - ) - - -@router.get("/asset/manifest/{path:path}") -async def get_asset_manifest( - request: Request, - id: int, - entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? - settings: Settings = Depends(get_settings), -): - if not settings.expose_raw_assets: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail=( - "This Tiled server is configured not to allow " - "downloading raw assets." - ), - ) - if not hasattr(entry, "asset_by_id"): - raise HTTPException( - status_code=HTTP_405_METHOD_NOT_ALLOWED, - detail="This node does not support downloading assets.", - ) - asset = await entry.asset_by_id(id) - if asset is None: - raise HTTPException( - status_code=HTTP_404_NOT_FOUND, - detail=f"This node exists but it does not have an Asset with id {id}", - ) - if not asset.is_directory: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="This asset is not a directory. There is no manifest.", - ) - if not asset.data_uri.startswith("file:"): - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="Only download assets stored as file:// is currently supported.", - ) - path = path_from_uri(asset.data_uri) - manifest = [] - # Walk the directory and any subdirectories. Aggregate a list of all the - # files, given as paths relative to the directory root. - for root, _directories, files in os.walk(path): - manifest.extend(Path(root, file) for file in files) - return json_or_msgpack(request, {"manifest": manifest}) - - -async def validate_metadata( - metadata: dict, - structure_family: StructureFamily, - structure, - specs: List[Spec], - validation_registry=Depends(get_validation_registry), - settings: Settings = Depends(get_settings), -): - metadata_modified = False - - # Specs should be ordered from most specific/constrained to least. - # Validate them in reverse order, with the least constrained spec first, - # because it may do normalization that helps pass the more constrained one. - # Known Issue: - # When there is more than one spec, it's possible for the validator for - # Spec 2 to make a modification that breaks the validation for Spec 1. - # For now we leave it to the server maintainer to ensure that validators - # won't step on each other in this way, but this may need revisiting. - for spec in reversed(specs): - if spec.name not in validation_registry: - if settings.reject_undeclared_specs: + asset = await entry.asset_by_id(id) + if asset is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"This node exists but it does not have an Asset with id {id}", + ) + if asset.is_directory: + if relative_path is None: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=( + "This asset is a directory. Must specify relative path, " + f"from manifest provided by /asset/manifest/...?id={id}" + ), + ) + if relative_path.is_absolute(): raise HTTPException( status_code=HTTP_400_BAD_REQUEST, - detail=f"Unrecognized spec: {spec.name}", + detail="relative_path query parameter must be a *relative* path", ) else: - validator = validation_registry(spec.name) - try: - result = validator(metadata, structure_family, structure, spec) - except ValidationError as e: + if relative_path is not None: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="This asset is not a directory. The relative_path query parameter must not be set.", + ) + if not asset.data_uri.startswith("file:"): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Only download assets stored as file:// is currently supported.", + ) + path = path_from_uri(asset.data_uri) + if relative_path is not None: + # Be doubly sure that this is under the Asset's data_uri directory + # and not sneakily escaping it. + if not os.path.commonpath([path, path / relative_path]) != path: + # This should not be possible. + raise RuntimeError( + f"Refusing to serve {path / relative_path} because it is outside " + "of the Asset's directory" + ) + full_path = path / relative_path + else: + full_path = path + stat_result = await anyio.to_thread.run_sync(os.stat, full_path) + filename = full_path.name + if "range" in request.headers: + range_header = request.headers["range"] + match = RANGE_HEADER_PATTERN.match(range_header) + if match is None: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, - detail=f"failed validation for spec {spec.name}:\n{e}", + detail=( + "Only a Range headers of the form 'bytes=start-end' are supported. " + f"Could not parse Range header: {range_header}", + ), ) - if result is not None: - metadata_modified = True - metadata = result + range = start, _ = (int(match.group(1)), int(match.group(2))) + if start > stat_result.st_size: + raise HTTPException( + status_code=HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE, + headers={"content-range": f"bytes */{stat_result.st_size}"}, + ) + status_code = HTTP_206_PARTIAL_CONTENT + else: + range = None + status_code = HTTP_200_OK + return FileResponseWithRange( + full_path, + stat_result=stat_result, + status_code=status_code, + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + range=range, + ) + + @router.get("/asset/manifest/{path:path}") + async def get_asset_manifest( + request: Request, + id: int, + entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? + settings: Settings = Depends(get_settings), + ): + if not settings.expose_raw_assets: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail=( + "This Tiled server is configured not to allow " + "downloading raw assets." + ), + ) + if not hasattr(entry, "asset_by_id"): + raise HTTPException( + status_code=HTTP_405_METHOD_NOT_ALLOWED, + detail="This node does not support downloading assets.", + ) - return metadata_modified, metadata + asset = await entry.asset_by_id(id) + if asset is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"This node exists but it does not have an Asset with id {id}", + ) + if not asset.is_directory: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="This asset is not a directory. There is no manifest.", + ) + if not asset.data_uri.startswith("file:"): + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Only download assets stored as file:// is currently supported.", + ) + path = path_from_uri(asset.data_uri) + manifest = [] + # Walk the directory and any subdirectories. Aggregate a list of all the + # files, given as paths relative to the directory root. + for root, _directories, files in os.walk(path): + manifest.extend(Path(root, file) for file in files) + return json_or_msgpack(request, {"manifest": manifest}) + + async def validate_metadata( + metadata: dict, + structure_family: StructureFamily, + structure, + specs: List[Spec], + settings: Settings = Depends(get_settings), + ): + metadata_modified = False + + # Specs should be ordered from most specific/constrained to least. + # Validate them in reverse order, with the least constrained spec first, + # because it may do normalization that helps pass the more constrained one. + # Known Issue: + # When there is more than one spec, it's possible for the validator for + # Spec 2 to make a modification that breaks the validation for Spec 1. + # For now we leave it to the server maintainer to ensure that validators + # won't step on each other in this way, but this may need revisiting. + for spec in reversed(specs): + if spec.name not in validation_registry: + if settings.reject_undeclared_specs: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Unrecognized spec: {spec.name}", + ) + else: + validator = validation_registry(spec.name) + try: + result = validator(metadata, structure_family, structure, spec) + except ValidationError as e: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"failed validation for spec {spec.name}:\n{e}", + ) + if result is not None: + metadata_modified = True + metadata = result + + return metadata_modified, metadata + + return router From 8848ed1a807bc64e4a1439175b352765e25f152a Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Tue, 11 Mar 2025 12:36:25 +0000 Subject: [PATCH 04/18] FIx minor issues from refactors --- tiled/server/app.py | 4 ++++ tiled/server/router.py | 5 ++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tiled/server/app.py b/tiled/server/app.py index c2565f059..905b55f12 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -43,6 +43,8 @@ CompressionRegistry, SerializationRegistry, default_compression_registry, + default_serialization_registry, + default_deserialization_registry ) from ..utils import SHARE_TILED_PATH, Conflicts, SpecialUsers, UnsupportedQueryType from ..validation_registration import ValidationRegistry, default_validation_registry @@ -138,6 +140,8 @@ def build_app( } server_settings = server_settings or {} query_registry = query_registry or default_query_registry + serialization_registry = serialization_registry or default_serialization_registry + deserialization_registry = deserialization_registry or default_deserialization_registry compression_registry = compression_registry or default_compression_registry validation_registry = validation_registry or default_validation_registry tasks = tasks or {} diff --git a/tiled/server/router.py b/tiled/server/router.py index 0275a1125..e2d2e4db1 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -235,7 +235,7 @@ async def about( ) @router.get( - "/api/v1/search/{path:path}", + "/search/{path:path}", response_model=schemas.Response[ List[schemas.Resource[schemas.NodeAttributes, dict, dict]], schemas.PaginationLinks, @@ -257,7 +257,6 @@ async def search( omit_links: bool = Query(False), include_data_sources: bool = Query(False), entry: Any = SecureEntry(scopes=["read:metadata"]), - principal: str = Depends(get_current_principal), **filters, ): request.state.endpoint = "search" @@ -315,7 +314,7 @@ async def search( ) @router.get( - "/api/v1/distinct/{path:path}", + "/distinct/{path:path}", response_model=schemas.GetDistinctResponse, ) @patch_route_signature From ea78168e717014c85ad9248201eb924018b3c536 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Tue, 11 Mar 2025 12:39:26 +0000 Subject: [PATCH 05/18] Nit: remove Security scopes when not used --- tiled/server/authentication.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index f987a055c..28fc96dab 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -980,7 +980,7 @@ async def revoke_session( async def revoke_session_by_id( session_id: str, # from path parameter request: Request, - principal: schemas.Principal = Security(get_current_principal, scopes=[]), + principal: schemas.Principal = Depends(get_current_principal), db=Depends(get_database_session), ): "Mark a Session as revoked so it cannot be refreshed again." @@ -1154,7 +1154,7 @@ async def revoke_apikey( ) async def whoami( request: Request, - principal=Security(get_current_principal, scopes=[]), + principal=Depends(get_current_principal), db=Depends(get_database_session), ): # TODO Permit filtering the fields of the response. From 0e0188a0b670026f6e24ab8f5144ecae5fab437f Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Tue, 11 Mar 2025 14:15:36 +0000 Subject: [PATCH 06/18] Fix CI --- docs/source/reference/service.md | 1 - tiled/server/app.py | 6 ++++-- tiled/server/authentication.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/source/reference/service.md b/docs/source/reference/service.md index 69f87acb8..38f89e37f 100644 --- a/docs/source/reference/service.md +++ b/docs/source/reference/service.md @@ -79,7 +79,6 @@ When registering new types, make reference to the .. autosummary:: :toctree: generated - tiled.media_type_registration.serialization_registry tiled.media_type_registration.SerializationRegistry tiled.media_type_registration.SerializationRegistry.register tiled.media_type_registration.SerializationRegistry.media_types diff --git a/tiled/server/app.py b/tiled/server/app.py index 905b55f12..4b740258a 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -43,8 +43,8 @@ CompressionRegistry, SerializationRegistry, default_compression_registry, + default_deserialization_registry, default_serialization_registry, - default_deserialization_registry ) from ..utils import SHARE_TILED_PATH, Conflicts, SpecialUsers, UnsupportedQueryType from ..validation_registration import ValidationRegistry, default_validation_registry @@ -141,7 +141,9 @@ def build_app( server_settings = server_settings or {} query_registry = query_registry or default_query_registry serialization_registry = serialization_registry or default_serialization_registry - deserialization_registry = deserialization_registry or default_deserialization_registry + deserialization_registry = ( + deserialization_registry or default_deserialization_registry + ) compression_registry = compression_registry or default_compression_registry validation_registry = validation_registry or default_validation_registry tasks = tasks or {} diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index 28fc96dab..ef9c8a73a 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -173,7 +173,7 @@ async def get_api_key( api_key_cookie: Optional[str] = Depends( APIKeyCookie(name=API_KEY_COOKIE_NAME, auto_error=False) ), -) -> str | None: +) -> Optional[str]: for api_key in [api_key_query, api_key_header, api_key_cookie]: if api_key is not None: return api_key From d56b886f945df5019d44fde1638dd6a189e51369 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Wed, 12 Mar 2025 10:27:44 +0000 Subject: [PATCH 07/18] Refactor Scope checking --- tiled/server/authentication.py | 138 +++++++++++++++++++++------------ tiled/server/metrics.py | 4 +- 2 files changed, 93 insertions(+), 49 deletions(-) diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index ef9c8a73a..cbb5055d6 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -4,7 +4,7 @@ import warnings from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Optional +from typing import Any, Optional, Sequence import sqlalchemy.exc from fastapi import ( @@ -227,6 +227,83 @@ async def move_api_key(request: Request, api_key: Optional[str] = Depends(get_ap ) +async def get_scopes_from_api_key( + api_key: str, settings: Settings, authenticators, db +) -> Sequence[str]: + if not authenticators: + # Tiled is in a "single user" mode with only one API key. + return ( + { + "read:metadata", + "read:data", + "write:metadata", + "write:data", + "create", + "register", + "metrics", + } + if secrets.compare_digest(api_key, settings.single_user_api_key) + else set() + ) + # Tiled is in a multi-user configuration with authentication providers. + # We store the hashed value of the API key secret. + # By comparing hashes we protect against timing attacks. + # By storing only the hash of the (high-entropy) secret + # we reduce the value of that an attacker can extracted from a + # stolen database backup. + try: + secret = bytes.fromhex(api_key) + except Exception: + return set() + api_key_orm = await lookup_valid_api_key(db, secret) + if api_key_orm is None: + return set() + else: + principal = api_key_orm.principal + principal_scopes = set().union(*[role.scopes for role in principal.roles]) + # This intersection addresses the case where the Principal has + # lost a scope that they had when this key was created. + scopes = set(api_key_orm.scopes).intersection(principal_scopes | {"inherit"}) + if "inherit" in scopes: + # The scope "inherit" is a metascope that confers all the + # scopes for the Principal associated with this API, + # resolved at access time. + scopes.update(principal_scopes) + return scopes + + +async def get_current_scopes( + decoded_access_token: Optional[dict[str, Any]] = Depends(get_decoded_access_token), + api_key: Optional[str] = Depends(get_api_key), + settings: Settings = Depends(get_settings), + authenticators=Depends(get_authenticators), + db=Depends(get_database_session), +) -> set[str]: + if api_key is not None: + return await get_scopes_from_api_key(api_key, settings, authenticators, db) + elif decoded_access_token is not None: + return decoded_access_token["scp"] + else: + return {"read:metadata", "read:data"} if settings.allow_anonymous_access else {} + + +async def check_scopes( + request: Request, + security_scopes: SecurityScopes, + scopes: set[str] = Depends(get_current_scopes), +) -> None: + if not set(security_scopes.scopes).issubset(scopes): + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=( + "Not enough permissions. " + f"Requires scopes {security_scopes.scopes}. " + f"Request had scopes {list(scopes)}" + ), + headers=headers_for_401(request, security_scopes), + ) + + async def get_current_principal( request: Request, security_scopes: SecurityScopes, @@ -293,15 +370,6 @@ async def get_current_principal( # Tiled is in a "single user" mode with only one API key. if secrets.compare_digest(api_key, settings.single_user_api_key): principal = SpecialUsers.admin - scopes = { - "read:metadata", - "read:data", - "write:metadata", - "write:data", - "create", - "register", - "metrics", - } else: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, @@ -317,40 +385,9 @@ async def get_current_principal( for identity in decoded_access_token["ids"] ], ) - scopes = decoded_access_token["scp"] else: # No form of authentication is present. principal = SpecialUsers.public - # Is anonymous public access permitted? - if settings.allow_anonymous_access: - # Any user who can see the server can make unauthenticated requests. - # This is a sentinel that has special meaning to the authorization - # code (the access control policies). - scopes = {"read:metadata", "read:data"} - else: - # In this mode, there may still be entries that are visible to all, - # but users have to authenticate as *someone* to see anything. - # They can still access the / and /docs routes. - scopes = {} - # Scope enforcement happens here. - # https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/ - if not set(security_scopes.scopes).issubset(scopes): - # Include a link to the root page which provides a list of - # authenticators. The use case here is: - # 1. User is emailed a link like https://example.com/subpath//metadata/a/b/c - # 2. Tiled Client tries to connect to that and gets 401. - # 3. Client can use this header to find its way to - # https://examples.com/subpath/ and obtain a list of - # authentication providers and endpoints. - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail=( - "Not enough permissions. " - f"Requires scopes {security_scopes.scopes}. " - f"Request had scopes {list(scopes)}" - ), - headers=headers_for_401(request, security_scopes), - ) # This is used to pass the currently-authenticated principal into the logger. request.state.principal = principal return principal @@ -787,7 +824,8 @@ async def principal_list( limit: Optional[int] = Query( DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE ), - principal=Security(get_current_principal, scopes=["read:principals"]), + principal=Depends(get_current_principal), + _=Security(check_scopes, scopes=["read:principals"]), db=Depends(get_database_session), ): "List Principals (users and services)." @@ -825,7 +863,8 @@ async def principal_list( ) async def create_service_principal( request: Request, - principal=Security(get_current_principal, scopes=["write:principals"]), + principal=Depends(get_current_principal), + _=Security(check_scopes, scopes=["write:principals"]), db=Depends(get_database_session), role: str = Query(...), ): @@ -860,7 +899,7 @@ async def create_service_principal( async def principal( request: Request, uuid: uuid_module.UUID, - _=Security(lambda: None, scopes=["read:principals"]), + _=Security(check_scopes, scopes=["read:principals"]), db=Depends(get_database_session), ): "Get information about one Principal (user or service)." @@ -896,7 +935,7 @@ async def revoke_apikey_for_principal( request: Request, uuid: uuid_module.UUID, first_eight: str, - _=Security(lambda: None, scopes=["admin:apikeys"]), + _=Security(check_scopes, scopes=["admin:apikeys"]), db=Depends(get_database_session), ): "Allow Tiled Admins to delete any user's apikeys e.g." @@ -925,7 +964,8 @@ async def apikey_for_principal( request: Request, uuid: uuid_module.UUID, apikey_params: schemas.APIKeyRequestParams, - principal=Security(get_current_principal, scopes=["admin:apikeys"]), + principal=Depends(get_current_principal), + _=Security(check_scopes, scopes=["admin:apikeys"]), db=Depends(get_database_session), ): "Generate an API key for a Principal." @@ -1071,7 +1111,8 @@ async def slide_session(refresh_token, settings, db): async def new_apikey( request: Request, apikey_params: schemas.APIKeyRequestParams, - principal=Security(get_current_principal, scopes=["apikeys"]), + principal=Depends(get_current_principal), + _=Security(check_scopes, scopes=["apikeys"]), db=Depends(get_database_session), ): """ @@ -1124,7 +1165,8 @@ async def current_apikey_info( async def revoke_apikey( request: Request, first_eight: str, - principal=Security(get_current_principal, scopes=["apikeys"]), + principal=Depends(get_current_principal), + _=Security(check_scopes, scopes=["apikeys"]), db=Depends(get_database_session), ): """ diff --git a/tiled/server/metrics.py b/tiled/server/metrics.py index 043065b65..77fb4440f 100644 --- a/tiled/server/metrics.py +++ b/tiled/server/metrics.py @@ -10,6 +10,8 @@ from fastapi import APIRouter, Request, Response, Security from prometheus_client import CONTENT_TYPE_LATEST, Histogram, generate_latest +from tiled.server.authentication import check_scopes + router = APIRouter() REQUEST_DURATION = Histogram( @@ -155,7 +157,7 @@ def prometheus_registry(): @router.get("/metrics") -async def metrics(request: Request, _=Security(lambda: None, scopes=["metrics"])): +async def metrics(request: Request, _=Security(check_scopes, scopes=["metrics"])): """ Prometheus metrics """ From 16f49fec5dbae3020e70223dac210d6e28492c8a Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Wed, 12 Mar 2025 12:10:19 +0000 Subject: [PATCH 08/18] refactor SecureEntry into standard Security --- tiled/server/dependencies.py | 18 ++++-- tiled/server/router.py | 122 +++++++++++++++++++---------------- 2 files changed, 80 insertions(+), 60 deletions(-) diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index 6a4ffa284..09abec0d5 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -3,15 +3,19 @@ import pydantic_settings from fastapi import Depends, HTTPException, Query, Request, Security +from fastapi.security import SecurityScopes from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND +from tiled.adapters.mapping import MapAdapter +from tiled.structures.core import StructureFamily + from ..media_type_registration import ( default_deserialization_registry, default_serialization_registry, ) from ..query_registration import default_query_registry from ..validation_registration import default_validation_registry -from .authentication import get_current_principal, get_session_state +from .authentication import check_scopes, get_current_principal, get_session_state from .core import NoEntry from .utils import filter_for_access, record_timing @@ -53,14 +57,16 @@ def get_root_tree(): ) -def SecureEntry(scopes, structure_families=None): +def get_entry(structure_families: Optional[set[StructureFamily]] = None): async def inner( path: str, request: Request, + security_scopes: SecurityScopes, principal: str = Depends(get_current_principal), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), - ): + _ = Security(check_scopes) + ) -> MapAdapter: """ Obtain a node in the tree from its path. @@ -131,7 +137,7 @@ async def inner( allowed_scopes = await access_policy.allowed_scopes( entry_with_access_policy, principal, path_parts_relative ) - if not set(scopes).issubset(allowed_scopes): + if not set(security_scopes.scopes).issubset(allowed_scopes): if "read:metadata" not in allowed_scopes: # If you can't read metadata, it does not exist for you. raise NoEntry(path_parts) @@ -142,7 +148,7 @@ async def inner( status_code=HTTP_403_FORBIDDEN, detail=( "Not enough permissions to perform this action on this node. " - f"Requires scopes {scopes}. " + f"Requires scopes {security_scopes.scopes}. " f"Principal had scopes {list(allowed_scopes)} on this node." ), ) @@ -164,7 +170,7 @@ async def inner( ), ) - return Security(inner, scopes=scopes) + return inner def block( diff --git a/tiled/server/router.py b/tiled/server/router.py index e2d2e4db1..09c4fe344 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -6,11 +6,11 @@ from datetime import datetime, timedelta, timezone from functools import partial from pathlib import Path -from typing import Any, Callable, List, Optional, TypeVar +from typing import Callable, List, Optional, TypeVar import anyio import packaging -from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request +from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Security from jmespath.exceptions import JMESPathError from json_merge_patch import merge as apply_merge_patch from jsonpatch import apply_patch as apply_json_patch @@ -26,6 +26,7 @@ HTTP_422_UNPROCESSABLE_ENTITY, ) +from tiled.adapters.mapping import MapAdapter from tiled.media_type_registration import SerializationRegistry from tiled.query_registration import QueryRegistry from tiled.schemas import About @@ -53,9 +54,9 @@ resolve_media_type, ) from .dependencies import ( - SecureEntry, block, expected_shape, + get_entry, offset_param, shape_param, slice_, @@ -256,7 +257,7 @@ async def search( max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), omit_links: bool = Query(False), include_data_sources: bool = Query(False), - entry: Any = SecureEntry(scopes=["read:metadata"]), + entry: MapAdapter = Security(get_entry(), scopes=["read:metadata"]), **filters, ): request.state.endpoint = "search" @@ -324,7 +325,7 @@ async def distinct( specs: bool = False, metadata: Optional[List[str]] = Query(default=[]), counts: bool = False, - entry: Any = SecureEntry(scopes=["read:metadata"]), + entry: MapAdapter = Security(get_entry(), scopes=["read:metadata"]), **filters, ): if hasattr(entry, "get_distinct"): @@ -357,8 +358,8 @@ async def metadata( max_depth: Optional[int] = Query(None, ge=0, le=DEPTH_LIMIT), omit_links: bool = Query(False), include_data_sources: bool = Query(False), - entry: Any = SecureEntry(scopes=["read:metadata"]), root_path: bool = Query(False), + entry: MapAdapter = Security(get_entry(), scopes=["read:metadata"]), ): """Fetch the metadata and structure information for one entry""" @@ -395,9 +396,9 @@ async def metadata( ) async def array_block( request: Request, - entry=SecureEntry( + entry: MapAdapter = Security( + get_entry({StructureFamily.array, StructureFamily.sparse}), scopes=["read:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, ), block=Depends(block), slice=Depends(slice_), @@ -472,9 +473,9 @@ async def array_block( ) async def array_full( request: Request, - entry=SecureEntry( + entry: MapAdapter = Security( + get_entry({StructureFamily.array, StructureFamily.sparse}), scopes=["read:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, ), slice=Depends(slice_), expected_shape=Depends(expected_shape), @@ -536,8 +537,8 @@ async def array_full( async def get_table_partition( request: Request, partition: int, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.table} + entry: MapAdapter = Security( + get_entry({StructureFamily.table}), scopes=["read:data"] ), column: Optional[List[str]] = Query(None, min_length=1), field: Optional[List[str]] = Query(None, min_length=1, deprecated=True), @@ -586,8 +587,8 @@ async def get_table_partition( async def post_table_partition( request: Request, partition: int, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.table} + entry: MapAdapter = Security( + get_entry({StructureFamily.table}), scopes=["read:data"] ), column: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, @@ -667,8 +668,8 @@ async def table_partition( ) async def get_table_full( request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.table} + entry: MapAdapter = Security( + get_entry({StructureFamily.table}), scopes=["read:data"] ), column: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -695,8 +696,8 @@ async def get_table_full( ) async def post_table_full( request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.table} + entry: MapAdapter = Security( + get_entry({StructureFamily.table}), scopes=["read:data"] ), column: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, @@ -769,8 +770,8 @@ async def table_full( ) async def get_container_full( request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.container} + entry: MapAdapter = Security( + get_entry({StructureFamily.container}), scopes=["read:data"] ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), @@ -797,8 +798,8 @@ async def get_container_full( ) async def post_container_full( request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.container} + entry: MapAdapter = Security( + get_entry({StructureFamily.container}), scopes=["read:data"] ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Body(None, min_length=1), @@ -870,9 +871,9 @@ async def container_full( ) async def node_full( request: Request, - entry=SecureEntry( + entry: MapAdapter = Security( + get_entry({StructureFamily.table, StructureFamily.container}), scopes=["read:data"], - structure_families={StructureFamily.table, StructureFamily.container}, ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), @@ -936,8 +937,8 @@ async def node_full( ) async def get_awkward_buffers( request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} + entry: MapAdapter = Security( + get_entry({StructureFamily.awkward}), scopes=["read:data"] ), form_key: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -972,12 +973,12 @@ async def get_awkward_buffers( async def post_awkward_buffers( request: Request, body: List[str], - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} - ), format: Optional[str] = None, filename: Optional[str] = None, settings: Settings = Depends(get_settings), + entry: MapAdapter = Security( + get_entry({StructureFamily.awkward}), scopes=["read:data"] + ), ): """ Fetch a slice of AwkwardArray data. @@ -1049,8 +1050,8 @@ async def _awkward_buffers( ) async def awkward_full( request: Request, - entry=SecureEntry( - scopes=["read:data"], structure_families={StructureFamily.awkward} + entry: MapAdapter = Security( + get_entry({StructureFamily.awkward}), scopes=["read:data"] ), # slice=Depends(slice_), format: Optional[str] = None, @@ -1100,7 +1101,7 @@ async def post_metadata( path: str, body: schemas.PostMetadataRequest, settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "create"]), + entry: MapAdapter = Security(get_entry(), scopes=["write:metadata", "create"]), ): for data_source in body.data_sources: if data_source.assets: @@ -1128,7 +1129,9 @@ async def post_register( path: str, body: schemas.PostMetadataRequest, settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "create", "register"]), + entry: MapAdapter = Security( + get_entry(), scopes=["write:metadata", "create", "register"] + ), ): return await _create_node( request=request, @@ -1194,7 +1197,9 @@ async def put_data_source( data_source: int, body: schemas.PutDataSourceRequest, settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata", "register"]), + entry: MapAdapter = Security( + get_entry(), scopes=["write:metadata", "register"] + ), ): await entry.put_data_source( data_source=body.data_source, @@ -1203,7 +1208,9 @@ async def put_data_source( @router.delete("/metadata/{path:path}") async def delete( request: Request, - entry=SecureEntry(scopes=["write:data", "write:metadata"]), + entry: MapAdapter = Security( + get_entry(), scopes=["write:data", "write:metadata"] + ), ): if hasattr(entry, "delete"): await entry.delete() @@ -1217,7 +1224,9 @@ async def delete( @router.delete("/nodes/{path:path}") async def bulk_delete( request: Request, - entry=SecureEntry(scopes=["write:data", "write:metadata"]), + entry: MapAdapter = Security( + get_entry(), scopes=["write:data", "write:metadata"] + ), ): if hasattr(entry, "delete_tree"): await entry.delete_tree() @@ -1231,9 +1240,9 @@ async def bulk_delete( @router.put("/array/full/{path:path}") async def put_array_full( request: Request, - entry=SecureEntry( + entry: MapAdapter = Security( + get_entry({StructureFamily.array, StructureFamily.sparse}), scopes=["write:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, ), ): body = await request.body() @@ -1259,9 +1268,9 @@ async def put_array_full( @router.put("/array/block/{path:path}") async def put_array_block( request: Request, - entry=SecureEntry( + entry: MapAdapter = Security( + get_entry({StructureFamily.array, StructureFamily.sparse}), scopes=["write:data"], - structure_families={StructureFamily.array, StructureFamily.sparse}, ), block=Depends(block), ): @@ -1295,9 +1304,9 @@ async def patch_array_full( offset=Depends(offset_param), shape=Depends(shape_param), extend: bool = False, - entry=SecureEntry( + entry: MapAdapter = Security( + get_entry({StructureFamily.array}), scopes=["write:data"], - structure_families={StructureFamily.array}, ), ): if not hasattr(entry, "patch"): @@ -1318,8 +1327,9 @@ async def patch_array_full( @router.put("/node/full/{path:path}", deprecated=True) async def put_node_full( request: Request, - entry=SecureEntry( - scopes=["write:data"], structure_families={StructureFamily.table} + entry: MapAdapter = Security( + get_entry({StructureFamily.table}), + scopes=["write:data"], ), ): if not hasattr(entry, "write"): @@ -1340,7 +1350,7 @@ async def put_node_full( async def put_table_partition( partition: int, request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry: MapAdapter = Security(get_entry(), scopes=["write:data"]), ): if not hasattr(entry, "write_partition"): raise HTTPException( @@ -1360,7 +1370,7 @@ async def put_table_partition( async def patch_table_partition( partition: int, request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry: MapAdapter = Security(get_entry(), scopes=["write:data"]), ): if not hasattr(entry, "write_partition"): raise HTTPException( @@ -1379,8 +1389,8 @@ async def patch_table_partition( @router.put("/awkward/full/{path:path}") async def put_awkward_full( request: Request, - entry=SecureEntry( - scopes=["write:data"], structure_families={StructureFamily.awkward} + entry: MapAdapter = Security( + get_entry({StructureFamily.awkward}), scopes=["write:data"] ), ): body = await request.body() @@ -1405,7 +1415,7 @@ async def patch_metadata( request: Request, body: schemas.PatchMetadataRequest, settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata"]), + entry: MapAdapter = Security(get_entry(), scopes=["write:metadata"]), ): if not hasattr(entry, "replace_metadata"): raise HTTPException( @@ -1465,7 +1475,7 @@ async def put_metadata( request: Request, body: schemas.PutMetadataRequest, settings: Settings = Depends(get_settings), - entry=SecureEntry(scopes=["write:metadata"]), + entry: MapAdapter = Security(get_entry(), scopes=["write:metadata"]), ): if not hasattr(entry, "replace_metadata"): raise HTTPException( @@ -1503,7 +1513,7 @@ async def get_revisions( limit: Optional[int] = Query( DEFAULT_PAGE_SIZE, alias="page[limit]", ge=0, le=MAX_PAGE_SIZE ), - entry=SecureEntry(scopes=["read:metadata"]), + entry: MapAdapter = Security(get_entry(), scopes=["read:metadata"]), ): if not hasattr(entry, "revisions"): raise HTTPException( @@ -1527,7 +1537,7 @@ async def get_revisions( async def delete_revision( request: Request, number: int, - entry=SecureEntry(scopes=["write:metadata"]), + entry: MapAdapter = Security(get_entry(), scopes=["write:metadata"]), ): if not hasattr(entry, "revisions"): raise HTTPException( @@ -1548,7 +1558,9 @@ async def get_asset( request: Request, id: int, relative_path: Optional[Path] = None, - entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? + entry: MapAdapter = Security( + get_entry(), scopes=["read:data"] + ), # TODO: Separate scope for assets? settings: Settings = Depends(get_settings), ): if not settings.expose_raw_assets: @@ -1644,7 +1656,9 @@ async def get_asset( async def get_asset_manifest( request: Request, id: int, - entry=SecureEntry(scopes=["read:data"]), # TODO: Separate scope for assets? + entry: MapAdapter = Security( + get_entry(), scopes=["read:data"] + ), # TODO: Separate scope for assets? settings: Settings = Depends(get_settings), ): if not settings.expose_raw_assets: From 518a4dfbce3c2cf9b1eb736dbe82b5bf1c4124a6 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Wed, 12 Mar 2025 12:22:54 +0000 Subject: [PATCH 09/18] Remove unused dependency_overrides --- tiled/server/dependencies.py | 33 +-------------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index 09abec0d5..c92ba3f3f 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -1,4 +1,3 @@ -from functools import cache from typing import Optional, Tuple, Union import pydantic_settings @@ -9,12 +8,6 @@ from tiled.adapters.mapping import MapAdapter from tiled.structures.core import StructureFamily -from ..media_type_registration import ( - default_deserialization_registry, - default_serialization_registry, -) -from ..query_registration import default_query_registry -from ..validation_registration import default_validation_registry from .authentication import check_scopes, get_current_principal, get_session_state from .core import NoEntry from .utils import filter_for_access, record_timing @@ -26,30 +19,6 @@ SLICE_REGEX = rf"^{DIM_REGEX}(?:,{DIM_REGEX})*$" -@cache -def get_query_registry(): - "This may be overridden via dependency_overrides." - return default_query_registry - - -@cache -def get_deserialization_registry(): - "This may be overridden via dependency_overrides." - return default_deserialization_registry - - -@cache -def get_serialization_registry(): - "This may be overridden via dependency_overrides." - return default_serialization_registry - - -@cache -def get_validation_registry(): - "This may be overridden via dependency_overrides." - return default_validation_registry - - def get_root_tree(): raise NotImplementedError( "This should be overridden via dependency_overrides. " @@ -65,7 +34,7 @@ async def inner( principal: str = Depends(get_current_principal), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), - _ = Security(check_scopes) + _=Security(check_scopes), ) -> MapAdapter: """ Obtain a node in the tree from its path. From ea3fc60a28930a93894f40f42aa13e960a8307ba Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Wed, 12 Mar 2025 12:26:45 +0000 Subject: [PATCH 10/18] Remove unused parameters --- tiled/server/router.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/tiled/server/router.py b/tiled/server/router.py index 09c4fe344..e0eb33354 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -575,7 +575,6 @@ async def get_table_partition( column=(column or field), format=format, filename=filename, - serialization_registry=serialization_registry, settings=settings, ) @@ -605,18 +604,16 @@ async def post_table_partition( column=column, format=format, filename=filename, - serialization_registry=serialization_registry, settings=settings, ) async def table_partition( request: Request, partition: int, - entry, + entry: MapAdapter, column: Optional[List[str]], format: Optional[str], filename: Optional[str], - serialization_registry, settings: Settings, ): """ @@ -685,7 +682,6 @@ async def get_table_full( column=column, format=format, filename=filename, - serialization_registry=serialization_registry, settings=settings, ) @@ -713,17 +709,15 @@ async def post_table_full( column=column, format=format, filename=filename, - serialization_registry=serialization_registry, settings=settings, ) async def table_full( request: Request, - entry, + entry: MapAdapter, column: Optional[List[str]], format: Optional[str], filename: Optional[str], - serialization_registry, settings: Settings, ): """ @@ -788,7 +782,6 @@ async def get_container_full( field=field, format=format, filename=filename, - serialization_registry=serialization_registry, ) @router.post( @@ -816,7 +809,6 @@ async def post_container_full( field=field, format=format, filename=filename, - serialization_registry=serialization_registry, ) async def container_full( @@ -826,7 +818,6 @@ async def container_full( field: Optional[List[str]], format: Optional[str], filename: Optional[str], - serialization_registry, ): """ Fetch the data for the given container. @@ -961,7 +952,6 @@ async def get_awkward_buffers( form_key=form_key, format=format, filename=filename, - serialization_registry=serialization_registry, settings=settings, ) @@ -996,7 +986,6 @@ async def post_awkward_buffers( form_key=body, format=format, filename=filename, - serialization_registry=serialization_registry, settings=settings, ) @@ -1006,7 +995,6 @@ async def _awkward_buffers( form_key: Optional[List[str]], format: Optional[str], filename: Optional[str], - serialization_registry, settings: Settings, ): structure_family = entry.structure_family @@ -1118,7 +1106,6 @@ async def post_metadata( request=request, path=path, body=body, - validation_registry=validation_registry, settings=settings, entry=entry, ) @@ -1137,7 +1124,6 @@ async def post_register( request=request, path=path, body=body, - validation_registry=validation_registry, settings=settings, entry=entry, ) @@ -1146,7 +1132,6 @@ async def _create_node( request: Request, path: str, body: schemas.PostMetadataRequest, - validation_registry, settings: Settings, entry, ): From b636fa6da1668372b75a558964c6934cd49207ac Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Wed, 12 Mar 2025 15:36:30 +0000 Subject: [PATCH 11/18] Document requirement for APIKeyHeader handling --- tiled/server/authentication.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index cbb5055d6..5b8591bcb 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -92,8 +92,14 @@ class TokenData(BaseModel): username: Optional[str] = None -# TODO: remove custom subclass https://github.com/bluesky/tiled/issues/921 class StrictAPIKeyHeader(APIKeyHeader): + # TODO: remove custom subclass https://github.com/bluesky/tiled/issues/921 + """ + APIKeyHeader does not enforce that the scheme matches the expected, potentially + leading to the case where Bearer tokens are treated as API keys. + Additionally strips the scheme, as expected by current handling. + """ + async def __call__(self, request: Request) -> Optional[str]: api_key: Optional[str] = request.headers.get(self.model.name) scheme, param = get_authorization_scheme_param(api_key) From 2a33e2187f3ef7a0e7a27154dd7da4ddf51ee626 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Thu, 13 Mar 2025 09:27:20 +0000 Subject: [PATCH 12/18] Clarify returned Scopes --- docs/source/explanations/access-control.md | 5 +++-- tiled/_tests/test_access_control.py | 8 ++------ tiled/_tests/test_protocols.py | 3 ++- tiled/access_policies.py | 4 +--- tiled/scopes.py | 13 +++++++++++++ tiled/server/authentication.py | 18 ++++++------------ 6 files changed, 27 insertions(+), 24 deletions(-) diff --git a/docs/source/explanations/access-control.md b/docs/source/explanations/access-control.md index 7ac7ba76d..cfd5ddbe5 100644 --- a/docs/source/explanations/access-control.md +++ b/docs/source/explanations/access-control.md @@ -70,6 +70,7 @@ integrate with our proposal system. import cachetools import httpx from tiled.queries import In +from tiled.scopes import PUBLIC_SCOPES # To reduce load on the external service and to expedite repeated lookups, use a @@ -102,12 +103,12 @@ class PASSAccessPolicy: ) def allowed_scopes(self, node, principal, path_parts): - return {"read:metadata", "read:data"} + return PUBLIC_SCOPES def filters(self, node, principal, scopes, path_parts): queries = [] id = self._get_id(principal) - if not scopes.issubset({"read:metadata", "read:data"}): + if not scopes.issubset(PUBLIC_SCOPES): return NO_ACCESS try: response = response_cache[id] diff --git a/tiled/_tests/test_access_control.py b/tiled/_tests/test_access_control.py index 09a5f2029..25c2eb1d7 100644 --- a/tiled/_tests/test_access_control.py +++ b/tiled/_tests/test_access_control.py @@ -5,16 +5,12 @@ from fastapi import HTTPException from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND -from ..access_policies import ( - ALL_SCOPES, - PUBLIC_SCOPES, - SimpleAccessPolicy, - SpecialUsers, -) +from ..access_policies import SimpleAccessPolicy, SpecialUsers from ..adapters.array import ArrayAdapter from ..adapters.mapping import MapAdapter from ..client import Context, from_context from ..client.utils import ClientError +from ..scopes import ALL_SCOPES, PUBLIC_SCOPES from ..server.app import build_app_from_config from ..server.core import NoEntry from .utils import enter_username_password, fail_with_status_code diff --git a/tiled/_tests/test_protocols.py b/tiled/_tests/test_protocols.py index 9f88afd11..40eaf2f54 100644 --- a/tiled/_tests/test_protocols.py +++ b/tiled/_tests/test_protocols.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from pytest_mock import MockFixture -from tiled.access_policies import ALL_ACCESS, ALL_SCOPES +from tiled.access_policies import ALL_ACCESS from tiled.adapters.awkward_directory_container import DirectoryContainer from tiled.adapters.protocols import ( AccessPolicy, @@ -19,6 +19,7 @@ SparseAdapter, TableAdapter, ) +from tiled.scopes import ALL_SCOPES from tiled.server.schemas import Principal, PrincipalType from tiled.structures.array import ArrayStructure, BuiltinDtype from tiled.structures.awkward import AwkwardStructure diff --git a/tiled/access_policies.py b/tiled/access_policies.py index 593bce7ab..1b2c97e87 100644 --- a/tiled/access_policies.py +++ b/tiled/access_policies.py @@ -1,12 +1,10 @@ from functools import partial from .queries import In, KeysFilter -from .scopes import SCOPES +from .scopes import ALL_SCOPES, PUBLIC_SCOPES from .utils import Sentinel, SpecialUsers, import_object ALL_ACCESS = Sentinel("ALL_ACCESS") -ALL_SCOPES = set(SCOPES) -PUBLIC_SCOPES = {"read:metadata", "read:data"} NO_ACCESS = Sentinel("NO_ACCESS") diff --git a/tiled/scopes.py b/tiled/scopes.py index 55ae0cca6..82a5d0ee9 100644 --- a/tiled/scopes.py +++ b/tiled/scopes.py @@ -19,3 +19,16 @@ "description": "Edit list of all users and services and their attributes." }, } + +ALL_SCOPES: set[str] = set(SCOPES) +PUBLIC_SCOPES: set[str] = {"read:metadata", "read:data"} +USER_SCOPES: set[str] = { + "read:metadata", + "read:data", + "write:metadata", + "write:data", + "create", + "register", + "metrics", +} +NO_SCOPES: set[str] = set() diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index 5b8591bcb..e59d1b658 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -36,6 +36,8 @@ HTTP_409_CONFLICT, ) +from tiled.scopes import NO_SCOPES, PUBLIC_SCOPES, USER_SCOPES + # To hide third-party warning # .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning: # int_from_bytes is deprecated, use int.from_bytes instead @@ -239,15 +241,7 @@ async def get_scopes_from_api_key( if not authenticators: # Tiled is in a "single user" mode with only one API key. return ( - { - "read:metadata", - "read:data", - "write:metadata", - "write:data", - "create", - "register", - "metrics", - } + USER_SCOPES if secrets.compare_digest(api_key, settings.single_user_api_key) else set() ) @@ -260,10 +254,10 @@ async def get_scopes_from_api_key( try: secret = bytes.fromhex(api_key) except Exception: - return set() + return NO_SCOPES api_key_orm = await lookup_valid_api_key(db, secret) if api_key_orm is None: - return set() + return NO_SCOPES else: principal = api_key_orm.principal principal_scopes = set().union(*[role.scopes for role in principal.roles]) @@ -290,7 +284,7 @@ async def get_current_scopes( elif decoded_access_token is not None: return decoded_access_token["scp"] else: - return {"read:metadata", "read:data"} if settings.allow_anonymous_access else {} + return PUBLIC_SCOPES if settings.allow_anonymous_access else NO_SCOPES async def check_scopes( From f8bab0e542df164d4316f6ba8889c8f9ed83da7f Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Thu, 13 Mar 2025 13:48:53 +0000 Subject: [PATCH 13/18] Prevent mutation of Scopes --- tiled/scopes.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tiled/scopes.py b/tiled/scopes.py index 82a5d0ee9..6a3a4a039 100644 --- a/tiled/scopes.py +++ b/tiled/scopes.py @@ -20,15 +20,17 @@ }, } -ALL_SCOPES: set[str] = set(SCOPES) -PUBLIC_SCOPES: set[str] = {"read:metadata", "read:data"} -USER_SCOPES: set[str] = { - "read:metadata", - "read:data", - "write:metadata", - "write:data", - "create", - "register", - "metrics", -} -NO_SCOPES: set[str] = set() +ALL_SCOPES: set[str] = frozenset(SCOPES) +PUBLIC_SCOPES: set[str] = frozenset(("read:metadata", "read:data")) +USER_SCOPES: set[str] = frozenset( + ( + "read:metadata", + "read:data", + "write:metadata", + "write:data", + "create", + "register", + "metrics", + ) +) +NO_SCOPES: set[str] = frozenset() From 0847d0778b3d9a100c6b188f9b5101803dfdabe5 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Thu, 13 Mar 2025 14:08:36 +0000 Subject: [PATCH 14/18] Prevent closure of QueryRegistry --- tiled/server/router.py | 54 +++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/tiled/server/router.py b/tiled/server/router.py index e0eb33354..da65a6376 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -66,35 +66,30 @@ from .settings import Settings, get_settings from .utils import filter_for_access, get_base_url, record_timing +T = TypeVar("T") -def get_router( - query_registry: QueryRegistry, - serialization_registry: SerializationRegistry, - deserialization_registry: SerializationRegistry, - validation_registry: ValidationRegistry, -) -> APIRouter: - router = APIRouter() - - T = TypeVar("T") - def patch_route_signature(route: Callable[..., T]) -> Callable[..., T]: - """ - This is done dynamically at router startup. +def _patch_route_signature( + query_registry: QueryRegistry, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """ + This is done dynamically at router startup. - We check the registry of known search query types, which is user - configurable, and use that to define the allowed HTTP query parameters for - this route. + We check the registry of known search query types, which is user + configurable, and use that to define the allowed HTTP query parameters for + this route. - Take a route that accept unspecified search queries as **filters. - Return a wrapped version of the route that has the supported - search queries explicitly spelled out in the function signature. + Take a route that accept unspecified search queries as **filters. + Return a wrapped version of the route that has the supported + search queries explicitly spelled out in the function signature. - This has no change in the actual behavior of the function, - but it enables FastAPI to generate good OpenAPI documentation - showing the supported search queries. + This has no change in the actual behavior of the function, + but it enables FastAPI to generate good OpenAPI documentation + showing the supported search queries. - """ + """ + def inner(route: Callable[..., T]) -> Callable[..., T]: # Build a wrapper so that we can modify the signature # without mutating the wrapped original. @@ -139,6 +134,17 @@ async def route_with_sig(*args, **kwargs): return route_with_sig + return inner + + +def get_router( + query_registry: QueryRegistry, + serialization_registry: SerializationRegistry, + deserialization_registry: SerializationRegistry, + validation_registry: ValidationRegistry, +) -> APIRouter: + router = APIRouter() + @router.get("/", response_model=About) async def about( request: Request, @@ -243,7 +249,7 @@ async def about( dict, ], ) - @patch_route_signature + @_patch_route_signature(query_registry) async def search( request: Request, path: str, @@ -318,7 +324,7 @@ async def search( "/distinct/{path:path}", response_model=schemas.GetDistinctResponse, ) - @patch_route_signature + @_patch_route_signature(query_registry) async def distinct( request: Request, structure_families: bool = False, From 1f85060cf7d71117a9140dafa353244b58972ae0 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Thu, 13 Mar 2025 14:20:42 +0000 Subject: [PATCH 15/18] Prevent trying to edit frozenset --- tiled/_tests/test_access_control.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tiled/_tests/test_access_control.py b/tiled/_tests/test_access_control.py index 25c2eb1d7..638e6c021 100644 --- a/tiled/_tests/test_access_control.py +++ b/tiled/_tests/test_access_control.py @@ -32,8 +32,6 @@ async def allowed_scopes(self, node, principal, path_parts): # If this is being called, filter_access has let us get this far. if principal is SpecialUsers.public: allowed = PUBLIC_SCOPES - elif principal.type == "service": - allowed = self.scopes else: allowed = self.scopes @@ -60,7 +58,7 @@ async def allowed_scopes(self, node, principal, path_parts): ) remove_scope = node.metadata().get("remove_scope", None) if remove_scope in allowed: - allowed = allowed.copy() + allowed = set(allowed) allowed.remove(remove_scope) return allowed From 6202a60cf1f37fb43788be95465b002d28222342 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Fri, 14 Mar 2025 11:50:54 +0000 Subject: [PATCH 16/18] Add changelog entries --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94bc96493..2a00416ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,9 @@ Write the date in place of the "Unreleased" in the case a new version is release ### Added - `Composite` structure family to enable direct access to table columns in a single namespace. - +- Extract API key handling +- Extract scope fetching and checking +- Refactor router construction to Pass completed objects ## 0.1.0-b20 (2025-03-07) From cd0c6506ec5db8cdfd000299dd047a14f071c04a Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Fri, 14 Mar 2025 12:11:33 +0000 Subject: [PATCH 17/18] Resolve missed merge conflicts --- tiled/serialization/container.py | 4 ++-- tiled/server/router.py | 25 +++++++++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tiled/serialization/container.py b/tiled/serialization/container.py index fc8fb5180..c2f0da167 100644 --- a/tiled/serialization/container.py +++ b/tiled/serialization/container.py @@ -84,7 +84,7 @@ async def serialize_hdf5(node, metadata, filter_for_access): default_serialization_registry.register( StructureFamily.container, "application/x-hdf5", serialize_hdf5 ) - serialization_registry.register( + default_serialization_registry.register( StructureFamily.composite, "application/x-hdf5", serialize_hdf5 ) @@ -110,6 +110,6 @@ async def serialize_json(node, metadata, filter_for_access): default_serialization_registry.register( StructureFamily.container, "application/json", serialize_json ) - serialization_registry.register( + default_serialization_registry.register( StructureFamily.composite, "application/json", serialize_json ) diff --git a/tiled/server/router.py b/tiled/server/router.py index 7db24a5ac..d23f684ef 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -267,8 +267,10 @@ async def search( **filters, ): request.state.endpoint = "search" - if entry.if entry.structure_family not in { - StructureFamily.container, StructureFamily.composite}: + if entry.structure_family not in { + StructureFamily.container, + StructureFamily.composite, + }: raise WrongTypeForRoute( "This is not a Node; it cannot be searched or listed." ) @@ -772,7 +774,8 @@ async def table_full( async def get_container_full( request: Request, entry: MapAdapter = Security( - get_entry({StructureFamily.container, StructureFamily.composite}), scopes=["read:data"] + get_entry({StructureFamily.container, StructureFamily.composite}), + scopes=["read:data"], ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), @@ -799,7 +802,8 @@ async def get_container_full( async def post_container_full( request: Request, entry: MapAdapter = Security( - get_entry({StructureFamily.container, StructureFamily.composite}), scopes=["read:data"] + get_entry({StructureFamily.container, StructureFamily.composite}), + scopes=["read:data"], ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Body(None, min_length=1), @@ -870,7 +874,13 @@ async def container_full( async def node_full( request: Request, entry: MapAdapter = Security( - get_entry({StructureFamily.table, StructureFamily.container, StructureFamily.composite}), + get_entry( + { + StructureFamily.table, + StructureFamily.container, + StructureFamily.composite, + } + ), scopes=["read:data"], ), principal: str = Depends(get_current_principal), @@ -901,7 +911,10 @@ async def node_full( "request a smaller chunks." ), ) - if entry.structure_family in {StructureFamily.container, StructureFamily.composite}: + if entry.structure_family in { + StructureFamily.container, + StructureFamily.composite, + }: curried_filter = partial( filter_for_access, principal=principal, From f0ef533dfbd004492af0109af7fd9e015ccb80c5 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Fri, 14 Mar 2025 12:32:57 +0000 Subject: [PATCH 18/18] Amend changelog --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a00416ed..0f6f5e4eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,13 @@ Write the date in place of the "Unreleased" in the case a new version is release ### Added - `Composite` structure family to enable direct access to table columns in a single namespace. + +### Maintenance + - Extract API key handling - Extract scope fetching and checking -- Refactor router construction to Pass completed objects +- Refactor router construction + ## 0.1.0-b20 (2025-03-07)