1919@dataclass (kw_only = True )
2020class OAuthManager :
2121 server_metadata_url : str
22- client_kwargs : dict [str , str ]
22+ client_kwargs : dict [str , str ] = field ( default_factory = dict )
2323 secret_key : Secret [str ] = field (default_factory = lambda : secrets .token_urlsafe (32 ))
2424 client_id : Secret [str ] = None
2525 client_secret : Secret [str ] = None
26- enable : bool = True
26+ force_user : UserInfo = None
2727 capabilities : CapabilitySet = field (default_factory = lambda : CapabilitySet ({}))
2828
2929 # [serieux: ignore]
@@ -33,8 +33,9 @@ class OAuthManager:
3333 token_cache : dict = field (default_factory = dict )
3434
3535 def __post_init__ (self ):
36- self .server_metadata = deserialize (OpenIDConfiguration , self .server_metadata_url )
37- self .secrets_serializer = URLSafeSerializer (self .secret_key )
36+ if not self .force_user :
37+ self .server_metadata = deserialize (OpenIDConfiguration , self .server_metadata_url )
38+ self .secrets_serializer = URLSafeSerializer (self .secret_key )
3839 self .user_management_capability = self .capabilities .registry .registry .get (
3940 "user_management" , None
4041 )
@@ -53,6 +54,8 @@ def ensure_user_manager(self, email):
5354 )
5455
5556 async def get_user (self , request : Request ):
57+ if self .force_user :
58+ return serialize (UserInfo , self .force_user )
5659 if auth := request .headers .get ("Authorization" ):
5760 match auth .split ("Bearer " ):
5861 case ("" , rtoken ):
@@ -143,6 +146,10 @@ async def route_login(self, request):
143146 red = request .session .get ("redirect_after_login" , "/" )
144147 request .session .clear ()
145148 request .session ["redirect_after_login" ] = red
149+ if self .force_user : # pragma: no cover
150+ # Pages won't redirect to /login when force_user is True,
151+ # so this won't happen unless the user directly goes to /login
152+ return RedirectResponse (url = red )
146153 auth_route = request .query_params .get ("redirect" , "auth" )
147154 redirect_uri = request .url_for (auth_route )
148155 params = {}
@@ -155,11 +162,14 @@ async def route_login(self, request):
155162 )
156163
157164 async def route_auth (self , request ):
158- await self .assimilate_payload (request )
165+ if not self .force_user :
166+ await self .assimilate_payload (request )
159167 red = request .session .get ("redirect_after_login" , "/" )
160168 return RedirectResponse (url = red )
161169
162170 async def route_token (self , request ):
171+ if self .force_user :
172+ return JSONResponse ({"refresh_token" : "XXX" })
163173 if state := request .query_params .get ("state" ):
164174 await self .assimilate_payload (request )
165175
@@ -257,9 +267,6 @@ class ListRequest:
257267 ##################
258268
259269 def install (self , app ):
260- if not self .enable : # pragma: no cover
261- return
262-
263270 app .add_middleware (
264271 SessionMiddleware ,
265272 secret_key = self .secret_key ,
0 commit comments