diff --git a/changes/6410.feature.md b/changes/6410.feature.md new file mode 100644 index 00000000000..b0f9212e2e5 --- /dev/null +++ b/changes/6410.feature.md @@ -0,0 +1 @@ +Implement JWT authentication module for GraphQL Federation \ No newline at end of file diff --git a/src/ai/backend/common/jwt/__init__.py b/src/ai/backend/common/jwt/__init__.py new file mode 100644 index 00000000000..c5fb3536d85 --- /dev/null +++ b/src/ai/backend/common/jwt/__init__.py @@ -0,0 +1,74 @@ +""" +JWT authentication module for GraphQL Federation. + +This module provides JWT-based authentication for GraphQL requests going through +Hive Router. It uses the X-BackendAI-Token custom header to avoid conflicts with +existing Bearer token usage in appproxy. + +Key components: +- JWTSigner: Generates JWT tokens from authenticated user context (webserver) +- JWTValidator: Validates JWT tokens and extracts user claims (manager) +- JWTConfig: Configuration for JWT authentication +- JWTClaims: Dataclass representing JWT payload claims + +Example usage in webserver: + from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext + + config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"]) + signer = JWTSigner(config) + + user_context = JWTUserContext( + user_id=user_uuid, + access_key=access_key, + role="user", + domain_name="default", + is_admin=False, + is_superadmin=False, + ) + token = signer.generate_token(user_context) + + # Add to request headers + headers["X-BackendAI-Token"] = token + +Example usage in manager: + from ai.backend.common.jwt import JWTValidator, JWTConfig + + config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"]) + validator = JWTValidator(config) + + token = request.headers.get("X-BackendAI-Token") + claims = validator.validate_token(token) + + # Use claims for authentication + user_id = claims.sub + access_key = claims.access_key +""" + +from .config import JWTConfig +from .exceptions import ( + JWTDecodeError, + JWTError, + JWTExpiredError, + JWTInvalidClaimsError, + JWTInvalidSignatureError, +) +from .signer import JWTSigner +from .types import JWTClaims, JWTUserContext +from .validator import JWTValidator + +__all__ = [ + # Configuration + "JWTConfig", + # Types + "JWTClaims", + "JWTUserContext", + # Core classes + "JWTSigner", + "JWTValidator", + # Exceptions + "JWTError", + "JWTExpiredError", + "JWTInvalidSignatureError", + "JWTInvalidClaimsError", + "JWTDecodeError", +] diff --git a/src/ai/backend/common/jwt/config.py b/src/ai/backend/common/jwt/config.py new file mode 100644 index 00000000000..0667e71111b --- /dev/null +++ b/src/ai/backend/common/jwt/config.py @@ -0,0 +1,68 @@ +"""JWT authentication configuration for GraphQL Federation.""" + +from __future__ import annotations + +from datetime import timedelta + +from pydantic import Field + +from ai.backend.common.config import BaseConfigSchema + + +class JWTConfig(BaseConfigSchema): + """ + Configuration for JWT-based authentication in GraphQL Federation. + + This configuration must be consistent between webserver (which generates tokens) + and manager (which validates tokens). The secret_key must be kept secure and + should be the same on both sides. + + Attributes: + enabled: Whether JWT authentication is enabled + secret_key: Secret key for HS256 signing and verification + algorithm: JWT signing algorithm (must be HS256) + token_expiration_seconds: Token validity duration in seconds + issuer: JWT issuer claim value for validation + header_name: HTTP header name for JWT token transmission + """ + + enabled: bool = Field( + default=True, + description="Enable JWT authentication for GraphQL Federation requests", + ) + + secret_key: str = Field( + description="Secret key for HS256 signing and verification. " + "MUST be the same between webserver and manager. " + "Should be at least 32 bytes of random data.", + ) + + algorithm: str = Field( + default="HS256", + description="JWT signing algorithm (only HS256 is supported)", + ) + + token_expiration_seconds: int = Field( + default=900, # 15 minutes + description="JWT token expiration time in seconds (default: 900 = 15 minutes)", + ) + + issuer: str = Field( + default="backend.ai-webserver", + description="JWT issuer claim value for GraphQL Federation tokens", + ) + + header_name: str = Field( + default="X-BackendAI-Token", + description="Custom HTTP header name for JWT token transmission", + ) + + @property + def token_expiration(self) -> timedelta: + """ + Get token expiration as a timedelta object. + + Returns: + Token expiration duration as timedelta + """ + return timedelta(seconds=self.token_expiration_seconds) diff --git a/src/ai/backend/common/jwt/exceptions.py b/src/ai/backend/common/jwt/exceptions.py new file mode 100644 index 00000000000..a0c7e55fa82 --- /dev/null +++ b/src/ai/backend/common/jwt/exceptions.py @@ -0,0 +1,86 @@ +"""JWT authentication exceptions for GraphQL Federation.""" + +from __future__ import annotations + +from aiohttp import web + +from ai.backend.common.exception import ( + BackendAIError, + ErrorCode, + ErrorDetail, + ErrorDomain, + ErrorOperation, +) + + +class JWTError(BackendAIError): + """ + Base exception for JWT-related errors in GraphQL Federation authentication. + + All JWT-specific exceptions inherit from this base class. + """ + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.USER, + operation=ErrorOperation.AUTH, + error_detail=ErrorDetail.UNAUTHORIZED, + ) + + +class JWTExpiredError(JWTError, web.HTTPUnauthorized): + """ + JWT token has expired. + + Raised when attempting to use a token past its expiration time. + """ + + error_type = "https://api.backend.ai/probs/jwt-expired" + error_title = "JWT token has expired." + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.USER, + operation=ErrorOperation.AUTH, + error_detail=ErrorDetail.DATA_EXPIRED, + ) + + +class JWTInvalidSignatureError(JWTError, web.HTTPUnauthorized): + """ + JWT signature verification failed. + + Raised when the token's signature doesn't match the expected signature, + indicating the token may have been tampered with or was signed with + a different secret key. + """ + + error_type = "https://api.backend.ai/probs/jwt-invalid-signature" + error_title = "JWT signature verification failed." + + +class JWTInvalidClaimsError(JWTError, web.HTTPUnauthorized): + """ + JWT claims are missing or invalid. + + Raised when required claims are missing from the token or when + claim values don't meet validation requirements (e.g., invalid role, + wrong issuer). + """ + + error_type = "https://api.backend.ai/probs/jwt-invalid-claims" + error_title = "JWT claims are invalid." + + +class JWTDecodeError(JWTError, web.HTTPUnauthorized): + """ + Failed to decode JWT token. + + Raised when the token cannot be decoded, typically due to malformed + token structure or encoding issues. + """ + + error_type = "https://api.backend.ai/probs/jwt-decode-error" + error_title = "Failed to decode JWT token." diff --git a/src/ai/backend/common/jwt/signer.py b/src/ai/backend/common/jwt/signer.py new file mode 100644 index 00000000000..1c17e522a44 --- /dev/null +++ b/src/ai/backend/common/jwt/signer.py @@ -0,0 +1,87 @@ +"""JWT token signer for generating authentication tokens.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import jwt + +from .config import JWTConfig +from .exceptions import JWTError +from .types import JWTClaims, JWTUserContext + + +class JWTSigner: + """ + JWT token generator for GraphQL Federation authentication. + + This class is used by the webserver to generate JWT tokens after successful + HMAC authentication. The generated tokens are then forwarded to the manager + via Hive Router using the X-BackendAI-Token header. + + Usage: + from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext + + config = JWTConfig(secret_key="your-secret-key") + signer = JWTSigner(config) + + user_context = JWTUserContext( + user_id=user_uuid, + access_key=access_key, + role="user", + domain_name="default", + is_admin=False, + is_superadmin=False, + ) + token = signer.generate_token(user_context) + """ + + _config: JWTConfig + + def __init__(self, config: JWTConfig) -> None: + """ + Initialize JWT signer with configuration. + + Args: + config: JWT configuration containing secret key and other settings + """ + self._config = config + + def generate_token(self, user_context: JWTUserContext) -> str: + """ + Generate a JWT token from authenticated user context. + + This method creates a JWT token containing all necessary user authentication + information. The token is signed using HS256 with the configured secret key. + + Args: + user_context: User context data containing authentication information + + Returns: + Encoded JWT token string + + Raises: + JWTError: If token generation fails + """ + now = datetime.now(timezone.utc) + + claims = JWTClaims( + sub=user_context.user_id, + exp=now + self._config.token_expiration, + iat=now, + iss=self._config.issuer, + access_key=user_context.access_key, + role=user_context.role, + domain_name=user_context.domain_name, + is_admin=user_context.is_admin, + is_superadmin=user_context.is_superadmin, + ) + + try: + return jwt.encode( + claims.to_dict(), + self._config.secret_key, + algorithm=self._config.algorithm, + ) + except Exception as e: + raise JWTError(f"JWT generation failed: {e}") from e diff --git a/src/ai/backend/common/jwt/types.py b/src/ai/backend/common/jwt/types.py new file mode 100644 index 00000000000..f33e861d92d --- /dev/null +++ b/src/ai/backend/common/jwt/types.py @@ -0,0 +1,123 @@ +"""JWT token types and claims for GraphQL Federation authentication.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any +from uuid import UUID + +from ai.backend.common.types import AccessKey + + +@dataclass(frozen=True) +class JWTUserContext: + """ + User context data for JWT token generation. + + This dataclass encapsulates all user information needed to generate + a JWT token. It provides a structured way to pass user data from + the authentication layer to the JWT signer. + + Attributes: + user_id: User's UUID + access_key: User's access key + role: User's role ("admin", "user", or "superadmin") + domain_name: User's domain name + is_admin: Whether the user has admin privileges + is_superadmin: Whether the user has superadmin privileges + """ + + user_id: UUID + access_key: AccessKey + role: str + domain_name: str + is_admin: bool + is_superadmin: bool + + +@dataclass(frozen=True) +class JWTClaims: + """ + JWT token payload for GraphQL Federation authentication. + + This dataclass represents the claims contained in a JWT token used for + authenticating GraphQL requests through Hive Router. The token is distinguished + from other JWT uses (like appproxy) by the 'iss' (issuer) claim. + + Attributes: + sub: Subject - User UUID + exp: Expiration time (UTC) + iat: Issued at time (UTC) + iss: Issuer identifier (e.g., "backend.ai-webserver") + access_key: User's access key + role: User role ("admin", "user", or "superadmin") + domain_name: User's domain + is_admin: Whether user has admin privileges + is_superadmin: Whether user has superadmin privileges + """ + + # Standard JWT claims (RFC 7519) + sub: UUID + exp: datetime + iat: datetime + iss: str + + # Backend.AI specific claims + access_key: AccessKey + role: str + domain_name: str + is_admin: bool + is_superadmin: bool + + def to_dict(self) -> dict[str, Any]: + """ + Convert JWTClaims to a dictionary suitable for JWT payload. + + Datetime objects are converted to Unix timestamps (integers) as required + by the JWT standard. + + Returns: + Dictionary representation of claims with timestamps as integers. + """ + return { + "sub": str(self.sub), + "exp": int(self.exp.timestamp()), + "iat": int(self.iat.timestamp()), + "iss": self.iss, + "access_key": str(self.access_key), + "role": self.role, + "domain_name": self.domain_name, + "is_admin": self.is_admin, + "is_superadmin": self.is_superadmin, + } + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> JWTClaims: + """ + Parse JWT payload dictionary to JWTClaims. + + Converts Unix timestamps back to datetime objects and validates + the structure of the payload. + + Args: + payload: Dictionary containing JWT claims + + Returns: + JWTClaims instance + + Raises: + KeyError: If required claims are missing + ValueError: If claim values are invalid + """ + return cls( + sub=UUID(payload["sub"]), + exp=datetime.fromtimestamp(payload["exp"], tz=timezone.utc), + iat=datetime.fromtimestamp(payload["iat"], tz=timezone.utc), + iss=payload["iss"], + access_key=AccessKey(payload["access_key"]), + role=payload["role"], + domain_name=payload["domain_name"], + is_admin=payload["is_admin"], + is_superadmin=payload["is_superadmin"], + ) diff --git a/src/ai/backend/common/jwt/validator.py b/src/ai/backend/common/jwt/validator.py new file mode 100644 index 00000000000..19c293633eb --- /dev/null +++ b/src/ai/backend/common/jwt/validator.py @@ -0,0 +1,128 @@ +"""JWT token validator for verifying authentication tokens.""" + +from __future__ import annotations + +import jwt +from jwt.exceptions import ( + DecodeError, + ExpiredSignatureError, + InvalidIssuerError, + InvalidSignatureError, +) + +from .config import JWTConfig +from .exceptions import ( + JWTDecodeError, + JWTExpiredError, + JWTInvalidClaimsError, + JWTInvalidSignatureError, +) +from .types import JWTClaims + + +class JWTValidator: + """ + JWT token validator for GraphQL Federation authentication. + + This class is used by the manager to validate JWT tokens received from + Hive Router via the X-BackendAI-Token header. It verifies the token's + signature, expiration, and claims. + + Usage: + config = JWTConfig(secret_key="your-secret-key") + validator = JWTValidator(config) + claims = validator.validate_token(token_string) + """ + + _config: JWTConfig + + def __init__(self, config: JWTConfig) -> None: + """ + Initialize JWT validator with configuration. + + Args: + config: JWT configuration containing secret key and validation settings + """ + self._config = config + + def validate_token(self, token: str) -> JWTClaims: + """ + Validate JWT token and extract claims. + + This method performs comprehensive validation: + 1. Verifies the token signature using the configured secret key + 2. Checks token expiration + 3. Validates the issuer claim + 4. Ensures all required claims are present and valid + + Args: + token: Encoded JWT token string + + Returns: + JWTClaims object containing validated user information + + Raises: + JWTExpiredError: If the token has expired + JWTInvalidSignatureError: If signature verification fails + JWTInvalidClaimsError: If claims are missing or invalid + JWTDecodeError: If the token cannot be decoded + """ + try: + # Decode and verify token + payload = jwt.decode( + token, + self._config.secret_key, + algorithms=[self._config.algorithm], + issuer=self._config.issuer, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_iat": True, + "verify_iss": True, + }, + ) + + # Parse claims from payload + claims = JWTClaims.from_dict(payload) + + # Validate claim values + self._validate_claims(claims) + + return claims + + except ExpiredSignatureError as e: + raise JWTExpiredError("JWT token has expired") from e + except InvalidSignatureError as e: + raise JWTInvalidSignatureError("JWT signature verification failed") from e + except InvalidIssuerError as e: + raise JWTInvalidClaimsError(f"Invalid issuer: {e}") from e + except (KeyError, ValueError, TypeError) as e: + raise JWTInvalidClaimsError(f"JWT claims are invalid: {e}") from e + except DecodeError as e: + raise JWTDecodeError(f"Failed to decode JWT token: {e}") from e + + def _validate_claims(self, claims: JWTClaims) -> None: + """ + Validate claim values meet requirements. + + Ensures that role is one of the expected values and that the issuer + matches the configured value. + + Args: + claims: Parsed JWT claims to validate + + Raises: + JWTInvalidClaimsError: If validation fails + """ + # Validate role is one of expected values + valid_roles = {"admin", "user", "superadmin"} + if claims.role not in valid_roles: + raise JWTInvalidClaimsError( + f"Invalid role: {claims.role}. Must be one of {valid_roles}" + ) + + # Validate issuer matches expected value + if claims.iss != self._config.issuer: + raise JWTInvalidClaimsError( + f"Invalid issuer: expected '{self._config.issuer}', got '{claims.iss}'" + ) diff --git a/tests/common/jwt/BUILD b/tests/common/jwt/BUILD new file mode 100644 index 00000000000..dabf212d7e7 --- /dev/null +++ b/tests/common/jwt/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/tests/common/jwt/test_signer.py b/tests/common/jwt/test_signer.py new file mode 100644 index 00000000000..d106912fbe5 --- /dev/null +++ b/tests/common/jwt/test_signer.py @@ -0,0 +1,248 @@ +"""Tests for JWT signer.""" + +from __future__ import annotations + +import time +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +import jwt as pyjwt +import pytest + +from ai.backend.common.jwt.config import JWTConfig +from ai.backend.common.jwt.signer import JWTSigner +from ai.backend.common.jwt.types import JWTUserContext +from ai.backend.common.types import AccessKey + + +@pytest.fixture +def jwt_config() -> JWTConfig: + """Create test JWT configuration.""" + return JWTConfig( + secret_key="test-secret-key-at-least-32-bytes-long", + algorithm="HS256", + token_expiration_seconds=900, + issuer="backend.ai-webserver", + ) + + +@pytest.fixture +def jwt_signer(jwt_config: JWTConfig) -> JWTSigner: + """Create JWT signer instance.""" + return JWTSigner(jwt_config) + + +@pytest.fixture +def user_context() -> JWTUserContext: + """Create test user context.""" + return JWTUserContext( + user_id=uuid4(), + access_key=AccessKey("AKIAIOSFODNN7EXAMPLE"), + role="user", + domain_name="default", + is_admin=False, + is_superadmin=False, + ) + + +def test_generate_token_creates_valid_jwt( + jwt_signer: JWTSigner, + jwt_config: JWTConfig, + user_context: JWTUserContext, +) -> None: + """Test that generate_token creates a valid JWT token.""" + token = jwt_signer.generate_token(user_context) + + # Verify token is a string + assert isinstance(token, str) + assert len(token) > 0 + + # Decode without verification to check structure + decoded = pyjwt.decode( + token, + options={"verify_signature": False}, + ) + + # Verify claims are present + assert "sub" in decoded + assert "exp" in decoded + assert "iat" in decoded + assert "iss" in decoded + assert "access_key" in decoded + assert "role" in decoded + assert "domain_name" in decoded + assert "is_admin" in decoded + assert "is_superadmin" in decoded + + +def test_generate_token_includes_user_data( + jwt_signer: JWTSigner, + jwt_config: JWTConfig, + user_context: JWTUserContext, +) -> None: + """Test that generated token includes all user context data.""" + token = jwt_signer.generate_token(user_context) + + decoded = pyjwt.decode( + token, + jwt_config.secret_key, + algorithms=[jwt_config.algorithm], + issuer=jwt_config.issuer, + ) + + assert decoded["sub"] == str(user_context.user_id) + assert decoded["access_key"] == str(user_context.access_key) + assert decoded["role"] == user_context.role + assert decoded["domain_name"] == user_context.domain_name + assert decoded["is_admin"] == user_context.is_admin + assert decoded["is_superadmin"] == user_context.is_superadmin + assert decoded["iss"] == jwt_config.issuer + + +def test_generate_token_sets_expiration( + jwt_signer: JWTSigner, + jwt_config: JWTConfig, + user_context: JWTUserContext, +) -> None: + """Test that generated token has correct expiration time.""" + before_generation = datetime.now(timezone.utc) + token = jwt_signer.generate_token(user_context) + after_generation = datetime.now(timezone.utc) + + decoded = pyjwt.decode( + token, + jwt_config.secret_key, + algorithms=[jwt_config.algorithm], + ) + + exp_time = datetime.fromtimestamp(decoded["exp"], tz=timezone.utc) + expected_min = before_generation + jwt_config.token_expiration - timedelta(seconds=2) + expected_max = after_generation + jwt_config.token_expiration + timedelta(seconds=2) + + # Expiration should be close to configured time (with 2 second margin due to timestamp precision) + assert expected_min <= exp_time <= expected_max + + +def test_generate_token_sets_issued_at( + jwt_signer: JWTSigner, + jwt_config: JWTConfig, + user_context: JWTUserContext, +) -> None: + """Test that generated token has correct issued-at time.""" + before_generation = datetime.now(timezone.utc) + token = jwt_signer.generate_token(user_context) + after_generation = datetime.now(timezone.utc) + + decoded = pyjwt.decode( + token, + jwt_config.secret_key, + algorithms=[jwt_config.algorithm], + ) + + iat_time = datetime.fromtimestamp(decoded["iat"], tz=timezone.utc) + + # Add tolerance for timestamp precision (JWT uses seconds, not microseconds) + before_with_margin = before_generation - timedelta(seconds=1) + after_with_margin = after_generation + timedelta(seconds=1) + + # Issued-at should be within test execution time (with 1 second margin) + assert before_with_margin <= iat_time <= after_with_margin + + +def test_generate_token_with_admin_user( + jwt_signer: JWTSigner, + jwt_config: JWTConfig, +) -> None: + """Test token generation for admin user.""" + admin_context = JWTUserContext( + user_id=uuid4(), + access_key=AccessKey("AKIAADMIN123456789"), + role="admin", + domain_name="admin-domain", + is_admin=True, + is_superadmin=False, + ) + + token = jwt_signer.generate_token(admin_context) + + decoded = pyjwt.decode( + token, + jwt_config.secret_key, + algorithms=[jwt_config.algorithm], + ) + + assert decoded["role"] == "admin" + assert decoded["is_admin"] is True + assert decoded["is_superadmin"] is False + + +def test_generate_token_with_superadmin_user( + jwt_signer: JWTSigner, + jwt_config: JWTConfig, +) -> None: + """Test token generation for superadmin user.""" + superadmin_context = JWTUserContext( + user_id=uuid4(), + access_key=AccessKey("AKIASUPERADMIN123456"), + role="superadmin", + domain_name="super-domain", + is_admin=True, + is_superadmin=True, + ) + + token = jwt_signer.generate_token(superadmin_context) + + decoded = pyjwt.decode( + token, + jwt_config.secret_key, + algorithms=[jwt_config.algorithm], + ) + + assert decoded["role"] == "superadmin" + assert decoded["is_admin"] is True + assert decoded["is_superadmin"] is True + + +def test_generate_token_signature_verification( + jwt_signer: JWTSigner, + jwt_config: JWTConfig, + user_context: JWTUserContext, +) -> None: + """Test that generated token can be verified with correct secret.""" + token = jwt_signer.generate_token(user_context) + + # Should not raise exception + pyjwt.decode( + token, + jwt_config.secret_key, + algorithms=[jwt_config.algorithm], + ) + + +def test_generate_token_wrong_secret_fails_verification( + jwt_signer: JWTSigner, + user_context: JWTUserContext, +) -> None: + """Test that token fails verification with wrong secret.""" + token = jwt_signer.generate_token(user_context) + + with pytest.raises(pyjwt.InvalidSignatureError): + pyjwt.decode( + token, + "wrong-secret-key", + algorithms=["HS256"], + ) + + +def test_multiple_tokens_are_different( + jwt_signer: JWTSigner, + user_context: JWTUserContext, +) -> None: + """Test that generating multiple tokens produces different tokens.""" + token1 = jwt_signer.generate_token(user_context) + # Sleep to ensure different timestamps (JWT uses second precision) + time.sleep(1.1) + token2 = jwt_signer.generate_token(user_context) + + # Tokens should be different due to different iat/exp timestamps + assert token1 != token2 diff --git a/tests/common/jwt/test_types.py b/tests/common/jwt/test_types.py new file mode 100644 index 00000000000..62d62802360 --- /dev/null +++ b/tests/common/jwt/test_types.py @@ -0,0 +1,198 @@ +"""Tests for JWT types.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from uuid import uuid4 + +import pytest + +from ai.backend.common.jwt.types import JWTClaims, JWTUserContext +from ai.backend.common.types import AccessKey + + +def test_jwt_user_context_creation() -> None: + """Test JWTUserContext dataclass creation.""" + user_id = uuid4() + access_key = AccessKey("AKIAIOSFODNN7EXAMPLE") + + context = JWTUserContext( + user_id=user_id, + access_key=access_key, + role="user", + domain_name="default", + is_admin=False, + is_superadmin=False, + ) + + assert context.user_id == user_id + assert context.access_key == access_key + assert context.role == "user" + assert context.domain_name == "default" + assert context.is_admin is False + assert context.is_superadmin is False + + +def test_jwt_user_context_immutable() -> None: + """Test that JWTUserContext is immutable.""" + context = JWTUserContext( + user_id=uuid4(), + access_key=AccessKey("AKIAIOSFODNN7EXAMPLE"), + role="user", + domain_name="default", + is_admin=False, + is_superadmin=False, + ) + + with pytest.raises(AttributeError): + context.role = "admin" # type: ignore + + +def test_jwt_claims_creation() -> None: + """Test JWTClaims dataclass creation.""" + user_id = uuid4() + access_key = AccessKey("AKIAIOSFODNN7EXAMPLE") + now = datetime.now(timezone.utc) + + claims = JWTClaims( + sub=user_id, + exp=now, + iat=now, + iss="backend.ai-webserver", + access_key=access_key, + role="user", + domain_name="default", + is_admin=False, + is_superadmin=False, + ) + + assert claims.sub == user_id + assert claims.access_key == access_key + assert claims.role == "user" + assert claims.iss == "backend.ai-webserver" + + +def test_jwt_claims_to_dict() -> None: + """Test JWTClaims serialization to dictionary.""" + user_id = uuid4() + access_key = AccessKey("AKIAIOSFODNN7EXAMPLE") + now = datetime.now(timezone.utc) + + claims = JWTClaims( + sub=user_id, + exp=now, + iat=now, + iss="backend.ai-webserver", + access_key=access_key, + role="admin", + domain_name="test-domain", + is_admin=True, + is_superadmin=False, + ) + + claims_dict = claims.to_dict() + + assert claims_dict["sub"] == str(user_id) + assert claims_dict["exp"] == int(now.timestamp()) + assert claims_dict["iat"] == int(now.timestamp()) + assert claims_dict["iss"] == "backend.ai-webserver" + assert claims_dict["access_key"] == str(access_key) + assert claims_dict["role"] == "admin" + assert claims_dict["domain_name"] == "test-domain" + assert claims_dict["is_admin"] is True + assert claims_dict["is_superadmin"] is False + + +def test_jwt_claims_from_dict() -> None: + """Test JWTClaims deserialization from dictionary.""" + user_id = uuid4() + access_key = AccessKey("AKIAIOSFODNN7EXAMPLE") + now = datetime.now(timezone.utc) + + payload = { + "sub": str(user_id), + "exp": int(now.timestamp()), + "iat": int(now.timestamp()), + "iss": "backend.ai-webserver", + "access_key": str(access_key), + "role": "superadmin", + "domain_name": "prod-domain", + "is_admin": True, + "is_superadmin": True, + } + + claims = JWTClaims.from_dict(payload) + + assert claims.sub == user_id + assert claims.access_key == access_key + assert claims.role == "superadmin" + assert claims.domain_name == "prod-domain" + assert claims.is_admin is True + assert claims.is_superadmin is True + assert claims.iss == "backend.ai-webserver" + + +def test_jwt_claims_roundtrip() -> None: + """Test that JWTClaims can be serialized and deserialized correctly.""" + user_id = uuid4() + access_key = AccessKey("AKIAIOSFODNN7EXAMPLE") + now = datetime.now(timezone.utc) + + original_claims = JWTClaims( + sub=user_id, + exp=now, + iat=now, + iss="backend.ai-webserver", + access_key=access_key, + role="user", + domain_name="default", + is_admin=False, + is_superadmin=False, + ) + + # Serialize to dict and back + claims_dict = original_claims.to_dict() + restored_claims = JWTClaims.from_dict(claims_dict) + + # Compare all fields + assert restored_claims.sub == original_claims.sub + assert restored_claims.access_key == original_claims.access_key + assert restored_claims.role == original_claims.role + assert restored_claims.domain_name == original_claims.domain_name + assert restored_claims.is_admin == original_claims.is_admin + assert restored_claims.is_superadmin == original_claims.is_superadmin + assert restored_claims.iss == original_claims.iss + + +def test_jwt_claims_from_dict_missing_field() -> None: + """Test that JWTClaims.from_dict raises error when required field is missing.""" + payload = { + "sub": str(uuid4()), + "exp": int(datetime.now(timezone.utc).timestamp()), + "iat": int(datetime.now(timezone.utc).timestamp()), + "iss": "backend.ai-webserver", + # Missing access_key, role, domain_name, etc. + } + + with pytest.raises(KeyError): + JWTClaims.from_dict(payload) + + +def test_jwt_claims_from_dict_invalid_uuid() -> None: + """Test that JWTClaims.from_dict raises error for invalid UUID.""" + now = datetime.now(timezone.utc) + + payload = { + "sub": "invalid-uuid", + "exp": int(now.timestamp()), + "iat": int(now.timestamp()), + "iss": "backend.ai-webserver", + "access_key": "AKIAIOSFODNN7EXAMPLE", + "role": "user", + "domain_name": "default", + "is_admin": False, + "is_superadmin": False, + } + + with pytest.raises(ValueError): + JWTClaims.from_dict(payload) diff --git a/tests/common/jwt/test_validator.py b/tests/common/jwt/test_validator.py new file mode 100644 index 00000000000..dab67b0d387 --- /dev/null +++ b/tests/common/jwt/test_validator.py @@ -0,0 +1,312 @@ +"""Tests for JWT validator.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from uuid import uuid4 + +import jwt as pyjwt +import pytest + +from ai.backend.common.jwt.config import JWTConfig +from ai.backend.common.jwt.exceptions import ( + JWTDecodeError, + JWTExpiredError, + JWTInvalidClaimsError, + JWTInvalidSignatureError, +) +from ai.backend.common.jwt.signer import JWTSigner +from ai.backend.common.jwt.types import JWTUserContext +from ai.backend.common.jwt.validator import JWTValidator +from ai.backend.common.types import AccessKey + + +@pytest.fixture +def jwt_config() -> JWTConfig: + """Create test JWT configuration.""" + return JWTConfig( + secret_key="test-secret-key-at-least-32-bytes-long", + algorithm="HS256", + token_expiration_seconds=900, + issuer="backend.ai-webserver", + ) + + +@pytest.fixture +def jwt_signer(jwt_config: JWTConfig) -> JWTSigner: + """Create JWT signer instance.""" + return JWTSigner(jwt_config) + + +@pytest.fixture +def jwt_validator(jwt_config: JWTConfig) -> JWTValidator: + """Create JWT validator instance.""" + return JWTValidator(jwt_config) + + +@pytest.fixture +def user_context() -> JWTUserContext: + """Create test user context.""" + return JWTUserContext( + user_id=uuid4(), + access_key=AccessKey("AKIAIOSFODNN7EXAMPLE"), + role="user", + domain_name="default", + is_admin=False, + is_superadmin=False, + ) + + +def test_validate_token_with_valid_token( + jwt_signer: JWTSigner, + jwt_validator: JWTValidator, + user_context: JWTUserContext, +) -> None: + """Test that validator accepts valid tokens.""" + token = jwt_signer.generate_token(user_context) + claims = jwt_validator.validate_token(token) + + assert claims.sub == user_context.user_id + assert claims.access_key == user_context.access_key + assert claims.role == user_context.role + assert claims.domain_name == user_context.domain_name + assert claims.is_admin == user_context.is_admin + assert claims.is_superadmin == user_context.is_superadmin + + +def test_validate_token_with_expired_token( + jwt_config: JWTConfig, + jwt_validator: JWTValidator, + user_context: JWTUserContext, +) -> None: + """Test that validator rejects expired tokens.""" + # Create token that expired 1 hour ago + past_time = datetime.now(timezone.utc) - timedelta(hours=1) + + payload = { + "sub": str(user_context.user_id), + "exp": int((past_time + timedelta(seconds=900)).timestamp()), + "iat": int(past_time.timestamp()), + "iss": jwt_config.issuer, + "access_key": str(user_context.access_key), + "role": user_context.role, + "domain_name": user_context.domain_name, + "is_admin": user_context.is_admin, + "is_superadmin": user_context.is_superadmin, + } + + expired_token = pyjwt.encode( + payload, + jwt_config.secret_key, + algorithm=jwt_config.algorithm, + ) + + with pytest.raises(JWTExpiredError): + jwt_validator.validate_token(expired_token) + + +def test_validate_token_with_invalid_signature( + jwt_signer: JWTSigner, + jwt_config: JWTConfig, + user_context: JWTUserContext, +) -> None: + """Test that validator rejects tokens with invalid signature.""" + token = jwt_signer.generate_token(user_context) + + # Create validator with different secret + wrong_config = JWTConfig( + secret_key="wrong-secret-key-different-from-original", + algorithm="HS256", + token_expiration_seconds=900, + issuer="backend.ai-webserver", + ) + wrong_validator = JWTValidator(wrong_config) + + with pytest.raises(JWTInvalidSignatureError): + wrong_validator.validate_token(token) + + +def test_validate_token_with_malformed_token( + jwt_validator: JWTValidator, +) -> None: + """Test that validator rejects malformed tokens.""" + malformed_token = "not.a.valid.jwt.token" + + with pytest.raises(JWTDecodeError): + jwt_validator.validate_token(malformed_token) + + +def test_validate_token_with_missing_claims( + jwt_config: JWTConfig, + jwt_validator: JWTValidator, +) -> None: + """Test that validator rejects tokens with missing required claims.""" + # Create token with missing claims + payload = { + "sub": str(uuid4()), + "exp": int((datetime.now(timezone.utc) + timedelta(seconds=900)).timestamp()), + "iat": int(datetime.now(timezone.utc).timestamp()), + "iss": jwt_config.issuer, + # Missing: access_key, role, domain_name, is_admin, is_superadmin + } + + incomplete_token = pyjwt.encode( + payload, + jwt_config.secret_key, + algorithm=jwt_config.algorithm, + ) + + with pytest.raises(JWTInvalidClaimsError): + jwt_validator.validate_token(incomplete_token) + + +def test_validate_token_with_invalid_role( + jwt_config: JWTConfig, + jwt_validator: JWTValidator, + user_context: JWTUserContext, +) -> None: + """Test that validator rejects tokens with invalid role.""" + payload = { + "sub": str(user_context.user_id), + "exp": int((datetime.now(timezone.utc) + timedelta(seconds=900)).timestamp()), + "iat": int(datetime.now(timezone.utc).timestamp()), + "iss": jwt_config.issuer, + "access_key": str(user_context.access_key), + "role": "invalid_role", # Not in valid_roles + "domain_name": user_context.domain_name, + "is_admin": user_context.is_admin, + "is_superadmin": user_context.is_superadmin, + } + + invalid_token = pyjwt.encode( + payload, + jwt_config.secret_key, + algorithm=jwt_config.algorithm, + ) + + with pytest.raises(JWTInvalidClaimsError) as exc_info: + jwt_validator.validate_token(invalid_token) + + assert "Invalid role" in str(exc_info.value) + + +def test_validate_token_with_wrong_issuer( + jwt_config: JWTConfig, + jwt_validator: JWTValidator, + user_context: JWTUserContext, +) -> None: + """Test that validator rejects tokens with wrong issuer.""" + payload = { + "sub": str(user_context.user_id), + "exp": int((datetime.now(timezone.utc) + timedelta(seconds=900)).timestamp()), + "iat": int(datetime.now(timezone.utc).timestamp()), + "iss": "wrong-issuer", # Different from expected issuer + "access_key": str(user_context.access_key), + "role": user_context.role, + "domain_name": user_context.domain_name, + "is_admin": user_context.is_admin, + "is_superadmin": user_context.is_superadmin, + } + + wrong_issuer_token = pyjwt.encode( + payload, + jwt_config.secret_key, + algorithm=jwt_config.algorithm, + ) + + with pytest.raises(JWTInvalidClaimsError) as exc_info: + jwt_validator.validate_token(wrong_issuer_token) + + assert "Invalid issuer" in str(exc_info.value) + + +def test_validate_token_with_admin_role( + jwt_signer: JWTSigner, + jwt_validator: JWTValidator, +) -> None: + """Test validation of token with admin role.""" + admin_context = JWTUserContext( + user_id=uuid4(), + access_key=AccessKey("AKIAADMIN123456789"), + role="admin", + domain_name="admin-domain", + is_admin=True, + is_superadmin=False, + ) + + token = jwt_signer.generate_token(admin_context) + claims = jwt_validator.validate_token(token) + + assert claims.role == "admin" + assert claims.is_admin is True + assert claims.is_superadmin is False + + +def test_validate_token_with_superadmin_role( + jwt_signer: JWTSigner, + jwt_validator: JWTValidator, +) -> None: + """Test validation of token with superadmin role.""" + superadmin_context = JWTUserContext( + user_id=uuid4(), + access_key=AccessKey("AKIASUPERADMIN123456"), + role="superadmin", + domain_name="super-domain", + is_admin=True, + is_superadmin=True, + ) + + token = jwt_signer.generate_token(superadmin_context) + claims = jwt_validator.validate_token(token) + + assert claims.role == "superadmin" + assert claims.is_admin is True + assert claims.is_superadmin is True + + +def test_validate_token_roundtrip( + jwt_signer: JWTSigner, + jwt_validator: JWTValidator, + user_context: JWTUserContext, +) -> None: + """Test complete roundtrip: sign -> validate -> verify data.""" + # Generate token + token = jwt_signer.generate_token(user_context) + + # Validate token + claims = jwt_validator.validate_token(token) + + # Verify all data matches + assert claims.sub == user_context.user_id + assert claims.access_key == user_context.access_key + assert claims.role == user_context.role + assert claims.domain_name == user_context.domain_name + assert claims.is_admin == user_context.is_admin + assert claims.is_superadmin == user_context.is_superadmin + + +def test_validate_token_with_invalid_uuid_in_sub( + jwt_config: JWTConfig, + jwt_validator: JWTValidator, +) -> None: + """Test that validator rejects tokens with invalid UUID in sub claim.""" + payload = { + "sub": "not-a-valid-uuid", + "exp": int((datetime.now(timezone.utc) + timedelta(seconds=900)).timestamp()), + "iat": int(datetime.now(timezone.utc).timestamp()), + "iss": jwt_config.issuer, + "access_key": "AKIAIOSFODNN7EXAMPLE", + "role": "user", + "domain_name": "default", + "is_admin": False, + "is_superadmin": False, + } + + invalid_token = pyjwt.encode( + payload, + jwt_config.secret_key, + algorithm=jwt_config.algorithm, + ) + + with pytest.raises(JWTInvalidClaimsError): + jwt_validator.validate_token(invalid_token)