Skip to content

Commit b1c3c4b

Browse files
authored
Merge pull request #221 from Point72/tkp/scope
Allow scoping authentication middlewares, allow multiple middlewares
2 parents 6917075 + 740deb7 commit b1c3c4b

File tree

7 files changed

+373
-47
lines changed

7 files changed

+373
-47
lines changed

.vscode/settings.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
{
2-
"eslint.workingDirectories": ["./js"]
2+
"eslint.workingDirectories": [
3+
"./js"
4+
],
5+
"python-envs.defaultEnvManager": "ms-python.python:system",
6+
"python-envs.pythonProjects": []
37
}

csp_gateway/server/gateway/csp/module.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import abc
2-
import typing
1+
from abc import ABC, abstractmethod
32
from datetime import datetime
4-
from typing import Any, Dict, Generic, List, Optional, Set, Type, Union
3+
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Type, Union
54

65
from ccflow import BaseModel
76
from pydantic import Field, TypeAdapter, model_validator
@@ -11,11 +10,11 @@
1110

1211
from .channels import ChannelsType
1312

14-
if typing.TYPE_CHECKING:
13+
if TYPE_CHECKING:
1514
from csp_gateway.server import GatewaySettings, GatewayWebApp
1615

1716

18-
class Module(BaseModel, Generic[ChannelsType], abc.ABC):
17+
class Module(BaseModel, Generic[ChannelsType], ABC):
1918
model_config = {"arbitrary_types_allowed": True}
2019

