10
10
from contextlib import asynccontextmanager
11
11
from functools import cache , partial
12
12
from pathlib import Path
13
- from typing import Optional
13
+ from typing import Optional , Union
14
14
15
15
import anyio
16
16
import packaging .version
17
17
import yaml
18
18
from asgi_correlation_id import CorrelationIdMiddleware , correlation_id
19
- from fastapi import APIRouter , Depends , FastAPI , HTTPException , Request , Response
19
+ from fastapi import Depends , FastAPI , HTTPException , Request , Response
20
20
from fastapi .exception_handlers import http_exception_handler
21
21
from fastapi .middleware .cors import CORSMiddleware
22
22
from fastapi .openapi .utils import get_openapi
52
52
from .dependencies import get_root_tree
53
53
from .router import get_router
54
54
from .settings import Settings , get_settings
55
- from .utils import (
56
- API_KEY_COOKIE_NAME ,
57
- CSRF_COOKIE_NAME ,
58
- get_authenticators ,
59
- get_root_url ,
60
- record_timing ,
61
- )
55
+ from .utils import API_KEY_COOKIE_NAME , CSRF_COOKIE_NAME , get_root_url , record_timing
62
56
63
57
SAFE_METHODS = {"GET" , "HEAD" , "OPTIONS" , "TRACE" }
64
58
SENSITIVE_COOKIES = {
@@ -134,7 +128,7 @@ def build_app(
134
128
Dict of other server configuration.
135
129
"""
136
130
authentication = authentication or {}
137
- authenticators = {
131
+ authenticators : dict [ str , Union [ ExternalAuthenticator , InternalAuthenticator ]] = {
138
132
spec ["provider" ]: spec ["authenticator" ]
139
133
for spec in authentication .get ("providers" , [])
140
134
}
@@ -354,6 +348,7 @@ async def unhandled_exception_handler(
354
348
serialization_registry ,
355
349
deserialization_registry ,
356
350
validation_registry ,
351
+ authenticators ,
357
352
)
358
353
app .include_router (router , prefix = "/api/v1" )
359
354
@@ -368,13 +363,9 @@ async def unhandled_exception_handler(
368
363
# Delay this imports to avoid delaying startup with the SQL and cryptography
369
364
# imports if they are not needed.
370
365
from .authentication import (
371
- base_authentication_router ,
372
- build_auth_code_route ,
373
- build_device_code_authorize_route ,
374
- build_device_code_token_route ,
375
- build_device_code_user_code_form_route ,
376
- build_device_code_user_code_submit_route ,
377
- build_handle_credentials_route ,
366
+ add_external_routes ,
367
+ add_internal_routes ,
368
+ authentication_router ,
378
369
oauth2_scheme ,
379
370
)
380
371
@@ -385,41 +376,17 @@ async def unhandled_exception_handler(
385
376
)
386
377
# Authenticators provide Router(s) for their particular flow.
387
378
# Collect them in the authentication_router.
388
- authentication_router = APIRouter ()
379
+ authentication_router = authentication_router ()
389
380
# This adds the universal routes like /session/refresh and /session/revoke.
390
381
# Below we will add routes specific to our authentication providers.
391
- authentication_router . include_router ( base_authentication_router )
382
+
392
383
for spec in authentication ["providers" ]:
393
384
provider = spec ["provider" ]
394
385
authenticator = spec ["authenticator" ]
395
386
if isinstance (authenticator , InternalAuthenticator ):
396
- authentication_router .post (f"/provider/{ provider } /token" )(
397
- build_handle_credentials_route (authenticator , provider )
398
- )
387
+ add_internal_routes (authentication_router , provider , authenticator )
399
388
elif isinstance (authenticator , ExternalAuthenticator ):
400
- # Client starts here to create a PendingSession.
401
- authentication_router .post (f"/provider/{ provider } /authorize" )(
402
- build_device_code_authorize_route (authenticator , provider )
403
- )
404
- # External OAuth redirects here with code, presenting form for user code.
405
- authentication_router .get (f"/provider/{ provider } /device_code" )(
406
- build_device_code_user_code_form_route (authenticator , provider )
407
- )
408
- # User code and auth code are submitted here.
409
- authentication_router .post (f"/provider/{ provider } /device_code" )(
410
- build_device_code_user_code_submit_route (authenticator , provider )
411
- )
412
- # Client polls here for token.
413
- authentication_router .post (f"/provider/{ provider } /token" )(
414
- build_device_code_token_route (authenticator , provider )
415
- )
416
- # Normal code flow end point for web UIs
417
- authentication_router .get (f"/provider/{ provider } /code" )(
418
- build_auth_code_route (authenticator , provider )
419
- )
420
- # authentication_router.post(f"/provider/{provider}/code")(
421
- # build_auth_code_route(authenticator, provider)
422
- # )
389
+ add_external_routes (authentication_router , provider , authenticator )
423
390
else :
424
391
raise ValueError (f"unknown authenticator type { type (authenticator )} " )
425
392
for custom_router in getattr (authenticator , "include_routers" , []):
@@ -432,10 +399,6 @@ async def unhandled_exception_handler(
432
399
else :
433
400
app .state .authenticated = False
434
401
435
- @cache
436
- def override_get_authenticators ():
437
- return authenticators
438
-
439
402
@cache
440
403
def override_get_root_tree ():
441
404
return tree
@@ -761,7 +724,6 @@ async def set_cookies(request: Request, call_next):
761
724
return response
762
725
763
726
app .openapi = partial (custom_openapi , app )
764
- app .dependency_overrides [get_authenticators ] = override_get_authenticators
765
727
app .dependency_overrides [get_root_tree ] = override_get_root_tree
766
728
app .dependency_overrides [get_settings ] = override_get_settings
767
729
0 commit comments