11import secrets
22from dataclasses import dataclass , field
33from datetime import datetime , timedelta
4- from functools import cached_property
5- from pathlib import Path
64
75import httpx
86from authlib .integrations .starlette_client import OAuth
97from itsdangerous import URLSafeSerializer
108from serieux import deserialize , serialize
119from serieux .features .encrypt import Secret
12- from serieux .features .filebacked import DefaultFactory , FileBacked
1310from starlette .exceptions import HTTPException
1411from starlette .middleware .sessions import SessionMiddleware
1512from starlette .requests import Request
1613from starlette .responses import JSONResponse , PlainTextResponse , RedirectResponse
1714
18- from .cap import Capability , CapabilitySet
15+ from .cap import CapabilitySet
1916from .structs import OpenIDConfiguration , Payload , UserInfo
2017
2118
@@ -27,8 +24,7 @@ class OAuthManager:
2724 client_id : Secret [str ] = None
2825 client_secret : Secret [str ] = None
2926 enable : bool = True
30- capability_file : Path = None
31- capset : CapabilitySet = field (default_factory = lambda : CapabilitySet ({}))
27+ capabilities : CapabilitySet = field (default_factory = lambda : CapabilitySet ({}))
3228
3329 # [serieux: ignore]
3430 server_metadata : OpenIDConfiguration = None
@@ -39,23 +35,12 @@ class OAuthManager:
3935 def __post_init__ (self ):
4036 self .server_metadata = deserialize (OpenIDConfiguration , self .server_metadata_url )
4137 self .secrets_serializer = URLSafeSerializer (self .secret_key )
42- self .user_management_capability = self .capset .registry .registry .get (
38+ self .user_management_capability = self .capabilities .registry .registry .get (
4339 "user_management" , None
4440 )
4541
46- @cached_property
47- def capability_db (self ):
48- return deserialize (
49- FileBacked [dict [str , set [self .capset .captype ]] @ DefaultFactory (dict )],
50- self .capability_file ,
51- )
52-
53- def has_capability (self , email , cap ):
54- cd = self .capability_db .value
55- return cap in Capability (implies = cd .get (email , set ()))
56-
5742 def ensure_user_manager (self , email ):
58- if self .user_management_capability is None or not self .has_capability (
43+ if self .user_management_capability is None or not self .capabilities . check (
5944 email , self .user_management_capability
6045 ):
6146 raise HTTPException (
@@ -94,7 +79,7 @@ async def ensure_email(self, request: Request):
9479
9580 def get_email_capability (self , cap = None , redirect = False ):
9681 if isinstance (cap , str ):
97- cap = deserialize (self .capset .captype , cap )
82+ cap = deserialize (self .capabilities .captype , cap )
9883
9984 async def get (request : Request ):
10085 if redirect :
@@ -103,7 +88,7 @@ async def get(request: Request):
10388 email = await self .get_email (request )
10489 if email is None :
10590 raise HTTPException (status_code = 401 , detail = "Authentication required" )
106- elif cap is None or self .has_capability (email , cap ):
91+ elif cap is None or self .capabilities . check (email , cap ):
10792 yield email
10893 else :
10994 raise HTTPException (status_code = 403 , detail = f"{ cap } capability required" )
@@ -145,6 +130,10 @@ async def assimilate_payload(self, request):
145130 request .session ["access_token" ] = payload .access_token
146131 request .session ["refresh_token" ] = payload .refresh_token
147132
133+ ##########
134+ # Routes #
135+ ##########
136+
148137 async def route_login (self , request ):
149138 red = request .session .get ("redirect_after_login" , "/" )
150139 request .session .clear ()
@@ -183,13 +172,19 @@ async def route_logout(self, request):
183172 request .session .clear ()
184173 return RedirectResponse (url = "/" )
185174
175+ ##########################
176+ # User management routes #
177+ ##########################
178+
186179 def _manage_cap_response (self , email ):
187- db = self .capability_db
180+ db = self .capabilities . db
188181 return JSONResponse (
189182 {
190183 "status" : "ok" ,
191184 "email" : email ,
192- "capabilities" : serialize (set [self .capset .captype ], db .value .get (email , set ())),
185+ "capabilities" : serialize (
186+ set [self .capabilities .captype ], db .value .get (email , set ())
187+ ),
193188 }
194189 )
195190
@@ -199,7 +194,7 @@ async def _manage_generic(self, request, reqcls):
199194
200195 req = deserialize (reqcls , await request .json ())
201196
202- db = self .capability_db
197+ db = self .capabilities . db
203198 req .apply (db .value )
204199 db .save ()
205200
@@ -209,7 +204,7 @@ async def route_manage_capabilities_add(self, request):
209204 @dataclass
210205 class AddRequest :
211206 email : str
212- capability : self .capset .captype
207+ capability : self .capabilities .captype
213208
214209 def apply (self , caps ):
215210 caps .setdefault (self .email , set ()).add (self .capability )
@@ -220,7 +215,7 @@ async def route_manage_capabilities_remove(self, request):
220215 @dataclass
221216 class RemoveRequest :
222217 email : str
223- capability : self .capset .captype
218+ capability : self .capabilities .captype
224219
225220 def apply (self , caps ):
226221 caps .setdefault (self .email , set ()).discard (self .capability )
@@ -231,7 +226,7 @@ async def route_manage_capabilities_set(self, request):
231226 @dataclass
232227 class SetRequest :
233228 email : str
234- capabilities : set [self .capset .captype ]
229+ capabilities : set [self .capabilities .captype ]
235230
236231 def apply (self , caps ):
237232 caps [self .email ] = self .capabilities
@@ -252,6 +247,10 @@ class ListRequest:
252247
253248 return self ._manage_cap_response (req .email )
254249
250+ ##################
251+ # Install to app #
252+ ##################
253+
255254 def install (self , app ):
256255 if not self .enable : # pragma: no cover
257256 return
0 commit comments