Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/ai/backend/common/jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
JWT authentication module for GraphQL Federation.

This module provides JWT-based authentication for GraphQL requests going through
Hive Router. It uses the X-BackendAI-Token custom header to avoid conflicts with
existing Bearer token usage in appproxy.

Key components:
- JWTSigner: Generates JWT tokens from authenticated user context (webserver)
- JWTValidator: Validates JWT tokens and extracts user claims (manager)
- JWTConfig: Configuration for JWT authentication
- JWTClaims: Dataclass representing JWT payload claims

Example usage in webserver:
from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext

config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"])
signer = JWTSigner(config)

user_context = JWTUserContext(
user_id=user_uuid,
access_key=access_key,
role="user",
domain_name="default",
is_admin=False,
is_superadmin=False,
)
token = signer.generate_token(user_context)

# Add to request headers
headers["X-BackendAI-Token"] = token

Example usage in manager:
from ai.backend.common.jwt import JWTValidator, JWTConfig

config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"])
validator = JWTValidator(config)

token = request.headers.get("X-BackendAI-Token")
claims = validator.validate_token(token)

# Use claims for authentication
user_id = claims.sub
access_key = claims.access_key
"""

from .config import JWTConfig
from .exceptions import (
JWTDecodeError,
JWTError,
JWTExpiredError,
JWTInvalidClaimsError,
JWTInvalidSignatureError,
)
from .signer import JWTSigner
from .types import JWTClaims, JWTUserContext
from .validator import JWTValidator

__all__ = [
# Configuration
"JWTConfig",
# Types
"JWTClaims",
"JWTUserContext",
# Core classes
"JWTSigner",
"JWTValidator",
# Exceptions
"JWTError",
"JWTExpiredError",
"JWTInvalidSignatureError",
"JWTInvalidClaimsError",
"JWTDecodeError",
]
68 changes: 68 additions & 0 deletions src/ai/backend/common/jwt/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""JWT authentication configuration for GraphQL Federation."""

from __future__ import annotations

from datetime import timedelta

from pydantic import Field

from ai.backend.common.config import BaseConfigSchema


class JWTConfig(BaseConfigSchema):
"""
Configuration for JWT-based authentication in GraphQL Federation.

This configuration must be consistent between webserver (which generates tokens)
and manager (which validates tokens). The secret_key must be kept secure and
should be the same on both sides.

Attributes:
enabled: Whether JWT authentication is enabled
secret_key: Secret key for HS256 signing and verification
algorithm: JWT signing algorithm (must be HS256)
token_expiration_seconds: Token validity duration in seconds
issuer: JWT issuer claim value for validation
header_name: HTTP header name for JWT token transmission
"""

enabled: bool = Field(
default=True,
description="Enable JWT authentication for GraphQL Federation requests",
)

secret_key: str = Field(
description="Secret key for HS256 signing and verification. "
"MUST be the same between webserver and manager. "
"Should be at least 32 bytes of random data.",
)

algorithm: str = Field(
default="HS256",
description="JWT signing algorithm (only HS256 is supported)",
)

token_expiration_seconds: int = Field(
default=900, # 15 minutes
description="JWT token expiration time in seconds (default: 900 = 15 minutes)",
)

issuer: str = Field(
default="backend.ai-webserver",
description="JWT issuer claim value for GraphQL Federation tokens",
)

header_name: str = Field(
default="X-BackendAI-Token",
description="Custom HTTP header name for JWT token transmission",
)

@property
def token_expiration(self) -> timedelta:
"""
Get token expiration as a timedelta object.

Returns:
Token expiration duration as timedelta
"""
return timedelta(seconds=self.token_expiration_seconds)
86 changes: 86 additions & 0 deletions src/ai/backend/common/jwt/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""JWT authentication exceptions for GraphQL Federation."""

from __future__ import annotations

from aiohttp import web

from ai.backend.common.exception import (
BackendAIError,
ErrorCode,
ErrorDetail,
ErrorDomain,
ErrorOperation,
)


class JWTError(BackendAIError):
"""
Base exception for JWT-related errors in GraphQL Federation authentication.

