Skip to content

Commit d86134f

Browse files
committed
Remove enable, add force_user
1 parent 4640526 commit d86134f

File tree

6 files changed

+105
-14
lines changed

6 files changed

+105
-14
lines changed

src/easy_oauth/manager.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
@dataclass(kw_only=True)
2020
class OAuthManager:
2121
server_metadata_url: str
22-
client_kwargs: dict[str, str]
22+
client_kwargs: dict[str, str] = field(default_factory=dict)
2323
secret_key: Secret[str] = field(default_factory=lambda: secrets.token_urlsafe(32))
2424
client_id: Secret[str] = None
2525
client_secret: Secret[str] = None
26-
enable: bool = True
26+
force_user: UserInfo = None
2727
capabilities: CapabilitySet = field(default_factory=lambda: CapabilitySet({}))
2828

2929
# [serieux: ignore]
@@ -33,8 +33,9 @@ class OAuthManager:
3333
token_cache: dict = field(default_factory=dict)
3434

3535
def __post_init__(self):
36-
self.server_metadata = deserialize(OpenIDConfiguration, self.server_metadata_url)
37-
self.secrets_serializer = URLSafeSerializer(self.secret_key)
36+
if not self.force_user:
37+
self.server_metadata = deserialize(OpenIDConfiguration, self.server_metadata_url)
38+
self.secrets_serializer = URLSafeSerializer(self.secret_key)
3839
self.user_management_capability = self.capabilities.registry.registry.get(
3940
"user_management", None
4041
)
@@ -53,6 +54,8 @@ def ensure_user_manager(self, email):
5354
)
5455

5556
async def get_user(self, request: Request):
57+
if self.force_user:
58+
return serialize(UserInfo, self.force_user)
5659
if auth := request.headers.get("Authorization"):
5760
match auth.split("Bearer "):
5861
case ("", rtoken):
@@ -143,6 +146,10 @@ async def route_login(self, request):
143146
red = request.session.get("redirect_after_login", "/")
144147
request.session.clear()
145148
request.session["redirect_after_login"] = red
149+
if self.force_user: # pragma: no cover
150+
# Pages won't redirect to /login when force_user is True,
151+
# so this won't happen unless the user directly goes to /login
152+
return RedirectResponse(url=red)
146153
auth_route = request.query_params.get("redirect", "auth")
147154
redirect_uri = request.url_for(auth_route)
148155
params = {}
@@ -155,11 +162,14 @@ async def route_login(self, request):
155162
)
156163

157164
async def route_auth(self, request):
158-
await self.assimilate_payload(request)
165+
if not self.force_user:
166+
await self.assimilate_payload(request)
159167
red = request.session.get("redirect_after_login", "/")
160168
return RedirectResponse(url=red)
161169

162170
async def route_token(self, request):
171+
if self.force_user:
172+
return JSONResponse({"refresh_token": "XXX"})
163173
if state := request.query_params.get("state"):
164174
await self.assimilate_payload(request)
165175

@@ -257,9 +267,6 @@ class ListRequest:
257267
##################
258268

