Skip to content

Commit 920ceab

Browse files
committed
Refactoring and type hints
1 parent cd69a67 commit 920ceab

File tree

8 files changed

+90
-81
lines changed

8 files changed

+90
-81
lines changed

example_configs/external_service/custom.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,36 @@
11
import numpy
2+
from pydantic import Secret
23

34
from tiled.adapters.array import ArrayAdapter
4-
from tiled.authenticators import Mode, UserSessionState
55
from tiled.structures.core import StructureFamily
66

77

8-
class Authenticator:
9-
"This accepts any password and stashes it in session state as 'token'."
10-
mode = Mode.password
11-
12-
async def authenticate(self, username: str, password: str) -> UserSessionState:
13-
return UserSessionState(username, {"token": password})
14-
15-
16-
# This stands in for a secret token issued by the external service.
17-
SERVICE_ISSUED_TOKEN = "secret"
18-
19-
208
class MockClient:
21-
def __init__(self, base_url):
22-
self.base_url = base_url
9+
10+
def __init__(self, base_url: str, example_token: str = "secret"):
11+
self._base_url = base_url
12+
# This stands in for a secret token issued by the external service.
13+
self._example_token = Secret(example_token)
2314

2415
# This API (get_contents, get_metadata, get_data) is just made up and not important.
2516
# Could be anything.
2617

2718
async def get_metadata(self, url, token):
2819
# This assert stands in for the mocked service
2920
# authenticating a request.
30-
assert token == SERVICE_ISSUED_TOKEN
21+
assert token == self._example_token.get_secret_value()
3122
return {"metadata": str(url)}
3223

3324
async def get_contents(self, url, token):
3425
# This assert stands in for the mocked service
3526
# authenticating a request.
36-
assert token == SERVICE_ISSUED_TOKEN
27+
assert token == self._example_token.get_secret_value()
3728
return ["a", "b", "c"]
3829

3930
async def get_data(self, url, token):
4031
# This assert stands in for the mocked service
4132
# authenticating a request.
42-
assert token == SERVICE_ISSUED_TOKEN
33+
assert token == self._example_token.get_secret_value()
4334
return numpy.ones((3, 3))
4435

4536

example_configs/mock-oidc-server.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ authentication:
99
client_secret: secret
1010
well_known_uri: http://localhost:8080/.well-known/openid-configuration
1111
trees:
12-
# Just some arbitrary example data...
13-
# The point of this example is the authenticaiton above.
14-
- tree: tiled.examples.generated_minimal:tree
15-
path: /
12+
- path: /
13+
tree: catalog
14+
args:
15+
uri: "sqlite+aiosqlite:///:memory:"
16+
writable_storage: "/tmp/data"
17+
init_if_not_exists: true

tiled/authenticators.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,43 +9,37 @@
99

1010
import httpx
1111
from fastapi import APIRouter, Request
12-
from jose import JWTError, jwk, jwt
12+
from jose import JWTError, jwt
1313
from pydantic import Secret
1414
from starlette.responses import RedirectResponse
1515

16-
from .server.authentication import Mode
17-
from .server.protocols import UserSessionState
16+
from .server.protocols import ExternalAuthenticator, UserSessionState, PasswordAuthenticator
1817
from .server.utils import get_root_url
1918
from .utils import modules_available
2019

2120
logger = logging.getLogger(__name__)
2221

2322

24-
class DummyAuthenticator:
23+
class DummyAuthenticator(PasswordAuthenticator):
2524
"""
2625
For test and demo purposes only!
2726
2827
Accept any username and any password.
2928
3029
"""
31-
32-
mode = Mode.password
33-
3430
def __init__(self, confirmation_message=""):
3531
self.confirmation_message = confirmation_message
3632

3733
async def authenticate(self, username: str, password: str) -> UserSessionState:
3834
return UserSessionState(username, {})
3935

4036

41-
class DictionaryAuthenticator:
37+
class DictionaryAuthenticator(PasswordAuthenticator):
4238
"""
4339
For test and demo purposes only!
4440
4541
Check passwords from a dictionary of usernames mapped to passwords.
4642
"""
47-
48-
mode = Mode.password
4943
configuration_schema = """
5044
$schema": http://json-schema.org/draft-07/schema#
5145
type: object
@@ -74,8 +68,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
7468
return UserSessionState(username, {})
7569

7670

77-
class PAMAuthenticator:
78-
mode = Mode.password
71+
class PAMAuthenticator(PasswordAuthenticator):
7972
configuration_schema = """
8073
$schema": http://json-schema.org/draft-07/schema#
8174
type: object
@@ -110,8 +103,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
110103
return UserSessionState(username, {})
111104

