Skip to content

Commit a0d9104

Browse files
Refactor auth (#891)
* Refactor API key handling * Clarify registry usage * Create routes after registries known * FIx minor issues from refactors * Nit: remove Security scopes when not used * Fix CI * Refactor Scope checking * refactor SecureEntry into standard Security * Remove unused dependency_overrides * Remove unused parameters * Document requirement for APIKeyHeader handling * Clarify returned Scopes * Prevent mutation of Scopes * Prevent closure of QueryRegistry * Prevent trying to edit frozenset * Add changelog entries * Resolve missed merge conflicts * Amend changelog
1 parent 43bc1fa commit a0d9104

23 files changed

+1866
-1870
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ Write the date in place of the "Unreleased" in the case a new version is release
99

1010
- `Composite` structure family to enable direct access to table columns in a single namespace.
1111

12+
### Maintenance
13+
14+
- Extract API key handling
15+
- Extract scope fetching and checking
16+
- Refactor router construction
17+
1218

1319
## 0.1.0-b20 (2025-03-07)
1420

docs/source/explanations/access-control.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ integrate with our proposal system.
7070
import cachetools
7171
import httpx
7272
from tiled.queries import In
73+
from tiled.scopes import PUBLIC_SCOPES
7374

7475

7576
# To reduce load on the external service and to expedite repeated lookups, use a
@@ -102,12 +103,12 @@ class PASSAccessPolicy:
102103
)
103104

104105
def allowed_scopes(self, node, principal, path_parts):
105-
return {"read:metadata", "read:data"}
106+
return PUBLIC_SCOPES
106107

107108
def filters(self, node, principal, scopes, path_parts):
108109
queries = []
109110
id = self._get_id(principal)
110-
if not scopes.issubset({"read:metadata", "read:data"}):
111+
if not scopes.issubset(PUBLIC_SCOPES):
111112
return NO_ACCESS
112113
try:
113114
response = response_cache[id]

docs/source/reference/service.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ When registering new types, make reference to the
7979
.. autosummary::
8080
:toctree: generated
8181
82-
tiled.media_type_registration.serialization_registry
8382
tiled.media_type_registration.SerializationRegistry
8483
tiled.media_type_registration.SerializationRegistry.register
8584
tiled.media_type_registration.SerializationRegistry.media_types

tiled/_tests/test_access_control.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@
55
from fastapi import HTTPException
66
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
77

8-
from ..access_policies import (
9-
ALL_SCOPES,
10-
PUBLIC_SCOPES,
11-
SimpleAccessPolicy,
12-
SpecialUsers,
13-
)
8+
from ..access_policies import SimpleAccessPolicy, SpecialUsers
149
from ..adapters.array import ArrayAdapter
1510
from ..adapters.mapping import MapAdapter
1611
from ..client import Context, from_context
1712
from ..client.utils import ClientError
13+
from ..scopes import ALL_SCOPES, PUBLIC_SCOPES
1814
from ..server.app import build_app_from_config
1915
from ..server.core import NoEntry
2016
from .utils import enter_username_password, fail_with_status_code
@@ -36,8 +32,6 @@ async def allowed_scopes(self, node, principal, path_parts):
3632
# If this is being called, filter_access has let us get this far.
3733
if principal is SpecialUsers.public:
3834
allowed = PUBLIC_SCOPES
39-
elif principal.type == "service":
40-
allowed = self.scopes
4135
else:
4236
allowed = self.scopes
4337

@@ -64,7 +58,7 @@ async def allowed_scopes(self, node, principal, path_parts):
6458
)
6559
remove_scope = node.metadata().get("remove_scope", None)
6660
if remove_scope in allowed:
67-
allowed = allowed.copy()
61+
allowed = set(allowed)
6862
allowed.remove(remove_scope)
6963
return allowed
7064

tiled/_tests/test_protocols.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numpy.typing import NDArray
1010
from pytest_mock import MockFixture
1111

