diff --git a/changes/10899.feature.md b/changes/10899.feature.md new file mode 100644 index 00000000000..35b132a505e --- /dev/null +++ b/changes/10899.feature.md @@ -0,0 +1 @@ +Add RS256 asymmetric signing and JWKS fetching/caching support to the `common/jwt` library for OAuth2/OIDC token validation. diff --git a/src/ai/backend/common/configs/jwt.py b/src/ai/backend/common/configs/jwt.py index 5ca655a6156..561c46c0c57 100644 --- a/src/ai/backend/common/configs/jwt.py +++ b/src/ai/backend/common/configs/jwt.py @@ -11,7 +11,12 @@ from pydantic import AliasChoices, Field from ai.backend.common.config import BaseConfigSchema -from ai.backend.common.jwt.config import JWTConfig as CoreJWTConfig +from ai.backend.common.jwt.config import ( + JWTAlgorithm as CoreJWTAlgorithm, +) +from ai.backend.common.jwt.config import ( + JWTConfig as CoreJWTConfig, +) from ai.backend.common.meta import BackendAIConfigMeta, ConfigExample @@ -58,6 +63,6 @@ class SharedJWTConfig(BaseConfigSchema): def to_jwt_config(self) -> CoreJWTConfig: """Convert to ai.backend.common.jwt.config.JWTConfig.""" return CoreJWTConfig( - algorithm=self.algorithm, + algorithm=CoreJWTAlgorithm(self.algorithm), token_expiration_seconds=self.token_expiration_seconds, ) diff --git a/src/ai/backend/common/jwt/__init__.py b/src/ai/backend/common/jwt/__init__.py index c5fb3536d85..80ef1e8f190 100644 --- a/src/ai/backend/common/jwt/__init__.py +++ b/src/ai/backend/common/jwt/__init__.py @@ -5,59 +5,68 @@ Hive Router. It uses the X-BackendAI-Token custom header to avoid conflicts with existing Bearer token usage in appproxy. +Supports both HS256 (symmetric, per-user secret keys) and RS256 (asymmetric, +RSA key pairs) signing algorithms, with JWKS utilities for distributed key +management. + Key components: - JWTSigner: Generates JWT tokens from authenticated user context (webserver) - JWTValidator: Validates JWT tokens and extracts user claims (manager) - JWTConfig: Configuration for JWT authentication - JWTClaims: Dataclass representing JWT payload claims +- JWKSKeySet: Public key set indexed by key ID for RS256 validation +- JWKSFetcher: Async JWKS endpoint fetcher with TTL caching +- Key utilities: RSA key generation, loading, serialization, and JWK conversion -Example usage in webserver: +Example usage (HS256): from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext - config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"]) + config = JWTConfig() signer = JWTSigner(config) user_context = JWTUserContext( - user_id=user_uuid, access_key=access_key, role="user", - domain_name="default", - is_admin=False, - is_superadmin=False, ) - token = signer.generate_token(user_context) - - # Add to request headers - headers["X-BackendAI-Token"] = token - -Example usage in manager: - from ai.backend.common.jwt import JWTValidator, JWTConfig - - config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"]) - validator = JWTValidator(config) + token = signer.generate_token(user_context, secret_key) - token = request.headers.get("X-BackendAI-Token") - claims = validator.validate_token(token) +Example usage (RS256): + from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext + from ai.backend.common.jwt.keys import load_private_key - # Use claims for authentication - user_id = claims.sub - access_key = claims.access_key + config = JWTConfig(algorithm="RS256") + signer = JWTSigner(config) + private_key = load_private_key(Path("/path/to/private.pem")) + token = signer.generate_token(user_context, private_key=private_key, kid="key-1") """ -from .config import JWTConfig -from .exceptions import ( +from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig +from ai.backend.common.jwt.exceptions import ( + JWKSError, + JWKSFetchError, + JWKSKeyNotFoundError, JWTDecodeError, JWTError, JWTExpiredError, JWTInvalidClaimsError, JWTInvalidSignatureError, ) -from .signer import JWTSigner -from .types import JWTClaims, JWTUserContext -from .validator import JWTValidator +from ai.backend.common.jwt.jwks import JWKSFetcher, JWKSKeySet +from ai.backend.common.jwt.keys import ( + generate_rsa_key_pair, + load_private_key, + load_public_key, + private_key_to_pem, + public_key_to_jwk, + public_key_to_pem, +) +from ai.backend.common.jwt.signer import JWTSigner +from ai.backend.common.jwt.types import JWTClaims, JWTUserContext +from ai.backend.common.jwt.validator import JWTValidator __all__ = [ # Configuration + "JWTAlgorithm", "JWTConfig", # Types "JWTClaims", @@ -65,10 +74,23 @@ # Core classes "JWTSigner", "JWTValidator", + # JWKS + "JWKSKeySet", + "JWKSFetcher", + # Key management + "generate_rsa_key_pair", + "load_private_key", + "load_public_key", + "private_key_to_pem", + "public_key_to_pem", + "public_key_to_jwk", # Exceptions "JWTError", "JWTExpiredError", "JWTInvalidSignatureError", "JWTInvalidClaimsError", "JWTDecodeError", + "JWKSError", + "JWKSFetchError", + "JWKSKeyNotFoundError", ] diff --git a/src/ai/backend/common/jwt/config.py b/src/ai/backend/common/jwt/config.py index e2c8e94bc32..0bc273733a6 100644 --- a/src/ai/backend/common/jwt/config.py +++ b/src/ai/backend/common/jwt/config.py @@ -3,12 +3,21 @@ from __future__ import annotations from datetime import timedelta +from enum import StrEnum +from pathlib import Path from pydantic import Field from ai.backend.common.config import BaseConfigSchema +class JWTAlgorithm(StrEnum): + """Supported JWT signing algorithms.""" + + HS256 = "HS256" + RS256 = "RS256" + + class JWTConfig(BaseConfigSchema): """ Configuration for JWT-based authentication in GraphQL Federation. @@ -16,15 +25,17 @@ class JWTConfig(BaseConfigSchema): This configuration must be consistent between webserver (which generates tokens) and manager (which validates tokens). - Note: JWT tokens are signed using per-user secret keys (from keypair table), - not a shared system secret key. This maintains the same security model as HMAC authentication. + Supports both HS256 (symmetric, per-user secret keys) and RS256 (asymmetric, + RSA key pairs) signing algorithms. JWT tokens are transmitted via X-BackendAI-Token HTTP header. Attributes: enabled: Whether JWT authentication is enabled - algorithm: JWT signing algorithm (must be HS256) + algorithm: JWT signing algorithm (HS256 or RS256) token_expiration_seconds: Token validity duration in seconds + private_key_path: Path to PEM-encoded RSA private key (RS256 only) + public_key_path: Path to PEM-encoded RSA public key (RS256 only) """ enabled: bool = Field( @@ -32,9 +43,9 @@ class JWTConfig(BaseConfigSchema): description="Enable JWT authentication for GraphQL Federation requests", ) - algorithm: str = Field( - default="HS256", - description="JWT signing algorithm (only HS256 is supported)", + algorithm: JWTAlgorithm = Field( + default=JWTAlgorithm.HS256, + description="JWT signing algorithm (HS256 or RS256)", ) token_expiration_seconds: int = Field( @@ -42,6 +53,16 @@ class JWTConfig(BaseConfigSchema): description="JWT token expiration time in seconds (default: 900 = 15 minutes)", ) + private_key_path: Path | None = Field( + default=None, + description="Path to PEM-encoded RSA private key file (required for RS256 signing)", + ) + + public_key_path: Path | None = Field( + default=None, + description="Path to PEM-encoded RSA public key file (required for RS256 validation)", + ) + @property def token_expiration(self) -> timedelta: """ diff --git a/src/ai/backend/common/jwt/exceptions.py b/src/ai/backend/common/jwt/exceptions.py index e81082115b0..2b19113e949 100644 --- a/src/ai/backend/common/jwt/exceptions.py +++ b/src/ai/backend/common/jwt/exceptions.py @@ -82,3 +82,37 @@ class JWTDecodeError(JWTError, web.HTTPUnauthorized): error_type = "https://api.backend.ai/probs/jwt-decode-error" error_title = "Failed to decode JWT token." + + +class JWKSError(JWTError): + """ + Base exception for JWKS-related errors. + + All JWKS-specific exceptions inherit from this base class. + """ + + error_type = "https://api.backend.ai/probs/jwks-error" + error_title = "JWKS error." + + +class JWKSFetchError(JWKSError, web.HTTPUnauthorized): + """ + Failed to fetch JWKS from the remote endpoint. + + Raised when the JWKS endpoint is unreachable or returns invalid data. + """ + + error_type = "https://api.backend.ai/probs/jwks-fetch-error" + error_title = "Failed to fetch JWKS." + + +class JWKSKeyNotFoundError(JWKSError, web.HTTPUnauthorized): + """ + Key ID (kid) not found in the JWKS key set. + + Raised when a token references a key ID that is not present + in the available JWKS key set. + """ + + error_type = "https://api.backend.ai/probs/jwks-key-not-found" + error_title = "Key ID not found in JWKS." diff --git a/src/ai/backend/common/jwt/jwks.py b/src/ai/backend/common/jwt/jwks.py new file mode 100644 index 00000000000..c4c5c03904f --- /dev/null +++ b/src/ai/backend/common/jwt/jwks.py @@ -0,0 +1,197 @@ +"""JWKS (JSON Web Key Set) fetching and caching utilities.""" + +from __future__ import annotations + +import base64 +import time +from typing import Any + +import aiohttp +from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPublicKey, + RSAPublicNumbers, +) + +from ai.backend.common.jwt.exceptions import JWKSFetchError, JWKSKeyNotFoundError + + +def _base64url_to_int(value: str) -> int: + """ + Decode a base64url-encoded string (no padding) to an integer. + + Args: + value: Base64url-encoded string + + Returns: + Decoded integer value + """ + padding = 4 - len(value) % 4 + if padding != 4: + value += "=" * padding + decoded = base64.urlsafe_b64decode(value) + return int.from_bytes(decoded, byteorder="big") + + +class JWKSKeySet: + """ + A set of RSA public keys indexed by key ID (kid). + + This class parses a JWKS (JSON Web Key Set) response and provides + key lookup by kid. + """ + + _keys: dict[str, RSAPublicKey] + + def __init__(self, keys: dict[str, RSAPublicKey]) -> None: + self._keys = keys + + @classmethod + def from_jwks_dict(cls, data: dict[str, Any]) -> JWKSKeySet: + """ + Parse a JWKS JSON response into a JWKSKeySet. + + Only RSA keys with ``"use": "sig"`` and a ``kid`` field are included. + + Args: + data: JWKS JSON dict with a ``keys`` array + + Returns: + JWKSKeySet instance containing the parsed public keys + """ + keys: dict[str, RSAPublicKey] = {} + for jwk in data.get("keys", []): + if jwk.get("kty") != "RSA": + continue + kid = jwk.get("kid") + if kid is None: + continue + n = _base64url_to_int(jwk["n"]) + e = _base64url_to_int(jwk["e"]) + public_numbers = RSAPublicNumbers(e=e, n=n) + public_key = public_numbers.public_key() + keys[kid] = public_key + return cls(keys) + + def get_key(self, kid: str) -> RSAPublicKey: + """ + Look up an RSA public key by key ID. + + Args: + kid: Key ID to look up + + Returns: + RSA public key corresponding to the given kid + + Raises: + JWKSKeyNotFoundError: If the kid is not found in the key set + """ + key = self._keys.get(kid) + if key is None: + raise JWKSKeyNotFoundError(f"Key ID '{kid}' not found in JWKS key set") + return key + + @property + def kids(self) -> list[str]: + """Return a list of all key IDs in this key set.""" + return list(self._keys.keys()) + + +class JWKSFetcher: + """ + Async JWKS fetcher with TTL-based caching. + + Fetches a JWKS endpoint and caches the result for a configurable duration. + Thread-safe for concurrent async access. + + Usage: + fetcher = JWKSFetcher(url="https://example.com/.well-known/jwks.json") + public_key = await fetcher.get_key("my-key-id") + """ + + _url: str + _cache_ttl: float + _cached_key_set: JWKSKeySet | None + _last_fetch_time: float + + def __init__(self, url: str, cache_ttl: float = 300.0) -> None: + """ + Initialize the JWKS fetcher. + + Args: + url: URL of the JWKS endpoint + cache_ttl: Cache time-to-live in seconds (default: 300 = 5 minutes) + """ + self._url = url + self._cache_ttl = cache_ttl + self._cached_key_set = None + self._last_fetch_time = 0.0 + + async def get_key(self, kid: str) -> RSAPublicKey: + """ + Get an RSA public key by key ID, fetching JWKS if cache is expired. + + Args: + kid: Key ID to look up + + Returns: + RSA public key corresponding to the given kid + + Raises: + JWKSFetchError: If the JWKS endpoint cannot be reached + JWKSKeyNotFoundError: If the kid is not found in the key set + """ + key_set = await self._get_key_set() + return key_set.get_key(kid) + + async def refresh(self) -> JWKSKeySet: + """ + Force refresh the cached JWKS key set. + + Returns: + The freshly fetched JWKSKeySet + + Raises: + JWKSFetchError: If the JWKS endpoint cannot be reached + """ + return await self._fetch_jwks() + + async def _get_key_set(self) -> JWKSKeySet: + """ + Get the cached key set, refreshing if expired. + + Returns: + The current JWKSKeySet + + Raises: + JWKSFetchError: If the JWKS endpoint cannot be reached + """ + now = time.monotonic() + if self._cached_key_set is not None and (now - self._last_fetch_time) < self._cache_ttl: + return self._cached_key_set + return await self._fetch_jwks() + + async def _fetch_jwks(self) -> JWKSKeySet: + """ + Fetch the JWKS endpoint and update the cache. + + Returns: + The fetched JWKSKeySet + + Raises: + JWKSFetchError: If the JWKS endpoint cannot be reached or returns invalid data + """ + try: + async with aiohttp.ClientSession() as session: + async with session.get(self._url) as response: + if response.status != 200: + raise JWKSFetchError(f"JWKS endpoint returned HTTP {response.status}") + data = await response.json() + except JWKSFetchError: + raise + except Exception as e: + raise JWKSFetchError(f"Failed to fetch JWKS from {self._url}: {e}") from e + + key_set = JWKSKeySet.from_jwks_dict(data) + self._cached_key_set = key_set + self._last_fetch_time = time.monotonic() + return key_set diff --git a/src/ai/backend/common/jwt/keys.py b/src/ai/backend/common/jwt/keys.py new file mode 100644 index 00000000000..08085c60e27 --- /dev/null +++ b/src/ai/backend/common/jwt/keys.py @@ -0,0 +1,137 @@ +"""RSA key management utilities for JWT RS256 signing.""" + +from __future__ import annotations + +import base64 +from pathlib import Path + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPublicKey, +) + + +def generate_rsa_key_pair( + key_size: int = 2048, +) -> tuple[RSAPrivateKey, RSAPublicKey]: + """ + Generate a new RSA key pair for JWT RS256 signing. + + Args: + key_size: RSA key size in bits (default: 2048) + + Returns: + Tuple of (private_key, public_key) RSA key objects + """ + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=key_size, + ) + public_key = private_key.public_key() + return private_key, public_key + + +def load_private_key(path: Path) -> RSAPrivateKey: + """ + Load a PEM-encoded RSA private key from a file. + + Args: + path: Path to the PEM-encoded private key file + + Returns: + RSA private key object + """ + key_data = path.read_bytes() + private_key = serialization.load_pem_private_key(key_data, password=None) + if not isinstance(private_key, RSAPrivateKey): + raise TypeError(f"Expected RSA private key, got {type(private_key).__name__}") + return private_key + + +def load_public_key(path: Path) -> RSAPublicKey: + """ + Load a PEM-encoded RSA public key from a file. + + Args: + path: Path to the PEM-encoded public key file + + Returns: + RSA public key object + """ + key_data = path.read_bytes() + public_key = serialization.load_pem_public_key(key_data) + if not isinstance(public_key, RSAPublicKey): + raise TypeError(f"Expected RSA public key, got {type(public_key).__name__}") + return public_key + + +def private_key_to_pem(key: RSAPrivateKey) -> bytes: + """ + Serialize an RSA private key to PEM format. + + Args: + key: RSA private key object + + Returns: + PEM-encoded private key bytes + """ + return key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +def public_key_to_pem(key: RSAPublicKey) -> bytes: + """ + Serialize an RSA public key to PEM format. + + Args: + key: RSA public key object + + Returns: + PEM-encoded public key bytes + """ + return key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +def _int_to_base64url(value: int) -> str: + """ + Encode an integer as a base64url string (no padding) for JWK format. + + Args: + value: Integer value to encode + + Returns: + Base64url-encoded string + """ + byte_length = (value.bit_length() + 7) // 8 + value_bytes = value.to_bytes(byte_length, byteorder="big") + return base64.urlsafe_b64encode(value_bytes).rstrip(b"=").decode("ascii") + + +def public_key_to_jwk(key: RSAPublicKey, kid: str) -> dict[str, str]: + """ + Convert an RSA public key to JWK (JSON Web Key) format. + + Args: + key: RSA public key object + kid: Key ID to include in the JWK + + Returns: + Dictionary in JWK format with kty, n, e, kid, use, and alg fields + """ + public_numbers = key.public_numbers() + return { + "kty": "RSA", + "n": _int_to_base64url(public_numbers.n), + "e": _int_to_base64url(public_numbers.e), + "kid": kid, + "use": "sig", + "alg": "RS256", + } diff --git a/src/ai/backend/common/jwt/signer.py b/src/ai/backend/common/jwt/signer.py index d2e5082c83b..9e622108ca3 100644 --- a/src/ai/backend/common/jwt/signer.py +++ b/src/ai/backend/common/jwt/signer.py @@ -4,11 +4,12 @@ from datetime import UTC, datetime -import jwt +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey -from .config import JWTConfig -from .exceptions import JWTError -from .types import JWTClaims, JWTUserContext +import jwt +from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig +from ai.backend.common.jwt.exceptions import JWTError +from ai.backend.common.jwt.types import JWTClaims, JWTUserContext class JWTSigner: @@ -19,23 +20,18 @@ class JWTSigner: HMAC authentication. The generated tokens are then forwarded to the manager via Hive Router using the X-BackendAI-Token header. - Note: JWT tokens are signed using per-user secret keys (from keypair table), - not a shared system secret key. This maintains the same security model as HMAC authentication. - - Usage: - from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext + Supports both HS256 (symmetric, per-user secret keys) and RS256 (asymmetric, + RSA key pairs) signing algorithms. + Usage (HS256): config = JWTConfig() signer = JWTSigner(config) + token = signer.generate_token(user_context, secret_key="my-secret") - user_context = JWTUserContext( - user_id=user_uuid, - access_key=access_key, - role="user", - ) - # Get user's secret key from keypair table - secret_key = keypair.secret_key - token = signer.generate_token(user_context, secret_key) + Usage (RS256): + config = JWTConfig(algorithm="RS256") + signer = JWTSigner(config) + token = signer.generate_token(user_context, private_key=rsa_private_key, kid="key-1") """ _config: JWTConfig @@ -49,22 +45,31 @@ def __init__(self, config: JWTConfig) -> None: """ self._config = config - def generate_token(self, user_context: JWTUserContext, secret_key: str) -> str: + def generate_token( + self, + user_context: JWTUserContext, + secret_key: str | None = None, + *, + private_key: RSAPrivateKey | None = None, + kid: str | None = None, + ) -> str: """ Generate a JWT token from authenticated user context. - This method creates a JWT token containing minimal user authentication - information. The token is signed using HS256 with the user's secret key. + For HS256, provide ``secret_key``. For RS256, provide ``private_key`` + and optionally ``kid`` (key ID included in the JWT header). Args: user_context: User context data containing authentication information - secret_key: User's secret key from keypair table for signing the token + secret_key: Secret key string for HS256 signing + private_key: RSA private key object for RS256 signing + kid: Key ID to include in the JWT header (RS256 only) Returns: Encoded JWT token string Raises: - JWTError: If token generation fails + JWTError: If token generation fails or invalid key arguments are provided """ now = datetime.now(UTC) @@ -73,13 +78,30 @@ def generate_token(self, user_context: JWTUserContext, secret_key: str) -> str: iat=now, access_key=user_context.access_key, role=user_context.role, + kid=kid, ) try: + if self._config.algorithm == JWTAlgorithm.RS256: + if private_key is None: + raise JWTError("RS256 algorithm requires a private_key argument") + headers: dict[str, str] = {} + if kid is not None: + headers["kid"] = kid + return jwt.encode( + claims.to_dict(), + private_key, + algorithm=self._config.algorithm, + headers=headers if headers else None, + ) + if secret_key is None: + raise JWTError("HS256 algorithm requires a secret_key argument") return jwt.encode( claims.to_dict(), secret_key, algorithm=self._config.algorithm, ) + except JWTError: + raise except Exception as e: raise JWTError(f"JWT generation failed: {e}") from e diff --git a/src/ai/backend/common/jwt/types.py b/src/ai/backend/common/jwt/types.py index e8ce03fce4e..afd668b93b5 100644 --- a/src/ai/backend/common/jwt/types.py +++ b/src/ai/backend/common/jwt/types.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import UTC, datetime from typing import Any @@ -43,6 +43,12 @@ class JWTClaims: iat: Issued at time (UTC) access_key: User's access key role: User role ("admin", "user", or "superadmin") + kid: Key ID used to sign this token (optional, for RS256/JWKS) + iss: Issuer identifier (optional, for OAuth2) + aud: Audience identifier (optional, for OAuth2) + sub: Subject identifier (optional, for OAuth2) + scope: Space-separated scope string (optional, for OAuth2) + jti: JWT ID / token identifier (optional, for OAuth2) """ # Standard JWT claims (RFC 7519) @@ -53,22 +59,43 @@ class JWTClaims: access_key: AccessKey role: str + # Optional claims for RS256/JWKS and OAuth2 + kid: str | None = field(default=None) + iss: str | None = field(default=None) + aud: str | None = field(default=None) + sub: str | None = field(default=None) + scope: str | None = field(default=None) + jti: str | None = field(default=None) + def to_dict(self) -> dict[str, Any]: """ Convert JWTClaims to a dictionary suitable for JWT payload. Datetime objects are converted to Unix timestamps (integers) as required - by the JWT standard. + by the JWT standard. Optional fields are only included when set. Returns: Dictionary representation of claims with timestamps as integers. """ - return { + result: dict[str, Any] = { "exp": int(self.exp.timestamp()), "iat": int(self.iat.timestamp()), "access_key": str(self.access_key), "role": self.role, } + if self.kid is not None: + result["kid"] = self.kid + if self.iss is not None: + result["iss"] = self.iss + if self.aud is not None: + result["aud"] = self.aud + if self.sub is not None: + result["sub"] = self.sub + if self.scope is not None: + result["scope"] = self.scope + if self.jti is not None: + result["jti"] = self.jti + return result @classmethod def from_dict(cls, payload: dict[str, Any]) -> JWTClaims: @@ -93,4 +120,10 @@ def from_dict(cls, payload: dict[str, Any]) -> JWTClaims: iat=datetime.fromtimestamp(payload["iat"], tz=UTC), access_key=AccessKey(payload["access_key"]), role=payload["role"], + kid=payload.get("kid"), + iss=payload.get("iss"), + aud=payload.get("aud"), + sub=payload.get("sub"), + scope=payload.get("scope"), + jti=payload.get("jti"), ) diff --git a/src/ai/backend/common/jwt/validator.py b/src/ai/backend/common/jwt/validator.py index deae277a803..92fc1c9f7d7 100644 --- a/src/ai/backend/common/jwt/validator.py +++ b/src/ai/backend/common/jwt/validator.py @@ -2,21 +2,27 @@ from __future__ import annotations +from typing import TYPE_CHECKING + +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + import jwt +from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig +from ai.backend.common.jwt.exceptions import ( + JWTDecodeError, + JWTExpiredError, + JWTInvalidClaimsError, + JWTInvalidSignatureError, +) +from ai.backend.common.jwt.types import JWTClaims from jwt.exceptions import ( DecodeError, ExpiredSignatureError, InvalidSignatureError, ) -from .config import JWTConfig -from .exceptions import ( - JWTDecodeError, - JWTExpiredError, - JWTInvalidClaimsError, - JWTInvalidSignatureError, -) -from .types import JWTClaims +if TYPE_CHECKING: + from ai.backend.common.jwt.jwks import JWKSKeySet class JWTValidator: @@ -27,14 +33,22 @@ class JWTValidator: Hive Router via the X-BackendAI-Token header. It verifies the token's signature, expiration, and claims. - Note: JWT tokens are signed using per-user secret keys (from keypair table), - not a shared system secret key. This maintains the same security model as HMAC authentication. + Supports both HS256 (symmetric) and RS256 (asymmetric) validation. - Usage: + Usage (HS256): config = JWTConfig() validator = JWTValidator(config) - # Get user's secret key from keypair table after extracting access_key from token - claims = validator.validate_token(token_string, secret_key) + claims = validator.validate_token(token_string, secret_key="my-secret") + + Usage (RS256): + config = JWTConfig(algorithm="RS256") + validator = JWTValidator(config) + claims = validator.validate_token(token_string, public_key=rsa_public_key) + + Usage (JWKS): + config = JWTConfig(algorithm="RS256") + validator = JWTValidator(config) + claims = validator.validate_token_with_jwks(token_string, jwks_key_set) """ _config: JWTConfig @@ -48,18 +62,96 @@ def __init__(self, config: JWTConfig) -> None: """ self._config = config - def validate_token(self, token: str, secret_key: str) -> JWTClaims: + def validate_token( + self, + token: str, + secret_key: str | None = None, + *, + public_key: RSAPublicKey | None = None, + ) -> JWTClaims: """ Validate JWT token and extract claims. This method performs comprehensive validation: - 1. Verifies the token signature using the user's secret key + 1. Verifies the token signature using the provided key 2. Checks token expiration 3. Ensures all required claims are present and valid + For HS256, provide ``secret_key``. For RS256, provide ``public_key``. + + Args: + token: Encoded JWT token string + secret_key: Secret key string for HS256 signature verification + public_key: RSA public key object for RS256 signature verification + + Returns: + JWTClaims object containing validated user information + + Raises: + JWTExpiredError: If the token has expired + JWTInvalidSignatureError: If signature verification fails + JWTInvalidClaimsError: If claims are missing or invalid + JWTDecodeError: If the token cannot be decoded + """ + if self._config.algorithm == JWTAlgorithm.RS256: + if public_key is None: + raise JWTDecodeError("RS256 algorithm requires a public_key argument") + key: str | RSAPublicKey = public_key + else: + if secret_key is None: + raise JWTDecodeError("HS256 algorithm requires a secret_key argument") + key = secret_key + + return self._decode_and_validate(token, key) + + def validate_token_with_jwks( + self, + token: str, + jwks_key_set: JWKSKeySet, + ) -> JWTClaims: + """ + Validate JWT token using a JWKS key set. + + Extracts the ``kid`` (key ID) from the token header, looks up the + corresponding public key in the JWKS key set, and validates the token. + + Args: + token: Encoded JWT token string + jwks_key_set: JWKS key set containing public keys indexed by kid + + Returns: + JWTClaims object containing validated user information + + Raises: + JWTExpiredError: If the token has expired + JWTInvalidSignatureError: If signature verification fails + JWTInvalidClaimsError: If claims are missing or invalid + JWTDecodeError: If the token cannot be decoded or has no kid header + JWKSKeyNotFoundError: If the kid is not found in the key set + """ + try: + unverified_header = jwt.get_unverified_header(token) + except DecodeError as e: + raise JWTDecodeError(f"Failed to decode JWT header: {e}") from e + + kid = unverified_header.get("kid") + if kid is None: + raise JWTDecodeError("JWT token header does not contain a 'kid' field") + + public_key = jwks_key_set.get_key(kid) + return self._decode_and_validate(token, public_key) + + def _decode_and_validate( + self, + token: str, + key: str | RSAPublicKey, + ) -> JWTClaims: + """ + Decode a JWT token, verify its signature, and validate claims. + Args: token: Encoded JWT token string - secret_key: User's secret key from keypair table for signature verification + key: Key for signature verification (str for HS256, RSAPublicKey for RS256) Returns: JWTClaims object containing validated user information @@ -71,24 +163,20 @@ def validate_token(self, token: str, secret_key: str) -> JWTClaims: JWTDecodeError: If the token cannot be decoded """ try: - # Decode and verify token payload = jwt.decode( token, - secret_key, + key, algorithms=[self._config.algorithm], options={ "verify_signature": True, "verify_exp": True, "verify_iat": True, + "verify_aud": False, }, ) - # Parse claims from payload claims = JWTClaims.from_dict(payload) - - # Validate claim values self._validate_claims(claims) - return claims except ExpiredSignatureError as e: @@ -112,7 +200,6 @@ def _validate_claims(self, claims: JWTClaims) -> None: Raises: JWTInvalidClaimsError: If validation fails """ - # Validate role is one of expected values valid_roles = {"admin", "user", "superadmin"} if claims.role not in valid_roles: raise JWTInvalidClaimsError( diff --git a/tests/unit/common/jwt/test_jwks.py b/tests/unit/common/jwt/test_jwks.py new file mode 100644 index 00000000000..62de578b943 --- /dev/null +++ b/tests/unit/common/jwt/test_jwks.py @@ -0,0 +1,234 @@ +"""Tests for JWKS key set and fetcher.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey + +from ai.backend.common.jwt.exceptions import JWKSFetchError, JWKSKeyNotFoundError +from ai.backend.common.jwt.jwks import JWKSFetcher, JWKSKeySet +from ai.backend.common.jwt.keys import generate_rsa_key_pair, public_key_to_jwk + + +@pytest.fixture +def rsa_keys_with_jwks() -> dict[str, dict[str, str] | RSAPublicKey]: + """Generate RSA keys and their JWK representations.""" + private_key_a, public_key_a = generate_rsa_key_pair() + private_key_b, public_key_b = generate_rsa_key_pair() + jwk_a = public_key_to_jwk(public_key_a, kid="key-a") + jwk_b = public_key_to_jwk(public_key_b, kid="key-b") + return { + "jwk_a": jwk_a, + "jwk_b": jwk_b, + "public_key_a": public_key_a, + "public_key_b": public_key_b, + } + + +class TestJWKSKeySet: + """Tests for JWKSKeySet parsing and key lookup.""" + + def test_from_jwks_dict_parses_keys(self) -> None: + """Test parsing a JWKS dict with multiple keys.""" + _, public_key_a = generate_rsa_key_pair() + _, public_key_b = generate_rsa_key_pair() + jwk_a = public_key_to_jwk(public_key_a, kid="key-a") + jwk_b = public_key_to_jwk(public_key_b, kid="key-b") + + jwks_dict = {"keys": [jwk_a, jwk_b]} + key_set = JWKSKeySet.from_jwks_dict(jwks_dict) + + assert "key-a" in key_set.kids + assert "key-b" in key_set.kids + assert len(key_set.kids) == 2 + + def test_get_key_returns_correct_key(self) -> None: + """Test that get_key returns the correct public key by kid.""" + _, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="my-key") + + key_set = JWKSKeySet.from_jwks_dict({"keys": [jwk]}) + retrieved_key = key_set.get_key("my-key") + + assert isinstance(retrieved_key, RSAPublicKey) + assert retrieved_key.public_numbers().n == public_key.public_numbers().n + + def test_get_key_not_found_raises_error(self) -> None: + """Test that get_key raises JWKSKeyNotFoundError for unknown kid.""" + key_set = JWKSKeySet.from_jwks_dict({"keys": []}) + + with pytest.raises(JWKSKeyNotFoundError): + key_set.get_key("nonexistent-key") + + def test_from_jwks_dict_ignores_non_rsa_keys(self) -> None: + """Test that non-RSA keys in the JWKS are ignored.""" + _, public_key = generate_rsa_key_pair() + rsa_jwk = public_key_to_jwk(public_key, kid="rsa-key") + ec_jwk = {"kty": "EC", "kid": "ec-key", "crv": "P-256", "x": "abc", "y": "def"} + + key_set = JWKSKeySet.from_jwks_dict({"keys": [rsa_jwk, ec_jwk]}) + assert key_set.kids == ["rsa-key"] + + def test_from_jwks_dict_ignores_keys_without_kid(self) -> None: + """Test that keys without a kid field are ignored.""" + _, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="has-kid") + jwk_no_kid = public_key_to_jwk(public_key, kid="temp") + del jwk_no_kid["kid"] + + key_set = JWKSKeySet.from_jwks_dict({"keys": [jwk, jwk_no_kid]}) + assert key_set.kids == ["has-kid"] + + def test_from_jwks_dict_empty_keys(self) -> None: + """Test parsing a JWKS dict with no keys.""" + key_set = JWKSKeySet.from_jwks_dict({"keys": []}) + assert key_set.kids == [] + + def test_from_jwks_dict_missing_keys_field(self) -> None: + """Test parsing a JWKS dict with no 'keys' field.""" + key_set = JWKSKeySet.from_jwks_dict({}) + assert key_set.kids == [] + + +class TestJWKSFetcher: + """Tests for JWKSFetcher caching and fetching.""" + + async def test_get_key_fetches_on_first_call(self) -> None: + """Test that the first call to get_key fetches the JWKS endpoint.""" + _, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="key-1") + jwks_response = {"keys": [jwk]} + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=jwks_response) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("ai.backend.common.jwt.jwks.aiohttp.ClientSession", return_value=mock_session): + fetcher = JWKSFetcher(url="https://example.com/.well-known/jwks.json") + result = await fetcher.get_key("key-1") + + assert isinstance(result, RSAPublicKey) + mock_session.get.assert_called_once_with("https://example.com/.well-known/jwks.json") + + async def test_get_key_uses_cache_within_ttl(self) -> None: + """Test that subsequent calls use the cache within TTL.""" + _, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="key-1") + jwks_response = {"keys": [jwk]} + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=jwks_response) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("ai.backend.common.jwt.jwks.aiohttp.ClientSession", return_value=mock_session): + fetcher = JWKSFetcher( + url="https://example.com/.well-known/jwks.json", + cache_ttl=300.0, + ) + # First call fetches + await fetcher.get_key("key-1") + # Second call should use cache + await fetcher.get_key("key-1") + + # Should only have fetched once + assert mock_session.get.call_count == 1 + + async def test_get_key_refetches_after_ttl(self) -> None: + """Test that the cache is refreshed after TTL expires.""" + _, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="key-1") + jwks_response = {"keys": [jwk]} + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=jwks_response) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("ai.backend.common.jwt.jwks.aiohttp.ClientSession", return_value=mock_session): + fetcher = JWKSFetcher( + url="https://example.com/.well-known/jwks.json", + cache_ttl=0.0, # Immediate expiry + ) + await fetcher.get_key("key-1") + # With TTL=0, next call should refetch + await fetcher.get_key("key-1") + + assert mock_session.get.call_count == 2 + + async def test_refresh_forces_fetch(self) -> None: + """Test that refresh() forces a new fetch regardless of TTL.""" + _, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="key-1") + jwks_response = {"keys": [jwk]} + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=jwks_response) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("ai.backend.common.jwt.jwks.aiohttp.ClientSession", return_value=mock_session): + fetcher = JWKSFetcher( + url="https://example.com/.well-known/jwks.json", + cache_ttl=9999.0, + ) + await fetcher.get_key("key-1") + await fetcher.refresh() + + assert mock_session.get.call_count == 2 + + async def test_fetch_error_raises_jwks_fetch_error(self) -> None: + """Test that HTTP errors raise JWKSFetchError.""" + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("ai.backend.common.jwt.jwks.aiohttp.ClientSession", return_value=mock_session): + fetcher = JWKSFetcher(url="https://example.com/.well-known/jwks.json") + with pytest.raises(JWKSFetchError): + await fetcher.get_key("key-1") + + async def test_connection_error_raises_jwks_fetch_error(self) -> None: + """Test that connection errors raise JWKSFetchError.""" + mock_session = MagicMock() + mock_session.get = MagicMock(side_effect=ConnectionError("Connection refused")) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("ai.backend.common.jwt.jwks.aiohttp.ClientSession", return_value=mock_session): + fetcher = JWKSFetcher(url="https://example.com/.well-known/jwks.json") + with pytest.raises(JWKSFetchError): + await fetcher.get_key("key-1") diff --git a/tests/unit/common/jwt/test_jwks_validation.py b/tests/unit/common/jwt/test_jwks_validation.py new file mode 100644 index 00000000000..7722b08747d --- /dev/null +++ b/tests/unit/common/jwt/test_jwks_validation.py @@ -0,0 +1,257 @@ +"""Tests for JWKS-based token validation and key rotation.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import jwt as pyjwt +import pytest + +from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig +from ai.backend.common.jwt.exceptions import JWKSKeyNotFoundError, JWTDecodeError +from ai.backend.common.jwt.jwks import JWKSKeySet +from ai.backend.common.jwt.keys import generate_rsa_key_pair, public_key_to_jwk +from ai.backend.common.jwt.signer import JWTSigner +from ai.backend.common.jwt.types import JWTClaims, JWTUserContext +from ai.backend.common.jwt.validator import JWTValidator +from ai.backend.common.types import AccessKey + + +@pytest.fixture +def rs256_config() -> JWTConfig: + """Create RS256 JWT configuration.""" + return JWTConfig( + algorithm=JWTAlgorithm.RS256, + token_expiration_seconds=900, + ) + + +@pytest.fixture +def user_context() -> JWTUserContext: + """Create test user context.""" + return JWTUserContext( + access_key=AccessKey("AKIAIOSFODNN7EXAMPLE"), + role="user", + ) + + +class TestJWKSValidation: + """Tests for validate_token_with_jwks.""" + + def test_validate_with_jwks_key_set( + self, + rs256_config: JWTConfig, + user_context: JWTUserContext, + ) -> None: + """Test token validation using a JWKS key set.""" + private_key, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="key-1") + key_set = JWKSKeySet.from_jwks_dict({"keys": [jwk]}) + + signer = JWTSigner(rs256_config) + validator = JWTValidator(rs256_config) + + token = signer.generate_token(user_context, private_key=private_key, kid="key-1") + claims = validator.validate_token_with_jwks(token, key_set) + + assert claims.access_key == user_context.access_key + assert claims.role == user_context.role + + def test_validate_with_jwks_kid_not_found( + self, + rs256_config: JWTConfig, + user_context: JWTUserContext, + ) -> None: + """Test that missing kid in JWKS raises JWKSKeyNotFoundError.""" + private_key, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="key-1") + key_set = JWKSKeySet.from_jwks_dict({"keys": [jwk]}) + + signer = JWTSigner(rs256_config) + validator = JWTValidator(rs256_config) + + token = signer.generate_token(user_context, private_key=private_key, kid="key-unknown") + + with pytest.raises(JWKSKeyNotFoundError): + validator.validate_token_with_jwks(token, key_set) + + def test_validate_with_jwks_no_kid_in_header( + self, + rs256_config: JWTConfig, + user_context: JWTUserContext, + ) -> None: + """Test that token without kid header raises JWTDecodeError.""" + private_key, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="key-1") + key_set = JWKSKeySet.from_jwks_dict({"keys": [jwk]}) + + signer = JWTSigner(rs256_config) + validator = JWTValidator(rs256_config) + + # Generate token without kid + token = signer.generate_token(user_context, private_key=private_key) + + with pytest.raises(JWTDecodeError, match="kid"): + validator.validate_token_with_jwks(token, key_set) + + +class TestKeyRotation: + """Tests for key rotation scenarios.""" + + def test_old_token_still_valid_after_adding_new_key( + self, + rs256_config: JWTConfig, + user_context: JWTUserContext, + ) -> None: + """Test that tokens signed with key A still verify after adding key B.""" + private_key_a, public_key_a = generate_rsa_key_pair() + private_key_b, public_key_b = generate_rsa_key_pair() + + signer = JWTSigner(rs256_config) + validator = JWTValidator(rs256_config) + + # Sign with key A + token_a = signer.generate_token(user_context, private_key=private_key_a, kid="key-a") + + # Create JWKS with both keys (simulating key rotation) + jwk_a = public_key_to_jwk(public_key_a, kid="key-a") + jwk_b = public_key_to_jwk(public_key_b, kid="key-b") + key_set = JWKSKeySet.from_jwks_dict({"keys": [jwk_a, jwk_b]}) + + # Old token should still validate + claims = validator.validate_token_with_jwks(token_a, key_set) + assert claims.access_key == user_context.access_key + + def test_new_token_with_new_key_validates( + self, + rs256_config: JWTConfig, + user_context: JWTUserContext, + ) -> None: + """Test that tokens signed with key B validate with updated JWKS.""" + private_key_a, public_key_a = generate_rsa_key_pair() + private_key_b, public_key_b = generate_rsa_key_pair() + + signer = JWTSigner(rs256_config) + validator = JWTValidator(rs256_config) + + # Sign with key B + token_b = signer.generate_token(user_context, private_key=private_key_b, kid="key-b") + + # JWKS with both keys + jwk_a = public_key_to_jwk(public_key_a, kid="key-a") + jwk_b = public_key_to_jwk(public_key_b, kid="key-b") + key_set = JWKSKeySet.from_jwks_dict({"keys": [jwk_a, jwk_b]}) + + claims = validator.validate_token_with_jwks(token_b, key_set) + assert claims.access_key == user_context.access_key + + def test_token_fails_after_key_removal( + self, + rs256_config: JWTConfig, + user_context: JWTUserContext, + ) -> None: + """Test that tokens fail after their signing key is removed from JWKS.""" + private_key_a, public_key_a = generate_rsa_key_pair() + private_key_b, public_key_b = generate_rsa_key_pair() + + signer = JWTSigner(rs256_config) + validator = JWTValidator(rs256_config) + + # Sign with key A + token_a = signer.generate_token(user_context, private_key=private_key_a, kid="key-a") + + # JWKS with only key B (key A removed) + jwk_b = public_key_to_jwk(public_key_b, kid="key-b") + key_set = JWKSKeySet.from_jwks_dict({"keys": [jwk_b]}) + + with pytest.raises(JWKSKeyNotFoundError): + validator.validate_token_with_jwks(token_a, key_set) + + +class TestOAuth2Claims: + """Tests for OAuth2-style claims round-trip.""" + + def test_oauth2_claims_roundtrip_rs256( + self, + rs256_config: JWTConfig, + ) -> None: + """Test that OAuth2 claims (iss, aud, sub, scope, jti) survive round-trip.""" + private_key, public_key = generate_rsa_key_pair() + + now = datetime.now(UTC) + payload = { + "exp": int((now + timedelta(seconds=900)).timestamp()), + "iat": int(now.timestamp()), + "access_key": "AKIAIOSFODNN7EXAMPLE", + "role": "user", + "iss": "https://auth.example.com", + "aud": "https://api.example.com", + "sub": "550e8400-e29b-41d4-a716-446655440000", + "scope": "read write admin", + "jti": "unique-token-id-123", + } + + token = pyjwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "key-1"}, + ) + + validator = JWTValidator(rs256_config) + claims = validator.validate_token(token, public_key=public_key) + + assert claims.iss == "https://auth.example.com" + assert claims.aud == "https://api.example.com" + assert claims.sub == "550e8400-e29b-41d4-a716-446655440000" + assert claims.scope == "read write admin" + assert claims.jti == "unique-token-id-123" + + def test_oauth2_claims_optional_fields( + self, + rs256_config: JWTConfig, + ) -> None: + """Test that OAuth2 claims are optional (backward compat).""" + private_key, public_key = generate_rsa_key_pair() + + now = datetime.now(UTC) + payload = { + "exp": int((now + timedelta(seconds=900)).timestamp()), + "iat": int(now.timestamp()), + "access_key": "AKIAIOSFODNN7EXAMPLE", + "role": "user", + } + + token = pyjwt.encode( + payload, + private_key, + algorithm="RS256", + ) + + validator = JWTValidator(rs256_config) + claims = validator.validate_token(token, public_key=public_key) + + assert claims.iss is None + assert claims.aud is None + assert claims.sub is None + assert claims.scope is None + assert claims.jti is None + + def test_oauth2_claims_to_dict_includes_set_fields(self) -> None: + """Test that to_dict includes OAuth2 fields when set.""" + now = datetime.now(UTC) + claims = JWTClaims( + exp=now, + iat=now, + access_key=AccessKey("AKIAIOSFODNN7EXAMPLE"), + role="user", + iss="https://auth.example.com", + sub="user-id-123", + scope="read", + ) + d = claims.to_dict() + assert d["iss"] == "https://auth.example.com" + assert d["sub"] == "user-id-123" + assert d["scope"] == "read" + assert "aud" not in d + assert "jti" not in d diff --git a/tests/unit/common/jwt/test_keys.py b/tests/unit/common/jwt/test_keys.py new file mode 100644 index 00000000000..182063bdfb7 --- /dev/null +++ b/tests/unit/common/jwt/test_keys.py @@ -0,0 +1,142 @@ +"""Tests for RSA key management utilities.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPublicKey, +) + +from ai.backend.common.jwt.keys import ( + generate_rsa_key_pair, + load_private_key, + load_public_key, + private_key_to_pem, + public_key_to_jwk, + public_key_to_pem, +) + + +class TestGenerateRSAKeyPair: + """Tests for generate_rsa_key_pair.""" + + def test_generates_valid_key_pair(self) -> None: + """Test that a valid RSA key pair is generated.""" + private_key, public_key = generate_rsa_key_pair() + assert isinstance(private_key, RSAPrivateKey) + assert isinstance(public_key, RSAPublicKey) + + def test_default_key_size_is_2048(self) -> None: + """Test that the default key size is 2048 bits.""" + private_key, _ = generate_rsa_key_pair() + assert private_key.key_size == 2048 + + def test_custom_key_size(self) -> None: + """Test generating keys with a custom key size.""" + private_key, _ = generate_rsa_key_pair(key_size=4096) + assert private_key.key_size == 4096 + + def test_public_key_matches_private_key(self) -> None: + """Test that the public key corresponds to the private key.""" + private_key, public_key = generate_rsa_key_pair() + derived_public_key = private_key.public_key() + assert public_key.public_numbers().n == derived_public_key.public_numbers().n + assert public_key.public_numbers().e == derived_public_key.public_numbers().e + + +class TestPEMSerialization: + """Tests for PEM serialization and loading.""" + + def test_private_key_pem_roundtrip(self, tmp_path: Path) -> None: + """Test private key serialization and loading roundtrip.""" + private_key, _ = generate_rsa_key_pair() + pem_bytes = private_key_to_pem(private_key) + + key_file = tmp_path / "private.pem" + key_file.write_bytes(pem_bytes) + + loaded_key = load_private_key(key_file) + assert isinstance(loaded_key, RSAPrivateKey) + assert ( + loaded_key.private_numbers().public_numbers.n + == private_key.private_numbers().public_numbers.n + ) + + def test_public_key_pem_roundtrip(self, tmp_path: Path) -> None: + """Test public key serialization and loading roundtrip.""" + _, public_key = generate_rsa_key_pair() + pem_bytes = public_key_to_pem(public_key) + + key_file = tmp_path / "public.pem" + key_file.write_bytes(pem_bytes) + + loaded_key = load_public_key(key_file) + assert isinstance(loaded_key, RSAPublicKey) + assert loaded_key.public_numbers().n == public_key.public_numbers().n + + def test_private_key_pem_starts_with_header(self) -> None: + """Test that PEM-encoded private key has correct header.""" + private_key, _ = generate_rsa_key_pair() + pem_bytes = private_key_to_pem(private_key) + assert pem_bytes.startswith(b"-----BEGIN PRIVATE KEY-----") + + def test_public_key_pem_starts_with_header(self) -> None: + """Test that PEM-encoded public key has correct header.""" + _, public_key = generate_rsa_key_pair() + pem_bytes = public_key_to_pem(public_key) + assert pem_bytes.startswith(b"-----BEGIN PUBLIC KEY-----") + + +class TestPublicKeyToJWK: + """Tests for public_key_to_jwk conversion.""" + + def test_jwk_has_required_fields(self) -> None: + """Test that JWK output contains all required fields.""" + _, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="test-key-1") + + assert jwk["kty"] == "RSA" + assert jwk["kid"] == "test-key-1" + assert jwk["use"] == "sig" + assert jwk["alg"] == "RS256" + assert "n" in jwk + assert "e" in jwk + + def test_jwk_n_and_e_are_strings(self) -> None: + """Test that n and e values are base64url-encoded strings.""" + _, public_key = generate_rsa_key_pair() + jwk = public_key_to_jwk(public_key, kid="test-key-1") + + assert isinstance(jwk["n"], str) + assert isinstance(jwk["e"], str) + # Base64url should not contain padding + assert "=" not in jwk["n"] + assert "=" not in jwk["e"] + + def test_jwk_different_kid(self) -> None: + """Test that different kid values are correctly set.""" + _, public_key = generate_rsa_key_pair() + jwk_a = public_key_to_jwk(public_key, kid="key-a") + jwk_b = public_key_to_jwk(public_key, kid="key-b") + + assert jwk_a["kid"] == "key-a" + assert jwk_b["kid"] == "key-b" + # Same key, so n and e should be identical + assert jwk_a["n"] == jwk_b["n"] + + +class TestLoadKeyErrors: + """Tests for key loading error cases.""" + + def test_load_private_key_nonexistent_file(self, tmp_path: Path) -> None: + """Test that loading from a nonexistent file raises an error.""" + with pytest.raises(FileNotFoundError): + load_private_key(tmp_path / "nonexistent.pem") + + def test_load_public_key_nonexistent_file(self, tmp_path: Path) -> None: + """Test that loading from a nonexistent file raises an error.""" + with pytest.raises(FileNotFoundError): + load_public_key(tmp_path / "nonexistent.pem") diff --git a/tests/unit/common/jwt/test_rs256.py b/tests/unit/common/jwt/test_rs256.py new file mode 100644 index 00000000000..beeabc9d872 --- /dev/null +++ b/tests/unit/common/jwt/test_rs256.py @@ -0,0 +1,174 @@ +"""Tests for RS256 signing and validation.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import jwt as pyjwt +import pytest +from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateKey, + RSAPublicKey, +) + +from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig +from ai.backend.common.jwt.exceptions import ( + JWTExpiredError, + JWTInvalidSignatureError, +) +from ai.backend.common.jwt.keys import generate_rsa_key_pair +from ai.backend.common.jwt.signer import JWTSigner +from ai.backend.common.jwt.types import JWTUserContext +from ai.backend.common.jwt.validator import JWTValidator +from ai.backend.common.types import AccessKey + + +@pytest.fixture +def rsa_key_pair() -> tuple[RSAPrivateKey, RSAPublicKey]: + """Generate an RSA key pair for testing.""" + return generate_rsa_key_pair() + + +@pytest.fixture +def rs256_config() -> JWTConfig: + """Create RS256 JWT configuration.""" + return JWTConfig( + algorithm=JWTAlgorithm.RS256, + token_expiration_seconds=900, + ) + + +@pytest.fixture +def rs256_signer(rs256_config: JWTConfig) -> JWTSigner: + """Create RS256 JWT signer.""" + return JWTSigner(rs256_config) + + +@pytest.fixture +def rs256_validator(rs256_config: JWTConfig) -> JWTValidator: + """Create RS256 JWT validator.""" + return JWTValidator(rs256_config) + + +@pytest.fixture +def user_context() -> JWTUserContext: + """Create test user context.""" + return JWTUserContext( + access_key=AccessKey("AKIAIOSFODNN7EXAMPLE"), + role="user", + ) + + +class TestRS256SignAndVerify: + """Tests for RS256 sign and verify round-trip.""" + + def test_rs256_roundtrip( + self, + rs256_signer: JWTSigner, + rs256_validator: JWTValidator, + rsa_key_pair: tuple[RSAPrivateKey, RSAPublicKey], + user_context: JWTUserContext, + ) -> None: + """Test RS256 sign -> verify round-trip.""" + private_key, public_key = rsa_key_pair + token = rs256_signer.generate_token(user_context, private_key=private_key, kid="test-key-1") + claims = rs256_validator.validate_token(token, public_key=public_key) + + assert claims.access_key == user_context.access_key + assert claims.role == user_context.role + + def test_rs256_token_has_kid_header( + self, + rs256_signer: JWTSigner, + rsa_key_pair: tuple[RSAPrivateKey, RSAPublicKey], + user_context: JWTUserContext, + ) -> None: + """Test that RS256 token includes kid in header.""" + private_key, _ = rsa_key_pair + token = rs256_signer.generate_token(user_context, private_key=private_key, kid="my-kid") + header = pyjwt.get_unverified_header(token) + assert header["kid"] == "my-kid" + assert header["alg"] == "RS256" + + def test_rs256_token_without_kid( + self, + rs256_signer: JWTSigner, + rs256_validator: JWTValidator, + rsa_key_pair: tuple[RSAPrivateKey, RSAPublicKey], + user_context: JWTUserContext, + ) -> None: + """Test RS256 token without kid still works for direct validation.""" + private_key, public_key = rsa_key_pair + token = rs256_signer.generate_token(user_context, private_key=private_key) + claims = rs256_validator.validate_token(token, public_key=public_key) + assert claims.access_key == user_context.access_key + + def test_rs256_with_wrong_key_fails( + self, + rs256_signer: JWTSigner, + rs256_validator: JWTValidator, + rsa_key_pair: tuple[RSAPrivateKey, RSAPublicKey], + user_context: JWTUserContext, + ) -> None: + """Test that RS256 verification fails with a different key pair.""" + private_key, _ = rsa_key_pair + _, wrong_public_key = generate_rsa_key_pair() + + token = rs256_signer.generate_token(user_context, private_key=private_key, kid="key-1") + + with pytest.raises(JWTInvalidSignatureError): + rs256_validator.validate_token(token, public_key=wrong_public_key) + + def test_rs256_expired_token( + self, + rs256_config: JWTConfig, + rs256_validator: JWTValidator, + rsa_key_pair: tuple[RSAPrivateKey, RSAPublicKey], + user_context: JWTUserContext, + ) -> None: + """Test that expired RS256 token raises JWTExpiredError.""" + private_key, public_key = rsa_key_pair + past_time = datetime.now(UTC) - timedelta(hours=1) + + payload = { + "exp": int((past_time + timedelta(seconds=900)).timestamp()), + "iat": int(past_time.timestamp()), + "access_key": str(user_context.access_key), + "role": user_context.role, + } + + expired_token = pyjwt.encode( + payload, + private_key, + algorithm="RS256", + headers={"kid": "key-1"}, + ) + + with pytest.raises(JWTExpiredError): + rs256_validator.validate_token(expired_token, public_key=public_key) + + +class TestHS256BackwardCompat: + """Tests that HS256 still works after RS256 additions.""" + + def test_hs256_roundtrip_still_works(self) -> None: + """Test that HS256 sign/verify still works.""" + config = JWTConfig(algorithm=JWTAlgorithm.HS256, token_expiration_seconds=900) + signer = JWTSigner(config) + validator = JWTValidator(config) + context = JWTUserContext( + access_key=AccessKey("AKIAIOSFODNN7EXAMPLE"), + role="admin", + ) + secret = "test-secret-key-at-least-32-bytes-long" + + token = signer.generate_token(context, secret) + claims = validator.validate_token(token, secret) + + assert claims.access_key == context.access_key + assert claims.role == context.role + + def test_hs256_default_config(self) -> None: + """Test that default config is HS256.""" + config = JWTConfig() + assert config.algorithm == JWTAlgorithm.HS256 diff --git a/tests/unit/common/jwt/test_signer.py b/tests/unit/common/jwt/test_signer.py index 5bfe1749015..7a7b38e2d6f 100644 --- a/tests/unit/common/jwt/test_signer.py +++ b/tests/unit/common/jwt/test_signer.py @@ -8,7 +8,7 @@ import jwt as pyjwt import pytest -from ai.backend.common.jwt.config import JWTConfig +from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig from ai.backend.common.jwt.signer import JWTSigner from ai.backend.common.jwt.types import JWTUserContext from ai.backend.common.types import AccessKey @@ -18,7 +18,7 @@ def jwt_config() -> JWTConfig: """Create test JWT configuration.""" return JWTConfig( - algorithm="HS256", + algorithm=JWTAlgorithm.HS256, token_expiration_seconds=900, ) diff --git a/tests/unit/common/jwt/test_validator.py b/tests/unit/common/jwt/test_validator.py index d2c7f0119c8..f375f63e9c0 100644 --- a/tests/unit/common/jwt/test_validator.py +++ b/tests/unit/common/jwt/test_validator.py @@ -7,7 +7,7 @@ import jwt as pyjwt import pytest -from ai.backend.common.jwt.config import JWTConfig +from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig from ai.backend.common.jwt.exceptions import ( JWTDecodeError, JWTExpiredError, @@ -24,7 +24,7 @@ def jwt_config() -> JWTConfig: """Create test JWT configuration.""" return JWTConfig( - algorithm="HS256", + algorithm=JWTAlgorithm.HS256, token_expiration_seconds=900, )