All JWT-specific exceptions inherit from this base class.
"""

@classmethod
def error_code(cls) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.USER,
operation=ErrorOperation.AUTH,
error_detail=ErrorDetail.UNAUTHORIZED,
)


class JWTExpiredError(JWTError, web.HTTPUnauthorized):
"""
JWT token has expired.

Raised when attempting to use a token past its expiration time.
"""

error_type = "https://api.backend.ai/probs/jwt-expired"
error_title = "JWT token has expired."

@classmethod
def error_code(cls) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.USER,
operation=ErrorOperation.AUTH,
error_detail=ErrorDetail.DATA_EXPIRED,
)


class JWTInvalidSignatureError(JWTError, web.HTTPUnauthorized):
"""
JWT signature verification failed.

Raised when the token's signature doesn't match the expected signature,
indicating the token may have been tampered with or was signed with
a different secret key.
"""

error_type = "https://api.backend.ai/probs/jwt-invalid-signature"
error_title = "JWT signature verification failed."


class JWTInvalidClaimsError(JWTError, web.HTTPUnauthorized):
"""
JWT claims are missing or invalid.

Raised when required claims are missing from the token or when
claim values don't meet validation requirements (e.g., invalid role,
wrong issuer).
"""

error_type = "https://api.backend.ai/probs/jwt-invalid-claims"
error_title = "JWT claims are invalid."


class JWTDecodeError(JWTError, web.HTTPUnauthorized):
"""
Failed to decode JWT token.

Raised when the token cannot be decoded, typically due to malformed
token structure or encoding issues.
"""

error_type = "https://api.backend.ai/probs/jwt-decode-error"
error_title = "Failed to decode JWT token."
87 changes: 87 additions & 0 deletions src/ai/backend/common/jwt/signer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""JWT token signer for generating authentication tokens."""

from __future__ import annotations

from datetime import datetime, timezone

import jwt

from .config import JWTConfig
from .exceptions import JWTError
from .types import JWTClaims, JWTUserContext


class JWTSigner:
"""
JWT token generator for GraphQL Federation authentication.

This class is used by the webserver to generate JWT tokens after successful
HMAC authentication. The generated tokens are then forwarded to the manager
via Hive Router using the X-BackendAI-Token header.

Usage:
from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext

config = JWTConfig(secret_key="your-secret-key")
signer = JWTSigner(config)

user_context = JWTUserContext(
user_id=user_uuid,
access_key=access_key,
role="user",
domain_name="default",
is_admin=False,
is_superadmin=False,
)
token = signer.generate_token(user_context)
"""

_config: JWTConfig

def __init__(self, config: JWTConfig) -> None:
"""
Initialize JWT signer with configuration.

Args:
config: JWT configuration containing secret key and other settings
"""
self._config = config

def generate_token(self, user_context: JWTUserContext) -> str:
"""
Generate a JWT token from authenticated user context.

This method creates a JWT token containing all necessary user authentication
information. The token is signed using HS256 with the configured secret key.

Args:
user_context: User context data containing authentication information

Returns:
Encoded JWT token string

Raises:
JWTError: If token generation fails
"""
now = datetime.now(timezone.utc)

claims = JWTClaims(
sub=user_context.user_id,
exp=now + self._config.token_expiration,
iat=now,
iss=self._config.issuer,
access_key=user_context.access_key,
role=user_context.role,
domain_name=user_context.domain_name,
is_admin=user_context.is_admin,
is_superadmin=user_context.is_superadmin,
)

try:
return jwt.encode(
claims.to_dict(),
self._config.secret_key,
algorithm=self._config.algorithm,
)
except Exception as e:
raise JWTError(f"JWT generation failed: {e}") from e
Loading
Loading