12-
from tiled.access_policies import ALL_ACCESS, ALL_SCOPES
12+
from tiled.access_policies import ALL_ACCESS
1313
from tiled.adapters.awkward_directory_container import DirectoryContainer
1414
from tiled.adapters.protocols import (
1515
AccessPolicy,
@@ -19,6 +19,7 @@
1919
SparseAdapter,
2020
TableAdapter,
2121
)
22+
from tiled.scopes import ALL_SCOPES
2223
from tiled.server.schemas import Principal, PrincipalType
2324
from tiled.structures.array import ArrayStructure, BuiltinDtype
2425
from tiled.structures.awkward import AwkwardStructure

tiled/access_policies.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from functools import partial
22

33
from .queries import In, KeysFilter
4-
from .scopes import SCOPES
4+
from .scopes import ALL_SCOPES, PUBLIC_SCOPES
55
from .utils import Sentinel, SpecialUsers, import_object
66

77
ALL_ACCESS = Sentinel("ALL_ACCESS")
8-
ALL_SCOPES = set(SCOPES)
9-
PUBLIC_SCOPES = {"read:metadata", "read:data"}
108
NO_ACCESS = Sentinel("NO_ACCESS")
119

1210

tiled/client/container.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..adapters.utils import IndexersMixin
1616
from ..iterviews import ItemsView, KeysView, ValuesView
1717
from ..queries import KeyLookup
18-
from ..query_registration import query_registry
18+
from ..query_registration import default_query_registry
1919
from ..structures.core import Spec, StructureFamily
2020
from ..structures.data_source import DataSource
2121
from ..utils import UNCHANGED, OneShotCachedMap, Sentinel, node_repr, safe_json_dump
@@ -1205,7 +1205,7 @@ def _queries_to_params(*queries):
12051205
"Compute GET params from the queries."
12061206
params = collections.defaultdict(list)
12071207
for query in queries:
1208-
name = query_registry.query_type_to_name[type(query)]
1208+
name = default_query_registry.query_type_to_name[type(query)]
12091209
for field, value in query.encode().items():
12101210
if value is not None:
12111211
params[f"filter[{name}][condition][{field}]"].append(value)

tiled/config.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@
1010
from datetime import timedelta
1111
from functools import cache
1212
from pathlib import Path
13+
from typing import Optional
1314

1415
import jsonschema
1516

1617
from .adapters.mapping import MapAdapter
1718
from .media_type_registration import (
18-
compression_registry as default_compression_registry,
19+
CompressionRegistry,
20+
SerializationRegistry,
21+
default_compression_registry,
22+
default_deserialization_registry,
23+
default_serialization_registry,
1924
)
20-
from .media_type_registration import (
21-
serialization_registry as default_serialization_registry,
22-
)
23-
from .query_registration import query_registry as default_query_registry
25+
from .query_registration import QueryRegistry, default_query_registry
2426
from .utils import import_object, parse, prepend_to_sys_path
25-
from .validation_registration import validation_registry as default_validation_registry
27+
from .validation_registration import ValidationRegistry, default_validation_registry
2628

2729

