diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d53c25bb..daa4ed9cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,27 @@ Write the date in place of the "Unreleased" in the case a new version is release ## v0.1.0-b37 (Unreleased) +### Added + +- The access tags compiler and db schema have been upstreamed into Tiled +- API keys can now be restricted to specific access tags +- New unit tests covering the new access policy and access control features + ### Changed - Remove `SpecialUsers` principals for single-user and anonymous-access cases +- Access control code is now in the `access_control` subdirectory +- `SimpleAccessPolicy` has been removed +- AuthN database can now be in-memory SQLite +- Catalog database can now be shared when using in-memory SQLite +- `TagBasedAccessPolicy` now supports anonymous access +- `AccessTagsParser` is now async +- `toy_authentication` example config now uses `TagBasedAccessPolicy` +- Added helpers for setting up the access tag and catalog databases for `toy_authentication` + +### Fixed + +- Access control on container export was partially broken, now access works as expected. ## v0.1.0-b36 (2025-08-26) @@ -29,7 +47,6 @@ Write the date in place of the "Unreleased" in the case a new version is release - The project ships with a pixi manifest (`pixi.toml`). - ## v0.1.0-b34 (2025-08-14) ### Fixed @@ -41,7 +58,6 @@ Write the date in place of the "Unreleased" in the case a new version is release should be re-run on any databases that could not be upgraded with the previous release. - ## v0.1.0-b33 (2025-08-13) _This release requires a database migration of the catalog database._ diff --git a/docs/source/explanations/access-control.md b/docs/source/explanations/access-control.md index 3fd6479a2..9704a8d25 100644 --- a/docs/source/explanations/access-control.md +++ b/docs/source/explanations/access-control.md @@ -35,95 +35,43 @@ Rephrasing these two items now using the jargon of entities in Tiled: children (if any). This determination can be backed by a call to an external service or by a -static configuration file. We demonstrate both here. +static configuration file. We demonstrate the static file case here. -First, the static configuration file. Consider this simple tree of data: +Consider this simple tree of data: -```{eval-rst} -.. literalinclude:: ../../../tiled/examples/toy_authentication.py - :caption: tiled/examples/toy_authentication.py +``` +/ +├── A -> array=(10 * numpy.ones((10, 10))), access_tags=["data_A"] +├── B -> array=(10 * numpy.ones((10, 10))), access_tags=["data_B"] +├── C -> array=(10 * numpy.ones((10, 10))), access_tags=["data_C"] +└── D -> array=(10 * numpy.ones((10, 10))), access_tags=["data_D"] ``` -protected by this simple Access Control Policy: +which will be protected by Tiled's "Tag Based" Access Control Policy. Note the +"access tags" associated with each node. The tag-based access policy uses ACLs, +which are compiled from provided access-tag definitions (example below), to make +decisions based on these access tags. ```{eval-rst} -.. literalinclude:: ../../../example_configs/toy_authentication.yml - :caption: example_configs/toy_authentication.yml +.. literalinclude:: ../../../example_configs/access_tags/tag_definitions.yml + :caption: example_configs/access_tags/tag_definitions.yml ``` -Under `access_lists:` usernames are mapped to the keys of the entries the user may access. -The section `public:` designates entries that an -unauthenticated (anonymous) user may access *if* the server is configured to -allow anonymous access. (See {doc}`security`.) The special value -``tiled.adapters.mapping:SimpleAccessPolicy.ALL`` designates that the user may access any entry -in the Tree. +Under `tags`, usernames and groupnames are mapped to either a role or list of scopes. +Roles are pre-defined lists of scopes, and are also defined in this file. This mapping +confers these scopes to these users for data which is tagged with the corresponding tagname. -``` -ALICE_PASSWORD=secret1 BOB_PASSWORD=secret2 CARA_PASSWORD=secret3 tiled serve config example_configs/config.yml -``` +Tags can also inherit the ACLs of other tags, using the `auto_tags` field. There is also a +`public` tag which is a special tag used to mark data as public (all users can read). + +Lastly, only "owners" of a tag can apply that tag to a node. Tag owners are defined in +this same tag definitions file, under the `tag_owners` key. -For large-scale deployment, Tiled typically integrates with an existing access management -system. This is sketch of the Access Control Policy used by NSLS-II to -integrate with our proposal system. - -```py -import cachetools -import httpx -from tiled.queries import In -from tiled.scopes import PUBLIC_SCOPES - - -# To reduce load on the external service and to expedite repeated lookups, use a -# process-global cache with a timeout. -response_cache = cachetools.TTLCache(maxsize=10_000, ttl=60) - - -class PASSAccessPolicy: - """ - access_control: - access_policy: pass_access_policy:PASSAccessPolicy - args: - url: ... - beamline: ... - """ - - def __init__(self, url, beamline, provider): - self._client = httpx.Client(base_url=url) - self._beamline = beamline - self.provider = provider - - def _get_id(self, principal): - for identity in principal.identities: - if identity.provider == self.provider: - return identity.id - else: - raise ValueError( - f"Principcal {principal} has no identity from provider {self.provider}. " - f"Its identities are: {principal.identities}" - ) - - def allowed_scopes(self, node, principal, authn_scopes): - return PUBLIC_SCOPES - - def filters(self, node, principal, authn_scopes, scopes): - queries = [] - id = self._get_id(principal) - if not scopes.issubset(PUBLIC_SCOPES): - return NO_ACCESS - try: - response = response_cache[id] - except KeyError: - response = self._client.get(f"/data_session/{id}") - response_cache[id] = response - if response.is_error: - response.raise_for_status() - data = response.json() - if ("nsls2" in (data["facility_all_access"] or [])) or ( - self._beamline in (data["beamline_all_access"] or []) - ): - return queries - queries.append( - In("data_session", data["data_sessions"] or []) - ) - return queries +To try out this access control configuration, an example server can be prepped and launched: +``` +# prep the access tags and catalog databases +python example_configs/access_tags/compile_tags.py +python example_configs/catalog/create_catalog.py +# launch the example server, which loads these databases +ALICE_PASSWORD=secret1 BOB_PASSWORD=secret2 CARA_PASSWORD=secret3 tiled serve config example_configs/toy_authentication.yml ``` diff --git a/docs/source/how-to/api-keys.md b/docs/source/how-to/api-keys.md index 0bd0421bb..b0d65c062 100644 --- a/docs/source/how-to/api-keys.md +++ b/docs/source/how-to/api-keys.md @@ -49,6 +49,14 @@ in the example below with that address. ALICE_PASSWORD=secret1 tiled serve config example_configs/toy_authentication.yml ``` +Note that you will need to run these helper tools to prep the backing databases that Tiled needs, +before you can use the example config shown above: +``` +# prep the access tags and catalog databases +python example_configs/access_tags/compile_tags.py +python example_configs/catalog/create_catalog.py +``` + Using the Tiled commandline interface, log in as `alice` using the password `secret1`. ``` diff --git a/docs/source/reference/authentication.md b/docs/source/reference/authentication.md index 801c51c8b..785ed8261 100644 --- a/docs/source/reference/authentication.md +++ b/docs/source/reference/authentication.md @@ -28,6 +28,14 @@ is included with the tiled source code, and start a server like so. :caption: example_configs/toy_authentication.py ``` +Note that you will need to run these helper tools to prep the backing databases that Tiled needs: +``` +# prep the access tags and catalog databases +python example_configs/access_tags/compile_tags.py +python example_configs/catalog/create_catalog.py +``` + +then, you can launch the server: ``` ALICE_PASSWORD=secret1 BOB_PASSWORD=secret2 CARA_PASSWORD=secret3 tiled serve config example_configs/toy_authentication.yml ``` diff --git a/example_configs/access_tags/compile_tags.py b/example_configs/access_tags/compile_tags.py new file mode 100644 index 000000000..c321751f9 --- /dev/null +++ b/example_configs/access_tags/compile_tags.py @@ -0,0 +1,30 @@ +from pathlib import Path + +from tiled.access_control.access_tags import AccessTagsCompiler +from tiled.access_control.scopes import ALL_SCOPES + + +def group_parser(groupname): + return { + "group_A": ["alice", "bob"], + "admins": ["cara"], + }[groupname] + + +def main(): + file_directory = Path(__file__).resolve().parent + + access_tags_compiler = AccessTagsCompiler( + ALL_SCOPES, + Path(file_directory, "tag_definitions.yml"), + {"uri": f"file:{file_directory}/compiled_tags.sqlite"}, + group_parser, + ) + + access_tags_compiler.load_tag_config() + access_tags_compiler.compile() + access_tags_compiler.connection.close() + + +if __name__ == "__main__": + main() diff --git a/example_configs/access_tags/tag_definitions.yml b/example_configs/access_tags/tag_definitions.yml new file mode 100644 index 000000000..7ed454d6a --- /dev/null +++ b/example_configs/access_tags/tag_definitions.yml @@ -0,0 +1,36 @@ +roles: + facility_user: + scopes: ["read:data", "read:metadata"] + facility_admin: + scopes: ["read:data", "read:metadata", "write:data", "write:metadata", "create", "register"] +tags: + data_A: + groups: + - name: group_A + role: facility_user + auto_tags: + - name: data_admin + data_B: + users: + - name: alice + scopes: ["read:data", "read:metadata"] + auto_tags: + - name: data_admin + data_C: + users: + - name: bob + role: facility_user + auto_tags: + - name: data_admin + data_D: + auto_tags: + - name: data_admin + - name: public + data_admin: + users: + - name: cara + role: facility_admin +tag_owners: + data_admin: + users: + - name: cara diff --git a/example_configs/catalog/create_catalog.py b/example_configs/catalog/create_catalog.py new file mode 100644 index 000000000..8e626b874 --- /dev/null +++ b/example_configs/catalog/create_catalog.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import numpy +import yaml + +from tiled._tests.utils import enter_username_password +from tiled.client import Context, from_context +from tiled.server.app import build_app_from_config + +CONFIG_NAME = "toy_authentication.yml" +CATALOG_STORAGE = "data/" + + +def main(): + file_directory = Path(__file__).resolve().parent + config_directory = file_directory.parent + Path(file_directory, CATALOG_STORAGE).mkdir() + + with open(Path(config_directory, CONFIG_NAME)) as config_file: + config = yaml.load(config_file, Loader=yaml.BaseLoader) + app = build_app_from_config(config) + context = Context.from_app(app) + with enter_username_password("admin", "admin"): + client = from_context(context, remember_me=False) + for n in ["A", "B", "C", "D"]: + client.write_array( + key=n, array=10 * numpy.ones((10, 10)), access_tags=[f"data_{n}"] + ) + client.logout() + context.close() + + +if __name__ == "__main__": + main() diff --git a/example_configs/toy_authentication.yml b/example_configs/toy_authentication.yml index 6f96d50eb..8a34921fe 100644 --- a/example_configs/toy_authentication.yml +++ b/example_configs/toy_authentication.yml @@ -7,27 +7,29 @@ authentication: alice: ${ALICE_PASSWORD} bob: ${BOB_PASSWORD} cara: ${CARA_PASSWORD} + admin: "admin" confirmation_message: "You have logged in as {id}." tiled_admins: - provider: toy - id: alice + id: admin access_control: - access_policy: tiled.access_policies:SimpleAccessPolicy + access_policy: "tiled.access_control.access_policies:TagBasedAccessPolicy" args: - provider: toy # matches provider above - access_lists: - alice: - - A - - B - bob: - - A - - C - cara: tiled.access_policies:SimpleAccessPolicy.ALL + provider: "toy" scopes: - "read:metadata" - "read:data" - public: - - D + - "write:metadata" + - "write:data" + - "create" + tags_db: + uri: "file:example_configs/access_tags/compiled_tags.sqlite" + access_tags_parser: "tiled.access_control.access_tags:AccessTagsParser" trees: - path: / - tree: tiled.examples.toy_authentication:tree + tree: catalog + args: + uri: "sqlite+aiosqlite:///./example_configs/catalog/catalog.db" + writable_storage: "./example_configs/catalog/data" + init_if_not_exists: true + top_level_access_blob: {"tags": ["public"]} diff --git a/tiled/_tests/conftest.py b/tiled/_tests/conftest.py index f1705ab22..6d206f7f2 100644 --- a/tiled/_tests/conftest.py +++ b/tiled/_tests/conftest.py @@ -66,7 +66,7 @@ def buffer(): @pytest.fixture(scope="function") -def buffer_factory(request): +def buffer_factory(): buffers = [] def _buffer(): @@ -74,12 +74,10 @@ def _buffer(): buffers.append(buf) return buf - def teardown(): - for buf in buffers: - buf.close() + yield _buffer - request.addfinalizer(teardown) - return _buffer + for buf in buffers: + buf.close() @pytest.fixture diff --git a/tiled/_tests/test_access_control.py b/tiled/_tests/test_access_control.py index c5976aa84..691d9b6c5 100644 --- a/tiled/_tests/test_access_control.py +++ b/tiled/_tests/test_access_control.py @@ -1,26 +1,33 @@ import json +import sqlite3 +from copy import deepcopy import numpy import pytest from starlette.status import HTTP_403_FORBIDDEN -from tiled.authenticators import DictionaryAuthenticator -from tiled.server.protocols import UserSessionState - -from ..access_policies import NO_ACCESS -from ..adapters.array import ArrayAdapter -from ..adapters.mapping import MapAdapter +from ..access_control.access_tags import AccessTagsCompiler +from ..access_control.scopes import ALL_SCOPES from ..client import Context, from_context -from ..client.utils import ClientError -from ..scopes import ALL_SCOPES, NO_SCOPES, USER_SCOPES from ..server.app import build_app_from_config from .utils import enter_username_password, fail_with_status_code arr = numpy.ones((5, 5)) -arr_ad = ArrayAdapter.from_array(arr) -server_common_config = { + +server_config = { + "access_control": { + "access_policy": "tiled.access_control.access_policies:TagBasedAccessPolicy", + "args": { + "provider": "toy", + "tags_db": { + "uri": "file:compiled_tags_mem?mode=memory&cache=shared" # in-memory and shareable + }, + "access_tags_parser": "tiled.access_control.access_tags:AccessTagsParser", + }, + }, "authentication": { + "tiled_admins": [{"provider": "toy", "id": "admin"}], "allow_anonymous_access": True, "secret_keys": ["SECRET"], "providers": [ @@ -29,8 +36,11 @@ "authenticator": "tiled.authenticators:DictionaryAuthenticator", "args": { "users_to_passwords": { - "alice": "secret1", - "bob": "secret2", + "alice": "alice", + "bob": "bob", + "chris": "chris", + "sue": "sue", + "zoe": "zoe", "admin": "admin", }, }, @@ -38,683 +48,1002 @@ ], }, "database": { - "uri": "sqlite://", # in-memory + "uri": "sqlite:///file:authn_mem?mode=memory&cache=shared&uri=true", # in-memory }, } - -def tree_a(access_policy=None): - return MapAdapter({"A1": arr_ad, "A2": arr_ad}) +access_tag_config = { + "roles": { + "facility_user": { + "scopes": [ + "read:data", + "read:metadata", + ] + }, + "facility_admin": { + "scopes": [ + "read:data", + "read:metadata", + "write:data", + "write:metadata", + "create", + "register", + ] + }, + }, + "tags": { + "alice_tag": { + "users": [ + { + "name": "alice", + "role": "facility_admin", + }, + { + "name": "chris", + "scopes": ["read:data", "read:metadata"], + }, + ], + }, + "chris_tag": { + "users": [ + { + "name": "alice", + "role": "facility_admin", + }, + { + "name": "chris", + "role": "facility_admin", + }, + ], + }, + "biologists_tag": { + "users": [ + { + "name": "alice", + "role": "facility_admin", + }, + ], + "groups": [ + { + "name": "biologists", + "scopes": ["read:data", "read:metadata"], + }, + ], + }, + "chemists_tag": { + "users": [ + { + "name": "sue", + "scopes": ["write:data", "write:metadata"], + }, + ], + "groups": [ + { + "name": "chemists", + "role": "facility_user", + }, + ], + "auto_tags": [ + { + "name": "alice_tag", + }, + ], + }, + "physicists_tag": { + "users": [ + { + "name": "alice", + "role": "facility_admin", + }, + ], + "groups": [ + { + "name": "physicists", + "role": "facility_admin", + }, + ], + }, + }, + "tag_owners": { + "alice_tag": { + "users": [ + { + "name": "alice", + }, + { + "name": "chris", + }, + ], + }, + "biologists_tag": { + "users": [ + { + "name": "alice", + }, + ], + "groups": [ + { + "name": "biologists", + }, + ], + }, + "chemists_tag": { + "users": [ + { + "name": "sue", + }, + ], + "groups": [ + { + "name": "chemists", + }, + ], + }, + "physicists_tag": { + "users": [ + { + "name": "alice", + }, + ], + "groups": [ + { + "name": "physicists", + }, + ], + }, + }, +} -def tree_b(access_policy=None): - return MapAdapter({"B1": arr_ad, "B2": arr_ad}) +def group_parser(groupname): + return { + "chemists": ["bob", "mary"], + "biologists": ["chris", "fred"], + "physicists": ["sue", "tony"], + }[groupname] @pytest.fixture(scope="module") -def context_a(tmpdir_module): - config = { - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "provider": "toy", - "access_lists": { - "alice": ["a", "A2"], - # This should have no effect because bob - # cannot access the parent node. - "bob": ["A1", "A2"], - }, - "admins": ["admin"], - }, - }, - "trees": [ - { - "tree": f"{__name__}:tree_a", - "path": "/a", - }, - ], - } - - config.update(server_common_config) - app = build_app_from_config(config) - with Context.from_app(app) as context: - yield context +def compile_access_tags_db(): + access_tags_compiler = AccessTagsCompiler( + ALL_SCOPES, + access_tag_config, + {"uri": "file:compiled_tags_mem?mode=memory&cache=shared"}, + group_parser, + ) + access_tags_compiler.load_tag_config() + access_tags_compiler.compile() + yield access_tags_compiler + access_tags_compiler.connection.close() -@pytest.fixture(scope="module") -def context_b(tmpdir_module): - config = { - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "provider": "toy", - "access_lists": { - "alice": [], - "bob": [], - }, - "admins": ["admin"], - }, - }, - "trees": [ - { - "tree": f"{__name__}:tree_b", - "path": "/b", - }, - ], - } - config.update(server_common_config) - app = build_app_from_config(config) - with Context.from_app(app) as context: - yield context +@pytest.fixture +def compile_access_tags_db_with_reset(compile_access_tags_db): + access_tags_compiler = compile_access_tags_db + access_tag_config_copy = deepcopy(access_tag_config) + access_tags_compiler.tag_config = access_tag_config_copy + yield access_tags_compiler + access_tags_compiler.tag_config = access_tag_config + access_tags_compiler.group_parser = group_parser + access_tags_compiler.clear_raw_tags() + access_tags_compiler.load_tag_config() + access_tags_compiler.recompile() @pytest.fixture(scope="module") -def context_c(tmpdir_module): +def access_control_test_context_factory(tmpdir_module, compile_access_tags_db): config = { "trees": [ { "tree": "tiled.catalog:in_memory", - "args": {"writable_storage": str(tmpdir_module / "c")}, - "path": "/c", - }, - ], - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "provider": "toy", - "access_lists": { - "alice": "tiled.access_policies:ALL_ACCESS", + "args": { + "named_memory": "catalog_foo", + "writable_storage": str(tmpdir_module / "foo"), + "top_level_access_blob": {"tags": ["alice_tag"]}, }, - "admins": ["admin"], + "path": "/foo", }, - }, - } - - config.update(server_common_config) - app = build_app_from_config(config) - with Context.from_app(app) as context: - admin_client = from_context(context) - with enter_username_password("admin", "admin"): - admin_client.login() - for k in ["c"]: - admin_client[k].write_array(arr, key="A1") - admin_client[k].write_array(arr, key="A2") - admin_client[k].write_array(arr, key="x") - yield context - - -@pytest.fixture(scope="module") -def context_d(tmpdir_module): - config = { - "trees": [ { "tree": "tiled.catalog:in_memory", - "args": {"writable_storage": str(tmpdir_module / "d")}, - "path": "/d", - }, - ], - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "provider": "toy", - "access_lists": { - "alice": "tiled.access_policies:ALL_ACCESS", + "args": { + "named_memory": "catalog_bar", + "writable_storage": str(tmpdir_module / "bar"), + "top_level_access_blob": {"tags": ["chemists_tag"]}, }, - "admins": ["admin"], - # Block writing. - "scopes": ["read:metadata", "read:data"], + "path": "/bar", }, - }, - } - - config.update(server_common_config) - app = build_app_from_config(config) - with Context.from_app(app) as context: - admin_client = from_context(context) - with enter_username_password("admin", "admin"): - admin_client.login() - for k in ["d"]: - admin_client[k].write_array(arr, key="A1") - admin_client[k].write_array(arr, key="A2") - admin_client[k].write_array(arr, key="x") - yield context - - -@pytest.fixture(scope="module") -def context_e(tmpdir_module): - config = { - "trees": [ { "tree": "tiled.catalog:in_memory", - "args": {"writable_storage": str(tmpdir_module / "e")}, - "path": "/e", + "args": { + "named_memory": "catalog_baz", + "writable_storage": str(tmpdir_module / "baz"), + "top_level_access_blob": {"tags": ["physicists_tag"]}, + }, + "path": "/baz", }, - ], - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "provider": "toy", - "access_lists": { - "alice": "tiled.access_policies:ALL_ACCESS", + { + "tree": "tiled.catalog:in_memory", + "args": { + "named_memory": "catalog_qux", + "writable_storage": str(tmpdir_module / "qux"), + "top_level_access_blob": {"tags": ["public"]}, }, - "admins": ["admin"], - # Block creation. - "scopes": [ - "read:metadata", - "read:data", - "write:metadata", - "write:data", - ], + "path": "/qux", }, - }, + ], } - config.update(server_common_config) - app = build_app_from_config(config) - with Context.from_app(app) as context: - admin_client = from_context(context) - with enter_username_password("admin", "admin"): - admin_client.login() - for k in ["e"]: - admin_client[k].write_array(arr, key="A1") - admin_client[k].write_array(arr, key="A2") - admin_client[k].write_array(arr, key="x") - yield context + config.update(server_config) + contexts = [] + clients = {} + def _create_and_login_context(username, password=None, api_key=None): + if not any([password, api_key]): + raise ValueError("Please provide either 'password' or 'api_key' for auth") -@pytest.fixture(scope="module") -def context_f(tmpdir_module): - config = { - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "provider": "toy", - "access_lists": {}, - "admins": ["admin"], - "public": ["f"], - }, - }, - "trees": [ - { - "tree": ArrayAdapter.from_array(arr), - "path": "/f", - }, - ], - } + if client := clients.get(username, None): + return client + app = build_app_from_config(config) + context = Context.from_app( + app, uri=f"http://local-tiled-app-{username}/api/v1", api_key=api_key + ) + contexts.append(context) + client = from_context(context, remember_me=False) + clients[username] = client + if api_key is None: + with enter_username_password(username, password): + client.context.login(remember_me=False) + return client - config.update(server_common_config) - app = build_app_from_config(config) - with Context.from_app(app) as context: - yield context + admin_client = _create_and_login_context("admin", "admin") + for k in ["foo", "bar", "baz", "qux"]: + admin_client[k].write_array(arr, key="data_A", access_tags=["alice_tag"]) + admin_client[k].write_array(arr, key="data_B", access_tags=["chemists_tag"]) + admin_client[k].write_array(arr, key="data_C", access_tags=["public"]) + yield _create_and_login_context -@pytest.fixture(scope="module") -def context_g(tmpdir_module): - config = { - "trees": [ - { - "tree": "tiled.catalog:in_memory", - "args": { - "writable_storage": str(tmpdir_module / "g"), - "metadata": {"project": "all_projects"}, - }, - "path": "/g", - }, - ], - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "provider": "toy", - "key": "project", - "access_lists": { - "alice": ["all_projects", "projectA"], - "bob": ["projectB"], - }, - "admins": ["admin"], - "public": ["projectC", "all_projects"], - }, - }, - } + for context in contexts: + context.close() - config.update(server_common_config) - app = build_app_from_config(config) - with Context.from_app(app) as context: - admin_client = from_context(context) - with enter_username_password("admin", "admin"): - admin_client.login() - for k, v in {"A3": "projectA", "A4": "projectB", "r": "projectC"}.items(): - admin_client["g"].write_array(arr, key=k, metadata={"project": v}) - yield context - - -def test_basic_access_control(context_a, context_b, context_g, enter_username_password): - alice_client_a = from_context(context_a) - alice_client_b = from_context(context_b) - alice_client_g = from_context(context_g) - with enter_username_password("alice", "secret1"): - alice_client_a.login() - alice_client_b.login() - alice_client_g.login() - assert "a" in alice_client_a - assert "A2" in alice_client_a["a"] - assert "A1" not in alice_client_a["a"] - assert "b" not in alice_client_b - assert "g" in alice_client_g - assert "A3" in alice_client_g["g"] - assert "A4" not in alice_client_g["g"] - alice_client_a["a"]["A2"] - alice_client_g["g"]["A3"] - with pytest.raises(KeyError): - alice_client_b["b"] - with pytest.raises(KeyError): - alice_client_g["g"]["A4"] - alice_client_a.logout() - alice_client_b.logout() - alice_client_g.logout() - - bob_client_a = from_context(context_a) - bob_client_b = from_context(context_b) - bob_client_g = from_context(context_g) - with enter_username_password("bob", "secret2"): - bob_client_a.login() - bob_client_b.login() - bob_client_g.login() - assert not list(bob_client_a) - assert not list(bob_client_b) - assert not list(bob_client_g) - with pytest.raises(KeyError): - bob_client_a["a"] - with pytest.raises(KeyError): - bob_client_b["b"] - with pytest.raises(KeyError): - bob_client_g["g"]["A3"] - bob_client_a.logout() - bob_client_b.logout() - bob_client_g.logout() - - -def test_access_control_with_api_key_auth( - context_a, context_g, enter_username_password -): - # Log in, create an API key, log out. - with enter_username_password("alice", "secret1"): - context_a.authenticate() - context_g.authenticate() - key_info_a = context_a.create_api_key() - context_a.logout() - key_info_g = context_g.create_api_key() - context_g.logout() - - try: - # Use API key auth while exercising the access control code. - context_a.api_key = key_info_a["secret"] - client_a = from_context(context_a) - context_g.api_key = key_info_g["secret"] - client_g = from_context(context_g) - client_a["a"]["A2"] - client_g["g"]["A3"] - finally: - # Clean up Context, which is a module-scope fixture shared with other tests. - context_a.api_key = None - context_g.api_key = None - - -def test_node_export( - enter_username_password, context_a, context_b, context_g, buffer_factory -): - "Exporting a node should include only the children we can see." - alice_client_a = from_context(context_a) - alice_client_b = from_context(context_b) - alice_client_g = from_context(context_g) - with enter_username_password("alice", "secret1"): - alice_client_a.login() - alice_client_b.login() - alice_client_g.login() - buffer_a = buffer_factory() - buffer_b = buffer_factory() - buffer_g = buffer_factory() - alice_client_a.export(buffer_a, format="application/json") - alice_client_b.export(buffer_b, format="application/json") - alice_client_g.export(buffer_g, format="application/json") - alice_client_a.logout() - alice_client_b.logout() - alice_client_g.logout() - buffer_a.seek(0) - buffer_b.seek(0) - buffer_g.seek(0) - exported_dict_a = json.loads(buffer_a.read()) - exported_dict_b = json.loads(buffer_b.read()) - exported_dict_g = json.loads(buffer_g.read()) - assert "a" in exported_dict_a["contents"] - assert "A2" in exported_dict_a["contents"]["a"]["contents"] - assert "A1" not in exported_dict_a["contents"]["a"]["contents"] - assert "b" not in exported_dict_b - assert "g" in exported_dict_g["contents"] - assert "A3" in exported_dict_g["contents"]["g"]["contents"] - assert "A4" not in exported_dict_g["contents"]["g"]["contents"] - exported_dict_a["contents"]["a"]["contents"]["A2"] - exported_dict_g["contents"]["g"]["contents"]["A3"] - - -def test_create_and_update_allowed(enter_username_password, context_c, context_g): - alice_client_c = from_context(context_c) - alice_client_g = from_context(context_g) - with enter_username_password("alice", "secret1"): - alice_client_c.login() - alice_client_g.login() - - # Update - alice_client_c["c"]["x"].metadata - alice_client_c["c"]["x"].update_metadata(metadata={"added_key": 3}) - assert alice_client_c["c"]["x"].metadata["added_key"] == 3 - - alice_client_g["g"]["A3"].metadata - alice_client_g["g"]["A3"].update_metadata(metadata={"added_key": 9}) - assert alice_client_g["g"]["A3"].metadata["added_key"] == 9 - - # Create - alice_client_c["c"].write_array([1, 2, 3]) - alice_client_g["g"].write_array([4, 5, 6], metadata={"project": "projectA"}) - alice_client_c.logout() - alice_client_g.logout() - - -def test_writing_blocked_by_access_policy(enter_username_password, context_d): - alice_client_d = from_context(context_d) - with enter_username_password("alice", "secret1"): - alice_client_d.login() - alice_client_d["d"]["x"].metadata + +def test_access_tag_compiler(compile_access_tags_db_with_reset): + """ + Test that compilation of access tags is working. This tests: + - Adding and removing a tag + - Adding and removing a role + - Adding and removing a user from a tag + - Adding and removing a group from a tag + - Adding and removing from the `auto_tags` for a tag + - Adding and removing a tag from the `tag_owners` section + - Adding and removing users and groups from the owners of a tag + - Adding and removing a member from a group + - Changing a user's role/scopes on a tag + - Changing a group's role/scopes on a tag + - Making a tag public/not-public + - Disallow redefining the `public` tag + """ + access_tags_compiler = compile_access_tags_db_with_reset + compiler_tag_config = access_tags_compiler.tag_config + + def new_group_parser(groupname): + return { + "chemists": ["bob", "mary", "kate"], + "biologists": ["chris", "fred"], + "physicists": ["sue", "tony"], + }[groupname] + + access_tags_compiler.group_parser = new_group_parser + + compiler_tag_config["tags"].update( + {"new_tag": {"users": [{"name": "tony", "scopes": ["read:metadata"]}]}} + ) + compiler_tag_config["roles"].update({"new_role": {"scopes": ["read:metadata"]}}) + compiler_tag_config["tags"]["biologists_tag"]["users"].append( + {"name": "tony", "role": "facility_user"} + ) + compiler_tag_config["tags"]["physicists_tag"]["groups"].append( + {"name": "biologists", "role": "facility_user"} + ) + compiler_tag_config["tags"]["chemists_tag"]["auto_tags"].append({"name": "new_tag"}) + compiler_tag_config["tag_owners"].update({"new_tag": {"users": [{"name": "tony"}]}}) + compiler_tag_config["tag_owners"]["biologists_tag"]["users"].append( + {"name": "tony"} + ) + compiler_tag_config["tag_owners"]["chemists_tag"]["groups"].append( + {"name": "biologists"} + ) + compiler_tag_config["tags"]["alice_tag"]["users"][0]["role"] = "facility_user" + compiler_tag_config["tags"]["biologists_tag"]["groups"][0].pop("scopes") + compiler_tag_config["tags"]["biologists_tag"]["groups"][0].update( + {"role": "facility_admin"} + ) + compiler_tag_config["tags"]["alice_tag"].update({"auto_tags": [{"name": "public"}]}) + + access_tags_compiler.load_tag_config() + access_tags_compiler.recompile() + + db = sqlite3.connect("file:compiled_tags_mem?mode=memory&cache=shared", uri=True) + cursor = db.cursor() + # check that new tag was added and compiled with user+scopes + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM tags WHERE name = ?);", + ("new_tag",), + ) + assert bool(cursor.fetchone()[0]) + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("new_tag", "tony"), + ) + assert bool(cursor.fetchone()[0]) + + # check that new role was added - note roles do not get saved in the db + assert "new_role" in access_tags_compiler.roles + + # check that newly added user and group were given scopes on tag + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("biologists_tag", "tony"), + ) + assert bool(cursor.fetchone()[0]) + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("physicists_tag", "chris"), + ) + assert bool(cursor.fetchone()[0]) + + # check that auto_tag added ACL to parent tag + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("chemists_tag", "tony"), + ) + assert bool(cursor.fetchone()[0]) + + # check tag was added to tag_owners section + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_owners WHERE tag_name = ?);", + ("new_tag",), + ) + assert bool(cursor.fetchone()[0]) + + # check adding new user and group to owners of tags + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_owners WHERE tag_name = ? AND user_name = ?);", + ("biologists_tag", "tony"), + ) + assert bool(cursor.fetchone()[0]) + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_owners WHERE tag_name = ? AND user_name = ?);", + ("chemists_tag", "chris"), + ) + assert bool(cursor.fetchone()[0]) + + # check that the role/scopes changes for a user and group on a tag were effective + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_scopes " + "WHERE tag_name = ? AND user_name = ? AND scope_name = ?);", + ("alice_tag", "alice", "write:metadata"), + ) + assert bool(cursor.fetchone()[0]) + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_scopes " + "WHERE tag_name = ? AND user_name = ? AND scope_name = ?);", + ("biologists_tag", "chris", "write:metadata"), + ) + assert bool(cursor.fetchone()[0]) + + # check tha tag was marked as public after inheriting public tag + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM public_tags WHERE name = ?);", + ("alice_tag",), + ) + assert bool(cursor.fetchone()[0]) + + # check that user added to group was compiled into tag ACL + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("chemists_tag", "kate"), + ) + assert bool(cursor.fetchone()[0]) + + # attempt redefining the public tag (and fail) + compiler_tag_config["tags"].update( + {"public": {"users": [{"name": "tony", "scopes": ["read:metadata"]}]}} + ) + access_tags_compiler.load_tag_config() + with pytest.raises(ValueError): + access_tags_compiler.recompile() + + # remove all changes/additions by reverting to the original config + access_tags_compiler.tag_config = access_tag_config + access_tags_compiler.group_parser = group_parser + access_tags_compiler.clear_raw_tags() + access_tags_compiler.load_tag_config() + access_tags_compiler.recompile() + + # check that new tag was removed and no longer compiled + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM tags WHERE name = ?);", + ("new_tag",), + ) + assert bool(cursor.fetchone()[0]) + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("new_tag", "tony"), + ) + assert bool(cursor.fetchone()[0]) + + # check that new role was removed - note roles do not get saved in the db + assert "new_role" not in access_tags_compiler.roles + + # check that removed user and group were not given scopes on tag + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("biologists_tag", "tony"), + ) + assert bool(cursor.fetchone()[0]) + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("physicists_tag", "chris"), + ) + assert bool(cursor.fetchone()[0]) + + # check that auto_tag ACL removed from parent tag + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("chemists_tag", "tony"), + ) + assert bool(cursor.fetchone()[0]) + + # check tag was removed from tag_owners section + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_owners WHERE tag_name = ?);", + ("new_tag",), + ) + assert bool(cursor.fetchone()[0]) + + # check removing user and group from owners of tags + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_owners WHERE tag_name = ? AND user_name = ?);", + ("biologists_tag", "tony"), + ) + assert bool(cursor.fetchone()[0]) + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_owners WHERE tag_name = ? AND user_name = ?);", + ("chemists_tag", "chris"), + ) + assert bool(cursor.fetchone()[0]) + + # check that the role/scopes changes for a user and group on a tag were undone + cursor.execute( + "SELECT EXISTS(SELECT 1 FROM user_tag_scopes " + "WHERE tag_name = ? AND user_name = ? AND scope_name = ?);", + ("alice_tag", "alice", "write:metadata"), + ) + assert bool(cursor.fetchone()[0]) + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_scopes " + "WHERE tag_name = ? AND user_name = ? AND scope_name = ?);", + ("biologists_tag", "chris", "write:metadata"), + ) + assert bool(cursor.fetchone()[0]) + + # check tha tag was unmarked as public after removing the public auto_tag + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM public_tags WHERE name = ?);", + ("alice_tag",), + ) + assert bool(cursor.fetchone()[0]) + + # check that user removed from group was compiled out of tag ACL + cursor.execute( + "SELECT NOT EXISTS(SELECT 1 FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?);", + ("chemists_tag", "kate"), + ) + assert bool(cursor.fetchone()[0]) + + +def test_basic_access_control(access_control_test_context_factory): + """ + Test that basic access control and tag compilation are working. + Only tests simple visibility of the data (i.e. "read:metadata" scope), + does not tests writing or full reading of the data. + + In other words, tests that compiled tags allow/disallow access including: + - top-level tags + - tags directly on datasets + - tags "inherited" on datasets (auto_tags) + - "public" tags on datasets + - groups compiled into tags + - scopes compiled into tags by a role + - scopes compiled into tags by a scopes list + - nested access blocked by upper tags (even if deeper tags would permit access) + + Note: MapAdapter does not support access control. As such, the server root + does not currently filter top-level entries. + """ + alice_client = access_control_test_context_factory("alice", "alice") + bob_client = access_control_test_context_factory("bob", "bob") + + top = "foo" + assert top in alice_client + # no access control on MapAdapter - can't filter top-level yet + # assert top not in bob_client + for data in ["data_A", "data_B", "data_C"]: + # Alice has access below the top-level, given by a direct tag + # Bob does not have access to any data, blocked by the top-level's tag + # data_A - alice has access given by a direct tag of which they are a user + # data_B - alice has access given by an inherited tag + # data_C - alice has access given by a public tag + assert data in alice_client[top] + alice_client[top][data] + with pytest.raises(KeyError): + bob_client[top][data] + + top = "bar" + assert top in alice_client + assert top in bob_client + for data in ["data_A"]: + # Alice has access below the top-level, given by an inherited tag + # data_A - bob does not have access conferred by any tags + assert data in alice_client[top] + alice_client[top][data] + assert data not in bob_client[top] + with pytest.raises(KeyError): + bob_client[top][data] + for data in ["data_B", "data_C"]: + # Bob has access below the top-level, given by a direct tag of which they are in a group + # data_B - alice has scopes compiled in via role + # data_B - bob has access given by a direct tag of which they are in a group + # data_B - bob has scopes compiled in via list of scopes + # data_C - alice and bob are given access by a public tag + assert data in alice_client[top] + alice_client[top][data] + assert data in bob_client[top] + bob_client[top][data] + + +def test_writing_access_control(access_control_test_context_factory): + """ + Test that writing access control and tag ownership is working. + Only tests that the writing request does not fail. + Does not test the written data for validity. + + This tests the following: + - Writing without applying an access tag + - Writing while applying an access tag the user owns + - Writing while applying an access tag the user does not own + - Writing while applying an access tag that is not defined + - Writing while applying the "public" tag (admin only) + - Writing into a location where the user does not have write access + - Writing while applying an access tag the user owns through group membership + - Writing while applying multiple access tags + - Writing while applying a tag which does not give the user the minimum scopes + """ + + alice_client = access_control_test_context_factory("alice", "alice") + bob_client = access_control_test_context_factory("bob", "bob") + sue_client = access_control_test_context_factory("sue", "sue") + + top = "foo" + alice_client[top].write_array(arr, key="data_Q") + alice_client[top].write_array(arr, key="data_R", access_tags=["alice_tag"]) + with fail_with_status_code(HTTP_403_FORBIDDEN): + alice_client[top].write_array(arr, key="data_S", access_tags=["chemists_tag"]) with fail_with_status_code(HTTP_403_FORBIDDEN): - alice_client_d["d"]["x"].update_metadata(metadata={"added_key": 3}) - alice_client_d.logout() + alice_client[top].write_array(arr, key="data_T", access_tags=["undefined_tag"]) + with fail_with_status_code(HTTP_403_FORBIDDEN): + alice_client[top].write_array(arr, key="data_U", access_tags=["public"]) + top = "bar" + with fail_with_status_code(HTTP_403_FORBIDDEN): + bob_client[top].write_array(arr, key="data_V") -def test_create_blocked_by_access_policy(enter_username_password, context_e): - alice_client_e = from_context(context_e) - with enter_username_password("alice", "secret1"): - alice_client_e.login() + top = "baz" + sue_client[top].write_array( + arr, key="data_W", access_tags=["physicists_tag", "chemists_tag"] + ) + access_tags = sue_client[top]["data_W"].access_blob["tags"] + assert "physicists_tag" in access_tags + assert "chemists_tag" in access_tags with fail_with_status_code(HTTP_403_FORBIDDEN): - alice_client_e["e"].write_array([1, 2, 3]) - alice_client_e.logout() + sue_client[top].write_array(arr, key="data_X", access_tags=["chemists_tag"]) -def test_public_access( - context_a, context_b, context_c, context_d, context_e, context_f, context_g -): - public_client_a = from_context(context_a) - public_client_b = from_context(context_b) - public_client_c = from_context(context_c) - public_client_d = from_context(context_d) - public_client_e = from_context(context_e) - public_client_f = from_context(context_f) - public_client_g = from_context(context_g) - for key, client in zip( - ["a", "b", "c", "d", "e"], - [ - public_client_a, - public_client_b, - public_client_c, - public_client_d, - public_client_e, - ], - ): - assert key not in client - public_client_f["f"].read() - public_client_g["g"]["r"].read() - with pytest.raises(KeyError): - public_client_a["a", "A1"] - with pytest.raises(KeyError): - public_client_g["g", "A3"] - - -def test_service_principal_access(tmpdir, sqlite_or_postgres_uri): - "Test that a service principal can work with SimpleAccessPolicy." - config = { - "authentication": { - "secret_keys": ["SECRET"], - "providers": [ - { - "provider": "toy", - "authenticator": "tiled.authenticators:DictionaryAuthenticator", - "args": { - "users_to_passwords": { - "admin": "admin", - } - }, - } - ], - "tiled_admins": [{"id": "admin", "provider": "toy"}], - }, - "database": { - "uri": f"sqlite:///{tmpdir}/auth.db", - "init_if_not_exists": True, - }, - "trees": [ - { - "tree": "catalog", - "args": { - "uri": sqlite_or_postgres_uri, - "writable_storage": f"file://localhost{tmpdir}/data", - "init_if_not_exists": True, - }, - "path": "/", - } - ], - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "access_lists": {}, - "provider": "toy", - "admins": ["admin"], - }, - }, - } - with Context.from_app(build_app_from_config(config)) as context: - with enter_username_password("admin", "admin"): - # Prompts for login here because anonymous access is not allowed - admin_client = from_context(context) - sp = admin_client.context.admin.create_service_principal("user") - key_info = admin_client.context.admin.create_api_key(sp["uuid"]) - admin_client.write_array([1, 2, 3], key="x") - admin_client.write_array([4, 5, 6], key="y") - admin_client.logout() - - # Drop the admin, no longer needed. - config["authentication"].pop("tiled_admins") - # Add the service principal to the access_lists. - config["access_control"]["args"]["access_lists"][sp["uuid"]] = ["x"] - with Context.from_app( - build_app_from_config(config), api_key=key_info["secret"] - ) as context: - sp_client = from_context(context) - assert list(sp_client) == ["x"] - - -class CustomAttributesAuthenticator(DictionaryAuthenticator): - """An example authenticator that enriches the stored user information.""" - - def __init__(self, users: dict, confirmation_message: str = ""): - self._users = users - super().__init__( - {username: user["password"] for username, user in users.items()}, - confirmation_message, +def test_user_owned_node_access_control(access_control_test_context_factory): + """ + Test that user-owned nodes (i.e. nodes created without access tags applied) + are visible after creation and can be modified by the user. + Also test that the data is visible after a tag is applied, and + that other users cannot see user-owned nodes. + """ + + alice_client = access_control_test_context_factory("alice", "alice") + bob_client = access_control_test_context_factory("bob", "bob") + + top = "foo" + for data in ["data_Y"]: + # Create a new user-owned node + alice_client[top].write_array(arr, key=data) + assert data in alice_client[top] + alice_client[top][data] + access_blob = alice_client[top][data].access_blob + assert "user" in access_blob + assert "alice" in access_blob["user"] + # Convert from user-owned node to a tagged node + alice_client[top][data].replace_metadata(access_tags=["alice_tag"]) + access_blob = alice_client[top][data].access_blob + assert "user" not in access_blob + assert "tags" in access_blob + assert "alice_tag" in access_blob["tags"] + assert data in alice_client[top] + alice_client[top][data] + + top = "bar" + for data in ["data_Z"]: + # Create a user-owned node and check that it is access restricted + alice_client[top].write_array(arr, key=data) + assert data not in bob_client[top] + with pytest.raises(KeyError): + bob_client[top][data] + + +def test_public_anonymous_access_control(access_control_test_context_factory): + """ + Test that data which is tagged public is visible to unauthenticated + (anonymous) users when the server allows anonymous access. + """ + zoe_client = access_control_test_context_factory("zoe", "zoe") + zoe_client.logout() + anon_client = zoe_client + + top = "qux" + assert top in anon_client + for data in ["data_A", "data_B"]: + assert data not in anon_client[top] + with pytest.raises(KeyError): + anon_client[top][data] + for data in ["data_C"]: + assert data in anon_client[top] + anon_client[top][data] + + +def test_admin_access_control(access_control_test_context_factory): + """ + Test that admin accounts have various elevated privileges, including: + - Apply/remove public tag to/from a node + - Apply/remove tags while ignoring minimum scopes + - Apply/remove tags that the user does not own + - View all data regardless of tags + - Apply an access tag that is not defined (disallowed) + - Remove all tags from a node, but still view that node + - Also includes test of an empty tags list blocking access for regular users + """ + admin_client = access_control_test_context_factory("admin", "admin") + alice_client = access_control_test_context_factory("alice", "alice") + + top = "foo" + for data in ["data_L"]: + # create a node and tag it public + admin_client[top].write_array(arr, key=data, access_tags=["public"]) + assert data in alice_client[top] + alice_client[top][data] + # remove public access, in fact remove all tags and ignore missing scopes + admin_client[top][data].replace_metadata(access_tags=[]) + assert data in admin_client[top] + admin_client[top][data] + assert data not in alice_client[top] + with pytest.raises(KeyError): + alice_client[top][data] + # apply a tag that the admin user does not own and ignore missing scopes + admin_client[top][data].replace_metadata(access_tags=["chemists_tag"]) + assert data in admin_client[top] + admin_client[top][data] + assert data in alice_client[top] + alice_client[top][data] + # remove a tag that the admin user does not own + admin_client[top][data].replace_metadata(access_tags=["chemists_tag"]) + # apply a tag which is not defined + with fail_with_status_code(HTTP_403_FORBIDDEN): + admin_client[top][data].replace_metadata(access_tags=["undefined_tag"]) + + +def test_update_node_access_control(access_control_test_context_factory): + """ + Test that access control on metadata changes is working. + + This tests the following: + - Update metadata while having write access + - Prevent updating metadata without having write access + - Successfully add an access tag and remove an access tag + - Prevent adding or removing an access tag without having write access + - Prevent adding or removing access tags which the user does not own + - Add and remove access tags which do not confer the necessary scopes + - Attempt to add an undefined access tag (not allowed) + - Attempt to add the "public" tag (admin only) + - Attempt to remove the "public" tag (admin only) + - Attempt to remove an undefined access tag (not allowed) + """ + admin_client = access_control_test_context_factory("admin", "admin") + alice_client = access_control_test_context_factory("alice", "alice") + chris_client = access_control_test_context_factory("chris", "chris") + sue_client = access_control_test_context_factory("sue", "sue") + + top = "qux" + for data in ["data_F"]: + admin_client[top].write_array(arr, key=data, access_tags=["alice_tag"]) + # successfully update metadata, user has write access + alice_client[top][data].replace_metadata(metadata={"materials": ["Cu", "Ag"]}) + assert "Ag" in alice_client[top][data].metadata["materials"] + # fail to update metadata, user does not have write access + with fail_with_status_code(HTTP_403_FORBIDDEN): + chris_client[top][data].replace_metadata( + metadata={"materials": ["Ag", "Au"]} + ) + assert "Au" not in chris_client[top][data].metadata["materials"] + + # succeeds to add a new access tag and remove the old access tag + alice_client[top][data].replace_metadata(access_tags=["biologists_tag"]) + access_tags = alice_client[top][data].access_blob["tags"] + assert "alice_tag" not in access_tags + assert "biologists_tag" in access_tags + + # fails to add a new access tag, user does not have write access + with fail_with_status_code(HTTP_403_FORBIDDEN): + chris_client[top][data].replace_metadata( + access_tags=["biologists_tag", "chris_tag"] + ) + admin_client[top][data].replace_metadata( + access_tags=["alice_tag", "biologists_tag"] + ) + # fails to remove an access tag, user does not have write access + with fail_with_status_code(HTTP_403_FORBIDDEN): + chris_client[top][data].replace_metadata(access_tags=["biologists_tag"]) + admin_client[top][data].replace_metadata(access_tags=["biologists_tag"]) + + # fails to add a new access tag, user does not own the tag + with fail_with_status_code(HTTP_403_FORBIDDEN): + alice_client[top][data].replace_metadata( + access_tags=["biologists_tag", "chris_tag"] + ) + admin_client[top][data].replace_metadata( + access_tags=["biologists_tag", "chris_tag"] + ) + # fails to remove an access tag, user does not own the tag + with fail_with_status_code(HTTP_403_FORBIDDEN): + chris_client[top][data].replace_metadata(access_tags=["biologists_tag"]) + admin_client[top][data].replace_metadata(access_tags=["biologists_tag"]) + + # fail to add an undefined tag + with fail_with_status_code(HTTP_403_FORBIDDEN): + alice_client[top][data].replace_metadata( + access_tags=["undefined_tag", "biologists_tag"] + ) + # fail to add the "public" tag + with fail_with_status_code(HTTP_403_FORBIDDEN): + alice_client[top][data].replace_metadata( + access_tags=["public", "biologists_tag"] + ) + admin_client[top][data].replace_metadata( + access_tags=["public", "biologists_tag"] + ) + # fail to remove the "public" tag + with fail_with_status_code(HTTP_403_FORBIDDEN): + alice_client[top][data].replace_metadata(access_tags=["biologists_tag"]) + admin_client[top][data].replace_metadata(access_tags=["biologists_tag"]) + + # surgically add an undefined tag to the node, then fail when trying to remove it + db = sqlite3.connect(f"file:catalog_{top}?mode=memory&cache=shared", uri=True) + cursor = db.cursor() + cursor.execute( + "UPDATE nodes SET access_blob = ? WHERE key = ?", + ('{"tags": ["undefined_tag", "biologists_tag"]}', data), + ) + db.commit() + with fail_with_status_code(HTTP_403_FORBIDDEN): + alice_client[top][data].replace_metadata(access_tags=["biologists_tag"]) + + top = "baz" + for data in ["data_G"]: + sue_client[top].write_array(arr, key=data) + # fail to apply a new access tag as it does not give the user the + # minimum required scopes + # this case only affects user-owned nodes: + # - if we did not have read access, we would not even see the node + # - if we did not have write access, we would be blocked by scopes + # - if an existing tag already gave us read and write, adding a tag would succeed + # - if an existing tag already gave us read and write, and we tried to remove it + # while adding the new tag, it's really the removal operation that prevents this, + # and this operation is tested below + # this leaves only user-owned nodes (full access for the user, but no existing tags) + with fail_with_status_code(HTTP_403_FORBIDDEN): + sue_client[top][data].replace_metadata(access_tags=["chemists_tag"]) + sue_client[top][data].replace_metadata( + access_tags=["physicists_tag", "chemists_tag"] ) + # fail to apply the new access tag as removing the old access tag results + # in insufficent scopes for the user + with fail_with_status_code(HTTP_403_FORBIDDEN): + sue_client[top][data].replace_metadata(access_tags=["chemists_tag"]) - async def authenticate(self, username, password): - state = await super().authenticate(username, password) - if isinstance(state, UserSessionState): - # enrich the auth state - state.state["attributes"] = self._users[username].get("attributes", {}) - return state + +def test_empty_access_blob_access_control(access_control_test_context_factory): + """ + Test the cases where a node in the catalog has an empty access blob. + This case occurs when migrating an older catalog without also + populating the access_blob column. + """ + admin_client = access_control_test_context_factory("admin", "admin") + alice_client = access_control_test_context_factory("alice", "alice") + + top = "qux" + for data in ["data_M"]: + admin_client[top].write_array(arr, key=data, access_tags=["alice_tag"]) + db = sqlite3.connect(f"file:catalog_{top}?mode=memory&cache=shared", uri=True) + cursor = db.cursor() + cursor.execute( + "UPDATE nodes SET access_blob = json('{}') WHERE key == 'data_M'" + ) + db.commit() + + assert data in admin_client[top] + admin_client[top][data] + assert data not in alice_client[top] + with pytest.raises(KeyError): + alice_client[top][data] -class CustomAttributesAccessPolicy: +def test_container_access_control(access_control_test_context_factory): """ - A policy that demonstrates comparing metadata against user information stored at login-time. + Test that access control for data nested in containers allows/denies access. + This mostly checks that if a user does not have access to a container, + that user cannot reach inside the container to view data they would + otherwise have access for. """ + alice_client = access_control_test_context_factory("alice", "alice") + sue_client = access_control_test_context_factory("sue", "sue") + + top = "baz" + for c in ["C1"]: + alice_client[top].create_container(key=c, access_tags=["alice_tag"]) + alice_client[top][c].write_array( + arr, key=f"{c}_array", access_tags=["physicists_tag"] + ) + assert f"{c}_array" in alice_client[top][c] + alice_client[top][c][f"{c}_array"] + assert c not in sue_client[top] + with pytest.raises(KeyError): + sue_client[top][c][f"{c}_array"] - READ_METADATA = ["read:metadata"] - - def __init__(self): - pass - - async def allowed_scopes(self, node, principal, authn_scopes): - if hasattr(principal, "sessions"): - if len(principal.sessions): - auth_state = principal.sessions[-1].state or {} - auth_attributes = auth_state.get("attributes", {}) - if auth_attributes: - if "admins" in auth_attributes.get("groups", []): - return ALL_SCOPES - - if not node.metadata(): - return self.READ_METADATA - - if node.metadata()["beamline"] in auth_attributes.get( - "beamlines", [] - ) or node.metadata()["proposal"] in auth_attributes.get( - "proposals", [] - ): - return USER_SCOPES - - return self.READ_METADATA - return NO_SCOPES - - async def filters(self, node, principal, authn_scopes, scopes): - if not scopes.issubset( - await self.allowed_scopes(node, principal, authn_scopes) - ): - return NO_ACCESS - return [] - - -def tree_enriched_metadata(): - return MapAdapter( - { - "A": ArrayAdapter.from_array( - numpy.ones(10), metadata={"beamline": "bl1", "proposal": "prop1"} - ), - "B": ArrayAdapter.from_array( - numpy.ones(10), metadata={"beamline": "bl1", "proposal": "prop2"} - ), - "C": ArrayAdapter.from_array( - numpy.ones(10), metadata={"beamline": "bl2", "proposal": "prop2"} - ), - "D": ArrayAdapter.from_array( - numpy.ones(10), metadata={"beamline": "bl2", "proposal": "prop3"} - ), - }, + +def test_node_export_access_control( + access_control_test_context_factory, buffer_factory +): + """ + Test access control when exporting from Tiled (here: a container). + These tests include: + - Test that top-level nodes are disincluded appropriately + (MapAdapter->CatalogAdapter transition). + - Test that basic export works - i.e. nodes for which the user has + access are included. + - Test that nodes for which the user does not have access are not included. + - Test that this behavior also works for user-owned (untagged) nodes. + """ + alice_client = access_control_test_context_factory("alice", "alice") + sue_client = access_control_test_context_factory("sue", "sue") + + top = "baz" + alice_client[top].write_array(arr, key="data_D") + sue_client[top].write_array(arr, key="data_E") + + alice_export_buffer = buffer_factory() + sue_export_buffer = buffer_factory() + + alice_client.export(alice_export_buffer, format="application/json") + sue_client.export(sue_export_buffer, format="application/json") + + alice_export_buffer.seek(0) + sue_export_buffer.seek(0) + + alice_exported_data = json.loads(alice_export_buffer.read()) + sue_exported_data = json.loads(sue_export_buffer.read()) + + top = "foo" + assert top in alice_exported_data["contents"] + assert top not in sue_exported_data["contents"] + for data in ["data_A", "data_B", "data_C"]: + assert data in alice_exported_data["contents"][top]["contents"] + alice_exported_data["contents"][top]["contents"][data] + + top = "baz" + assert top in alice_exported_data["contents"] + assert top in sue_exported_data["contents"] + for data in ["data_A", "data_B", "data_D"]: + assert data not in sue_exported_data["contents"][top]["contents"] + with pytest.raises(KeyError): + sue_exported_data["contents"][top]["contents"][data] + for data in ["data_C", "data_E"]: + assert data in sue_exported_data["contents"][top]["contents"] + sue_exported_data["contents"][top]["contents"][data] + + +def test_apikey_auth_access_control(access_control_test_context_factory): + """ + Test access control when authenticated by an API key, including: + - Allow basic access with an API key that is not tag-restricted + - Disallow access to tags that are not added to a tag-restricted API key + - Allow access to tags that are added to a tag-restricted API key + - User-owned node access/writing is blocked when using a tag-restricted API key + """ + alice_client = access_control_test_context_factory("alice", "alice") + alice_apikey_info = alice_client.context.create_api_key() + alice_client.logout() + alice_client.context.api_key = alice_apikey_info["secret"] + + top = "foo" + for data in ["data_A"]: + assert data in alice_client[top] + alice_client[top][data] + + top = "bar" + alice_client[top].write_array(arr, key="data_O") + + alice_apikey_info = alice_client.context.create_api_key( + access_tags=["chemists_tag"] ) + alice_client.context.api_key = alice_apikey_info["secret"] + + top = "bar" + for data in ["data_A"]: + assert data not in alice_client[top] + with pytest.raises(KeyError): + alice_client[top][data] + for data in ["data_B"]: + assert data in alice_client[top] + alice_client[top][data] + for data in ["data_O"]: + assert data not in alice_client[top] + with pytest.raises(KeyError): + alice_client[top][data] + with fail_with_status_code(HTTP_403_FORBIDDEN): + alice_client[top].write_array(arr, key="data_P") -@pytest.fixture(scope="module") -def custom_attributes_context(): - config = { - "authentication": { - "allow_anonymous_access": False, - "secret_keys": ["SECRET"], - "providers": [ - { - "provider": "toy", - "authenticator": f"{__name__}:CustomAttributesAuthenticator", - "args": { - "users": { - "alice": { - "password": "secret1", - "attributes": {"proposals": ["prop1"]}, - }, - "bob": { - "password": "secret2", - "attributes": {"beamlines": ["bl1"]}, - }, - "cara": { - "password": "secret3", - "attributes": { - "beamlines": ["bl2"], - "proposals": ["prop1"], - }, - }, - "john": {"password": "secret4", "attributes": {}}, - "admin": { - "password": "admin", - "attributes": {"groups": ["admins"]}, - }, - } - }, - } - ], - }, - "database": { - "uri": "sqlite://", # in-memory - }, - "access_control": { - "access_policy": f"{__name__}:CustomAttributesAccessPolicy", - "args": {}, - }, - "trees": [ - {"tree": f"{__name__}:tree_enriched_metadata", "path": "/"}, - ], - } - app = build_app_from_config(config) - with Context.from_app(app) as context: - yield context - - -@pytest.mark.parametrize( - ("username", "password", "nodes"), - [ - ("admin", "admin", ["A", "B", "C", "D"]), - ("alice", "secret1", ["A"]), - ("bob", "secret2", ["A", "B"]), - ("cara", "secret3", ["A", "C", "D"]), - ], -) -def test_custom_attributes_with_data_access( - enter_username_password, custom_attributes_context, username, password, nodes +def test_service_principal_access_control( + access_control_test_context_factory, compile_access_tags_db_with_reset ): - """Test that the user has access to the data based on their auth attributes.""" - with enter_username_password(username, password): - custom_attributes_context.authenticate() - key_info = custom_attributes_context.create_api_key() - custom_attributes_context.logout() - - try: - custom_attributes_context.api_key = key_info["secret"] - client = from_context(custom_attributes_context) - - for node in nodes: - client[node].read() - - finally: - custom_attributes_context.api_key = None - - -@pytest.mark.parametrize( - ("username", "password", "nodes"), - [ - ("alice", "secret1", ["B", "C", "D"]), - ("bob", "secret2", ["C", "D"]), - ("cara", "secret3", ["B"]), - ("john", "secret4", ["A", "B", "C", "D"]), - ], -) -def test_custom_attributes_without_data_access( - enter_username_password, custom_attributes_context, username, password, nodes -): - """Test that the user cannot access data due to missing auth attributes.""" - with enter_username_password(username, password): - custom_attributes_context.authenticate() - key_info = custom_attributes_context.create_api_key() - custom_attributes_context.logout() - - try: - custom_attributes_context.api_key = key_info["secret"] - client = from_context(custom_attributes_context) - - for node in nodes: - with pytest.raises(ClientError): - client[node].read() - - finally: - custom_attributes_context.api_key = None + """ + Test that access control works for service principals. + Creates a service principal and updates the access tag config to + add this prinicpal to a tag. + """ + admin_client = access_control_test_context_factory("admin", "admin") + sp = admin_client.context.admin.create_service_principal("user") + sp_apikey_info = admin_client.context.admin.create_api_key(sp["uuid"]) + sp_client = access_control_test_context_factory( + sp["uuid"], api_key=sp_apikey_info["secret"] + ) + + access_tags_compiler = compile_access_tags_db_with_reset + compiler_tag_config = access_tags_compiler.tag_config + + compiler_tag_config["tags"]["physicists_tag"]["users"].append( + {"name": sp["uuid"], "role": "facility_admin"} + ) + compiler_tag_config["tag_owners"]["physicists_tag"]["users"].append( + {"name": sp["uuid"]} + ) + + access_tags_compiler.load_tag_config() + access_tags_compiler.recompile() + + top = "baz" + for data in ["data_A"]: + assert data not in sp_client[top] + with pytest.raises(KeyError): + sp_client[top][data] + for data in ["data_N"]: + sp_client[top].write_array(arr, key=data, access_tags=["physicists_tag"]) + assert data in sp_client[top] + sp_client[top][data] diff --git a/tiled/_tests/test_authentication.py b/tiled/_tests/test_authentication.py index 3b8b385b6..94999d569 100644 --- a/tiled/_tests/test_authentication.py +++ b/tiled/_tests/test_authentication.py @@ -7,7 +7,11 @@ import numpy import pytest -from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_401_UNAUTHORIZED, + HTTP_403_FORBIDDEN, +) from ..adapters.array import ArrayAdapter from ..adapters.mapping import MapAdapter @@ -419,7 +423,7 @@ def test_api_key_scopes(enter_username_password, config): with enter_username_password("bob", "secret2"): context.authenticate() # Try to request a key with more scopes that the user has. - with fail_with_status_code(HTTP_400_BAD_REQUEST): + with fail_with_status_code(HTTP_403_FORBIDDEN): context.create_api_key(scopes=["admin:apikeys"]) # Request a key with reduced scope that can *only* read metadata. metadata_key_info = context.create_api_key(scopes=["read:metadata"]) @@ -613,7 +617,7 @@ def test_admin_api_key_any_principal_exceeds_scopes( context.authenticate() principal_uuid = principals_context["uuid"]["bob"] - with fail_with_status_code(HTTP_400_BAD_REQUEST) as fail_info: + with fail_with_status_code(HTTP_403_FORBIDDEN) as fail_info: context.admin.create_api_key(principal_uuid, scopes=["read:principals"]) fail_message = " must be a subset of the principal's scopes " assert fail_message in fail_info.value.response.text diff --git a/tiled/_tests/test_catalog.py b/tiled/_tests/test_catalog.py index 96fb4a19e..6c0a1d5c4 100644 --- a/tiled/_tests/test_catalog.py +++ b/tiled/_tests/test_catalog.py @@ -33,7 +33,7 @@ from ..storage import SQLStorage, get_storage, parse_storage from ..structures.core import StructureFamily from ..utils import Conflicts, ensure_specified_sql_driver, ensure_uri -from .utils import enter_username_password, sql_table_exists +from .utils import sql_table_exists @pytest_asyncio.fixture @@ -702,76 +702,6 @@ async def test_delete_external_asset_registered_twice(tmpdir): assert len(assets_after_second_delete) == 0 -@pytest.mark.asyncio -async def test_access_control(tmpdir, sqlite_or_postgres_uri): - config = { - "authentication": { - "allow_anonymous_access": True, - "secret_keys": ["SECRET"], - "providers": [ - { - "provider": "toy", - "authenticator": "tiled.authenticators:DictionaryAuthenticator", - "args": { - "users_to_passwords": { - "alice": "secret1", - "bob": "secret2", - "admin": "admin", - } - }, - } - ], - }, - "access_control": { - "access_policy": "tiled.access_policies:SimpleAccessPolicy", - "args": { - "provider": "toy", - "access_lists": { - "alice": ["outer_x", "inner"], - "bob": ["outer_y"], - }, - "admins": ["admin"], - "public": ["outer_z", "inner"], - }, - }, - "database": { - "uri": "sqlite://", # in-memory - }, - "trees": [ - { - "tree": "catalog", - "path": "/", - "args": { - "uri": sqlite_or_postgres_uri, - "writable_storage": str(tmpdir / "data"), - "init_if_not_exists": True, - }, - }, - ], - } - - app = build_app_from_config(config) - with Context.from_app(app) as context: - admin_client = from_context(context) - with enter_username_password("admin", "admin"): - admin_client.login() - for key in ["outer_x", "outer_y", "outer_z"]: - container = admin_client.create_container(key) - container.write_array([1, 2, 3], key="inner") - admin_client.logout() - alice_client = from_context(context) - with enter_username_password("alice", "secret1"): - alice_client.login() - alice_client["outer_x"]["inner"].read() - with pytest.raises(KeyError): - alice_client["outer_y"] - alice_client.logout() - public_client = from_context(context) - public_client["outer_z"]["inner"].read() - with pytest.raises(KeyError): - public_client["outer_x"] - - @pytest.mark.parametrize( "assets", [ diff --git a/tiled/_tests/test_protocols.py b/tiled/_tests/test_protocols.py index 82ac075d9..6eea5204a 100644 --- a/tiled/_tests/test_protocols.py +++ b/tiled/_tests/test_protocols.py @@ -9,7 +9,8 @@ from numpy.typing import NDArray from pytest_mock import MockFixture -from ..access_policies import ALL_ACCESS +from ..access_control.access_policies import ALL_ACCESS +from ..access_control.scopes import ALL_SCOPES from ..adapters.awkward_directory_container import DirectoryContainer from ..adapters.protocols import ( AccessPolicy, @@ -20,7 +21,6 @@ TableAdapter, ) from ..ndslice import NDSlice -from ..scopes import ALL_SCOPES from ..server.schemas import Principal, PrincipalType from ..storage import Storage from ..structures.array import ArrayStructure, BuiltinDtype @@ -383,6 +383,7 @@ async def allowed_scopes( self, node: BaseAdapter, principal: Principal, + authn_access_tags: Optional[Set[str]], authn_scopes: Scopes, ) -> Scopes: allowed = self.scopes @@ -393,6 +394,7 @@ async def filters( self, node: BaseAdapter, principal: Principal, + authn_access_tags: Optional[Set[str]], authn_scopes: Scopes, scopes: Scopes, ) -> Filters: @@ -405,11 +407,12 @@ async def accesspolicy_protocol_functions( policy: AccessPolicy, node: BaseAdapter, principal: Principal, + authn_access_tags: Optional[Set[str]], authn_scopes: Scopes, scopes: Scopes, ) -> None: - await policy.allowed_scopes(node, principal, authn_scopes) - await policy.filters(node, principal, authn_scopes, scopes) + await policy.allowed_scopes(node, principal, authn_access_tags, authn_scopes) + await policy.filters(node, principal, authn_access_tags, authn_scopes, scopes) @pytest.mark.asyncio # type: ignore @@ -426,6 +429,7 @@ async def test_accesspolicy_protocol(mocker: MockFixture) -> None: principal = Principal( uuid="12345678124123412345678123456781", type=PrincipalType.user ) + authn_access_tags = {"qux", "quux"} authn_scopes = {"abc", "baz"} scopes = {"abc"} @@ -435,6 +439,7 @@ async def test_accesspolicy_protocol(mocker: MockFixture) -> None: anyaccesspolicy, anyawkwardadapter, principal, + authn_access_tags, authn_scopes, scopes, ) diff --git a/tiled/access_policies.py b/tiled/access_control/access_policies.py similarity index 60% rename from tiled/access_policies.py rename to tiled/access_control/access_policies.py index 4c924d568..af8432f9f 100644 --- a/tiled/access_policies.py +++ b/tiled/access_control/access_policies.py @@ -1,12 +1,9 @@ import logging import os -import sqlite3 -from contextlib import closing -from functools import partial -from .queries import AccessBlobFilter, In, KeysFilter +from ..queries import AccessBlobFilter +from ..utils import Sentinel, import_object from .scopes import ALL_SCOPES, PUBLIC_SCOPES -from .utils import Sentinel, import_object ALL_ACCESS = Sentinel("ALL_ACCESS") NO_ACCESS = Sentinel("NO_ACCESS") @@ -26,160 +23,13 @@ class DummyAccessPolicy: "Impose no access restrictions." - async def allowed_scopes(self, node, principal, authn_scopes): + async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes): return ALL_SCOPES - async def filters(self, node, principal, authn_scopes, scopes): + async def filters(self, node, principal, authn_access_tags, authn_scopes, scopes): return [] -class SimpleAccessPolicy: - """ - A mapping of user names to lists of entries they have access to. - - This simple policy does not provide fine-grained control of scopes. - Any restriction on scopes is applied the same to all users, except - for an optional list of 'admins'. - - This is used in the test suite; it may be suitable for very simple - deployments. - - >>> SimpleAccessPolicy({"alice": ["A", "B"], "bob": ["B"]}, provider="toy") - """ - - ALL = ALL_ACCESS - - def __init__( - self, access_lists, *, provider, key=None, scopes=None, public=None, admins=None - ): - self.access_lists = {} - self.provider = provider - self.key = key - self.scopes = scopes if (scopes is not None) else ALL_SCOPES - self.public = set(public or []) - self.admins = set(admins or []) - for key, value in access_lists.items(): - if isinstance(value, str): - value = import_object(value) - self.access_lists[key] = value - - def _get_id(self, principal): - # Get the id (i.e. username) of this Principal for the - # associated authentication provider. - for identity in principal.identities: - if identity.provider == self.provider: - id = identity.id - break - else: - raise ValueError( - f"Principcal {principal} has no identity from provider {self.provider}. " - f"Its identities are: {principal.identities}" - ) - return id - - async def allowed_scopes(self, node, principal, authn_scopes): - # If this is being called, filter_access has let us get this far. - if principal is None: - allowed = PUBLIC_SCOPES - elif principal.type == "service": - allowed = self.scopes - elif self._get_id(principal) in self.admins: - allowed = ALL_SCOPES - # The simple policy does not provide for different Principals to - # have different scopes on different Nodes. If the Principal has access, - # they have the same hard-coded access everywhere. - else: - allowed = self.scopes - return allowed - - async def filters(self, node, principal, authn_scopes, scopes): - queries = [] - query_filter = KeysFilter if not self.key else partial(In, self.key) - if principal is None: - queries.append(query_filter(self.public)) - else: - # Services have no identities; just use the uuid. - if principal.type == "service": - id = str(principal.uuid) - else: - id = self._get_id(principal) - if id in self.admins: - return queries - if not scopes.issubset(self.scopes): - return NO_ACCESS - access_list = self.access_lists.get(id, []) - if not (access_list == self.ALL): - try: - allowed = set(access_list or []) - except TypeError: - # Provide rich debugging info because we have encountered a confusing - # bug here in a previous implementation. - raise TypeError( - f"Unexpected access_list {access_list} of type {type(access_list)}. " - f"Expected iterable or {self.ALL}, instance of {type(self.ALL)}." - ) - queries.append(query_filter(allowed)) - return queries - - -class AccessTagsParser: - @classmethod - def from_uri(cls, uri): - db = sqlite3.connect(f"{uri}?ro", uri=True, check_same_thread=False) - return cls(db) - - def __init__(self, db): - self.db = db - - def is_tag_defined(self, name): - with closing(self.db.cursor()) as cursor: - cursor.execute("SELECT 1 FROM tags WHERE name = ?;", (name,)) - row = cursor.fetchone() - found_tagname = bool(row) - return found_tagname - - def get_public_tags(self): - with closing(self.db.cursor()) as cursor: - cursor.execute("SELECT name FROM public_tags;") - public_tags = {name for (name,) in cursor.fetchall()} - return public_tags - - def get_scopes_from_tag(self, tagname, username): - with closing(self.db.cursor()) as cursor: - cursor.execute( - "SELECT scope_name FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?;", - (tagname, username), - ) - user_tag_scopes = {scope for (scope,) in cursor.fetchall()} - return user_tag_scopes - - def is_tag_owner(self, tagname, username): - with closing(self.db.cursor()) as cursor: - cursor.execute( - "SELECT 1 FROM user_tag_owners WHERE tag_name = ? AND user_name = ?;", - (tagname, username), - ) - row = cursor.fetchone() - found_owner = bool(row) - return found_owner - - def is_tag_public(self, name): - with closing(self.db.cursor()) as cursor: - cursor.execute("SELECT 1 FROM public_tags WHERE name = ?;", (name,)) - row = cursor.fetchone() - found_public = bool(row) - return found_public - - def get_tags_from_scope(self, scope, username): - with closing(self.db.cursor()) as cursor: - cursor.execute( - "SELECT tag_name FROM user_tag_scopes WHERE user_name = ? AND scope_name = ?;", - (username, scope), - ) - user_scope_tags = {tag for (tag,) in cursor.fetchall()} - return user_scope_tags - - class TagBasedAccessPolicy: def __init__( self, @@ -205,6 +55,7 @@ def __init__( self.unremovable_scopes = ["read:metadata", "write:metadata"] self.admin_scopes = ["admin:apikeys"] self.public_tag = "public".casefold() + self.invalid_tag_names = [name.casefold() for name in []] def _get_id(self, principal): for identity in principal.identities: @@ -221,7 +72,9 @@ def _is_admin(self, authn_scopes): return True return False - async def init_node(self, principal, authn_scopes, access_blob=None): + async def init_node( + self, principal, authn_access_tags, authn_scopes, access_blob=None + ): if principal.type == "service": identifier = str(principal.uuid) else: @@ -241,20 +94,29 @@ async def init_node(self, principal, authn_scopes, access_blob=None): access_tags = set(access_blob["tags"]) include_public_tag = False for tag in access_tags: + if authn_access_tags is not None: + if tag not in authn_access_tags: + raise ValueError( + f"Cannot apply tag to node: API key is restricted to access tags: {authn_access_tags}." + ) if tag.casefold() == self.public_tag: include_public_tag = True if not self._is_admin(authn_scopes): raise ValueError( "Cannot apply 'public' tag to node: only Tiled admins can apply the 'public' tag." ) - elif not self.is_tag_defined(tag): + elif not await self.is_tag_defined(tag): raise ValueError(f"Cannot apply tag to node: {tag=} is not defined") - elif not self.is_tag_owner(tag, identifier): + elif not await self.is_tag_owner(tag, identifier): # admins can ignore the tag ownership check if not self._is_admin(authn_scopes): raise ValueError( f"Cannot apply tag to node: user='{identifier}' is not an owner of {tag=}" ) + elif tag.casefold() in self.invalid_tag_names: + raise ValueError( + f"Cannot apply tag to node: '{tag}' is not a valid tag name." + ) access_tags_from_policy = { tag for tag in access_tags if tag.casefold() != self.public_tag @@ -270,7 +132,7 @@ async def init_node(self, principal, authn_scopes, access_blob=None): # check that the access_blob would not result in invalid scopes for user. new_scopes = set() for tag in access_tags_from_policy: - new_scopes.update(self.get_scopes_from_tag(tag, identifier)) + new_scopes.update(await self.get_scopes_from_tag(tag, identifier)) if not all(scope in new_scopes for scope in self.unremovable_scopes): raise ValueError( f"Cannot init node with tags: operation does not grant necessary scopes.\n" @@ -278,6 +140,12 @@ async def init_node(self, principal, authn_scopes, access_blob=None): f"This access_blob does not confer the minimum scopes: {self.unremovable_scopes}" ) else: + if authn_access_tags is not None: + raise ValueError( + f"Cannot init node as user-owned node.\n" + f"Current API key does not permit action on user-owned nodes.\n" + f"Please provide a tag allowed by this API key: {authn_access_tags}" + ) access_blob_from_policy = {"user": identifier} access_blob_modified = True @@ -287,7 +155,9 @@ async def init_node(self, principal, authn_scopes, access_blob=None): # modified means the blob to-be-used was changed in comparison to the user input return access_blob_modified, access_blob_from_policy - async def modify_node(self, node, principal, authn_scopes, access_blob): + async def modify_node( + self, node, principal, authn_access_tags, authn_scopes, access_blob + ): if principal.type == "service": identifier = str(principal.uuid) else: @@ -314,6 +184,11 @@ async def modify_node(self, node, principal, authn_scopes, access_blob): include_public_tag = False # check for tags that need to be added for tag in access_tags: + if authn_access_tags is not None: + if tag not in authn_access_tags: + raise ValueError( + f"Cannot apply tag to node: API key is restricted to access tags: {authn_access_tags}." + ) if tag in node.access_blob.get("tags", []): # node already has this tag - no action. # or: access_blob does not have "tags" key, @@ -328,14 +203,18 @@ async def modify_node(self, node, principal, authn_scopes, access_blob): raise ValueError( "Cannot apply 'public' tag to node: only Tiled admins can apply the 'public' tag." ) - elif not self.is_tag_defined(tag): + elif not await self.is_tag_defined(tag): raise ValueError(f"Cannot apply tag to node: {tag=} is not defined") - elif not self.is_tag_owner(tag, identifier): + elif not await self.is_tag_owner(tag, identifier): # admins can ignore the tag ownership check if not self._is_admin(authn_scopes): raise ValueError( f"Cannot apply tag to node: user='{identifier}' is not an owner of {tag=}" ) + elif tag.casefold() in self.invalid_tag_names: + raise ValueError( + f"Cannot apply tag to node: '{tag}' is not a valid tag name." + ) access_tags_from_policy = { tag for tag in access_tags if tag.casefold() != self.public_tag @@ -348,21 +227,31 @@ async def modify_node(self, node, principal, authn_scopes, access_blob): for tag in set(node.access_blob["tags"]).difference( access_tags_from_policy ): + if authn_access_tags is not None: + if tag not in authn_access_tags: + raise ValueError( + f"Cannot remove tag from node: " + f"API key is restricted to access tags: {authn_access_tags}." + ) if tag == self.public_tag: if not self._is_admin(authn_scopes): raise ValueError( "Cannot remove 'public' tag from node: only Tiled admins can remove the 'public' tag." ) - elif not self.is_tag_defined(tag): + elif not await self.is_tag_defined(tag): raise ValueError( f"Cannot remove tag from node: {tag=} is not defined" ) - elif not self.is_tag_owner(tag, identifier): + elif not await self.is_tag_owner(tag, identifier): # admins can ignore the tag ownership check if not self._is_admin(authn_scopes): raise ValueError( f"Cannot remove tag from node: user='{identifier}' is not an owner of {tag=}" ) + elif tag.casefold() in self.invalid_tag_names: + raise ValueError( + f"Cannot remove tag from node: '{tag}' is not a valid tag name." + ) access_blob_from_policy = {"tags": list(access_tags_from_policy)} access_blob_modified = access_tags != access_tags_from_policy @@ -374,7 +263,7 @@ async def modify_node(self, node, principal, authn_scopes, access_blob): # converting from user-owned node to shared (tagged) node new_scopes = set() for tag in access_tags_from_policy: - new_scopes.update(self.get_scopes_from_tag(tag, identifier)) + new_scopes.update(await self.get_scopes_from_tag(tag, identifier)) if not all(scope in new_scopes for scope in self.unremovable_scopes): raise ValueError( f"Cannot modify tags on node: operation removes unremovable scopes.\n" @@ -389,7 +278,7 @@ async def modify_node(self, node, principal, authn_scopes, access_blob): # modified means the blob to-be-used was changed in comparison to the user input return access_blob_modified, access_blob_from_policy - async def allowed_scopes(self, node, principal, authn_scopes): + async def allowed_scopes(self, node, principal, authn_access_tags, authn_scopes): # If this is being called, filter_for_access has let us get this far. # However, filters and allowed_scopes should always be implemented to # give answers consistent with each other. @@ -398,31 +287,37 @@ async def allowed_scopes(self, node, principal, authn_scopes): elif self._is_admin(authn_scopes): allowed = self.scopes else: - if principal.type == "service": + if principal is None: + identifier = None + elif principal.type == "service": identifier = str(principal.uuid) else: identifier = self._get_id(principal) allowed = set() if "user" in node.access_blob: - if identifier == node.access_blob["user"]: + if authn_access_tags is None and identifier == node.access_blob["user"]: allowed = self.scopes elif "tags" in node.access_blob: for tag in node.access_blob["tags"]: - if self.is_tag_public(tag): + if authn_access_tags is not None: + if tag not in authn_access_tags: + continue + if await self.is_tag_public(tag): allowed.update(self.read_scopes) if tag == self.public_tag: continue - elif not self.is_tag_defined(tag): + elif not await self.is_tag_defined(tag): continue - tag_scopes = self.get_scopes_from_tag(tag, identifier) - allowed.update( - tag_scopes if tag_scopes.issubset(self.scopes) else set() - ) + if identifier is not None: + tag_scopes = await self.get_scopes_from_tag(tag, identifier) + allowed.update( + tag_scopes if tag_scopes.issubset(self.scopes) else set() + ) return allowed - async def filters(self, node, principal, authn_scopes, scopes): + async def filters(self, node, principal, authn_access_tags, authn_scopes, scopes): queries = [] query_filter = AccessBlobFilter @@ -431,24 +326,37 @@ async def filters(self, node, principal, authn_scopes, scopes): if not scopes.issubset(self.scopes): return NO_ACCESS - if principal.type == "service": - identifier = str(principal.uuid) - elif self._is_admin(authn_scopes): - return queries + tag_list = set() + if principal is None: + identifier = None else: - identifier = self._get_id(principal) + if principal.type == "service": + identifier = str(principal.uuid) + elif self._is_admin(authn_scopes): + return queries + else: + identifier = self._get_id(principal) + tag_list.update( + set.intersection( + *[ + await self.get_tags_from_scope(scope, identifier) + for scope in scopes + ] + ) + ) - tag_list = set.intersection( - *[self.get_tags_from_scope(scope, identifier) for scope in scopes] - ) tag_list.update( set.intersection( *[ - self.get_public_tags() if scope in self.read_scopes else set() + await self.get_public_tags() if scope in self.read_scopes else set() for scope in scopes ] ) ) + if authn_access_tags is not None: + identifier = None + tag_list.intersection_update(authn_access_tags) + queries.append(query_filter(identifier, tag_list)) return queries diff --git a/tiled/access_control/access_tags.py b/tiled/access_control/access_tags.py new file mode 100644 index 000000000..c4f8de70c --- /dev/null +++ b/tiled/access_control/access_tags.py @@ -0,0 +1,536 @@ +import sqlite3 +import warnings +from contextlib import closing +from pathlib import Path +from sys import intern + +import aiosqlite +import yaml + +from ..utils import InterningLoader, ensure_specified_sql_driver + + +class AccessTagsParser: + @classmethod + def from_uri(cls, uri): + if uri.startswith("file:"): + uri = uri.split(":", 1)[1] + uri = ensure_specified_sql_driver(uri) + if not uri.startswith("sqlite+aiosqlite:"): + raise ValueError( + f"AccessTagsParser must be given a SQLite database URI, " + f"i.e. 'sqlite:///...', 'sqlite+aiosqlite:///...'\n" + f"Given URI results in: {uri=}" + ) + uri_path = uri.split(":", 1)[1] + if not uri_path.startswith("///"): + raise ValueError( + "Invalid URI provided, URI must contain 3 forward slashes, " + "e.g. 'sqlite:///...'." + ) + uri = f"file:{uri_path[3:]}" + uri = uri if "?" in uri else f"{uri}?mode=ro" + return cls(uri=uri) + + def __init__(self, db=None, uri=None): + self._uri = uri + self._db = db + + async def connect(self): + if self._db is None: + self._db = await aiosqlite.connect( + self._uri, uri=True, check_same_thread=False + ) + + async def is_tag_defined(self, name): + async with self._db.cursor() as cursor: + await cursor.execute("SELECT 1 FROM tags WHERE name = ?;", (name,)) + row = await cursor.fetchone() + found_tagname = bool(row) + return found_tagname + + async def get_public_tags(self): + async with self._db.cursor() as cursor: + await cursor.execute("SELECT name FROM public_tags;") + public_tags = {name for (name,) in await cursor.fetchall()} + return public_tags + + async def get_scopes_from_tag(self, tagname, username): + async with self._db.cursor() as cursor: + await cursor.execute( + "SELECT scope_name FROM user_tag_scopes WHERE tag_name = ? AND user_name = ?;", + (tagname, username), + ) + user_tag_scopes = {scope for (scope,) in await cursor.fetchall()} + return user_tag_scopes + + async def is_tag_owner(self, tagname, username): + async with self._db.cursor() as cursor: + await cursor.execute( + "SELECT 1 FROM user_tag_owners WHERE tag_name = ? AND user_name = ?;", + (tagname, username), + ) + row = await cursor.fetchone() + found_owner = bool(row) + return found_owner + + async def is_tag_public(self, name): + async with self._db.cursor() as cursor: + await cursor.execute("SELECT 1 FROM public_tags WHERE name = ?;", (name,)) + row = await cursor.fetchone() + found_public = bool(row) + return found_public + + async def get_tags_from_scope(self, scope, username): + async with self._db.cursor() as cursor: + await cursor.execute( + "SELECT tag_name FROM user_tag_scopes WHERE user_name = ? AND scope_name = ?;", + (username, scope), + ) + user_scope_tags = {tag for (tag,) in await cursor.fetchall()} + return user_scope_tags + + +def create_access_tags_tables(db): + with closing(db.cursor()) as cursor: + tables_setup_sql = """ +PRAGMA journal_mode = WAL; +PRAGMA synchronous = NORMAL; +PRAGMA temp_store = MEMORY; +PRAGMA foreign_keys = ON; +BEGIN TRANSACTION; +CREATE TABLE IF NOT EXISTS tags ( + id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL, + is_public INTEGER NOT NULL DEFAULT 0 + CHECK (is_public IN (0,1)) +); +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL +); +CREATE TABLE IF NOT EXISTS scopes ( + id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL +); +CREATE TABLE IF NOT EXISTS tags_users_scopes ( + tag_id INTEGER NOT NULL + REFERENCES tags(id) + ON UPDATE CASCADE + ON DELETE CASCADE, + user_id INTEGER NOT NULL + REFERENCES users(id) + ON UPDATE CASCADE + ON DELETE CASCADE, + scope_id INTEGER NOT NULL + REFERENCES scopes(id) + ON UPDATE CASCADE + ON DELETE CASCADE, + PRIMARY KEY (tag_id, user_id, scope_id) +); +CREATE TABLE IF NOT EXISTS tag_owners ( + tag_id INTEGER NOT NULL + REFERENCES tags(id) + ON UPDATE CASCADE + ON DELETE CASCADE, + user_id INTEGER NOT NULL + REFERENCES users(id) + ON UPDATE CASCADE + ON DELETE CASCADE, + PRIMARY KEY (tag_id, user_id) +); +CREATE INDEX IF NOT EXISTS idx_tags_is_public ON tags (is_public); +CREATE INDEX IF NOT EXISTS idx_tus_users_scopes ON tags_users_scopes (user_id, scope_id); +CREATE INDEX IF NOT EXISTS idx_tus_users_scopes_scopeid ON tags_users_scopes (scope_id); +CREATE INDEX IF NOT EXISTS idx_tag_owners ON tag_owners (user_id); +CREATE VIEW IF NOT EXISTS public_tags AS + SELECT name + FROM tags + WHERE is_public = 1; +CREATE VIEW IF NOT EXISTS user_tag_scopes AS + SELECT + u.name AS user_name, + t.name AS tag_name, + s.name AS scope_name + FROM tags_users_scopes tus + JOIN users u ON u.id = tus.user_id + JOIN tags t ON t.id = tus.tag_id + JOIN scopes s ON s.id = tus.scope_id; +CREATE VIEW IF NOT EXISTS user_tag_owners AS + SELECT + u.name AS user_name, + t.name AS tag_name + FROM tag_owners towner + JOIN users u ON u.id = towner.user_id + JOIN tags t ON t.id = towner.tag_id; +PRAGMA optimize; +""" + cursor.executescript(tables_setup_sql) + db.commit() + + +def update_access_tags_tables(db, scopes, tags, owners, public_tags): + with closing(db.cursor()) as cursor: + tables_stage_sql = """ +BEGIN TRANSACTION; +CREATE TEMP TABLE IF NOT EXISTS stage_tags ( + id INTEGER, + name TEXT NOT NULL, + is_public INTEGER NOT NULL + CHECK (is_public IN (0,1)) +); +CREATE TEMP TABLE IF NOT EXISTS stage_users ( + id INTEGER, + name TEXT NOT NULL +); +CREATE TEMP TABLE IF NOT EXISTS stage_scopes ( + id INTEGER, + name TEXT NOT NULL +); +CREATE TEMP TABLE IF NOT EXISTS stage_tags_users_scopes ( + tag_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + scope_id INTEGER NOT NULL +); +CREATE TEMP TABLE IF NOT EXISTS stage_tag_owners ( + tag_id INTEGER NOT NULL, + user_id INTEGER NOT NULL +); +""" + cursor.executescript(tables_stage_sql) + + # put all items into staging + all_tags = [(tag, 0) for tag in tags] + [(tag, 0) for tag in owners] + all_public = [(tag,) for tag in public_tags] + all_users = {(user,) for users in tags.values() for user in users} + all_users.update({(user,) for users in owners.values() for user in users}) + all_scopes = [(scope,) for scope in scopes] + cursor.executemany( + "INSERT INTO stage_tags(name, is_public) VALUES (?,?);", + all_tags, + ) + cursor.executemany( + "UPDATE stage_tags SET is_public = 1 WHERE name = (?);", all_public + ) + cursor.executemany( + "INSERT INTO stage_users(name) VALUES (?);", + all_users, + ) + cursor.executemany( + "INSERT INTO stage_scopes(name) VALUES (?);", + all_scopes, + ) + + # push item names and metadata from staging to prod + # then pull back ID values from prod to staging + # note that UPSERT should always have a WHERE clause + # to avoid ambiguity. See the SQLite docs section 2.2 + # https://www.sqlite.org/lang_upsert.html + stage_push_sql = """ +BEGIN TRANSACTION; +INSERT INTO tags (name, is_public) + SELECT name, is_public + FROM stage_tags + WHERE true + ON CONFLICT (name) + DO UPDATE SET is_public = excluded.is_public; +INSERT INTO users(name) SELECT name FROM stage_users WHERE true ON CONFLICT(name) DO NOTHING; +INSERT INTO scopes(name) SELECT name FROM stage_scopes WHERE true ON CONFLICT(name) DO NOTHING; +UPDATE stage_tags SET id = (SELECT id FROM tags WHERE tags.name = stage_tags.name); +UPDATE stage_users SET id = (SELECT id FROM users WHERE users.name = stage_users.name); +UPDATE stage_scopes SET id = (SELECT id FROM scopes WHERE scopes.name = stage_scopes.name); +""" + cursor.executescript(stage_push_sql) + + # load db IDs for items into memory + cursor.execute("SELECT id, name FROM stage_tags") + tags_to_id = {intern(name): tag_id for (tag_id, name) in cursor.fetchall()} + cursor.execute("SELECT id, name FROM stage_users") + users_to_id = {intern(name): user_id for (user_id, name) in cursor.fetchall()} + cursor.execute("SELECT id, name FROM stage_scopes") + scopes_to_id = { + intern(name): scope_id for (scope_id, name) in cursor.fetchall() + } + + # flatten relationships and push to staging + tags_users_scopes = [ + (tags_to_id[tag], users_to_id[user], scopes_to_id[scope]) + for tag, users in tags.items() + for user, scopes in users.items() + for scope in scopes + ] + tag_owners = [ + (tags_to_id[tag], users_to_id[user]) + for tag, users in owners.items() + for user in users + ] + cursor.executemany( + "INSERT INTO stage_tags_users_scopes(tag_id, user_id, scope_id) VALUES (?,?,?);", + tags_users_scopes, + ) + cursor.executemany( + "INSERT INTO stage_tag_owners(tag_id, user_id) VALUES (?,?);", tag_owners + ) + + # delete outdated tags from prod and add updated relationships to db + # finally, drop the staging tables + # to-do: consider refactoring this to indvidual execute statements + # to avoid implicit pre-mature commit by executescript() + upsert_delete_sql = """ +BEGIN TRANSACTION; +DELETE from tags WHERE id NOT in (SELECT id FROM stage_tags); +DELETE from users WHERE id NOT in (SELECT id FROM stage_users); +DELETE from scopes WHERE id NOT in (SELECT id FROM stage_scopes); +INSERT INTO tags_users_scopes (tag_id, user_id, scope_id) + SELECT tag_id, user_id, scope_id FROM stage_tags_users_scopes + WHERE true + ON CONFLICT (tag_id, user_id, scope_id) DO NOTHING; +DELETE FROM tags_users_scopes + WHERE (tag_id, user_id, scope_id) + NOT IN (SELECT tag_id, user_id, scope_id FROM stage_tags_users_scopes); +INSERT INTO tag_owners (tag_id, user_id) + SELECT tag_id, user_id FROM stage_tag_owners + WHERE true + ON CONFLICT (tag_id, user_id) DO NOTHING; +DELETE FROM tag_owners + WHERE (tag_id, user_id) NOT IN (SELECT tag_id, user_id FROM stage_tag_owners); +DROP TABLE IF EXISTS stage_tags; +DROP TABLE IF EXISTS stage_users; +DROP TABLE IF EXISTS stage_scopes; +DROP TABLE IF EXISTS stage_tags_users_scopes; +DROP TABLE IF EXISTS stage_tag_owners; +PRAGMA optimize; +""" + cursor.executescript(upsert_delete_sql) + db.commit() + + +class AccessTagsCompiler: + _MAX_TAG_NESTING = 5 + + def __init__( + self, + scopes, + tag_config, + tags_db, + group_parser, + ): + self.scopes = scopes or {} + self.tag_config = tag_config + self.connection = sqlite3.connect( + tags_db["uri"], uri=True, check_same_thread=False + ) + self.group_parser = group_parser + + self.max_tag_nesting = max(self._MAX_TAG_NESTING, 0) + self.public_tag = intern("public".casefold()) + self.invalid_tag_names = [name.casefold() for name in []] + + self.roles = {} + self.tags = {} + self.tag_owners = {} + self.compiled_tags = {self.public_tag: {}} + self.compiled_public = set({self.public_tag}) + self.compiled_tag_owners = {} + + create_access_tags_tables(self.connection) + + def load_tag_config(self): + if isinstance(self.tag_config, str) or isinstance(self.tag_config, Path): + try: + with open(Path(self.tag_config)) as tag_config_file: + tag_definitions = yaml.load(tag_config_file, Loader=InterningLoader) + self.roles.update(tag_definitions.get("roles", {})) + self.tags.update(tag_definitions["tags"]) + self.tag_owners.update(tag_definitions.get("tag_owners", {})) + except FileNotFoundError as e: + raise ValueError( + f"The tag config file {self.tag_config!s} doesn't exist." + ) from e + elif isinstance(self.tag_config, dict): + tag_definitions = self.tag_config + self.roles.update(tag_definitions.get("roles", {})) + self.tags.update(tag_definitions["tags"]) + self.tag_owners.update(tag_definitions.get("tag_owners", {})) + + def _dfs(self, current_tag, tags, seen_tags, nested_level=0): + if current_tag in self.compiled_tags: + return self.compiled_tags[current_tag], current_tag in self.compiled_public + if current_tag in seen_tags: + return {}, False + if nested_level > self.max_tag_nesting: + raise RecursionError( + f"Exceeded maximum tag nesting of {self.max_tag_nesting} levels" + ) + + public_auto_tag = False + seen_tags.add(current_tag) + users = {} + for tag in tags[current_tag]: + if tag.casefold() == self.public_tag: + public_auto_tag = True + continue + try: + child_users, child_public = self._dfs( + tag, tags, seen_tags, nested_level + 1 + ) + public_auto_tag = public_auto_tag or child_public + users.update(child_users) + except (RecursionError, ValueError) as e: + raise RuntimeError( + f"Tag compilation failed at tag: {current_tag}" + ) from e + + if public_auto_tag: + self.compiled_public.add(current_tag) + + if "users" in self.tags[current_tag]: + for user in self.tags[current_tag]["users"]: + username = user["name"] + if all(k in user for k in ("scopes", "role")): + raise ValueError( + f"Cannot define both 'scopes' and 'role' for a user. {username=}" + ) + elif not any(k in user for k in ("scopes", "role")): + raise ValueError( + f"Must define either 'scopes' or 'role' for a user. {username=}" + ) + + user_scopes = set( + self.roles[user["role"]]["scopes"] + if ("role" in user) and (user["role"] in self.roles) + else user.get("scopes", []) + ) + if not user_scopes: + raise ValueError(f"Scopes must not be empty. {username=}") + if not user_scopes.issubset(self.scopes): + raise ValueError( + f"Scopes for {username=} are not in the valid set of scopes. The invalid scopes are:" + f"{user_scopes.difference(self.scopes)}" + ) + users.setdefault(username, set()) + users[username].update(user_scopes) + + if "groups" in self.tags[current_tag]: + for group in self.tags[current_tag]["groups"]: + groupname = group["name"] + if all(k in group for k in ("scopes", "role")): + raise ValueError( + f"Cannot define both 'scopes' and 'role' for a group. {groupname=}" + ) + elif not any(k in group for k in ("scopes", "role")): + raise ValueError( + f"Must define either 'scopes' or 'role' for a group. {groupname=}" + ) + + group_scopes = set( + self.roles[group["role"]]["scopes"] + if ("role" in group) and (group["role"] in self.roles) + else group.get("scopes", []) + ) + if not group_scopes: + raise ValueError(f"Scopes must not be empty. {groupname=}") + if not group_scopes.issubset(self.scopes): + raise ValueError( + f"Scopes for {groupname=} are not in the valid set of scopes. The invalid scopes are:" + f"{group_scopes.difference(self.scopes)}" + ) + + try: + usernames = self.group_parser(groupname) + except KeyError: + warnings.warn( + f"Group with {groupname=} does not exist - skipping", + UserWarning, + ) + continue + else: + for username in usernames: + username = intern(username) + users.setdefault(username, set()) + users[username].update(group_scopes) + + self.compiled_tags[current_tag] = users + return users, public_auto_tag + + def compile(self): + for role in self.roles.values(): + if "scopes" not in role: + raise ValueError(f"Scopes must be defined for a role. {role=}") + if not role["scopes"]: + raise ValueError(f"Scopes must not be empty. {role=}") + if not set(role["scopes"]).issubset(self.scopes): + raise ValueError( + f"Scopes for {role=} are not in the valid set of scopes. The invalid scopes are:" + f'{set(role["scopes"]).difference(self.scopes)}' + ) + + adjacent_tags = {} + for tag, members in self.tags.items(): + if tag.casefold() == self.public_tag: + raise ValueError( + f"'Public' tag '{self.public_tag}' cannot be redefined." + ) + if tag.casefold() in self.invalid_tag_names: + raise ValueError( + f"Tag 'tag' is an invalid tag name.\n" + f"The invalid tag names are: {self.invalid_tag_names}" + ) + adjacent_tags[tag] = set() + if "auto_tags" in members: + for auto_tag in members["auto_tags"]: + if ( + auto_tag["name"] not in self.tags + and auto_tag["name"].casefold() != self.public_tag + ): + raise KeyError( + f"Tag '{tag}' has nested tag '{auto_tag}' which does not have a definition." + ) + adjacent_tags[tag].add(auto_tag["name"]) + + for tag in adjacent_tags: + try: + self._dfs(tag, adjacent_tags, set()) + except (RecursionError, ValueError) as e: + raise RuntimeError(f"Tag compilation failed at tag: {tag}") from e + + for tag in self.tag_owners: + self.compiled_tag_owners.setdefault(tag, set()) + if "users" in self.tag_owners[tag]: + for user in self.tag_owners[tag]["users"]: + username = user["name"] + self.compiled_tag_owners[tag].add(username) + if "groups" in self.tag_owners[tag]: + for group in self.tag_owners[tag]["groups"]: + groupname = group["name"] + try: + usernames = self.group_parser(groupname) + except KeyError: + warnings.warn( + f"Group with {groupname=} does not exist - skipping", + UserWarning, + ) + continue + else: + for username in usernames: + username = intern(username) + self.compiled_tag_owners[tag].add(username) + + update_access_tags_tables( + self.connection, + self.scopes, + self.compiled_tags, + self.compiled_tag_owners, + self.compiled_public, + ) + + def clear_raw_tags(self): + self.roles = {} + self.tags = {} + self.tag_owners = {} + + def recompile(self): + self.compiled_tags = {self.public_tag: {}} + self.compiled_public = set({self.public_tag}) + self.compiled_tag_owners = {} + self.compile() diff --git a/tiled/scopes.py b/tiled/access_control/scopes.py similarity index 100% rename from tiled/scopes.py rename to tiled/access_control/scopes.py diff --git a/tiled/adapters/protocols.py b/tiled/adapters/protocols.py index 2fbd522a1..22b50e4c4 100644 --- a/tiled/adapters/protocols.py +++ b/tiled/adapters/protocols.py @@ -138,6 +138,7 @@ async def allowed_scopes( self, node: BaseAdapter, principal: Principal, + authn_access_tags: Optional[Set[str]], authn_scopes: Scopes, ) -> Scopes: pass @@ -147,6 +148,7 @@ async def filters( self, node: BaseAdapter, principal: Principal, + authn_access_tags: Optional[Set[str]], authn_scopes: Scopes, scopes: Scopes, ) -> Filters: diff --git a/tiled/authn_database/core.py b/tiled/authn_database/core.py index a97d0749f..5cbd507c9 100644 --- a/tiled/authn_database/core.py +++ b/tiled/authn_database/core.py @@ -13,6 +13,7 @@ # This is list of all valid alembic revisions (from current to oldest). ALL_REVISIONS = [ + "a806cc635ab2", "0c705a02954c", "d88e91ea03f9", "13024b8a6b74", @@ -26,40 +27,47 @@ REQUIRED_REVISION = ALL_REVISIONS[0] -async def create_default_roles(db: AsyncSession) -> None: - db.add_all( - [ - Role( - name="user", - description="Default Role for users.", - scopes=[ - "read:metadata", - "read:data", - "create", - "write:metadata", - "write:data", - "apikeys", - ], - ), - Role( - name="admin", - description="Role with elevated privileges.", - scopes=[ - "read:metadata", - "read:data", - "create", - "register", - "write:metadata", - "write:data", - "admin:apikeys", - "read:principals", - "write:principals", - "metrics", - ], - ), - ] - ) - await db.commit() +async def create_default_roles(db): + default_roles = [ + Role( + name="user", + description="Default Role for users.", + scopes=[ + "read:metadata", + "read:data", + "create", + "write:metadata", + "write:data", + "apikeys", + ], + ), + Role( + name="admin", + description="Role with elevated privileges.", + scopes=[ + "read:metadata", + "read:data", + "create", + "register", + "write:metadata", + "write:data", + "admin:apikeys", + "read:principals", + "write:principals", + "metrics", + ], + ), + ] + + roles_result = await db.execute(select(Role.name)) + existing_role_names = set(roles_result.scalars().all()) + roles_to_add = [ + role for role in default_roles if role.name not in existing_role_names + ] + + if roles_to_add: + db.add_all(roles_to_add) + await db.commit() async def initialize_database(engine: AsyncEngine) -> None: diff --git a/tiled/authn_database/migrations/versions/a806cc635ab2_add_access_tags_to_api_keys.py b/tiled/authn_database/migrations/versions/a806cc635ab2_add_access_tags_to_api_keys.py new file mode 100644 index 000000000..84e0ce4a7 --- /dev/null +++ b/tiled/authn_database/migrations/versions/a806cc635ab2_add_access_tags_to_api_keys.py @@ -0,0 +1,31 @@ +"""Add access_tags to API keys + +Revision ID: a806cc635ab2 +Revises: 0c705a02954c +Create Date: 2025-08-26 17:10:47.717942 + +""" +import sqlalchemy as sa +from alembic import op + +from tiled.authn_database.orm import APIKey, JSONList + +# revision identifiers, used by Alembic. +revision = "a806cc635ab2" +down_revision = "0c705a02954c" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + APIKey.__tablename__, + sa.Column("access_tags", JSONList(511), nullable=True), + ) + + +def downgrade(): + op.drop_column( + APIKey.__tablename__, + "access_tags", + ) diff --git a/tiled/authn_database/orm.py b/tiled/authn_database/orm.py index b3f5ebc88..bf16933bf 100644 --- a/tiled/authn_database/orm.py +++ b/tiled/authn_database/orm.py @@ -38,10 +38,13 @@ class JSONList(TypeDecorator): cache_ok = True def process_bind_param(self, value, dialect): - # Make sure we don't get passed some iterable like a dict. - if not isinstance(value, list): - raise ValueError("JSONList must be given a literal `list` type.") - if value is not None: + if value is None: + # Allow None for columns that are nullable + return None + else: + # Make sure we don't get passed some iterable like a dict. + if not isinstance(value, list): + raise ValueError("JSONList must be given a literal `list` type.") value = json.dumps(value) return value @@ -186,6 +189,7 @@ class APIKey(Timestamped, Base): note = Column(Unicode(1023), nullable=True) principal_id = Column(Integer, ForeignKey("principals.id"), nullable=False) scopes = Column(JSONList(511), nullable=False) + access_tags = Column(JSONList(511), nullable=True) # In the future we could make it possible to disable API keys # without deleting them from the database, for forensics and # record-keeping. diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index 03d7f2386..6444f191c 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -271,6 +271,7 @@ def metadata(self): async def startup(self): if (self.context.engine.dialect.name == "sqlite") and ( self.context.engine.url.database == ":memory:" + or self.context.engine.url.query.get("mode") == "memory" ): # Special-case for in-memory SQLite: Because it is transient we can # skip over anything related to migrations. @@ -1352,21 +1353,23 @@ def access_blob_filter(query, tree): attr_id = access_blob["user"] attr_tags = access_blob["tags"] access_tags_json = func.json_each(attr_tags).table_valued("value") - contains_tags = ( + condition = ( select(1) .select_from(access_tags_json) .where(access_tags_json.c.value.in_(query.tags)) .exists() ) - user_match = func.json_extract(func.json_quote(attr_id), "$") == query.user_id - condition = or_(contains_tags, user_match) + if query.user_id is not None: + user_match = ( + func.json_extract(func.json_quote(attr_id), "$") == query.user_id + ) + condition = or_(condition, user_match) elif dialect_name == "postgresql": access_blob_jsonb = type_coerce(access_blob, JSONB) - contains_tags = access_blob_jsonb["tags"].has_any( - sql_cast(query.tags, ARRAY(TEXT)) - ) - user_match = access_blob_jsonb["user"].astext == query.user_id - condition = or_(contains_tags, user_match) + condition = access_blob_jsonb["tags"].has_any(sql_cast(query.tags, ARRAY(TEXT))) + if query.user_id is not None: + user_match = access_blob_jsonb["user"].astext == query.user_id + condition = or_(condition, user_match) else: raise UnsupportedQueryType("access_blob_filter") @@ -1473,14 +1476,19 @@ def structure_family(query, tree): def in_memory( *, + named_memory=None, metadata=None, specs=None, writable_storage=None, readable_storage=None, echo=DEFAULT_ECHO, adapters_by_mimetype=None, + top_level_access_blob=None, ): - uri = "sqlite:///:memory:" + if not named_memory: + uri = "sqlite:///:memory:" + else: + uri = f"sqlite:///file:{named_memory}?mode=memory&cache=shared&uri=true" return from_uri( uri=uri, metadata=metadata, @@ -1490,6 +1498,7 @@ def in_memory( init_if_not_exists=True, echo=echo, adapters_by_mimetype=adapters_by_mimetype, + top_level_access_blob=top_level_access_blob, ) @@ -1527,8 +1536,10 @@ def from_uri( logger.info(f"Subprocess stderr: {stderr}") parsed_url = make_url(uri) - if (parsed_url.get_dialect().name == "sqlite") and ( - parsed_url.database != ":memory:" + if ( + (parsed_url.get_dialect().name == "sqlite") + and (parsed_url.database != ":memory:") + and (parsed_url.query.get("mode", None) != "memory") ): # For file-backed SQLite databases, connection pooling offers a # significant performance boost. For SQLite databases that exist diff --git a/tiled/client/context.py b/tiled/client/context.py index d142b2a8c..64813a142 100644 --- a/tiled/client/context.py +++ b/tiled/client/context.py @@ -226,10 +226,12 @@ def __init__( # starlette is available. from starlette.testclient import TestClient + base_uri = f"{uri.scheme}://{uri.netloc}" # verify parameter is dropped, as there is no SSL in ASGI mode client = TestClient( app=app, raise_server_exceptions=raise_server_exceptions, + base_url=base_uri, ) client.timeout = timeout client.headers = headers @@ -444,12 +446,13 @@ def from_app( timeout=None, api_key=UNSET, raise_server_exceptions=True, + uri=None, ): """ Construct a Context around a FastAPI app. Primarily for testing. """ context = cls( - uri="http://local-tiled-app/api/v1", + uri="http://local-tiled-app/api/v1" if not uri else uri, headers=headers, api_key=None, cache=cache, @@ -517,7 +520,7 @@ def which_api_key(self): ) ).json() - def create_api_key(self, scopes=None, expires_in=None, note=None): + def create_api_key(self, scopes=None, expires_in=None, note=None, access_tags=None): """ Generate a new API key. @@ -536,6 +539,9 @@ def create_api_key(self, scopes=None, expires_in=None, note=None): have the maximum lifetime allowed by the server. note : Optional[str] Description (for humans). + access_tags : Optional[List[str]] + Restrict the access available to the API key by listing specific tags. + By default, this will have no limits on access tags. """ if isinstance(expires_in, str): expires_in = parse_time_string(expires_in) @@ -545,7 +551,12 @@ def create_api_key(self, scopes=None, expires_in=None, note=None): self.http_client.post( self.server_info.authentication.links.apikey, headers={"Accept": MSGPACK_MIME_TYPE}, - json={"scopes": scopes, "expires_in": expires_in, "note": note}, + json={ + "scopes": scopes, + "access_tags": access_tags, + "expires_in": expires_in, + "note": note, + }, ) ).json() @@ -875,7 +886,9 @@ def show_principal(self, uuid): ) ).json() - def create_api_key(self, uuid, scopes=None, expires_in=None, note=None): + def create_api_key( + self, uuid, scopes=None, expires_in=None, note=None, access_tags=None + ): """ Generate a new API key for another user or service. @@ -892,6 +905,9 @@ def create_api_key(self, uuid, scopes=None, expires_in=None, note=None): allowed by the server. note : Optional[str] Description (for humans). + access_tags : Optional[List[str]] + Restrict the access available to the API key by listing specific tags. + By default, this will have no limits on access tags. """ for attempt in retry_context(): with attempt: @@ -899,7 +915,12 @@ def create_api_key(self, uuid, scopes=None, expires_in=None, note=None): self.context.http_client.post( f"{self.base_url}/auth/principal/{uuid}/apikey", headers={"Accept": MSGPACK_MIME_TYPE}, - json={"scopes": scopes, "expires_in": expires_in, "note": note}, + json={ + "scopes": scopes, + "access_tags": access_tags, + "expires_in": expires_in, + "note": note, + }, ) ).json() diff --git a/tiled/commandline/_api_key.py b/tiled/commandline/_api_key.py index 3b6dfc86a..dda7986ed 100644 --- a/tiled/commandline/_api_key.py +++ b/tiled/commandline/_api_key.py @@ -28,6 +28,13 @@ def create_api_key( "By default, it will inherit the scopes of its owner." ), ), + access_tags: Optional[List[str]] = typer.Option( + None, + help=( + "Restrict the access available to the API key by listing specific tags. " + "By default, it will have no limits on access tags." + ), + ), note: Optional[str] = typer.Option(None, help="Add a note to label this API key."), no_verify: bool = typer.Option(False, "--no-verify", help="Skip SSL verification."), ): @@ -38,7 +45,9 @@ def create_api_key( scopes = None if expires_in and expires_in.isdigit(): expires_in = int(expires_in) - info = context.create_api_key(scopes=scopes, expires_in=expires_in, note=note) + info = context.create_api_key( + scopes=scopes, access_tags=access_tags, expires_in=expires_in, note=note + ) # TODO Print other info to the stderr? typer.echo(info["secret"]) @@ -55,10 +64,32 @@ def list_api_keys( typer.echo("No API keys found", err=True) return max_note_len = max(len(api_key["note"] or "") for api_key in info["api_keys"]) - COLUMNS = f"First 8 Expires at (UTC) Latest activity Note{' ' * (max_note_len - 4)} Scopes" + if (starting_notes_pad := max_note_len) < len("Note"): + starting_notes_pad += 4 - max_note_len + max_scopes_len = max( + sum(len(scope) for scope in api_key["scopes"]) + len(api_key["scopes"]) - 1 + for api_key in info["api_keys"] + ) + if (starting_scopes_pad := max_scopes_len) < len("Scopes"): + starting_scopes_pad += 6 - max_scopes_len + COLUMNS = ( + f"First 8 Expires at (UTC) " + f"Latest activity Note{' ' * (max_note_len - 4)} " + f"Scopes{' ' * (max_scopes_len - 6)} Access tags" + ) typer.echo(COLUMNS) for api_key in info["api_keys"]: - note_padding = 2 + max_note_len - len(api_key["note"] or "") + note_padding = 4 + starting_notes_pad - len(api_key["note"] or "") + # the '1' subtraction works in all cases because the amount of spaces + # in a sentence is count(words) - 1, and also because we otherwise + # print a single 'dash' for an empty list + scopes_padding = ( + 4 + + starting_scopes_pad + - sum(len(scope) for scope in api_key["scopes"]) + + len(api_key["scopes"]) + - 1 + ) if api_key["expiration_time"] is None: expiration_time = "-" else: @@ -75,13 +106,19 @@ def list_api_keys( .replace(microsecond=0, tzinfo=None) .isoformat() ) + access_tags = ( + " ".join([tag.replace(" ", "\\ ") for tag in api_key["access_tags"]]) + if api_key["access_tags"] is not None + else "-" + ) typer.echo( ( f"{api_key['first_eight']:10}" f"{expiration_time:21}" f"{latest_activity:21}" f"{(api_key['note'] or '')}{' ' * note_padding}" - f"{' '.join(api_key['scopes']) or '-'}" + f"{' '.join(api_key['scopes']) or '-'}{' ' * scopes_padding}" + f"{access_tags}" ) ) diff --git a/tiled/config_schemas/service_configuration.yml b/tiled/config_schemas/service_configuration.yml index cefb7119f..c80b8b653 100644 --- a/tiled/config_schemas/service_configuration.yml +++ b/tiled/config_schemas/service_configuration.yml @@ -308,13 +308,12 @@ properties: Example: ```yaml - access_control: - access_policy: "tiled.access_policies:SimpleAccessPolicy" - args: - access_lists: - alice: ["A", "B"] - bob: ["C"] - cara: "tiled.access_policies:ALL_ACCESS" + access_policy: "tiled.access_control.access_policies:TagBasedAccessPolicy" + args: + provider: "pam" + tags_db: + uri: "file:compiled_tags.sqlite" + access_tags_parser: "tiled.access_control.access_tags:AccessTagsParser" ``` args: type: object diff --git a/tiled/examples/toy_authentication.py b/tiled/examples/toy_authentication.py deleted file mode 100644 index 5d446c612..000000000 --- a/tiled/examples/toy_authentication.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -This contains a simple tree for demonstrating access control. -See the configuration: - -example_configs/toy_authentication.yml -""" -import numpy - -from tiled.adapters.array import ArrayAdapter -from tiled.adapters.mapping import MapAdapter - -# Make a MapAdapter with a couple arrays in it. -tree = MapAdapter( - { - "A": ArrayAdapter.from_array(10 * numpy.ones((10, 10))), - "B": ArrayAdapter.from_array(20 * numpy.ones((10, 10))), - "C": ArrayAdapter.from_array(30 * numpy.ones((10, 10))), - "D": ArrayAdapter.from_array(30 * numpy.ones((10, 10))), - }, -) diff --git a/tiled/queries.py b/tiled/queries.py index 25e850e6c..ee005dcea 100644 --- a/tiled/queries.py +++ b/tiled/queries.py @@ -8,7 +8,7 @@ import enum import json from dataclasses import dataclass -from typing import Any, List +from typing import Any, List, Optional from .query_registration import register from .structures.core import StructureFamily as StructureFamilyEnum @@ -89,8 +89,6 @@ class KeysFilter(NoBool): """ Filter entries that do not match one of these keys. - This is used by the SimpleAccessPolicy. - Parameters ---------- keys : List[str] @@ -539,7 +537,7 @@ class AccessBlobFilter: >>> c.search(AccessBlobFilter("bill", ["tag_for_bill", "useful_data"])) """ - user_id: str + user_id: Optional[str] tags: List[str] def encode(self): diff --git a/tiled/serialization/container.py b/tiled/serialization/container.py index 549dd04a0..1ad5e3a10 100644 --- a/tiled/serialization/container.py +++ b/tiled/serialization/container.py @@ -25,17 +25,16 @@ async def walk(node, filter_for_access, pre=None): """ pre = pre[:] if pre else [] if node.structure_family != StructureFamily.array: - if hasattr(node, "items_range"): - for key, value in await (await filter_for_access(node)).items_range( - 0, None - ): + filtered = await filter_for_access(node) + if hasattr(filtered, "items_range"): + for key, value in await filtered.items_range(0, None): async for d in walk(value, filter_for_access, pre + [key]): yield d elif node.structure_family == StructureFamily.table: for key in node.structure().columns: - yield (pre + [key], await filter_for_access(node)) + yield (pre + [key], filtered) else: - for key, value in (await filter_for_access(node)).items(): + for key, value in filtered.items(): async for d in walk(value, filter_for_access, pre + [key]): yield d else: diff --git a/tiled/server/app.py b/tiled/server/app.py index cef9b807c..0a4099688 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -557,7 +557,7 @@ async def startup_event(): # registry, keyed on database_settings, where can be retrieved by # the Dependency get_database_session. engine = open_database_connection_pool(settings.database_settings) - if not engine.url.database: + if not engine.url.database or engine.url.query.get("mode") == "memory": # Special-case for in-memory SQLite: Because it is transient we can # skip over anything related to migrations. await initialize_database(engine) @@ -633,6 +633,11 @@ async def startup_event(): id=admin["id"], ) + if app.state.access_policy is not None and hasattr( + app.state.access_policy, "access_tags_parser" + ): + await app.state.access_policy.access_tags_parser.connect() + async def purge_expired_sessions_and_api_keys(): PURGE_INTERVAL = 600 # seconds while True: diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index 7a0816ff0..696de2a95 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -4,7 +4,7 @@ import warnings from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, Set from fastapi import ( APIRouter, @@ -37,7 +37,7 @@ HTTP_409_CONFLICT, ) -from tiled.scopes import NO_SCOPES, PUBLIC_SCOPES, USER_SCOPES +from tiled.access_control.scopes import NO_SCOPES, PUBLIC_SCOPES, USER_SCOPES # To hide third-party warning # .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning: @@ -226,6 +226,44 @@ async def get_session_state(decoded_access_token=Depends(get_decoded_access_toke return decoded_access_token.get("state") +async def get_access_tags_from_api_key( + api_key: str, authenticated: bool, db: Optional[AsyncSession] +) -> Optional[Set[str]]: + if not authenticated: + # Tiled is in a "single user" mode with only one API key. + # In this mode, there is no meaningful access tag limit. + return None + # Tiled is in a multi-user configuration with authentication providers. + # We store the hashed value of the API key secret. + try: + secret = bytes.fromhex(api_key) + except Exception: + # access tag limit cannot be enforced without key information + return None + api_key_orm = await lookup_valid_api_key(db, secret) + if api_key_orm is None: + # access tag limit cannot be enforced without key information + return None + else: + if (access_tags := api_key_orm.access_tags) is not None: + access_tags = set(access_tags) + return access_tags + + +async def get_current_access_tags( + request: Request, + api_key: Optional[str] = Depends(get_api_key), + db: Optional[AsyncSession] = Depends(get_database_session), +) -> Optional[Set[str]]: + if api_key is not None: + return await get_access_tags_from_api_key( + api_key, request.app.state.authenticated, db + ) + else: + # Limits on access tags only available via API key auth + return None + + async def move_api_key(request: Request, api_key: Optional[str] = Depends(get_api_key)): if ("api_key" in request.query_params) and ( request.cookies.get(API_KEY_COOKIE_NAME) != api_key @@ -763,12 +801,22 @@ async def generate_apikey(db: AsyncSession, principal, apikey_params, request): principal_scopes = set().union(*[role.scopes for role in principal.roles]) if not set(scopes).issubset(principal_scopes | {"inherit"}): raise HTTPException( - 400, + 403, ( f"Requested scopes {apikey_params.scopes} must be a subset of the " f"principal's scopes {list(principal_scopes)}." ), ) + admin_scopes = ["admin:apikeys"] + if (access_tags := apikey_params.access_tags) is not None: + if all(scope in scopes for scope in admin_scopes): + raise HTTPException( + 403, + ( + f"Requested scopes {scopes} contain scopes {admin_scopes}, " + f"which cannot be combined with access tag restrictions." + ), + ) if apikey_params.expires_in is not None: expiration_time = utcnow() + timedelta(seconds=apikey_params.expires_in) else: @@ -797,6 +845,7 @@ async def generate_apikey(db: AsyncSession, principal, apikey_params, request): expiration_time=expiration_time, note=apikey_params.note, scopes=scopes, + access_tags=access_tags, first_eight=secret.hex()[:8], hashed_secret=hashed_secret, ) @@ -1106,7 +1155,8 @@ async def new_apikey( db: Optional[AsyncSession] = Depends(get_database_session), ): """ - Generate an API for the currently-authenticated user or service.""" + Generate an API for the currently-authenticated user or service. + """ # TODO Permit filtering the fields of the response. request.state.endpoint = "auth" if principal is None: diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index c3b422944..b0772effa 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Set import pydantic_settings from fastapi import HTTPException, Query, Request @@ -22,6 +22,7 @@ async def get_entry( path: str, security_scopes: List[str], principal: Optional[Principal], + authn_access_tags: Optional[Set[str]], authn_scopes: Scopes, root_tree: pydantic_settings.BaseSettings, session_state: dict, @@ -50,6 +51,7 @@ async def get_entry( entry, access_policy, principal, + authn_access_tags, authn_scopes, ["read:metadata"], metrics, @@ -75,6 +77,7 @@ async def get_entry( entry, access_policy, principal, + authn_access_tags, authn_scopes, ["read:metadata"], metrics, @@ -86,6 +89,7 @@ async def get_entry( allowed_scopes = await access_policy.allowed_scopes( entry, principal, + authn_access_tags, authn_scopes, ) if not set(security_scopes).issubset(allowed_scopes): diff --git a/tiled/server/router.py b/tiled/server/router.py index 111b7794d..f76c85156 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta, timezone from functools import partial from pathlib import Path -from typing import Callable, List, Optional, TypeVar, Union +from typing import Callable, List, Optional, Set, TypeVar, Union import anyio import packaging @@ -44,6 +44,7 @@ from . import schemas from .authentication import ( check_scopes, + get_current_access_tags, get_current_principal, get_current_scopes, get_session_state, @@ -276,6 +277,7 @@ async def search( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), settings: Settings = Depends(get_settings), _=Security(check_scopes, scopes=["read:metadata"]), @@ -285,6 +287,7 @@ async def search( path, ["read:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -360,6 +363,7 @@ async def distinct( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:metadata"]), **filters, @@ -368,6 +372,7 @@ async def distinct( path, ["read:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -409,6 +414,7 @@ async def metadata( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), settings: Settings = Depends(get_settings), _=Security(check_scopes, scopes=["read:metadata"]), @@ -418,6 +424,7 @@ async def metadata( path, ["read:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -471,6 +478,7 @@ async def array_block( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -481,6 +489,7 @@ async def array_block( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -560,6 +569,7 @@ async def array_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -570,6 +580,7 @@ async def array_full( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -637,6 +648,7 @@ async def get_table_partition( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -647,6 +659,7 @@ async def get_table_partition( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -699,6 +712,7 @@ async def post_table_partition( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -709,6 +723,7 @@ async def post_table_partition( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -792,6 +807,7 @@ async def get_table_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -802,6 +818,7 @@ async def get_table_full( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -833,6 +850,7 @@ async def post_table_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -843,6 +861,7 @@ async def post_table_full( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -913,6 +932,7 @@ async def get_container_full( request: Request, path: str, principal: Optional[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -928,6 +948,7 @@ async def get_container_full( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -939,6 +960,7 @@ async def get_container_full( request=request, entry=entry, principal=principal, + authn_access_tags=authn_access_tags, authn_scopes=authn_scopes, field=field, format=format, @@ -954,6 +976,7 @@ async def post_container_full( request: Request, path: str, principal: Optional[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), field: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, @@ -969,6 +992,7 @@ async def post_container_full( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -980,6 +1004,7 @@ async def post_container_full( request=request, entry=entry, principal=principal, + authn_access_tags=authn_access_tags, authn_scopes=authn_scopes, field=field, format=format, @@ -990,6 +1015,7 @@ async def container_full( request: Request, entry, principal: Optional[Principal], + authn_access_tags: Optional[Set[str]], authn_scopes: Scopes, field: Optional[List[str]], format: Optional[str], @@ -1010,6 +1036,7 @@ async def container_full( filter_for_access, access_policy=request.app.state.access_policy, principal=principal, + authn_access_tags=authn_access_tags, authn_scopes=authn_scopes, scopes=["read:data"], metrics=request.state.metrics, @@ -1042,6 +1069,7 @@ async def node_full( request: Request, path: str, principal: Optional[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -1058,6 +1086,7 @@ async def node_full( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1089,6 +1118,7 @@ async def node_full( filter_for_access, access_policy=request.app.state.access_policy, principal=principal, + authn_access_tags=authn_access_tags, authn_scopes=authn_scopes, scopes=["read:data"], metrics=request.state.metrics, @@ -1128,6 +1158,7 @@ async def get_awkward_buffers( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -1145,6 +1176,7 @@ async def get_awkward_buffers( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1176,6 +1208,7 @@ async def post_awkward_buffers( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -1193,6 +1226,7 @@ async def post_awkward_buffers( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1265,6 +1299,7 @@ async def awkward_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -1275,6 +1310,7 @@ async def awkward_full( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1323,6 +1359,7 @@ async def post_metadata( body: schemas.PostMetadataRequest, settings: Settings = Depends(get_settings), principal: Optional[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -1332,6 +1369,7 @@ async def post_metadata( path, ["write:metadata", "create"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1357,6 +1395,7 @@ async def post_metadata( settings=settings, entry=entry, principal=principal, + authn_access_tags=authn_access_tags, authn_scopes=authn_scopes, ) @@ -1367,6 +1406,7 @@ async def post_register( body: schemas.PostMetadataRequest, settings: Settings = Depends(get_settings), principal: Optional[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -1376,6 +1416,7 @@ async def post_register( path, ["write:metadata", "create", "register"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1390,6 +1431,7 @@ async def post_register( settings=settings, entry=entry, principal=principal, + authn_access_tags=authn_access_tags, authn_scopes=authn_scopes, ) @@ -1400,6 +1442,7 @@ async def _create_node( settings: Settings, entry, principal: Optional[Principal], + authn_access_tags: Optional[Set[str]], authn_scopes: Scopes, ): metadata, structure_family, specs, access_blob = ( @@ -1433,7 +1476,7 @@ async def _create_node( access_blob_modified, access_blob, ) = await request.app.state.access_policy.init_node( - principal, authn_scopes, access_blob=access_blob + principal, authn_access_tags, authn_scopes, access_blob=access_blob ) except ValueError as e: raise HTTPException( @@ -1476,6 +1519,7 @@ async def put_data_source( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:metadata", "register"]), ): @@ -1483,6 +1527,7 @@ async def put_data_source( path, ["write:metadata", "register"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1509,6 +1554,7 @@ async def delete( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data", "write:metadata"]), ): @@ -1516,6 +1562,7 @@ async def delete( path, ["write:data", "write:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1539,6 +1586,7 @@ async def put_array_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1546,6 +1594,7 @@ async def put_array_full( path, ["write:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1581,6 +1630,7 @@ async def put_array_block( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1588,6 +1638,7 @@ async def put_array_block( path, ["write:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1629,6 +1680,7 @@ async def patch_array_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1636,6 +1688,7 @@ async def patch_array_full( path, ["write:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1665,6 +1718,7 @@ async def put_node_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1672,6 +1726,7 @@ async def put_node_full( path, ["write:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1701,6 +1756,7 @@ async def put_table_partition( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1708,6 +1764,7 @@ async def put_table_partition( path, ["write:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1737,6 +1794,7 @@ async def patch_table_partition( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1744,6 +1802,7 @@ async def patch_table_partition( path, ["write:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1772,6 +1831,7 @@ async def put_awkward_full( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:data"]), ): @@ -1779,6 +1839,7 @@ async def put_awkward_full( path, ["write:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1810,6 +1871,7 @@ async def patch_metadata( body: schemas.PatchMetadataRequest, settings: Settings = Depends(get_settings), principal: Optional[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), drop_revision: bool = False, root_tree=Depends(get_root_tree), @@ -1820,6 +1882,7 @@ async def patch_metadata( path, ["write:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1886,7 +1949,7 @@ async def patch_metadata( access_blob_modified, access_blob, ) = await request.app.state.access_policy.modify_node( - entry, principal, authn_scopes, access_blob + entry, principal, authn_access_tags, authn_scopes, access_blob ) except ValueError as e: raise HTTPException( @@ -1919,6 +1982,7 @@ async def put_metadata( body: schemas.PutMetadataRequest, settings: Settings = Depends(get_settings), principal: Optional[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), drop_revision: bool = False, root_tree=Depends(get_root_tree), @@ -1929,6 +1993,7 @@ async def put_metadata( path, ["write:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -1963,7 +2028,7 @@ async def put_metadata( access_blob_modified, access_blob, ) = await request.app.state.access_policy.modify_node( - entry, principal, authn_scopes, access_blob + entry, principal, authn_access_tags, authn_scopes, access_blob ) except ValueError as e: raise HTTPException( @@ -2000,6 +2065,7 @@ async def get_revisions( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:metadata"]), ): @@ -2007,6 +2073,7 @@ async def get_revisions( path, ["read:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -2040,6 +2107,7 @@ async def delete_revision( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["write:metadata"]), ): @@ -2047,6 +2115,7 @@ async def delete_revision( path, ["write:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -2078,6 +2147,7 @@ async def get_asset( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -2085,6 +2155,7 @@ async def get_asset( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -2190,6 +2261,7 @@ async def get_asset_manifest( principal: Optional[Principal] = Depends(get_current_principal), root_tree=Depends(get_root_tree), session_state: dict = Depends(get_session_state), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), _=Security(check_scopes, scopes=["read:data"]), ): @@ -2197,6 +2269,7 @@ async def get_asset_manifest( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, diff --git a/tiled/server/schemas.py b/tiled/server/schemas.py index 136a93c87..e2c6e196c 100644 --- a/tiled/server/schemas.py +++ b/tiled/server/schemas.py @@ -308,6 +308,7 @@ class APIKey(pydantic.BaseModel): expiration_time: Optional[datetime] = None note: Optional[Annotated[str, StringConstraints(max_length=255)]] = None scopes: List[str] + access_tags: Optional[List[str]] = None latest_activity: Optional[datetime] = None @classmethod @@ -317,6 +318,7 @@ def from_orm(cls, orm: tiled.authn_database.orm.APIKey) -> APIKey: expiration_time=orm.expiration_time, note=orm.note, scopes=orm.scopes, + access_tags=orm.access_tags, latest_activity=orm.latest_activity, ) @@ -333,6 +335,7 @@ def from_orm( expiration_time=orm.expiration_time, note=orm.note, scopes=orm.scopes, + access_tags=orm.access_tags, latest_activity=orm.latest_activity, secret=secret, ) @@ -402,6 +405,9 @@ class APIKeyRequestParams(pydantic.BaseModel): expires_in: Optional[int] = pydantic.Field( ..., json_schema_extra={"example": 600} ) # seconds + access_tags: Optional[List[str]] = pydantic.Field( + default=None, json_schema_extra={"example": ["writing_tag", "public"]} + ) scopes: Optional[List[str]] = pydantic.Field( ..., json_schema_extra={"example": ["inherit"]} ) diff --git a/tiled/server/utils.py b/tiled/server/utils.py index 730088e87..536616e24 100644 --- a/tiled/server/utils.py +++ b/tiled/server/utils.py @@ -6,7 +6,7 @@ from fastapi import Request from starlette.types import Scope -from ..access_policies import NO_ACCESS +from ..access_control.access_policies import NO_ACCESS from ..adapters.mapping import MapAdapter EMPTY_NODE = MapAdapter({}) @@ -75,12 +75,23 @@ def get_root_url_low_level(request_headers: Mapping[str, str], scope: Scope) -> async def filter_for_access( - entry, access_policy, principal, authn_scopes, scopes, metrics + entry, access_policy, principal, authn_access_tags, authn_scopes, scopes, metrics ): if access_policy is not None and hasattr(entry, "search"): with record_timing(metrics, "acl"): + if hasattr(entry, "lookup_adapter") and entry.node.parent is None: + # This conditional only catches for the MapAdapter->CatalogAdapter + # transition, to cover MapAdapter's lack of access control. + # It can be removed once MapAdapter goes away. + if not set(scopes).issubset( + await access_policy.allowed_scopes( + entry, principal, authn_access_tags, authn_scopes + ) + ): + return (entry := EMPTY_NODE) + queries = await access_policy.filters( - entry, principal, authn_scopes, set(scopes) + entry, principal, authn_access_tags, authn_scopes, set(scopes) ) if queries is NO_ACCESS: entry = EMPTY_NODE diff --git a/tiled/server/zarr.py b/tiled/server/zarr.py index cb2a439fa..914af7415 100644 --- a/tiled/server/zarr.py +++ b/tiled/server/zarr.py @@ -1,6 +1,6 @@ import json import re -from typing import Tuple, Union +from typing import Optional, Set, Tuple, Union import numcodecs import orjson @@ -12,7 +12,12 @@ from ..structures.core import StructureFamily from ..type_aliases import Scopes from ..utils import ensure_awaitable -from .authentication import get_current_principal, get_current_scopes, get_session_state +from .authentication import ( + get_current_access_tags, + get_current_principal, + get_current_scopes, + get_session_state, +) from .dependencies import get_entry, get_root_tree from .schemas import Principal from .utils import record_timing @@ -51,6 +56,7 @@ async def get_zarr_attrs( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -60,6 +66,7 @@ async def get_zarr_attrs( path, ["read:data", "read:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -85,6 +92,7 @@ async def get_zarr_group_metadata( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -93,6 +101,7 @@ async def get_zarr_group_metadata( path, ["read:data", "read:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -111,6 +120,7 @@ async def get_zarr_array_metadata( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -119,6 +129,7 @@ async def get_zarr_array_metadata( path, ["read:data", "read:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -153,6 +164,7 @@ async def get_zarr_array( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -169,6 +181,7 @@ async def get_zarr_array( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -270,6 +283,7 @@ async def get_zarr_metadata( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -280,6 +294,7 @@ async def get_zarr_metadata( path, ["read:data", "read:metadata"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -360,6 +375,7 @@ async def get_zarr_array( path: str, block: str, principal: Union[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -368,6 +384,7 @@ async def get_zarr_array( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -439,6 +456,7 @@ async def get_zarr_group( request: Request, path: str, principal: Union[Principal] = Depends(get_current_principal), + authn_access_tags: Optional[Set[str]] = Depends(get_current_access_tags), authn_scopes: Scopes = Depends(get_current_scopes), root_tree: pydantic_settings.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -447,6 +465,7 @@ async def get_zarr_group( path, ["read:data"], principal, + authn_access_tags, authn_scopes, root_tree, session_state, @@ -483,6 +502,7 @@ async def get_zarr_group( request, path, principal=principal, + authn_access_tags=authn_access_tags, authn_scopes=authn_scopes, root_tree=root_tree, session_state=session_state, diff --git a/tiled/utils.py b/tiled/utils.py index a38eaeb93..297efb442 100644 --- a/tiled/utils.py +++ b/tiled/utils.py @@ -29,6 +29,7 @@ from urllib.parse import urlparse, urlunparse import anyio +import yaml # helper for avoiding re-typing patch mimetypes # namedtuple for the lack of StrEnum in py<3.11 @@ -593,11 +594,13 @@ class UnsupportedQueryType(TypeError): class Conflicts(Exception): "Prompts the server to send 409 Conflicts with message" + pass class BrokenLink(Exception): "Prompts the server to send 410 Gone with message" + pass @@ -883,3 +886,15 @@ def parse_mimetype(mimetype: str) -> tuple[str, dict]: ) params[key] = value return base, params + + +class InterningLoader(yaml.loader.BaseLoader): + pass + + +def interning_constructor(loader, node): + value = loader.construct_scalar(node) + return sys.intern(value) + + +InterningLoader.add_constructor("tag:yaml.org,2002:str", interning_constructor)