Skip to content

Commit 1cb4e9f

Browse files
rapsealkclaude
andcommitted
feat: add RS256 signing and JWKS support to common/jwt library
Extend the JWT library to support RS256 asymmetric signing alongside existing HS256, and add JWKS fetching/caching utilities for distributed token validation. Foundation for OAuth2/OIDC Provider (Account Manager). Closes #10898 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 68d1033 commit 1cb4e9f

12 files changed

Lines changed: 1440 additions & 80 deletions

File tree

src/ai/backend/common/jwt/__init__.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,70 +5,92 @@
55
Hive Router. It uses the X-BackendAI-Token custom header to avoid conflicts with
66
existing Bearer token usage in appproxy.
77
8+
Supports both HS256 (symmetric, per-user secret keys) and RS256 (asymmetric,
9+
RSA key pairs) signing algorithms, with JWKS utilities for distributed key
10+
management.
11+
812
Key components:
913
- JWTSigner: Generates JWT tokens from authenticated user context (webserver)
1014
- JWTValidator: Validates JWT tokens and extracts user claims (manager)
1115
- JWTConfig: Configuration for JWT authentication
1216
- JWTClaims: Dataclass representing JWT payload claims
17+
- JWKSKeySet: Public key set indexed by key ID for RS256 validation
18+
- JWKSFetcher: Async JWKS endpoint fetcher with TTL caching
19+
- Key utilities: RSA key generation, loading, serialization, and JWK conversion
1320
14-
Example usage in webserver:
21+
Example usage (HS256):
1522
from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext
1623
17-
config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"])
24+
config = JWTConfig()
1825
signer = JWTSigner(config)
1926
2027
user_context = JWTUserContext(
21-
user_id=user_uuid,
2228
access_key=access_key,
2329
role="user",
24-
domain_name="default",
25-
is_admin=False,
26-
is_superadmin=False,
2730
)
28-
token = signer.generate_token(user_context)
29-
30-
# Add to request headers
31-
headers["X-BackendAI-Token"] = token
32-
33-
Example usage in manager:
34-
from ai.backend.common.jwt import JWTValidator, JWTConfig
35-
36-
config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"])
37-
validator = JWTValidator(config)
31+
token = signer.generate_token(user_context, secret_key)
3832
39-
token = request.headers.get("X-BackendAI-Token")
40-
claims = validator.validate_token(token)
33+
Example usage (RS256):
34+
from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext
35+
from ai.backend.common.jwt.keys import load_private_key
4136
42-
# Use claims for authentication
43-
user_id = claims.sub
44-
access_key = claims.access_key
37+
config = JWTConfig(algorithm="RS256")
38+
signer = JWTSigner(config)
39+
private_key = load_private_key(Path("/path/to/private.pem"))
40+
token = signer.generate_token(user_context, private_key=private_key, kid="key-1")
4541
"""
4642

47-
from .config import JWTConfig
48-
from .exceptions import (
43+
from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig
44+
from ai.backend.common.jwt.exceptions import (
45+
JWKSError,
46+
JWKSFetchError,
47+
JWKSKeyNotFoundError,
4948
JWTDecodeError,
5049
JWTError,
5150
JWTExpiredError,
5251
JWTInvalidClaimsError,
5352
JWTInvalidSignatureError,
5453
)
55-
from .signer import JWTSigner
56-
from .types import JWTClaims, JWTUserContext
57-
from .validator import JWTValidator
54+
from ai.backend.common.jwt.jwks import JWKSFetcher, JWKSKeySet
55+
from ai.backend.common.jwt.keys import (
56+
generate_rsa_key_pair,
57+
load_private_key,
58+
load_public_key,
59+
private_key_to_pem,
60+
public_key_to_jwk,
61+
public_key_to_pem,
62+
)
63+
from ai.backend.common.jwt.signer import JWTSigner
64+
from ai.backend.common.jwt.types import JWTClaims, JWTUserContext
65+
from ai.backend.common.jwt.validator import JWTValidator
5866

5967
__all__ = [
6068
# Configuration
69+
"JWTAlgorithm",
6170
"JWTConfig",
6271
# Types
6372
"JWTClaims",
6473
"JWTUserContext",
6574
# Core classes
6675
"JWTSigner",
6776
"JWTValidator",
77+
# JWKS
78+
"JWKSKeySet",
79+
"JWKSFetcher",
80+
# Key management
81+
"generate_rsa_key_pair",
82+
"load_private_key",
83+
"load_public_key",
84+
"private_key_to_pem",
85+
"public_key_to_pem",
86+
"public_key_to_jwk",
6887
# Exceptions
6988
"JWTError",
7089
"JWTExpiredError",
7190
"JWTInvalidSignatureError",
7291
"JWTInvalidClaimsError",
7392
"JWTDecodeError",
93+
"JWKSError",
94+
"JWKSFetchError",
95+
"JWKSKeyNotFoundError",
7496
]

src/ai/backend/common/jwt/config.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,66 @@
33
from __future__ import annotations
44

55
from datetime import timedelta
6+
from enum import StrEnum
7+
from pathlib import Path
68

79
from pydantic import Field
810

911
from ai.backend.common.config import BaseConfigSchema
1012

1113

14+
class JWTAlgorithm(StrEnum):
15+
"""Supported JWT signing algorithms."""
16+
17+
HS256 = "HS256"
18+
RS256 = "RS256"
19+
20+
1221
class JWTConfig(BaseConfigSchema):
1322
"""
1423
Configuration for JWT-based authentication in GraphQL Federation.
1524
1625
This configuration must be consistent between webserver (which generates tokens)
1726
and manager (which validates tokens).
1827
19-
Note: JWT tokens are signed using per-user secret keys (from keypair table),
20-
not a shared system secret key. This maintains the same security model as HMAC authentication.
28+
Supports both HS256 (symmetric, per-user secret keys) and RS256 (asymmetric,
29+
RSA key pairs) signing algorithms.
2130
2231
JWT tokens are transmitted via X-BackendAI-Token HTTP header.
2332
2433
Attributes:
2534
enabled: Whether JWT authentication is enabled
26-
algorithm: JWT signing algorithm (must be HS256)
35+
algorithm: JWT signing algorithm (HS256 or RS256)
2736
token_expiration_seconds: Token validity duration in seconds
37+
private_key_path: Path to PEM-encoded RSA private key (RS256 only)
38+
public_key_path: Path to PEM-encoded RSA public key (RS256 only)
2839
"""
2940

3041
enabled: bool = Field(
3142
default=True,
3243
description="Enable JWT authentication for GraphQL Federation requests",
3344
)
3445

35-
algorithm: str = Field(
36-
default="HS256",
37-
description="JWT signing algorithm (only HS256 is supported)",
46+
algorithm: JWTAlgorithm = Field(
47+
default=JWTAlgorithm.HS256,
48+
description="JWT signing algorithm (HS256 or RS256)",
3849
)
3950

4051
token_expiration_seconds: int = Field(
4152
default=900, # 15 minutes
4253
description="JWT token expiration time in seconds (default: 900 = 15 minutes)",
4354
)
4455

56+
private_key_path: Path | None = Field(
57+
default=None,
58+
description="Path to PEM-encoded RSA private key file (required for RS256 signing)",
59+
)
60+
61+
public_key_path: Path | None = Field(
62+
default=None,
63+
description="Path to PEM-encoded RSA public key file (required for RS256 validation)",
64+
)
65+
4566
@property
4667
def token_expiration(self) -> timedelta:
4768
"""