112105

113-
class OIDCAuthenticator:
114-
mode = Mode.external
106+
class OIDCAuthenticator(ExternalAuthenticator):
115107
configuration_schema = """
116108
$schema": http://json-schema.org/draft-07/schema#
117109
type: object
@@ -164,7 +156,7 @@ def jwks_uri(self) -> str:
164156
def token_endpoint(self) -> str:
165157
return cast(str, self._config_from_oidc_url.get("token_endpoint"))
166158

167-
async def authenticate(self, request: Request) -> UserSessionState:
159+
async def authenticate(self, request: Request) -> UserSessionState | None:
168160
code = request.query_params["code"]
169161
# A proxy in the middle may make the request into something like
170162
# 'http://localhost:8000/...' so we fix the first part but keep
@@ -228,8 +220,7 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect
228220
return response
229221

230222

231-
class SAMLAuthenticator:
232-
mode = Mode.external
223+
class SAMLAuthenticator(ExternalAuthenticator):
233224

234225
def __init__(
235226
self,
@@ -271,7 +262,7 @@ async def saml_login(request: Request):
271262

272263
self.include_routers = [router]
273264

274-
async def authenticate(self, request) -> UserSessionState:
265+
async def authenticate(self, request) -> UserSessionState | None:
275266
if not modules_available("onelogin"):
276267
raise ModuleNotFoundError(
277268
"This SAMLAuthenticator requires the module 'oneline' to be installed."
@@ -323,7 +314,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False):
323314
return rv
324315

325316

326-
class LDAPAuthenticator:
317+
class LDAPAuthenticator(PasswordAuthenticator):
327318
"""
328319
The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
329320
The parameter ``use_tls`` was added for convenience of testing.
@@ -506,8 +497,6 @@ class LDAPAuthenticator:
506497
id: user02
507498
"""
508499

509-
mode = Mode.password
510-
511500
def __init__(
512501
self,
513502
server_address,

tiled/server/app.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import asynccontextmanager
1111
from functools import lru_cache, partial
1212
from pathlib import Path
13-
from typing import List
13+
from typing import Any, Dict, List
1414

1515
import anyio
1616
import packaging.version
@@ -34,7 +34,8 @@
3434
HTTP_500_INTERNAL_SERVER_ERROR,
3535
)
3636

37-
from ..authenticators import Mode
37+
from tiled.server.protocols import ExternalAuthenticator, PasswordAuthenticator
38+
3839
from ..config import construct_build_app_kwargs
3940
from ..media_type_registration import (
4041
compression_registry as default_compression_registry,
@@ -81,7 +82,7 @@
8182
current_principal = contextvars.ContextVar("current_principal")
8283

8384

84-
def custom_openapi(app: FastAPI):
85+
def custom_openapi(app: FastAPI) -> Dict[str, Any]:
8586
"""
8687
The app's openapi method will be monkey-patched with this.
8788
@@ -118,7 +119,7 @@ def build_app(
118119
validation_registry=None,
119120
tasks=None,
120121
scalable=False,
121-
):
122+
) -> FastAPI:
122123
"""
123124
Serve a Tree
124125
@@ -385,12 +386,11 @@ async def unhandled_exception_handler(
385386
for spec in authentication["providers"]:
386387
provider = spec["provider"]
387388
authenticator = spec["authenticator"]
388-
mode = authenticator.mode
389-
if mode == Mode.password:
389+
if isinstance(authenticator, PasswordAuthenticator):
390390
authentication_router.post(f"/provider/{provider}/token")(
391391
build_handle_credentials_route(authenticator, provider)
392392
)
393-
elif mode == Mode.external:
393+
elif isinstance(authenticator, ExternalAuthenticator):
394394
# Client starts here to create a PendingSession.
395395
authentication_router.post(f"/provider/{provider}/authorize")(
396396
build_device_code_authorize_route(authenticator, provider)
@@ -415,7 +415,7 @@ async def unhandled_exception_handler(
415415
# build_auth_code_route(authenticator, provider)
416416
# )
417417
else:
418-
raise ValueError(f"unknown authentication mode {mode}")
418+
raise ValueError(f"Unexpected authenticator type {type(authenticator)}")
419419
for custom_router in getattr(authenticator, "include_routers", []):
420420
authentication_router.include_router(
421421
custom_router, prefix=f"/provider/{provider}"

0 commit comments

Comments
 (0)