Skip to content

Commit a08872b

Browse files
authored
Merge pull request #113 from nebulabroadcast/sso-support
SSO
2 parents ad740c8 + ce1d80f commit a08872b

File tree

19 files changed

+563
-46
lines changed

19 files changed

+563
-46
lines changed

backend/api/auth/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1-
__all__ = ["LoginRequest", "LogoutRequest", "SetPasswordRequest"]
1+
__all__ = [
2+
"LoginRequest",
3+
"LogoutRequest",
4+
"SetPasswordRequest",
5+
"SSOLoginRequest",
6+
"SSOLoginCallback",
7+
"TokenExchangeRequest",
8+
]
29

310
from .login_request import LoginRequest
411
from .logout_request import LogoutRequest
512
from .set_password_request import SetPasswordRequest
13+
from .sso import SSOLoginCallback, SSOLoginRequest
14+
from .token_exchange import TokenExchangeRequest

backend/api/auth/login_request.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,14 @@
11
import time
22

33
from fastapi import Request
4-
from pydantic import Field
54

65
import nebula
76
from server.clientinfo import get_real_ip
8-
from server.models import RequestModel, ResponseModel
7+
from server.models.login import LoginRequestModel, LoginResponseModel
98
from server.request import APIRequest
109
from server.session import Session
1110

1211

13-
class LoginRequestModel(RequestModel):
14-
username: str = Field(
15-
...,
16-
title="Username",
17-
examples=["admin"],
18-
pattern=r"^[a-zA-Z0-9_\-\.]{2,}$",
19-
)
20-
password: str = Field(
21-
...,
22-
title="Password",
23-
description="Password in plain text",
24-
examples=["Password.123"],
25-
)
26-
27-
28-
class LoginResponseModel(ResponseModel):
29-
access_token: str = Field(
30-
...,
31-
title="Access token",
32-
description="Access token to be used in Authorization header"
33-
"for the subsequent requests",
34-
)
35-
36-
3712
async def check_failed_login(ip_address: str) -> None:
3813
banned_until = await nebula.redis.get("banned-ip-until", ip_address)
3914
if banned_until is None:

backend/api/auth/sso.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from urllib.parse import urlparse
2+
3+
from authlib.integrations.starlette_client import OAuthError
4+
from fastapi import Request
5+
from fastapi.responses import RedirectResponse
6+
7+
import nebula
8+
from server.request import APIRequest
9+
from server.session import Session
10+
from server.sso import NebulaSSO
11+
12+
13+
class SSOLoginRequest(APIRequest):
14+
name = "sso_login"
15+
path = "/api/sso/login/{provider}"
16+
methods = ["GET"]
17+
18+
async def handle(self, request: Request, provider: str) -> RedirectResponse:
19+
client = NebulaSSO.client(provider)
20+
21+
referer = request.headers.get("referer")
22+
if referer:
23+
parsed_url = urlparse(referer)
24+
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
25+
else:
26+
base_url = "http://localhost:4455"
27+
28+
# We cannot use request.url_for here because it screws the frontend
29+
# dev server proxy
30+
31+
redirect_uri = f"{base_url}/api/sso/callback/{provider}"
32+
nebula.log.debug(f"Redirect URI: {redirect_uri}")
33+
return await client.authorize_redirect(request, redirect_uri)
34+
35+
36+
class SSOLoginCallback(APIRequest):
37+
name = "sso_callback"
38+
path = "/api/sso/callback/{provider}"
39+
methods = ["GET"]
40+
41+
async def handle(self, request: Request, provider: str) -> RedirectResponse:
42+
client = NebulaSSO.client(provider)
43+
44+
try:
45+
token = await client.authorize_access_token(request)
46+
except OAuthError as error:
47+
return RedirectResponse(f"/?error={error}")
48+
user = token.get("userinfo", {})
49+
email = user.get("email")
50+
51+
if not email:
52+
return RedirectResponse("/?error=User email not found")
53+
54+
try:
55+
user = await nebula.User.by_email(email)
56+
except nebula.NotFoundException:
57+
return RedirectResponse("/?error=User not found")
58+
session = await Session.create(user, request, transient=True)
59+
return RedirectResponse(f"/?authorize={session.token}")

