Skip to content

Commit cc42fc8

Browse files
authored
feat(BA-5584): migrate keypair auth plugin into core repository (#10771)
1 parent bf2d491 commit cc42fc8

10 files changed

Lines changed: 319 additions & 4 deletions

File tree

changes/10771.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Migrate keypair auth plugin into the core repository under `src/ai/backend/manager/plugin/keypair/`

src/ai/backend/manager/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,12 @@ python_distribution(
6060
"backendai_hook_v20": {
6161
"totp": "ai.backend.manager.plugin.totp.hook:TOTPHook",
6262
"openid": "ai.backend.manager.plugin.openid.hook:OIDCHookPlugin",
63+
"auth_keypair": "ai.backend.manager.plugin.keypair.hook:KeypairAuthHookPlugin",
6364
},
6465
"backendai_webapp_v20": {
6566
"totp": "ai.backend.manager.plugin.totp.webapp:TOTPWebapp",
6667
"openid": "ai.backend.manager.plugin.openid.webapp:OIDCWebAppPlugin",
68+
"auth_keypair": "ai.backend.manager.plugin.keypair.webapp:KeypairAuthWebAppPlugin",
6769
},
6870
"backendai_network_manager_v1": {
6971
"overlay": "ai.backend.manager.network.overlay:OverlayNetworkPlugin",

src/ai/backend/manager/plugin/keypair/__init__.py

Whitespace-only changes.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class ExternalError(Exception):
2+
pass
3+
4+
5+
class InvalidSToken(ExternalError):
6+
pass
7+
8+
9+
class ExpiredSToken(ExternalError):
10+
pass
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from __future__ import annotations
2+
3+
import hashlib
4+
import hmac
5+
import logging
6+
import re
7+
from collections.abc import Mapping, Sequence
8+
from typing import Any
9+
10+
import sqlalchemy as sa
11+
import trafaret as t
12+
from aiohttp import web
13+
from dateutil.parser import parse as dateutil_parse
14+
15+
from ai.backend.common.logging_utils import BraceStyleAdapter
16+
from ai.backend.common.plugin.hook import HookHandler, HookPlugin, Reject
17+
from ai.backend.common.utils import nmget
18+
from ai.backend.manager.errors.auth import AuthorizationFailed, InvalidAuthParameters
19+
from ai.backend.manager.models.keypair import KeyPairRow, keypairs
20+
from ai.backend.manager.models.user import UserStatus, users
21+
22+
from .utils import deserialize_stoken
23+
24+
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
25+
26+
plugin_config_checker = t.Dict({
27+
t.Key("auth_token_name", default="sToken"): t.Null | t.String,
28+
}).allow_extra("*")
29+
30+
31+
DEFAULT_STOKEN_COOKIE_VALUE = "BackendAI"
32+
33+
34+
class KeypairAuthHookPlugin(HookPlugin):
35+
def __init__(self, plugin_config: Mapping[str, Any], local_config: Mapping[str, Any]) -> None:
36+
super().__init__(plugin_config, local_config)
37+
self.plugin_config = plugin_config_checker.check(self.plugin_config)
38+
39+
def get_handlers(self) -> Sequence[tuple[str, HookHandler]]:
40+
return [
41+
("AUTHORIZE", self.authorize),
42+
]
43+
44+
async def update_plugin_config(self, plugin_config: Mapping[str, Any]) -> None:
45+
self.plugin_config = plugin_config
46+
47+
async def init(self, context: Any = None) -> None:
48+
pass
49+
50+
async def cleanup(self) -> None:
51+
pass
52+
53+
def parse_token(self, token: str) -> tuple[str, str, str] | None:
54+
pattern = r"BackendAI signMethod=(?P<sign_method>[A-Z0-9-]+), credential=(?P<access_key>\w+):(?P<signature>\w+)"
55+
match = re.search(pattern, token)
56+
if match:
57+
sign_method = match.group("sign_method")
58+
access_key = match.group("access_key")
59+
signature = match.group("signature")
60+
return (sign_method, access_key, signature)
61+
return None
62+
63+
async def sign_token(self, sign_method: str, secret_key: str, params: Mapping[str, Any]) -> str:
64+
try:
65+
mac_type, hash_type = map(lambda s: s.lower(), sign_method.split("-"))
66+
if mac_type != "hmac":
67+
raise InvalidAuthParameters("Unsupported signing method (MAC type)")
68+
if hash_type not in hashlib.algorithms_guaranteed:
69+
raise InvalidAuthParameters("Unsupported signing method (hash type)")
70+
71+
date_obj = dateutil_parse(params["date"])
72+
date = date_obj.isoformat()
73+
endpoint = params["endpoint"]
74+
api_version = params["api_version"]
75+
if date is None:
76+
raise InvalidAuthParameters("Request date is missing")
77+
if endpoint is None:
78+
raise InvalidAuthParameters("Request endpoint is missing")
79+
if api_version is None:
80+
raise InvalidAuthParameters("API version is missing")
81+
82+
body = b""
83+
body_hash = hashlib.new(hash_type, body).hexdigest()
84+
sign_bytes = (
85+
"{0}\n{1}\n{2}\nhost:{3}\ncontent-type:{4}\nx-{name}-version:{5}\n{6}".format(
86+
"POST",
87+
"/authorize/keypair",
88+
date,
89+
endpoint,
90+
"application/json",
91+
api_version,
92+
body_hash,
93+
name="backendai",
94+
)
95+
).encode()
96+
sign_key = hmac.new(
97+
secret_key.encode(), date_obj.strftime("%Y%m%d").encode(), hash_type
98+
).digest()
99+
sign_key = hmac.new(sign_key, endpoint.encode(), hash_type).digest()
100+
return hmac.new(sign_key, sign_bytes, hash_type).hexdigest()
101+
except ValueError:
102+
raise AuthorizationFailed("Invalid signature") from None
103+
104+
async def authorize(
105+
self,
106+
request: web.Request,
107+
params: Mapping[str, Any],
108+
) -> Any:
109+
root_app = request.app["_root_app"]
110+
db = root_app["_db"]
111+
config_provider = root_app["_config_provider"]
112+
shared_config = await config_provider.legacy_etcd_config_loader.load()
113+
plugin_config = nmget(shared_config, "plugins.webapp.keypair_auth")
114+
auth_token_name = self.plugin_config["auth_token_name"]
115+
116+
try:
117+
body = await request.json()
118+
except Exception:
119+
body = {}
120+
121+
stoken = params[auth_token_name]
122+
if stoken:
123+
secret = plugin_config["secret"]
124+
try:
125+
payload = deserialize_stoken(stoken, secret)
126+
query = sa.select(KeyPairRow).where(KeyPairRow.access_key == payload.access_key)
127+
async with db.begin_readonly_session() as db_session:
128+
keypair_row = await db_session.scalar(query)
129+
user_id = keypair_row.user
130+
131+
except Exception:
132+
try:
133+
result = self.parse_token(stoken)
134+
if not result:
135+
raise Reject("invalid authentication token")
136+
sign_method, access_key, signature = result
137+
138+
async with db.begin() as conn:
139+
query = (
140+
sa.select(keypairs.c.user, keypairs.c.secret_key)
141+
.select_from(keypairs)
142+
.where(keypairs.c.access_key == access_key)
143+
)
144+
result = await conn.execute(query)
145+
keypair = result.fetchone()
146+
147+
sign_params = {
148+
"date": body.get("date"),
149+
"endpoint": body.get("endpoint"),
150+
"api_version": body.get("api_version"),
151+
}
152+
generated_token = await self.sign_token(
153+
sign_method, keypair.secret_key, sign_params
154+
)
155+
if generated_token != signature:
156+
raise Reject("Invalid auth token")
157+
user_id = keypair.user
158+
159+
except Exception as e:
160+
log.error("AUTHORIZE_KEYPAIR_HOOK: invalid auth token {}", stoken)
161+
log.error(repr(e))
162+
raise Reject("Invalid auth token") from None
163+
164+
else:
165+
return None # no-op for normal login
166+
167+
async with db.begin() as conn:
168+
query = sa.select(users).select_from(users).where(users.c.uuid == user_id)
169+
result = await conn.execute(query)
170+
user = result.fetchone()
171+
if not user:
172+
raise Reject("No such user with access key")
173+
if user.status != UserStatus.ACTIVE:
174+
raise Reject("user is inactivated with access key")
175+
return user
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any
2+
3+
import jwt
4+
import jwt.exceptions
5+
from pydantic import BaseModel
6+
7+
from ai.backend.common.utils import nmget
8+
9+
from .exception import ExpiredSToken, InvalidSToken
10+
11+
KEYPAIR_PLUGIN_CONFIG_KEY = "plugins.webapp.keypair_auth"
12+
13+
14+
class STokenData(BaseModel):
15+
access_key: str
16+
secret_key: str
17+
18+
19+
def get_plugin_config(shared_config: dict[str, Any]) -> Any:
20+
return nmget(shared_config, KEYPAIR_PLUGIN_CONFIG_KEY)
21+
22+
23+
def encode_jwt_token(token_data: dict[str, Any], secret: str) -> str:
24+
return jwt.encode(token_data, secret, algorithm="HS256")
25+
26+
27+
def decode_jwt_token(val: str, secret: str) -> dict[str, Any]:
28+
result: dict[str, Any] = jwt.decode(val, secret, algorithms=["HS256"])
29+
return result
30+
31+
32+
def serialize_stoken(data: STokenData, secret: str) -> str:
33+
return encode_jwt_token(data.model_dump(mode="json"), secret=secret)
34+
35+
36+
def deserialize_stoken(val: str, secret: str) -> STokenData:
37+
try:
38+
raw = decode_jwt_token(val, secret=secret)
39+
return STokenData.model_validate(raw)
40+
except jwt.ExpiredSignatureError:
41+
raise ExpiredSToken from None
42+
except (jwt.PyJWTError, KeyError):
43+
raise InvalidSToken from None
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import json
2+
import logging
3+
from collections.abc import Mapping, Sequence
4+
from typing import Any
5+
6+
import aiohttp_cors
7+
import yarl
8+
from aiohttp import web
9+
from pydantic import BaseModel, ValidationError
10+
11+
from ai.backend.common.logging_utils import BraceStyleAdapter
12+
from ai.backend.manager.api.rest.types import CORSOptions, WebMiddleware
13+
from ai.backend.manager.plugin.webapp import WebappPlugin
14+
15+
from .utils import (
16+
STokenData,
17+
get_plugin_config,
18+
serialize_stoken,
19+
)
20+
21+
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
22+
23+
24+
class LoginRequestData(BaseModel):
25+
access_key: str
26+
secret_key: str
27+
28+
29+
async def login(request: web.Request) -> web.Response:
30+
root_app = request.app["_root_app"]
31+
config_provider = root_app["_config_provider"]
32+
shared_config = await config_provider.legacy_etcd_config_loader.load()
33+
plugin_config = get_plugin_config(shared_config)
34+
35+
try:
36+
raw_data = await request.json()
37+
json_data = LoginRequestData(**raw_data)
38+
except (json.decoder.JSONDecodeError, ValidationError, TypeError) as e:
39+
log.warning(
40+
"Invalid login request data: {}",
41+
repr(e),
42+
)
43+
raise web.HTTPBadRequest(reason="Invalid JSON data in request body.") from None
44+
45+
token_secret = plugin_config["secret"]
46+
redirect_uri = yarl.URL(plugin_config["login_uri"])
47+
token = serialize_stoken(
48+
data=STokenData(
49+
access_key=json_data.access_key,
50+
secret_key=json_data.secret_key,
51+
),
52+
secret=token_secret,
53+
)
54+
redirect_location = redirect_uri.update_query({"sToken": token})
55+
return web.HTTPFound(redirect_location)
56+
57+
58+
async def _webapp_init(app: web.Application) -> None:
59+
pass
60+
61+
62+
async def _webapp_shutdown(app: web.Application) -> None:
63+
pass
64+
65+
66+
class KeypairAuthWebAppPlugin(WebappPlugin):
67+
async def init(self, context: Any = None) -> None:
68+
pass
69+
70+
async def cleanup(self) -> None:
71+
pass
72+
73+
async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None:
74+
self.plugin_config = new_plugin_config
75+
76+
async def create_app(
77+
self,
78+
cors_options: CORSOptions,
79+
) -> tuple[web.Application, Sequence[WebMiddleware]]:
80+
app = web.Application()
81+
app["prefix"] = "custom-auth"
82+
app["api_versions"] = (4, 5, 6)
83+
app.on_startup.append(_webapp_init)
84+
app.on_shutdown.append(_webapp_shutdown)
85+
cors = aiohttp_cors.setup(app, defaults=cors_options)
86+
cors.add(app.router.add_route("POST", "/login", login))
87+
return app, []
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
__version__ = "0.0.8"

src/ai/backend/manager/plugin/openid/webapp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
UserSystemRoleSpec,
4040
)
4141

42-
from . import __version__
4342
from .config import OIDCWebAppConfig
4443
from .valkey_client import ValkeyOpenIDClient
4544

@@ -53,7 +52,7 @@ class OpenIDError(Exception):
5352

5453

5554
async def ping(_request: web.Request) -> web.Response:
56-
return web.Response(status=200, body=f"Backend.AI OpenID Connect SSO plugin ({__version__}).")
55+
return web.Response(status=200, body="Backend.AI OpenID Connect SSO plugin.")
5756

5857

5958
def generate_random_string(length: int = 10) -> str:
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
__version__ = "1.0.3"

0 commit comments

Comments
 (0)