Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/10771.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Migrate keypair auth plugin into the core repository under `src/ai/backend/manager/plugin/keypair/`
2 changes: 2 additions & 0 deletions src/ai/backend/manager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ python_distribution(
"backendai_hook_v20": {
"totp": "ai.backend.manager.plugin.totp.hook:TOTPHook",
"openid": "ai.backend.manager.plugin.openid.hook:OIDCHookPlugin",
"auth_keypair": "ai.backend.manager.plugin.keypair.hook:KeypairAuthHookPlugin",
},
"backendai_webapp_v20": {
"totp": "ai.backend.manager.plugin.totp.webapp:TOTPWebapp",
"openid": "ai.backend.manager.plugin.openid.webapp:OIDCWebAppPlugin",
"auth_keypair": "ai.backend.manager.plugin.keypair.webapp:KeypairAuthWebAppPlugin",
},
"backendai_network_manager_v1": {
"overlay": "ai.backend.manager.network.overlay:OverlayNetworkPlugin",
Expand Down
Empty file.
10 changes: 10 additions & 0 deletions src/ai/backend/manager/plugin/keypair/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class ExternalError(Exception):
pass
Comment thread
fregataa marked this conversation as resolved.


class InvalidSToken(ExternalError):
pass


class ExpiredSToken(ExternalError):
pass
175 changes: 175 additions & 0 deletions src/ai/backend/manager/plugin/keypair/hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations

import hashlib
import hmac
import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any

import sqlalchemy as sa
import trafaret as t
from aiohttp import web
from dateutil.parser import parse as dateutil_parse

from ai.backend.common.logging_utils import BraceStyleAdapter
from ai.backend.common.plugin.hook import HookHandler, HookPlugin, Reject
from ai.backend.common.utils import nmget
from ai.backend.manager.errors.auth import AuthorizationFailed, InvalidAuthParameters
from ai.backend.manager.models.keypair import KeyPairRow, keypairs
from ai.backend.manager.models.user import UserStatus, users

from .utils import deserialize_stoken

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

plugin_config_checker = t.Dict({
t.Key("auth_token_name", default="sToken"): t.Null | t.String,
}).allow_extra("*")


DEFAULT_STOKEN_COOKIE_VALUE = "BackendAI"


Comment thread
fregataa marked this conversation as resolved.
class KeypairAuthHookPlugin(HookPlugin):
def __init__(self, plugin_config: Mapping[str, Any], local_config: Mapping[str, Any]) -> None:
super().__init__(plugin_config, local_config)
self.plugin_config = plugin_config_checker.check(self.plugin_config)

def get_handlers(self) -> Sequence[tuple[str, HookHandler]]:
return [
("AUTHORIZE", self.authorize),
]

async def update_plugin_config(self, plugin_config: Mapping[str, Any]) -> None:
self.plugin_config = plugin_config
Comment thread
fregataa marked this conversation as resolved.

async def init(self, context: Any = None) -> None:
pass

async def cleanup(self) -> None:
pass

def parse_token(self, token: str) -> tuple[str, str, str] | None:
pattern = r"BackendAI signMethod=(?P<sign_method>[A-Z0-9-]+), credential=(?P<access_key>\w+):(?P<signature>\w+)"
match = re.search(pattern, token)
if match:
sign_method = match.group("sign_method")
access_key = match.group("access_key")
signature = match.group("signature")
return (sign_method, access_key, signature)
return None

async def sign_token(self, sign_method: str, secret_key: str, params: Mapping[str, Any]) -> str:
try:
mac_type, hash_type = map(lambda s: s.lower(), sign_method.split("-"))
if mac_type != "hmac":
raise InvalidAuthParameters("Unsupported signing method (MAC type)")
if hash_type not in hashlib.algorithms_guaranteed:
raise InvalidAuthParameters("Unsupported signing method (hash type)")

date_obj = dateutil_parse(params["date"])
date = date_obj.isoformat()
endpoint = params["endpoint"]
api_version = params["api_version"]
if date is None:
raise InvalidAuthParameters("Request date is missing")
if endpoint is None:
raise InvalidAuthParameters("Request endpoint is missing")
if api_version is None:
raise InvalidAuthParameters("API version is missing")

body = b""
body_hash = hashlib.new(hash_type, body).hexdigest()
sign_bytes = (
"{0}\n{1}\n{2}\nhost:{3}\ncontent-type:{4}\nx-{name}-version:{5}\n{6}".format(
"POST",
"/authorize/keypair",
date,
endpoint,
"application/json",
api_version,
body_hash,
name="backendai",
)
).encode()
sign_key = hmac.new(
secret_key.encode(), date_obj.strftime("%Y%m%d").encode(), hash_type
).digest()
sign_key = hmac.new(sign_key, endpoint.encode(), hash_type).digest()
return hmac.new(sign_key, sign_bytes, hash_type).hexdigest()
except ValueError:
raise AuthorizationFailed("Invalid signature") from None

async def authorize(
self,
request: web.Request,
params: Mapping[str, Any],
) -> Any:
root_app = request.app["_root_app"]
db = root_app["_db"]
config_provider = root_app["_config_provider"]
shared_config = await config_provider.legacy_etcd_config_loader.load()
plugin_config = nmget(shared_config, "plugins.webapp.keypair_auth")
auth_token_name = self.plugin_config["auth_token_name"]

try:
body = await request.json()
except Exception:
body = {}

stoken = params[auth_token_name]
if stoken:
secret = plugin_config["secret"]
try:
payload = deserialize_stoken(stoken, secret)
query = sa.select(KeyPairRow).where(KeyPairRow.access_key == payload.access_key)
async with db.begin_readonly_session() as db_session:
keypair_row = await db_session.scalar(query)
user_id = keypair_row.user

Comment thread
fregataa marked this conversation as resolved.
except Exception:
try:
result = self.parse_token(stoken)
if not result:
raise Reject("invalid authentication token")
sign_method, access_key, signature = result

async with db.begin() as conn:
query = (
sa.select(keypairs.c.user, keypairs.c.secret_key)
.select_from(keypairs)
.where(keypairs.c.access_key == access_key)
)
result = await conn.execute(query)
keypair = result.fetchone()

sign_params = {
"date": body.get("date"),
"endpoint": body.get("endpoint"),
"api_version": body.get("api_version"),
}
generated_token = await self.sign_token(
sign_method, keypair.secret_key, sign_params
)
if generated_token != signature:
Comment thread
fregataa marked this conversation as resolved.
raise Reject("Invalid auth token")
user_id = keypair.user

except Exception as e:
log.error("AUTHORIZE_KEYPAIR_HOOK: invalid auth token {}", stoken)
log.error(repr(e))
raise Reject("Invalid auth token") from None
Comment thread
fregataa marked this conversation as resolved.

else:
return None # no-op for normal login

async with db.begin() as conn:
query = sa.select(users).select_from(users).where(users.c.uuid == user_id)
result = await conn.execute(query)
user = result.fetchone()
if not user:
raise Reject("No such user with access key")
if user.status != UserStatus.ACTIVE:
raise Reject("user is inactivated with access key")
return user
43 changes: 43 additions & 0 deletions src/ai/backend/manager/plugin/keypair/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any

import jwt
import jwt.exceptions
from pydantic import BaseModel

from ai.backend.common.utils import nmget

from .exception import ExpiredSToken, InvalidSToken

KEYPAIR_PLUGIN_CONFIG_KEY = "plugins.webapp.keypair_auth"


class STokenData(BaseModel):
access_key: str
secret_key: str


def get_plugin_config(shared_config: dict[str, Any]) -> Any:
return nmget(shared_config, KEYPAIR_PLUGIN_CONFIG_KEY)


def encode_jwt_token(token_data: dict[str, Any], secret: str) -> str:
return jwt.encode(token_data, secret, algorithm="HS256")


def decode_jwt_token(val: str, secret: str) -> dict[str, Any]:
result: dict[str, Any] = jwt.decode(val, secret, algorithms=["HS256"])
return result


def serialize_stoken(data: STokenData, secret: str) -> str:
return encode_jwt_token(data.model_dump(mode="json"), secret=secret)


def deserialize_stoken(val: str, secret: str) -> STokenData:
try:
raw = decode_jwt_token(val, secret=secret)
return STokenData.model_validate(raw)
except jwt.ExpiredSignatureError:
raise ExpiredSToken from None
except (jwt.PyJWTError, KeyError):
raise InvalidSToken from None
Comment thread
fregataa marked this conversation as resolved.
87 changes: 87 additions & 0 deletions src/ai/backend/manager/plugin/keypair/webapp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
import logging
from collections.abc import Mapping, Sequence
from typing import Any

import aiohttp_cors
import yarl
from aiohttp import web
from pydantic import BaseModel, ValidationError

from ai.backend.common.logging_utils import BraceStyleAdapter
from ai.backend.manager.api.rest.types import CORSOptions, WebMiddleware
from ai.backend.manager.plugin.webapp import WebappPlugin

from .utils import (
STokenData,
get_plugin_config,
serialize_stoken,
)

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class LoginRequestData(BaseModel):
access_key: str
secret_key: str


async def login(request: web.Request) -> web.Response:
root_app = request.app["_root_app"]
config_provider = root_app["_config_provider"]
shared_config = await config_provider.legacy_etcd_config_loader.load()
plugin_config = get_plugin_config(shared_config)

Comment thread
fregataa marked this conversation as resolved.
try:
raw_data = await request.json()
json_data = LoginRequestData(**raw_data)
except (json.decoder.JSONDecodeError, ValidationError, TypeError) as e:
log.warning(
"Invalid login request data: {}",
repr(e),
)
raise web.HTTPBadRequest(reason="Invalid JSON data in request body.") from None
Comment thread
fregataa marked this conversation as resolved.

token_secret = plugin_config["secret"]
redirect_uri = yarl.URL(plugin_config["login_uri"])
token = serialize_stoken(
data=STokenData(
access_key=json_data.access_key,
secret_key=json_data.secret_key,
),
secret=token_secret,
)
redirect_location = redirect_uri.update_query({"sToken": token})
Comment thread
fregataa marked this conversation as resolved.
return web.HTTPFound(redirect_location)
Comment thread
fregataa marked this conversation as resolved.


async def _webapp_init(app: web.Application) -> None:
pass


async def _webapp_shutdown(app: web.Application) -> None:
pass


class KeypairAuthWebAppPlugin(WebappPlugin):
async def init(self, context: Any = None) -> None:
pass

async def cleanup(self) -> None:
pass

async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None:
self.plugin_config = new_plugin_config

async def create_app(
self,
cors_options: CORSOptions,
) -> tuple[web.Application, Sequence[WebMiddleware]]:
app = web.Application()
app["prefix"] = "custom-auth"
app["api_versions"] = (4, 5, 6)
app.on_startup.append(_webapp_init)
app.on_shutdown.append(_webapp_shutdown)
cors = aiohttp_cors.setup(app, defaults=cors_options)
cors.add(app.router.add_route("POST", "/login", login))
return app, []
Comment thread
fregataa marked this conversation as resolved.
1 change: 0 additions & 1 deletion src/ai/backend/manager/plugin/openid/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
__version__ = "0.0.8"
3 changes: 1 addition & 2 deletions src/ai/backend/manager/plugin/openid/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
UserSystemRoleSpec,
)

from . import __version__
from .config import OIDCWebAppConfig
from .valkey_client import ValkeyOpenIDClient

Expand All @@ -53,7 +52,7 @@ class OpenIDError(Exception):


async def ping(_request: web.Request) -> web.Response:
return web.Response(status=200, body=f"Backend.AI OpenID Connect SSO plugin ({__version__}).")
return web.Response(status=200, body="Backend.AI OpenID Connect SSO plugin.")


def generate_random_string(length: int = 10) -> str:
Expand Down
1 change: 0 additions & 1 deletion src/ai/backend/manager/plugin/totp/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
__version__ = "1.0.3"
Comment thread
jopemachine marked this conversation as resolved.
Loading