2120
requires: Optional[ChannelSelection] = None
@@ -28,14 +27,14 @@ class Module(BaseModel, Generic[ChannelsType], abc.ABC):
2827
""",
2928
)
3029

31-
@abc.abstractmethod
30+
@abstractmethod
3231
def connect(self, Channels: ChannelsType) -> None: ...
3332

3433
def rest(self, app: "GatewayWebApp") -> None: ...
3534

3635
def info(self, settings: "GatewaySettings") -> Optional[str]: ...
3736

38-
@abc.abstractmethod
37+
@abstractmethod
3938
def shutdown(self) -> None: ...
4039

4140
def dynamic_keys(self) -> Optional[Dict[str, List[Any]]]: ...

csp_gateway/server/middleware/api_key.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,18 @@ def info(self, settings: GatewaySettings) -> str:
3737
return f"\tUI: {url}?token={self.api_key}"
3838
return f"\tAPI: {url}/openapi.json?token={self.api_key}"
3939

40-
def rest(self, app: GatewayWebApp) -> None:
41-
# reinitialize header
42-
api_key_query = APIKeyQuery(name=self.api_key_name, auto_error=False)
43-
api_key_header = APIKeyHeader(name=self.api_key_name, auto_error=False)
44-
api_key_cookie = APIKeyCookie(name=self.api_key_name, auto_error=False)
45-
46-
# routers
47-
auth_router: APIRouter = app.get_router("auth")
48-
49-
# now mount middleware
50-
async def get_api_key(
51-
api_key_query: str = Security(api_key_query),
52-
api_key_header: str = Security(api_key_header),
53-
api_key_cookie: str = Security(api_key_cookie),
54-
):
55-
# Support both single string and list of valid API keys
40+
def validate(self):
41+
"""Return a FastAPI dependency function for API key validation."""
42+
api_key_query_security = Security(APIKeyQuery(name=self.api_key_name, auto_error=False))
43+
api_key_header_security = Security(APIKeyHeader(name=self.api_key_name, auto_error=False))
44+
api_key_cookie_security = Security(APIKeyCookie(name=self.api_key_name, auto_error=False))
45+
46+
async def validate_credentials(
47+
api_key_query: str = api_key_query_security,
48+
api_key_header: str = api_key_header_security,
49+
api_key_cookie: str = api_key_cookie_security,
50+
) -> str:
51+
"""Validate API key from query, header, or cookie."""
5652
valid_keys = self.api_key if isinstance(self.api_key, list) else [self.api_key]
5753
for provided_key in (api_key_query, api_key_header, api_key_cookie):
5854
if provided_key in valid_keys:
@@ -62,8 +58,15 @@ async def get_api_key(
6258
detail=self.unauthorized_status_message,
6359
)
6460

61+
return validate_credentials
62+
63+
def rest(self, app: GatewayWebApp) -> None:
64+
# routers
65+
auth_router: APIRouter = app.get_router("auth")
66+
check = self.get_check_dependency()
67+
6568
@auth_router.get("/login")
66-
async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
69+
async def route_login_and_add_cookie(api_key: str = Depends(check)):
6770
response = RedirectResponse(url="/")
6871
response.set_cookie(
6972
self.api_key_name,
@@ -81,10 +84,10 @@ async def route_logout_and_remove_cookie():
8184
response.delete_cookie(self.api_key_name, domain=self.domain)
8285
return response
8386

84-
self._setup_public_routes(app, get_api_key)
87+
self._setup_public_routes(app)
8588

86-
def _setup_public_routes(self, app: GatewayWebApp, get_api_key) -> None:
87-
"""Setup public routes, middleware, and exception handler. Shared by subclasses."""
89+
def _setup_public_routes(self, app: GatewayWebApp) -> None:
90+
"""Setup public routes, middleware, and exception handler. Shared by subclasses.""" ""
8891
public_router: APIRouter = app.get_router("public")
8992

9093
@public_router.get("/login", response_class=HTMLResponse, include_in_schema=False)
@@ -102,7 +105,7 @@ async def get_logout_page(request: Request = None):
102105
return app.templates.TemplateResponse("logout.html.j2", {"request": request})
103106

104107
# add auth to all other routes
105-
app.add_middleware(Depends(get_api_key))
108+
app.add_middleware(Depends(self.get_check_dependency()))
106109

107110
@app.app.exception_handler(403)
108111
async def custom_403_handler(request: Request = None, *args):

csp_gateway/server/middleware/api_key_external.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,24 @@ def _invoke_external(self, api_key: str, settings: GatewaySettings, module=None)
4040
return None
4141
return self.external_validator.object(api_key, settings, module)
4242

43-
def _get_api_key_dependency(self, app: GatewayWebApp):
44-
"""Returns the get_api_key dependency for external validation."""
45-
api_key_query = APIKeyQuery(name=self.api_key_name, auto_error=False)
46-
api_key_header = APIKeyHeader(name=self.api_key_name, auto_error=False)
47-
api_key_cookie = APIKeyCookie(name=self.api_key_name, auto_error=False)
48-
49-
async def get_api_key(
50-
api_key_query: str = Security(api_key_query),
51-
api_key_header: str = Security(api_key_header),
52-
api_key_cookie: str = Security(api_key_cookie),
53-
):
43+
def validate(self):
44+
"""Return a FastAPI dependency function for external API key validation."""
45+
api_key_query_security = Security(APIKeyQuery(name=self.api_key_name, auto_error=False))
46+
api_key_header_security = Security(APIKeyHeader(name=self.api_key_name, auto_error=False))
47+
api_key_cookie_security = Security(APIKeyCookie(name=self.api_key_name, auto_error=False))
48+
49+
async def validate_credentials(
50+
api_key_query: str = api_key_query_security,
51+
api_key_header: str = api_key_header_security,
52+
api_key_cookie: str = api_key_cookie_security,
53+
) -> str:
54+
"""Validate API key using external validator and return a session UUID."""
5455
try:
5556
for provided_key in (api_key_query, api_key_header, api_key_cookie):
56-
identity = self._invoke_external(provided_key, app.settings, app)
57+
identity = self._invoke_external(provided_key, self._app_settings, self._app_module)
5758
if identity and isinstance(identity, dict):
5859
user_uuid = str(uuid4())
5960
while user_uuid in self._identity_store:
60-
# Should never happen, but just in case of a uuid collision, generate a new one
6161
user_uuid = str(uuid4())
6262
self._identity_store[user_uuid] = identity
6363
return user_uuid
@@ -66,20 +66,23 @@ async def get_api_key(
6666
status_code=HTTP_403_FORBIDDEN,
6767
detail=self.unauthorized_status_message,
6868
) from e
69-
# No valid key found
7069
raise HTTPException(
7170
status_code=HTTP_403_FORBIDDEN,
7271
detail=self.unauthorized_status_message,
7372
)
7473

75-
return get_api_key
74+
return validate_credentials
7675

7776
def rest(self, app: GatewayWebApp) -> None:
77+
# Store app references for use in check()
78+
self._app_settings = app.settings
79+
self._app_module = app
80+
7881
auth_router: APIRouter = app.get_router("auth")
79-
get_api_key = self._get_api_key_dependency(app)
82+
check = self.get_check_dependency()
8083

8184
@auth_router.get("/login")
82-
async def route_login_and_add_cookie(api_key: str = Depends(get_api_key)):
85+
async def route_login_and_add_cookie(api_key: str = Depends(check)):
8386
response = RedirectResponse(url="/")
8487
if api_key in self._identity_store:
8588
response.set_cookie(
@@ -102,4 +105,4 @@ async def route_logout_and_remove_cookie(request: Request = None):
102105
return response
103106

104107
# Call parent to set up public routes, middleware, and exception handler
105-
self._setup_public_routes(app, get_api_key)
108+
self._setup_public_routes(app)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,61 @@
1+
from fnmatch import fnmatch
2+
from typing import Callable, List, Optional, Union
3+
4+
from ccflow import PyObjectPath
5+
from fastapi import Request
6+
from starlette.middleware.base import RequestResponseEndpoint
7+
from starlette.responses import Response
8+
19
from csp_gateway.server import GatewayChannels, GatewayModule
210

311
__all__ = ("AuthenticationMiddleware",)
412

513

614
class AuthenticationMiddleware(GatewayModule):
15+
scope: Optional[Union[str, List[str]]] = "*"
16+
check: Optional[Union[PyObjectPath, Callable]] = None
17+
18+
def get_check_callable(self) -> Optional[Callable]:
19+
"""Return the check callable from PyObjectPath or direct callable."""
20+
if self.check is None:
21+
return None
22+
return self.check if callable(self.check) else self.check.object
23+
24+
def _matches_scope(self, path: str) -> bool:
25+
"""Check if path matches any of the scope glob patterns."""
26+
if self.scope is None:
27+
return True
28+
patterns = self.scope if isinstance(self.scope, list) else [self.scope]
29+
return any(fnmatch(path, pattern) for pattern in patterns)
30+
31+
def validate(self) -> Callable:
32+
"""Return a FastAPI dependency function for credential validation.
33+
34+
Subclasses must implement this method. The returned function should:
35+
- Accept credentials (extracted via Security dependencies)
36+
- Return a validated identity/token on success
37+
- Raise HTTPException on failure
38+
39+
Note: Scope checking via _matches_scope() is available but not automatically
40+
applied in validate() due to WebSocket route compatibility constraints.
41+
"""
42+
raise NotImplementedError("Subclasses must implement validate()")
43+
44+
def _skip_if_out_of_scope(self, request: Request) -> bool:
45+
"""Check if request is out of scope. Returns True if should skip auth."""
46+
return not self._matches_scope(request.url.path)
47+
48+
def get_check_dependency(self) -> Callable:
49+
"""Return the validate() dependency. Scope checking is handled in validate()."""
50+
return self.validate()
51+
52+
async def check_scope(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
53+
"""Check that request path is valid in the scope/s. Returns True if in scope."""
54+
if self._matches_scope(request.url.path):
55+
return await call_next(request)
56+
# Path not in scope, skip authentication middleware
57+
return await call_next(request)
58+
759
def connect(self, channels: GatewayChannels) -> None:
860
# NO-OP
961
...

0 commit comments

Comments
 (0)