2830
@cache
@@ -40,10 +42,11 @@ def construct_build_app_kwargs(
4042
config,
4143
*,
4244
source_filepath=None,
43-
query_registry=None,
44-
compression_registry=None,
45-
serialization_registry=None,
46-
validation_registry=None,
45+
query_registry: Optional[QueryRegistry] = None,
46+
compression_registry: Optional[CompressionRegistry] = None,
47+
serialization_registry: Optional[SerializationRegistry] = None,
48+
deserialization_registry: Optional[SerializationRegistry] = None,
49+
validation_registry: Optional[ValidationRegistry] = None,
4750
):
4851
"""
4952
Given parsed configuration, construct arguments for build_app(...).
@@ -61,6 +64,8 @@ def construct_build_app_kwargs(
6164
query_registry = default_query_registry
6265
if serialization_registry is None:
6366
serialization_registry = default_serialization_registry
67+
if deserialization_registry is None:
68+
deserialization_registry = default_deserialization_registry
6469
if compression_registry is None:
6570
compression_registry = default_compression_registry
6671
if validation_registry is None:
@@ -220,6 +225,7 @@ def construct_build_app_kwargs(
220225
"server_settings": server_settings,
221226
"query_registry": query_registry,
222227
"serialization_registry": serialization_registry,
228+
"deserialization_registry": deserialization_registry,
223229
"compression_registry": compression_registry,
224230
"validation_registry": validation_registry,
225231
"tasks": {

tiled/media_type_registration.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,21 +197,21 @@ def __call__(self, media_type, encoder, *args, **kwargs):
197197
return self.dispatch(media_type, encoder)(*args, **kwargs)
198198

199199

200-
serialization_registry = SerializationRegistry()
200+
default_serialization_registry = SerializationRegistry()
201201
"Global serialization registry. See Registry for usage examples."
202202

203-
deserialization_registry = SerializationRegistry()
203+
default_deserialization_registry = SerializationRegistry()
204204
"Global deserialization registry. See Registry for usage examples."
205205

206-
compression_registry = CompressionRegistry()
206+
default_compression_registry = CompressionRegistry()
207207
"Global compression registry. See Registry for usage examples."
208208

209209

210210
for media_type in [
211211
"application/json",
212212
"application/x-msgpack",
213213
]:
214-
compression_registry.register(
214+
default_compression_registry.register(
215215
media_type,
216216
"gzip",
217217
lambda buffer: gzip.GzipFile(mode="wb", fileobj=buffer, compresslevel=9),
@@ -225,7 +225,7 @@ def __call__(self, media_type, encoder, *args, **kwargs):
225225
"text/plain",
226226
"text/html",
227227
]:
228-
compression_registry.register(
228+
default_compression_registry.register(
229229
media_type,
230230
"gzip",
231231
# Use a lower compression level. High compression is extremely slow
@@ -270,7 +270,7 @@ def close(self):
270270
"text/html",
271271
"text/plain",
272272
]:
273-
compression_registry.register(media_type, "zstd", ZstdBuffer)
273+
default_compression_registry.register(media_type, "zstd", ZstdBuffer)
274274

275275
if modules_available("lz4"):
276276
import lz4
@@ -326,7 +326,7 @@ def close(self):
326326
"text/html",
327327
"text/plain",
328328
]:
329-
compression_registry.register(media_type, "lz4", LZ4Buffer)
329+
default_compression_registry.register(media_type, "lz4", LZ4Buffer)
330330

331331
if modules_available("blosc2"):
332332
import blosc2
@@ -355,4 +355,4 @@ def close(self):
355355
pass
356356

357357
for media_type in ["application/octet-stream", APACHE_ARROW_FILE_MIME_TYPE]:
358-
compression_registry.register(media_type, "blosc2", BloscBuffer)
358+
default_compression_registry.register(media_type, "blosc2", BloscBuffer)

tiled/query_registration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def inner(cls):
7979

8080

8181
# Make a global registry.
82-
query_registry = QueryRegistry()
83-
register = query_registry.register
82+
default_query_registry = QueryRegistry()
83+
register = default_query_registry.register
8484
"""Register a new type of query."""
8585

8686

tiled/scopes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,18 @@
1919
"description": "Edit list of all users and services and their attributes."
2020
},
2121
}
22+
23+
ALL_SCOPES: set[str] = frozenset(SCOPES)
24+
PUBLIC_SCOPES: set[str] = frozenset(("read:metadata", "read:data"))
25+
USER_SCOPES: set[str] = frozenset(
26+
(
27+
"read:metadata",
28+
"read:data",
29+
"write:metadata",
30+
"write:data",
31+
"create",
32+
"register",
33+
"metrics",
34+
)
35+
)
36+
NO_SCOPES: set[str] = frozenset()

tiled/serialization/array.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
import numpy
55

6-
from ..media_type_registration import deserialization_registry, serialization_registry
6+
from ..media_type_registration import (
7+
default_deserialization_registry,
8+
default_serialization_registry,
9+
)
710
from ..utils import (
811
SerializationError,
912
UnsupportedShape,
@@ -22,13 +25,13 @@ def as_buffer(array, metadata):
2225
return numpy.asarray(array).tobytes()
2326

2427

25-
serialization_registry.register(
28+
default_serialization_registry.register(
2629
"array",
2730
"application/octet-stream",
2831
as_buffer,
2932
)
3033
if modules_available("orjson"):
31-
serialization_registry.register(
34+
default_serialization_registry.register(
3235
"array",
3336
"application/json",
3437
lambda array, metadata: safe_json_dump(array),
@@ -43,10 +46,12 @@ def serialize_csv(array, metadata):
4346
return file.getvalue().encode()
4447

4548

46-
serialization_registry.register("array", "text/csv", serialize_csv)
47-
serialization_registry.register("array", "text/x-comma-separated-values", serialize_csv)
48-
serialization_registry.register("array", "text/plain", serialize_csv)
49-
deserialization_registry.register(
49+
default_serialization_registry.register("array", "text/csv", serialize_csv)
50+
default_serialization_registry.register(
51+
"array", "text/x-comma-separated-values", serialize_csv
52+
)
53+
default_serialization_registry.register("array", "text/plain", serialize_csv)
54+
default_deserialization_registry.register(
5055
"array",
5156
"application/octet-stream",
5257
lambda buffer, dtype, shape: numpy.frombuffer(buffer, dtype=dtype).reshape(shape),
@@ -90,10 +95,10 @@ def array_from_buffer_PIL(buffer, format, dtype, shape):
9095
image = Image.open(file, format=format)
9196
return numpy.asarray(image).asdtype(dtype).reshape(shape)
9297

93-
serialization_registry.register(
98+
default_serialization_registry.register(
9499
"array", "image/png", lambda array, metadata: save_to_buffer_PIL(array, "png")
95100
)
96-
deserialization_registry.register(
101+
default_deserialization_registry.register(
97102
"array",
98103
"image/png",
99104
lambda buffer, dtype, shape: array_from_buffer_PIL(buffer, "png", dtype, shape),
@@ -120,18 +125,24 @@ def save_to_buffer_tifffile(array, metadata):
120125
imwrite(file, normalized_array)
121126
return file.getbuffer()
122127

123-
serialization_registry.register("array", "image/tiff", save_to_buffer_tifffile)
124-
deserialization_registry.register("array", "image/tiff", array_from_buffer_tifffile)
128+
default_serialization_registry.register(
129+
"array", "image/tiff", save_to_buffer_tifffile
130+
)
131+
default_deserialization_registry.register(
132+
"array", "image/tiff", array_from_buffer_tifffile
133+
)
125134

126135

127136
def serialize_html(array, metadata):
128137
"Try to display as image. Fall back to CSV."
129138
try:
130-
png_data = serialization_registry.dispatch("array", "image/png")(
139+
png_data = default_serialization_registry.dispatch("array", "image/png")(
131140
array, metadata
132141
)
133142
except Exception:
134-
csv_data = serialization_registry.dispatch("array", "text/csv")(array, metadata)
143+
csv_data = default_serialization_registry.dispatch("array", "text/csv")(
144+
array, metadata
145+
)
135146
return "<html>" "<body>" f"{csv_data.decode()!s}" "</body>" "</html>"
136147
else:
137148
return (
@@ -145,4 +156,4 @@ def serialize_html(array, metadata):
145156
)
146157

147158

148-
serialization_registry.register("array", "text/html", serialize_html)
159+
default_serialization_registry.register("array", "text/html", serialize_html)

0 commit comments

Comments
 (0)