backend/api/auth/token_exchange.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from fastapi import Request
2+
3+
import nebula
4+
from server.models.login import LoginResponseModel, TokenExchangeRequestModel
5+
from server.request import APIRequest
6+
from server.session import Session
7+
8+
9+
class TokenExchangeRequest(APIRequest):
10+
"""Exachange a transient access token for a normal one
11+
12+
This request will exchange an access token for a new one.
13+
The original access token will be invalidated.
14+
"""
15+
16+
name: str = "token-exchange"
17+
response_model = LoginResponseModel
18+
19+
async def handle(
20+
self,
21+
request: Request,
22+
payload: TokenExchangeRequestModel,
23+
) -> LoginResponseModel:
24+
session = await Session.check(payload.access_token, request, transient=True)
25+
if not session:
26+
raise nebula.UnauthorizedException("Invalid token")
27+
user_id = session.user["id"]
28+
user = await nebula.User.load(user_id)
29+
session = await Session.create(user, request)
30+
nebula.log.debug(f"{user} token exchanged")
31+
await Session.delete(payload.access_token)
32+
return LoginResponseModel(access_token=session.token)

backend/api/init/init_request.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Annotated, Any, get_args
1+
from typing import Annotated, get_args
22

33
import fastapi
44
from pydantic import Field
@@ -11,6 +11,7 @@
1111
from server.dependencies import CurrentUserOptional
1212
from server.models import ResponseModel, UserModel
1313
from server.request import APIRequest
14+
from server.sso import NebulaSSO, SSOOption
1415

1516
from .client_settings import ClientSettingsModel, get_client_settings
1617

@@ -63,10 +64,10 @@ class InitResponseModel(ResponseModel):
6364
),
6465
] = None
6566

