diff --git a/changes/10771.feature.md b/changes/10771.feature.md new file mode 100644 index 00000000000..59c8d8a9f86 --- /dev/null +++ b/changes/10771.feature.md @@ -0,0 +1 @@ +Migrate keypair auth plugin into the core repository under `src/ai/backend/manager/plugin/keypair/` diff --git a/src/ai/backend/manager/BUILD b/src/ai/backend/manager/BUILD index a01e0fb8a35..ef49ba0b1f5 100644 --- a/src/ai/backend/manager/BUILD +++ b/src/ai/backend/manager/BUILD @@ -60,10 +60,12 @@ python_distribution( "backendai_hook_v20": { "totp": "ai.backend.manager.plugin.totp.hook:TOTPHook", "openid": "ai.backend.manager.plugin.openid.hook:OIDCHookPlugin", + "auth_keypair": "ai.backend.manager.plugin.keypair.hook:KeypairAuthHookPlugin", }, "backendai_webapp_v20": { "totp": "ai.backend.manager.plugin.totp.webapp:TOTPWebapp", "openid": "ai.backend.manager.plugin.openid.webapp:OIDCWebAppPlugin", + "auth_keypair": "ai.backend.manager.plugin.keypair.webapp:KeypairAuthWebAppPlugin", }, "backendai_network_manager_v1": { "overlay": "ai.backend.manager.network.overlay:OverlayNetworkPlugin", diff --git a/src/ai/backend/manager/plugin/keypair/__init__.py b/src/ai/backend/manager/plugin/keypair/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/plugin/keypair/exception.py b/src/ai/backend/manager/plugin/keypair/exception.py new file mode 100644 index 00000000000..17a31797e07 --- /dev/null +++ b/src/ai/backend/manager/plugin/keypair/exception.py @@ -0,0 +1,10 @@ +class ExternalError(Exception): + pass + + +class InvalidSToken(ExternalError): + pass + + +class ExpiredSToken(ExternalError): + pass diff --git a/src/ai/backend/manager/plugin/keypair/hook.py b/src/ai/backend/manager/plugin/keypair/hook.py new file mode 100644 index 00000000000..69c476f2e27 --- /dev/null +++ b/src/ai/backend/manager/plugin/keypair/hook.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +import re +from collections.abc import Mapping, Sequence +from typing import Any + +import sqlalchemy as sa +import trafaret as t +from aiohttp import web +from dateutil.parser import parse as dateutil_parse + +from ai.backend.common.logging_utils import BraceStyleAdapter +from ai.backend.common.plugin.hook import HookHandler, HookPlugin, Reject +from ai.backend.common.utils import nmget +from ai.backend.manager.errors.auth import AuthorizationFailed, InvalidAuthParameters +from ai.backend.manager.models.keypair import KeyPairRow, keypairs +from ai.backend.manager.models.user import UserStatus, users + +from .utils import deserialize_stoken + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + +plugin_config_checker = t.Dict({ + t.Key("auth_token_name", default="sToken"): t.Null | t.String, +}).allow_extra("*") + + +DEFAULT_STOKEN_COOKIE_VALUE = "BackendAI" + + +class KeypairAuthHookPlugin(HookPlugin): + def __init__(self, plugin_config: Mapping[str, Any], local_config: Mapping[str, Any]) -> None: + super().__init__(plugin_config, local_config) + self.plugin_config = plugin_config_checker.check(self.plugin_config) + + def get_handlers(self) -> Sequence[tuple[str, HookHandler]]: + return [ + ("AUTHORIZE", self.authorize), + ] + + async def update_plugin_config(self, plugin_config: Mapping[str, Any]) -> None: + self.plugin_config = plugin_config + + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + def parse_token(self, token: str) -> tuple[str, str, str] | None: + pattern = r"BackendAI signMethod=(?P[A-Z0-9-]+), credential=(?P\w+):(?P\w+)" + match = re.search(pattern, token) + if match: + sign_method = match.group("sign_method") + access_key = match.group("access_key") + signature = match.group("signature") + return (sign_method, access_key, signature) + return None + + async def sign_token(self, sign_method: str, secret_key: str, params: Mapping[str, Any]) -> str: + try: + mac_type, hash_type = map(lambda s: s.lower(), sign_method.split("-")) + if mac_type != "hmac": + raise InvalidAuthParameters("Unsupported signing method (MAC type)") + if hash_type not in hashlib.algorithms_guaranteed: + raise InvalidAuthParameters("Unsupported signing method (hash type)") + + date_obj = dateutil_parse(params["date"]) + date = date_obj.isoformat() + endpoint = params["endpoint"] + api_version = params["api_version"] + if date is None: + raise InvalidAuthParameters("Request date is missing") + if endpoint is None: + raise InvalidAuthParameters("Request endpoint is missing") + if api_version is None: + raise InvalidAuthParameters("API version is missing") + + body = b"" + body_hash = hashlib.new(hash_type, body).hexdigest() + sign_bytes = ( + "{0}\n{1}\n{2}\nhost:{3}\ncontent-type:{4}\nx-{name}-version:{5}\n{6}".format( + "POST", + "/authorize/keypair", + date, + endpoint, + "application/json", + api_version, + body_hash, + name="backendai", + ) + ).encode() + sign_key = hmac.new( + secret_key.encode(), date_obj.strftime("%Y%m%d").encode(), hash_type + ).digest() + sign_key = hmac.new(sign_key, endpoint.encode(), hash_type).digest() + return hmac.new(sign_key, sign_bytes, hash_type).hexdigest() + except ValueError: + raise AuthorizationFailed("Invalid signature") from None + + async def authorize( + self, + request: web.Request, + params: Mapping[str, Any], + ) -> Any: + root_app = request.app["_root_app"] + db = root_app["_db"] + config_provider = root_app["_config_provider"] + shared_config = await config_provider.legacy_etcd_config_loader.load() + plugin_config = nmget(shared_config, "plugins.webapp.keypair_auth") + auth_token_name = self.plugin_config["auth_token_name"] + + try: + body = await request.json() + except Exception: + body = {} + + stoken = params[auth_token_name] + if stoken: + secret = plugin_config["secret"] + try: + payload = deserialize_stoken(stoken, secret) + query = sa.select(KeyPairRow).where(KeyPairRow.access_key == payload.access_key) + async with db.begin_readonly_session() as db_session: + keypair_row = await db_session.scalar(query) + user_id = keypair_row.user + + except Exception: + try: + result = self.parse_token(stoken) + if not result: + raise Reject("invalid authentication token") + sign_method, access_key, signature = result + + async with db.begin() as conn: + query = ( + sa.select(keypairs.c.user, keypairs.c.secret_key) + .select_from(keypairs) + .where(keypairs.c.access_key == access_key) + ) + result = await conn.execute(query) + keypair = result.fetchone() + + sign_params = { + "date": body.get("date"), + "endpoint": body.get("endpoint"), + "api_version": body.get("api_version"), + } + generated_token = await self.sign_token( + sign_method, keypair.secret_key, sign_params + ) + if generated_token != signature: + raise Reject("Invalid auth token") + user_id = keypair.user + + except Exception as e: + log.error("AUTHORIZE_KEYPAIR_HOOK: invalid auth token {}", stoken) + log.error(repr(e)) + raise Reject("Invalid auth token") from None + + else: + return None # no-op for normal login + + async with db.begin() as conn: + query = sa.select(users).select_from(users).where(users.c.uuid == user_id) + result = await conn.execute(query) + user = result.fetchone() + if not user: + raise Reject("No such user with access key") + if user.status != UserStatus.ACTIVE: + raise Reject("user is inactivated with access key") + return user diff --git a/src/ai/backend/manager/plugin/keypair/utils.py b/src/ai/backend/manager/plugin/keypair/utils.py new file mode 100644 index 00000000000..e37fc75db96 --- /dev/null +++ b/src/ai/backend/manager/plugin/keypair/utils.py @@ -0,0 +1,43 @@ +from typing import Any + +import jwt +import jwt.exceptions +from pydantic import BaseModel + +from ai.backend.common.utils import nmget + +from .exception import ExpiredSToken, InvalidSToken + +KEYPAIR_PLUGIN_CONFIG_KEY = "plugins.webapp.keypair_auth" + + +class STokenData(BaseModel): + access_key: str + secret_key: str + + +def get_plugin_config(shared_config: dict[str, Any]) -> Any: + return nmget(shared_config, KEYPAIR_PLUGIN_CONFIG_KEY) + + +def encode_jwt_token(token_data: dict[str, Any], secret: str) -> str: + return jwt.encode(token_data, secret, algorithm="HS256") + + +def decode_jwt_token(val: str, secret: str) -> dict[str, Any]: + result: dict[str, Any] = jwt.decode(val, secret, algorithms=["HS256"]) + return result + + +def serialize_stoken(data: STokenData, secret: str) -> str: + return encode_jwt_token(data.model_dump(mode="json"), secret=secret) + + +def deserialize_stoken(val: str, secret: str) -> STokenData: + try: + raw = decode_jwt_token(val, secret=secret) + return STokenData.model_validate(raw) + except jwt.ExpiredSignatureError: + raise ExpiredSToken from None + except (jwt.PyJWTError, KeyError): + raise InvalidSToken from None diff --git a/src/ai/backend/manager/plugin/keypair/webapp.py b/src/ai/backend/manager/plugin/keypair/webapp.py new file mode 100644 index 00000000000..d2051bfb1e2 --- /dev/null +++ b/src/ai/backend/manager/plugin/keypair/webapp.py @@ -0,0 +1,87 @@ +import json +import logging +from collections.abc import Mapping, Sequence +from typing import Any + +import aiohttp_cors +import yarl +from aiohttp import web +from pydantic import BaseModel, ValidationError + +from ai.backend.common.logging_utils import BraceStyleAdapter +from ai.backend.manager.api.rest.types import CORSOptions, WebMiddleware +from ai.backend.manager.plugin.webapp import WebappPlugin + +from .utils import ( + STokenData, + get_plugin_config, + serialize_stoken, +) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +class LoginRequestData(BaseModel): + access_key: str + secret_key: str + + +async def login(request: web.Request) -> web.Response: + root_app = request.app["_root_app"] + config_provider = root_app["_config_provider"] + shared_config = await config_provider.legacy_etcd_config_loader.load() + plugin_config = get_plugin_config(shared_config) + + try: + raw_data = await request.json() + json_data = LoginRequestData(**raw_data) + except (json.decoder.JSONDecodeError, ValidationError, TypeError) as e: + log.warning( + "Invalid login request data: {}", + repr(e), + ) + raise web.HTTPBadRequest(reason="Invalid JSON data in request body.") from None + + token_secret = plugin_config["secret"] + redirect_uri = yarl.URL(plugin_config["login_uri"]) + token = serialize_stoken( + data=STokenData( + access_key=json_data.access_key, + secret_key=json_data.secret_key, + ), + secret=token_secret, + ) + redirect_location = redirect_uri.update_query({"sToken": token}) + return web.HTTPFound(redirect_location) + + +async def _webapp_init(app: web.Application) -> None: + pass + + +async def _webapp_shutdown(app: web.Application) -> None: + pass + + +class KeypairAuthWebAppPlugin(WebappPlugin): + async def init(self, context: Any = None) -> None: + pass + + async def cleanup(self) -> None: + pass + + async def update_plugin_config(self, new_plugin_config: Mapping[str, Any]) -> None: + self.plugin_config = new_plugin_config + + async def create_app( + self, + cors_options: CORSOptions, + ) -> tuple[web.Application, Sequence[WebMiddleware]]: + app = web.Application() + app["prefix"] = "custom-auth" + app["api_versions"] = (4, 5, 6) + app.on_startup.append(_webapp_init) + app.on_shutdown.append(_webapp_shutdown) + cors = aiohttp_cors.setup(app, defaults=cors_options) + cors.add(app.router.add_route("POST", "/login", login)) + return app, [] diff --git a/src/ai/backend/manager/plugin/openid/__init__.py b/src/ai/backend/manager/plugin/openid/__init__.py index a73339bf813..e69de29bb2d 100644 --- a/src/ai/backend/manager/plugin/openid/__init__.py +++ b/src/ai/backend/manager/plugin/openid/__init__.py @@ -1 +0,0 @@ -__version__ = "0.0.8" diff --git a/src/ai/backend/manager/plugin/openid/webapp.py b/src/ai/backend/manager/plugin/openid/webapp.py index 0eb71f60324..164f9057259 100644 --- a/src/ai/backend/manager/plugin/openid/webapp.py +++ b/src/ai/backend/manager/plugin/openid/webapp.py @@ -39,7 +39,6 @@ UserSystemRoleSpec, ) -from . import __version__ from .config import OIDCWebAppConfig from .valkey_client import ValkeyOpenIDClient @@ -53,7 +52,7 @@ class OpenIDError(Exception): async def ping(_request: web.Request) -> web.Response: - return web.Response(status=200, body=f"Backend.AI OpenID Connect SSO plugin ({__version__}).") + return web.Response(status=200, body="Backend.AI OpenID Connect SSO plugin.") def generate_random_string(length: int = 10) -> str: diff --git a/src/ai/backend/manager/plugin/totp/__init__.py b/src/ai/backend/manager/plugin/totp/__init__.py index 976498ab9ca..e69de29bb2d 100644 --- a/src/ai/backend/manager/plugin/totp/__init__.py +++ b/src/ai/backend/manager/plugin/totp/__init__.py @@ -1 +0,0 @@ -__version__ = "1.0.3"