src/ai/backend/common/jwt/exceptions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,37 @@ class JWTDecodeError(JWTError, web.HTTPUnauthorized):
8282

8383
error_type = "https://api.backend.ai/probs/jwt-decode-error"
8484
error_title = "Failed to decode JWT token."
85+
86+
87+
class JWKSError(JWTError):
88+
"""
89+
Base exception for JWKS-related errors.
90+
91+
All JWKS-specific exceptions inherit from this base class.
92+
"""
93+
94+
error_type = "https://api.backend.ai/probs/jwks-error"
95+
error_title = "JWKS error."
96+
97+
98+
class JWKSFetchError(JWKSError, web.HTTPUnauthorized):
99+
"""
100+
Failed to fetch JWKS from the remote endpoint.
101+
102+
Raised when the JWKS endpoint is unreachable or returns invalid data.
103+
"""
104+
105+
error_type = "https://api.backend.ai/probs/jwks-fetch-error"
106+
error_title = "Failed to fetch JWKS."
107+
108+
109+
class JWKSKeyNotFoundError(JWKSError, web.HTTPUnauthorized):
110+
"""
111+
Key ID (kid) not found in the JWKS key set.
112+
113+
Raised when a token references a key ID that is not present
114+
in the available JWKS key set.
115+
"""
116+
117+
error_type = "https://api.backend.ai/probs/jwks-key-not-found"
118+
error_title = "Key ID not found in JWKS."

0 commit comments

Comments
 (0)