66-
oauth2_options: Annotated[
67-
list[dict[str, Any]] | None,
67+
sso_options: Annotated[
68+
list[SSOOption] | None,
6869
Field(
69-
title="OAuth2 options",
70+
title="SSO options",
7071
),
7172
] = None
7273

@@ -104,10 +105,11 @@ async def handle(
104105
return InitResponseModel(installed=False)
105106

106107
# Not logged in. Only return motd and oauth2 options.
107-
# TODO: return oauth2 options
108108
if user is None:
109+
sso_options = await NebulaSSO.options() or None
109110
return InitResponseModel(
110111
motd=motd,
112+
sso_options=sso_options,
111113
experimental=nebula.config.enable_experimental or None,
112114
)
113115

backend/nebula/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ class NebulaConfig(BaseModel):
4141
description="Password hashing method",
4242
)
4343

44+
session_secret: str = Field(
45+
default_factory=lambda: os.urandom(32).hex(),
46+
description="Session secret. MUST be set when the server is scaled",
47+
)
48+
4449
max_failed_login_attempts: int = Field(
4550
10,
4651
description="Maximum number of failed login attempts before the IP is banned",

backend/nebula/objects/user.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ async def by_api_key(cls, api_key: str) -> "User":
8888
raise NotFoundException(f"User with API key {api_key} not found")
8989
return cls.from_row(row[0])
9090

91+
@classmethod
92+
async def by_email(cls, email: str) -> "User":
93+
"""Return the user with the given email."""
94+
row = await db.fetch(
95+
"""
96+
SELECT meta FROM users WHERE LOWER(meta->>'email') = $1
97+
""",
98+
email.lower(),
99+
)
100+
if not row:
101+
raise NotFoundException(f"User with email {email} not found")
102+
return cls.from_row(row[0])
103+
91104
@classmethod
92105
async def login(cls, username: str, password: str) -> "User":
93106
"""Return a User instance based on username and password."""

backend/nebula/settings/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ class BaseSystemSettings(SettingsModel):
9393
)
9494

9595

96+
class SSOProvider(SettingsModel):
97+
name: str
98+
title: str | None = None
99+
entrypoint: str
100+
client_id: str
101+
client_secret: str
102+
103+
96104
class SystemSettings(BaseSystemSettings):
97105
"""System settings.
98106
@@ -107,6 +115,7 @@ class SystemSettings(BaseSystemSettings):
107115
upload_storage: int | None = Field(default=None)
108116
upload_dir: str | None = Field(default=None)
109117
upload_base_name: str = Field(default="{id}")
118+
sso_providers: list[SSOProvider] = Field(default_factory=list)
110119

111120
smtp_host: str | None = Field(
112121
default=None,

backend/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ requires-python = ">=3.11,<3.13"
77
dependencies = [
88
"aiofiles >=24.1.0",
99
"asyncpg >=0.29.0",
10+
"authlib>=1.5.1",
1011
"email-validator >=2.1.1",
1112
"fastapi >=0.115.0",
1213
"geoip2 >=4.8.0",
1314
"gunicorn >=22.0.0",
1415
"httpx >=0.27.2",
16+
"itsdangerous>=2.2.0",
1517
"mistune >=3.0.1",
1618
"nxtools >=1.6",
1719
"pydantic >=2.9.2",
@@ -33,6 +35,7 @@ dev = [
3335
"pytest-asyncio >=0.20.3",
3436
"ruff >=0.6.8",
3537
"types-aiofiles >=23.2.0.20240311",
38+
"types-authlib>=1.4.0.20241230",
3639
"types-requests >=2.31.0.20240311",
3740
]
3841

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import json
2+
import typing
3+
from base64 import b64decode, b64encode
4+
5+
import itsdangerous
6+
from itsdangerous.exc import BadSignature
7+
from starlette.datastructures import MutableHeaders
8+
from starlette.requests import HTTPConnection
9+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
10+
11+
import nebula
12+
from nebula.common import create_hash
13+
14+
15+
async def get_session_key() -> str:
16+
key_candidate = create_hash()
17+
18+
query = """
19+
INSERT INTO settings (key, value)
20+
VALUES ('.session_key', $1)
21+
ON CONFLICT (key) DO UPDATE
22+
SET value = settings.value
23+
RETURNING value
24+
"""
25+
26+
res = await nebula.db.fetchrow(query, key_candidate)
27+
assert res, "Failed to retrieve session key. This shouldn't happen"
28+
29+
if res["value"] == key_candidate:
30+
nebula.log.info("Created new session key")
31+
32+
return res["value"]
33+
34+
35+
class SessionMiddleware:
36+
"""Custom session middleware for Nebula.
37+
38+
The main difference with the default Starlette session middleware is that
39+
it loads the session key from the database so it can be shared between
40+
multiple replicas of the application.
41+
"""
42+
43+
_signer: itsdangerous.Signer | None = None
44+
45+
def __init__(
46+
self,
47+
app: ASGIApp,
48+
session_cookie: str = "session",
49+
max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds
50+
path: str = "/",
51+
same_site: typing.Literal["lax", "strict", "none"] = "lax",
52+
https_only: bool = False,
53+
domain: str | None = None,
54+
) -> None:
55+
self.app = app
56+
self.session_cookie = session_cookie
57+
self.max_age = max_age
58+
self.path = path
59+
self.security_flags = "httponly; samesite=" + same_site
60+
if https_only: # Secure flag can be used with HTTPS only
61+
self.security_flags += "; secure"
62+
if domain is not None:
63+
self.security_flags += f"; domain={domain}"
64+
65+
async def get_signer(self) -> itsdangerous.Signer:
66+
if self._signer is None:
67+
key = await get_session_key()
68+
self._signer = itsdangerous.TimestampSigner(key)
69+
return self._signer
70+
71+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
72+
if scope["type"] not in ("http", "websocket"): # pragma: no cover
73+
await self.app(scope, receive, send)
74+
return
75+
76+
signer = await self.get_signer()
77+
78+
connection = HTTPConnection(scope)
79+
initial_session_was_empty = True
80+
81+
if self.session_cookie in connection.cookies:
82+
data = connection.cookies[self.session_cookie].encode("utf-8")
83+
try:
84+
data = signer.unsign(data)
85+
scope["session"] = json.loads(b64decode(data))
86+
initial_session_was_empty = False
87+
except BadSignature:
88+
scope["session"] = {}
89+
else:
90+
scope["session"] = {}
91+
92+
async def send_wrapper(message: Message) -> None:
93+
if message["type"] == "http.response.start":
94+
if scope["session"]:
95+
# We have session data to persist.
96+
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
97+
data = signer.sign(data)
98+
headers = MutableHeaders(scope=message)
99+
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa: E501
100+
session_cookie=self.session_cookie,
101+
data=data.decode("utf-8"),
102+
path=self.path,
103+
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
104+
security_flags=self.security_flags,
105+
)
106+
headers.append("Set-Cookie", header_value)
107+
elif not initial_session_was_empty:
108+
# The session has been cleared.
109+
headers = MutableHeaders(scope=message)
110+
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa: E501
111+
session_cookie=self.session_cookie,
112+
data="null",
113+
path=self.path,
114+
expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
115+
security_flags=self.security_flags,
116+
)
117+
headers.append("Set-Cookie", header_value)
118+
await send(message)
119+
120+
await self.app(scope, receive, send_wrapper)

0 commit comments

Comments
 (0)