259269
def install(self, app):
260-
if not self.enable: # pragma: no cover
261-
return
262-
263270
app.add_middleware(
264271
SessionMiddleware,
265272
secret_key=self.secret_key,

src/easy_oauth/structs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ class UserInfo(Base):
4242
email: str
4343

4444
# The user's unique ID
45-
sub: str
45+
sub: str = None
46+
47+
def __post_init__(self):
48+
if self.sub is None:
49+
self.sub = self.email
4650

4751
@classmethod
4852
def serieux_from_string(cls, idtoken):

tests/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
here = Path(__file__).parent
1212

1313

14-
def make_app(tmpdir: Path = None):
14+
def make_app(config_path, tmpdir: Path = None):
1515
app = FastAPI()
1616

17-
oauth = deserialize(OAuthManager, Path(here / "appconfig.yaml"))
17+
oauth = deserialize(OAuthManager, config_path)
1818

1919
if tmpdir is not None:
2020
dest_cap_file = Path(tmpdir) / oauth.capabilities.user_file.name

tests/conftest.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
import threading
44
import time
5+
from contextlib import contextmanager
56
from dataclasses import dataclass
67
from itertools import count
8+
from pathlib import Path
79
from random import randint
810

911
import httpx
1012
import pytest
1113
import uvicorn
14+
from serieux import Sources
1215

1316
from .app import make_app
1417
from .oauth_mock import PORT as OAUTH_PORT
@@ -17,6 +20,9 @@
1720
_port = count(OAUTH_PORT + randint(1, 1000))
1821

1922

23+
here = Path(__file__).parent
24+
25+
2026
class ServerThread(threading.Thread):
2127
# Generated by Claude
2228

@@ -45,6 +51,7 @@ def stop(self):
4551
self.server.should_exit = True
4652

4753

54+
@contextmanager
4855
def create_endpoint(app, host, port):
4956
# Mostly generated by Claude
5057

@@ -84,7 +91,8 @@ def oauth_endpoint():
8491
Yields:
8592
str: The base URL of the mock OAuth server (e.g., "http://127.0.0.1:29313")
8693
"""
87-
yield from create_endpoint(oauth_app, "127.0.0.1", OAUTH_PORT)
94+
with create_endpoint(oauth_app, "127.0.0.1", OAUTH_PORT) as endpoint:
95+
yield endpoint
8896

8997

9098
@pytest.fixture
@@ -141,7 +149,8 @@ def post(self, endpoint, expect=None, **data):
141149
@pytest.fixture(scope="session")
142150
def app():
143151
port = next(_port)
144-
yield from create_endpoint(make_app(), "127.0.0.1", port)
152+
with create_endpoint(make_app(Path(here / "appconfig.yaml")), "127.0.0.1", port) as endpoint:
153+
yield endpoint
145154

146155

147156
@pytest.fixture
@@ -156,7 +165,10 @@ def make_interactor(email):
156165
@pytest.fixture
157166
def app_write(tmpdir):
158167
port = next(_port)
159-
yield from create_endpoint(make_app(tmpdir), "127.0.0.1", port)
168+
with create_endpoint(
169+
make_app(Path(here / "appconfig.yaml"), tmpdir), "127.0.0.1", port
170+
) as endpoint:
171+
yield endpoint
160172

161173

162174
@pytest.fixture
@@ -166,3 +178,15 @@ def make_interactor(email):
166178
return TokenInteractor.make(app_write, email)
167179

168180
yield make_interactor
181+
182+
183+
@pytest.fixture
184+
def app_force_user(tmpdir):
185+
@contextmanager
186+
def make(email):
187+
port = next(_port)
188+
sources = Sources(Path(here / "noauthconfig.yaml"), {"force_user": {"email": email}})
189+
with create_endpoint(make_app(sources, tmpdir), "127.0.0.1", port) as endpoint:
190+
yield endpoint
191+
192+
yield make

tests/noauthconfig.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
server_metadata_url: "n/a"
2+
capabilities:
3+
auto_admin: true
4+
user_file: caps.yaml
5+
graph:
6+
user_management: []
7+
villager: []
8+
mafia:
9+
- villager
10+
police:
11+
- villager
12+
mayor:
13+
- villager
14+
- police
15+
baker:
16+
- villager

tests/test_manager.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,43 @@ def test_set_capability(user_write, tmpdir):
209209

210210
new_caps = deserialize(dict[str, set[str]], Path(tmpdir / "caps.yaml"))
211211
assert new_caps[u.email] == {"baker"}
212+
213+
214+
def test_force_admin(app_force_user):
215+
with app_force_user("[email protected]") as app:
216+
resp = httpx.get(f"{app}/hello")
217+
assert resp.status_code == 200
218+
assert resp.text == "Hello, [email protected]!"
219+
220+
resp = httpx.get(f"{app}/murder", params={"target": "Bart"})
221+
assert resp.status_code == 200
222+
assert resp.text == "Bart was murdered by [email protected]"
223+
224+
resp = httpx.get(f"{app}/bake", params={"food": "chocolate cake"})
225+
assert resp.status_code == 200
226+
assert resp.text == "chocolate cake was baked by [email protected]"
227+
228+
229+
def test_force_cap(app_force_user):
230+
with app_force_user("[email protected]") as app:
231+
resp = httpx.get(f"{app}/hello")
232+
assert resp.status_code == 200
233+
assert resp.text == "Hello, [email protected]!"
234+
235+
resp = httpx.get(f"{app}/murder", params={"target": "Lisa"})
236+
assert resp.status_code == 200
237+
assert resp.text == "Lisa was murdered by [email protected]"
238+
239+
resp = httpx.get(f"{app}/bake", params={"food": "baguette"})
240+
assert resp.status_code == 403 # boss does not have baker capability
241+
242+
243+
def test_force_user_token(app_force_user):
244+
# Make sure the token flow is still valid
245+
with app_force_user("[email protected]") as app:
246+
response = httpx.get(f"{app}/token", follow_redirects=True)
247+
token = response.json()["refresh_token"]
248+
assert token == "XXX"
249+
response = httpx.get(f"{app}/hello", headers={"Authorization": f"Bearer {token}"})
250+
assert response.text == "Hello, [email protected]!"
251+
assert response.status_code == 200

0 commit comments

Comments
 (0)