Skip to content

Commit 32ae536

Browse files
HyeockJinKimclaude
authored andcommitted
feat(BA-2825): Implement JWT authentication module for GraphQL Federation (#6410)
Co-authored-by: Claude <noreply@anthropic.com>
1 parent c78f7ea commit 32ae536

11 files changed

Lines changed: 1326 additions & 0 deletions

File tree

changes/6410.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implement JWT authentication module for GraphQL Federation
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
JWT authentication module for GraphQL Federation.
3+
4+
This module provides JWT-based authentication for GraphQL requests going through
5+
Hive Router. It uses the X-BackendAI-Token custom header to avoid conflicts with
6+
existing Bearer token usage in appproxy.
7+
8+
Key components:
9+
- JWTSigner: Generates JWT tokens from authenticated user context (webserver)
10+
- JWTValidator: Validates JWT tokens and extracts user claims (manager)
11+
- JWTConfig: Configuration for JWT authentication
12+
- JWTClaims: Dataclass representing JWT payload claims
13+
14+
Example usage in webserver:
15+
from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext
16+
17+
config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"])
18+
signer = JWTSigner(config)
19+
20+
user_context = JWTUserContext(
21+
user_id=user_uuid,
22+
access_key=access_key,
23+
role="user",
24+
domain_name="default",
25+
is_admin=False,
26+
is_superadmin=False,
27+
)
28+
token = signer.generate_token(user_context)
29+
30+
# Add to request headers
31+
headers["X-BackendAI-Token"] = token
32+
33+
Example usage in manager:
34+
from ai.backend.common.jwt import JWTValidator, JWTConfig
35+
36+
config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"])
37+
validator = JWTValidator(config)
38+
39+
token = request.headers.get("X-BackendAI-Token")
40+
claims = validator.validate_token(token)
41+
42+
# Use claims for authentication
43+
user_id = claims.sub
44+
access_key = claims.access_key
45+
"""
46+
47+
from .config import JWTConfig
48+
from .exceptions import (
49+
JWTDecodeError,
50+
JWTError,
51+
JWTExpiredError,
52+
JWTInvalidClaimsError,
53+
JWTInvalidSignatureError,
54+
)
55+
from .signer import JWTSigner
56+
from .types import JWTClaims, JWTUserContext
57+
from .validator import JWTValidator
58+
59+
__all__ = [
60+
# Configuration
61+
"JWTConfig",
62+
# Types
63+
"JWTClaims",
64+
"JWTUserContext",
65+
# Core classes
66+
"JWTSigner",
67+
"JWTValidator",
68+
# Exceptions
69+
"JWTError",
70+
"JWTExpiredError",
71+
"JWTInvalidSignatureError",
72+
"JWTInvalidClaimsError",
73+
"JWTDecodeError",
74+
]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""JWT authentication configuration for GraphQL Federation."""
2+
3+
from __future__ import annotations
4+
5+
from datetime import timedelta
6+
7+
from pydantic import Field
8+
9+
from ai.backend.common.config import BaseConfigSchema
10+
11+
12+
class JWTConfig(BaseConfigSchema):
13+
"""
14+
Configuration for JWT-based authentication in GraphQL Federation.
15+
16+
This configuration must be consistent between webserver (which generates tokens)
17+
and manager (which validates tokens). The secret_key must be kept secure and
18+
should be the same on both sides.
19+
20+
Attributes:
21+
enabled: Whether JWT authentication is enabled
22+
secret_key: Secret key for HS256 signing and verification
23+
algorithm: JWT signing algorithm (must be HS256)
24+
token_expiration_seconds: Token validity duration in seconds
25+
issuer: JWT issuer claim value for validation
26+
header_name: HTTP header name for JWT token transmission
27+
"""
28+
29+
enabled: bool = Field(
30+
default=True,
31+
description="Enable JWT authentication for GraphQL Federation requests",
32+
)
33+
34+
secret_key: str = Field(
35+
description="Secret key for HS256 signing and verification. "
36+
"MUST be the same between webserver and manager. "
37+
"Should be at least 32 bytes of random data.",
38+
)
39+
40+
algorithm: str = Field(
41+
default="HS256",
42+
description="JWT signing algorithm (only HS256 is supported)",
43+
)
44+
45+
token_expiration_seconds: int = Field(
46+
default=900, # 15 minutes
47+
description="JWT token expiration time in seconds (default: 900 = 15 minutes)",
48+
)
49+
50+
issuer: str = Field(
51+
default="backend.ai-webserver",
52+
description="JWT issuer claim value for GraphQL Federation tokens",
53+
)
54+
55+
header_name: str = Field(
56+
default="X-BackendAI-Token",
57+
description="Custom HTTP header name for JWT token transmission",
58+
)
59+
60+
@property
61+
def token_expiration(self) -> timedelta:
62+
"""
63+
Get token expiration as a timedelta object.
64+
65+
Returns:
66+
Token expiration duration as timedelta
67+
"""
68+
return timedelta(seconds=self.token_expiration_seconds)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""JWT authentication exceptions for GraphQL Federation."""
2+
3+
from __future__ import annotations
4+
5+
from aiohttp import web
6+
7+
from ai.backend.common.exception import (
8+
BackendAIError,
9+
ErrorCode,
10+
ErrorDetail,
11+
ErrorDomain,
12+
ErrorOperation,
13+
)
14+
15+
16+
class JWTError(BackendAIError):
17+
"""
18+
Base exception for JWT-related errors in GraphQL Federation authentication.
19+
20+
All JWT-specific exceptions inherit from this base class.
21+
"""
22+
23+
@classmethod
24+
def error_code(cls) -> ErrorCode:
25+
return ErrorCode(
26+
domain=ErrorDomain.USER,
27+
operation=ErrorOperation.AUTH,
28+
error_detail=ErrorDetail.UNAUTHORIZED,
29+
)
30+
31+
32+
class JWTExpiredError(JWTError, web.HTTPUnauthorized):
33+
"""
34+
JWT token has expired.
35+
36+
Raised when attempting to use a token past its expiration time.
37+
"""
38+
39+
error_type = "https://api.backend.ai/probs/jwt-expired"
40+
error_title = "JWT token has expired."
41+
42+
@classmethod
43+
def error_code(cls) -> ErrorCode:
44+
return ErrorCode(
45+
domain=ErrorDomain.USER,
46+
operation=ErrorOperation.AUTH,
47+
error_detail=ErrorDetail.DATA_EXPIRED,
48+
)
49+
50+
51+
class JWTInvalidSignatureError(JWTError, web.HTTPUnauthorized):
52+
"""
53+
JWT signature verification failed.
54+
55+
Raised when the token's signature doesn't match the expected signature,
56+
indicating the token may have been tampered with or was signed with
57+
a different secret key.
58+
"""
59+
60+
error_type = "https://api.backend.ai/probs/jwt-invalid-signature"
61+
error_title = "JWT signature verification failed."
62+
63+
64+
class JWTInvalidClaimsError(JWTError, web.HTTPUnauthorized):
65+
"""
66+
JWT claims are missing or invalid.
67+
68+
Raised when required claims are missing from the token or when
69+
claim values don't meet validation requirements (e.g., invalid role,
70+
wrong issuer).
71+
"""
72+
73+
error_type = "https://api.backend.ai/probs/jwt-invalid-claims"
74+
error_title = "JWT claims are invalid."
75+
76+
77+
class JWTDecodeError(JWTError, web.HTTPUnauthorized):
78+
"""
79+
Failed to decode JWT token.
80+
81+
Raised when the token cannot be decoded, typically due to malformed
82+
token structure or encoding issues.
83+
"""
84+
85+
error_type = "https://api.backend.ai/probs/jwt-decode-error"
86+
error_title = "Failed to decode JWT token."
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""JWT token signer for generating authentication tokens."""
2+
3+
from __future__ import annotations
4+
5+
from datetime import datetime, timezone
6+
7+
import jwt
8+
9+
from .config import JWTConfig
10+
from .exceptions import JWTError
11+
from .types import JWTClaims, JWTUserContext
12+
13+
14+
class JWTSigner:
15+
"""
16+
JWT token generator for GraphQL Federation authentication.
17+
18+
This class is used by the webserver to generate JWT tokens after successful
19+
HMAC authentication. The generated tokens are then forwarded to the manager
20+
via Hive Router using the X-BackendAI-Token header.
21+
22+
Usage:
23+
from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext
24+
25+
config = JWTConfig(secret_key="your-secret-key")
26+
signer = JWTSigner(config)
27+
28+
user_context = JWTUserContext(
29+
user_id=user_uuid,
30+
access_key=access_key,
31+
role="user",
32+
domain_name="default",
33+
is_admin=False,
34+
is_superadmin=False,
35+
)
36+
token = signer.generate_token(user_context)
37+
"""
38+
39+
_config: JWTConfig
40+
41+
def __init__(self, config: JWTConfig) -> None:
42+
"""
43+
Initialize JWT signer with configuration.
44+
45+
Args:
46+
config: JWT configuration containing secret key and other settings
47+
"""
48+
self._config = config
49+
50+
def generate_token(self, user_context: JWTUserContext) -> str:
51+
"""
52+
Generate a JWT token from authenticated user context.
53+
54+
This method creates a JWT token containing all necessary user authentication
55+
information. The token is signed using HS256 with the configured secret key.
56+
57+
Args:
58+
user_context: User context data containing authentication information
59+
60+
Returns:
61+
Encoded JWT token string
62+
63+
Raises:
64+
JWTError: If token generation fails
65+
"""
66+
now = datetime.now(timezone.utc)
67+
68+
claims = JWTClaims(
69+
sub=user_context.user_id,
70+
exp=now + self._config.token_expiration,
71+
iat=now,
72+
iss=self._config.issuer,
73+
access_key=user_context.access_key,
74+
role=user_context.role,
75+
domain_name=user_context.domain_name,
76+
is_admin=user_context.is_admin,
77+
is_superadmin=user_context.is_superadmin,
78+
)
79+
80+
try:
81+
return jwt.encode(
82+
claims.to_dict(),
83+
self._config.secret_key,
84+
algorithm=self._config.algorithm,
85+
)
86+
except Exception as e:
87+
raise JWTError(f"JWT generation failed: {e}") from e

0 commit comments

Comments
 (0)