Skip to content

Commit f5f572e

Browse files
committed
Reorganize CapabilitySet
1 parent b570afe commit f5f572e

File tree

5 files changed

+78
-63
lines changed

5 files changed

+78
-63
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ oauth = OAuthManager(
3232
"prompt": "select_account",
3333
}
3434
# Set of capabilities that can be assigned to users
35-
capset=CapabilitySet(
36-
capabilities={
35+
capabilities=CapabilitySet(
36+
graph={
3737
# Basic capability
3838
"read": [],
3939
# write, also implies read
@@ -48,9 +48,9 @@ oauth = OAuthManager(
4848
},
4949
# Create the "admin" capability that has every other capability
5050
auto_admin=True,
51+
# File where each user's capability is stored
52+
user_file="caps.yaml",
5153
),
52-
# File where each user's capability is stored
53-
capability_file="caps.yaml",
5454
)
5555

5656
app = FastAPI()
@@ -113,15 +113,15 @@ client_secret: "<CLIENT_SECRET>"
113113
client_kwargs:
114114
scope: openid email
115115
prompt: select_account
116-
capset:
117-
capabilities:
116+
capabilities:
117+
graph:
118118
read: []
119119
write: [read]
120120
moderate: [read, write]
121121
announce: [read, write]
122122
user_management: []
123123
auto_admin: true
124-
capability_file: caps.yaml
124+
user_file: caps.yaml
125125
```
126126

127127
And instantiated like this:

src/easy_oauth/cap.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from __future__ import annotations
2-
31
from dataclasses import dataclass, field
2+
from functools import cached_property
3+
from pathlib import Path
44

5+
from serieux import deserialize
6+
from serieux.features.filebacked import DefaultFactory, FileBacked
57
from serieux.features.registered import Registry
68

79

810
@dataclass(eq=False)
911
class Capability:
1012
name: str = None
11-
implies: set[Capability] = field(default_factory=set)
13+
implies: set["Capability"] = field(default_factory=set)
1214

1315
def __contains__(self, cap):
1416
return cap is self or any(cap in cap2 for cap2 in self.implies)
@@ -20,17 +22,33 @@ def __str__(self):
2022

2123

2224
class CapabilitySet:
23-
def __init__(self, capabilities: dict[str, list[str]], auto_admin: bool = True):
25+
def __init__(
26+
self,
27+
graph: dict[str, list[str]],
28+
auto_admin: bool = True,
29+
user_file: Path = None,
30+
):
2431
self.registry = Registry()
25-
for name in capabilities:
32+
for name in graph:
2633
self.registry.register(name, Capability(name))
27-
for name, implies in capabilities.items():
34+
for name, implies in graph.items():
2835
self[name].implies.update(self[n] for n in implies)
2936
if auto_admin:
3037
self.registry.register(
3138
"admin", Capability("admin", set(self.registry.registry.values()))
3239
)
3340
self.captype = Capability @ self.registry
41+
self.user_file = user_file
3442

3543
def __getitem__(self, item):
3644
return self.registry.registry[item]
45+
46+
@cached_property
47+
def db(self):
48+
return deserialize(
49+
FileBacked[dict[str, set[self.captype]] @ DefaultFactory(dict)],
50+
self.user_file,
51+
)
52+
53+
def check(self, email, cap):
54+
return cap in Capability(implies=self.db.value.get(email, set()))

src/easy_oauth/manager.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
import secrets
22
from dataclasses import dataclass, field
33
from datetime import datetime, timedelta
4-
from functools import cached_property
5-
from pathlib import Path
64

75
import httpx
86
from authlib.integrations.starlette_client import OAuth
97
from itsdangerous import URLSafeSerializer
108
from serieux import deserialize, serialize
119
from serieux.features.encrypt import Secret
12-
from serieux.features.filebacked import DefaultFactory, FileBacked
1310
from starlette.exceptions import HTTPException
1411
from starlette.middleware.sessions import SessionMiddleware
1512
from starlette.requests import Request
1613
from starlette.responses import JSONResponse, PlainTextResponse, RedirectResponse
1714

18-
from .cap import Capability, CapabilitySet
15+
from .cap import CapabilitySet
1916
from .structs import OpenIDConfiguration, Payload, UserInfo
2017

2118

@@ -27,8 +24,7 @@ class OAuthManager:
2724
client_id: Secret[str] = None
2825
client_secret: Secret[str] = None
2926
enable: bool = True
30-
capability_file: Path = None
31-
capset: CapabilitySet = field(default_factory=lambda: CapabilitySet({}))
27+
capabilities: CapabilitySet = field(default_factory=lambda: CapabilitySet({}))
3228

3329
# [serieux: ignore]
3430
server_metadata: OpenIDConfiguration = None
@@ -39,23 +35,12 @@ class OAuthManager:
3935
def __post_init__(self):
4036
self.server_metadata = deserialize(OpenIDConfiguration, self.server_metadata_url)
4137
self.secrets_serializer = URLSafeSerializer(self.secret_key)
42-
self.user_management_capability = self.capset.registry.registry.get(
38+
self.user_management_capability = self.capabilities.registry.registry.get(
4339
"user_management", None
4440
)
4541

46-
@cached_property
47-
def capability_db(self):
48-
return deserialize(
49-
FileBacked[dict[str, set[self.capset.captype]] @ DefaultFactory(dict)],
50-
self.capability_file,
51-
)
52-
53-
def has_capability(self, email, cap):
54-
cd = self.capability_db.value
55-
return cap in Capability(implies=cd.get(email, set()))
56-
5742
def ensure_user_manager(self, email):
58-
if self.user_management_capability is None or not self.has_capability(
43+
if self.user_management_capability is None or not self.capabilities.check(
5944
email, self.user_management_capability
6045
):
6146
raise HTTPException(
@@ -94,7 +79,7 @@ async def ensure_email(self, request: Request):
9479

9580
def get_email_capability(self, cap=None, redirect=False):
9681
if isinstance(cap, str):
97-
cap = deserialize(self.capset.captype, cap)
82+
cap = deserialize(self.capabilities.captype, cap)
9883

9984
async def get(request: Request):
10085
if redirect:
@@ -103,7 +88,7 @@ async def get(request: Request):
10388
email = await self.get_email(request)
10489
if email is None:
10590
raise HTTPException(status_code=401, detail="Authentication required")
106-
elif cap is None or self.has_capability(email, cap):
91+
elif cap is None or self.capabilities.check(email, cap):
10792
yield email
10893
else:
10994
raise HTTPException(status_code=403, detail=f"{cap} capability required")
@@ -145,6 +130,10 @@ async def assimilate_payload(self, request):
145130
request.session["access_token"] = payload.access_token
146131
request.session["refresh_token"] = payload.refresh_token
147132

133+
##########
134+
# Routes #
135+
##########
136+
148137
async def route_login(self, request):
149138
red = request.session.get("redirect_after_login", "/")
150139
request.session.clear()
@@ -183,13 +172,19 @@ async def route_logout(self, request):
183172
request.session.clear()
184173
return RedirectResponse(url="/")
185174

175+
##########################
176+
# User management routes #
177+
##########################
178+
186179
def _manage_cap_response(self, email):
187-
db = self.capability_db
180+
db = self.capabilities.db
188181
return JSONResponse(
189182
{
190183
"status": "ok",
191184
"email": email,
192-
"capabilities": serialize(set[self.capset.captype], db.value.get(email, set())),
185+
"capabilities": serialize(
186+
set[self.capabilities.captype], db.value.get(email, set())
187+
),
193188
}
194189
)
195190

@@ -199,7 +194,7 @@ async def _manage_generic(self, request, reqcls):
199194

200195
req = deserialize(reqcls, await request.json())
201196

202-
db = self.capability_db
197+
db = self.capabilities.db
203198
req.apply(db.value)
204199
db.save()
205200

@@ -209,7 +204,7 @@ async def route_manage_capabilities_add(self, request):
209204
@dataclass
210205
class AddRequest:
211206
email: str
212-
capability: self.capset.captype
207+
capability: self.capabilities.captype
213208

214209
def apply(self, caps):
215210
caps.setdefault(self.email, set()).add(self.capability)
@@ -220,7 +215,7 @@ async def route_manage_capabilities_remove(self, request):
220215
@dataclass
221216
class RemoveRequest:
222217
email: str
223-
capability: self.capset.captype
218+
capability: self.capabilities.captype
224219

225220
def apply(self, caps):
226221
caps.setdefault(self.email, set()).discard(self.capability)
@@ -231,7 +226,7 @@ async def route_manage_capabilities_set(self, request):
231226
@dataclass
232227
class SetRequest:
233228
email: str
234-
capabilities: set[self.capset.captype]
229+
capabilities: set[self.capabilities.captype]
235230

236231
def apply(self, caps):
237232
caps[self.email] = self.capabilities
@@ -252,6 +247,10 @@ class ListRequest:
252247

253248
return self._manage_cap_response(req.email)
254249

250+
##################
251+
# Install to app #
252+
##################
253+
255254
def install(self, app):
256255
if not self.enable: # pragma: no cover
257256
return

tests/app.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from fastapi import Depends, FastAPI
55
from fastapi.responses import JSONResponse, PlainTextResponse
6-
from serieux import Sources, deserialize
6+
from serieux import deserialize
77
from starlette.requests import Request
88

99
from easy_oauth.manager import OAuthManager
@@ -14,28 +14,12 @@
1414
def make_app(tmpdir: Path = None):
1515
app = FastAPI()
1616

17-
capgraph = {
18-
"user_management": [],
19-
"villager": [],
20-
"mafia": ["villager"],
21-
"police": ["villager"],
22-
"mayor": ["villager", "police"],
23-
"baker": ["villager"],
24-
}
25-
26-
oauth = deserialize(
27-
OAuthManager,
28-
Sources(
29-
Path(here / "appconfig.yaml"),
30-
{
31-
"capset": {"capabilities": capgraph},
32-
},
33-
),
34-
)
17+
oauth = deserialize(OAuthManager, Path(here / "appconfig.yaml"))
18+
3519
if tmpdir is not None:
36-
dest_cap_file = Path(tmpdir) / oauth.capability_file.name
37-
shutil.copy(oauth.capability_file, dest_cap_file)
38-
oauth.capability_file = dest_cap_file
20+
dest_cap_file = Path(tmpdir) / oauth.capabilities.user_file.name
21+
shutil.copy(oauth.capabilities.user_file, dest_cap_file)
22+
oauth.capabilities.user_file = dest_cap_file
3923

4024
oauth.install(app)
4125

tests/appconfig.yaml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,18 @@ client_secret: "dHPf4JQhK9VutnlJlcln"
55
client_kwargs:
66
scope: openid email
77
prompt: select_account
8-
capability_file: caps.yaml
8+
capabilities:
9+
auto_admin: true
10+
user_file: caps.yaml
11+
graph:
12+
user_management: []
13+
villager: []
14+
mafia:
15+
- villager
16+
police:
17+
- villager
18+
mayor:
19+
- villager
20+
- police
21+
baker:
22+
- villager

0 commit comments

Comments
 (0)