Skip to content

feat: introduce OAuth2TokenAuthProvider and notion of "principal" #2185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions llama_stack/distribution/server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def __call__(self, scope, receive, send):

# Validate token and get access attributes
try:
access_attributes = await self.auth_provider.validate_token(token, scope)
validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
Expand All @@ -105,17 +105,20 @@ async def __call__(self, scope, receive, send):
return await self._send_auth_error(send, "Authentication service error")

# Store attributes in request scope for access control
if access_attributes:
user_attributes = access_attributes.model_dump(exclude_none=True)
if validation_result.access_attributes:
user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"namespaces": [token],
"roles": [token],
}

# Store attributes in request scope
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
scope["principal"] = validation_result.principal
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
)

return await self.app(scope, receive, send)

Expand Down
189 changes: 151 additions & 38 deletions llama_stack/distribution/server/auth_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@
# the root directory of this source tree.

import json
import time
from abc import ABC, abstractmethod
from enum import Enum
from urllib.parse import parse_qs

import httpx
from pydantic import BaseModel, Field
from jose import jwt
from pydantic import BaseModel, Field, field_validator

from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger

logger = get_logger(name=__name__, category="auth")


class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""

class TokenValidationResult(BaseModel):
principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field(
default=None,
description="""
Expand All @@ -43,6 +47,10 @@ class AuthResponse(BaseModel):
""",
)


class AuthResponse(TokenValidationResult):
"""The format of the authentication response from the auth endpoint."""

message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
Expand All @@ -69,6 +77,7 @@ class AuthProviderType(str, Enum):

KUBERNETES = "kubernetes"
CUSTOM = "custom"
OAUTH2_TOKEN = "oauth2_token"


class AuthProviderConfig(BaseModel):
Expand All @@ -82,7 +91,7 @@ class AuthProvider(ABC):
"""Abstract base class for authentication providers."""

@abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token and return access attributes."""
pass

Expand All @@ -92,12 +101,16 @@ async def close(self):
pass


class KubernetesAuthProviderConfig(BaseModel):
api_server_url: str
ca_cert_path: str | None = None


class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""

def __init__(self, config: dict[str, str]):
self.api_server_url = config["api_server_url"]
self.ca_cert_path = config.get("ca_cert_path")
def __init__(self, config: KubernetesAuthProviderConfig):
self.config = config
self._client = None

async def _get_client(self):
Expand All @@ -110,16 +123,16 @@ async def _get_client(self):

# Configure the client
configuration = client.Configuration()
configuration.host = self.api_server_url
if self.ca_cert_path:
configuration.ssl_ca_cert = self.ca_cert_path
configuration.verify_ssl = bool(self.ca_cert_path)
configuration.host = self.config.api_server_url
if self.config.ca_cert_path:
configuration.ssl_ca_cert = self.config.ca_cert_path
configuration.verify_ssl = bool(self.config.ca_cert_path)

# Create API client
self._client = ApiClient(configuration)
return self._client

async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a Kubernetes token and return access attributes."""
try:
client = await self._get_client()
Expand All @@ -146,9 +159,12 @@ async def validate_token(self, token: str, scope: dict | None = None) -> AccessA
username = payload.get("sub", "")
groups = payload.get("groups", [])

return AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
return TokenValidationResult(
principal=username,
access_attributes=AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
),
)

except Exception as e:
Expand All @@ -162,18 +178,125 @@ async def close(self):
self._client = None


def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key):
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
values = claim.split()

current = getattr(attributes, attribute_key)
if current:
current.extend(values)
else:
setattr(attributes, attribute_key, values)
return attributes


class OAuth2TokenAuthProviderConfig(BaseModel):
# The JWKS URI for collecting public keys
jwks_uri: str
cache_ttl: int = 3600
audience: str = "llama-stack"
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)

@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v


class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.

This should be the standard authentication provider for most use cases.
"""

def __init__(self, config: OAuth2TokenAuthProviderConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}

async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token."""
await self._refresh_jwks()

try:
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode(
token,
key_data,
algorithms=[algorithm],
audience=self.config.audience,
options={"verify_exp": True},
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc

# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=access_attributes,
)

async def close(self):
"""Close the HTTP client."""

async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl:
async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
self._jwks = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
self._jwks[kid] = k
self._jwks_at = time.time()


class CustomAuthProviderConfig(BaseModel):
endpoint: str


class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""

def __init__(self, config: dict[str, str]):
self.endpoint = config["endpoint"]
def __init__(self, config: CustomAuthProviderConfig):
self.config = config
self._client = None

async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the custom authentication endpoint."""
if not self.endpoint:
raise ValueError("Authentication endpoint not configured")

if scope is None:
scope = {}

Expand Down Expand Up @@ -202,7 +325,7 @@ async def validate_token(self, token: str, scope: dict | None = None) -> AccessA
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.endpoint,
self.config.endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
Expand All @@ -214,19 +337,7 @@ async def validate_token(self, token: str, scope: dict | None = None) -> AccessA
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)

# Store attributes in request scope for access control
if auth_response.access_attributes:
return auth_response.access_attributes
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [token],
}

scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
return auth_response.access_attributes
return auth_response
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
Expand All @@ -253,9 +364,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
provider_type = config.provider_type.lower()

if provider_type == "kubernetes":
return KubernetesAuthProvider(config.config)
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom":
return CustomAuthProvider(config.config)
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"openai>=1.66",
"prompt-toolkit",
"python-dotenv",
"python-jose",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not happy about this right now, I am going to split off Auth (and then Credentials) to proper APIs which have proper providers so the overall distribution dependency system takes over. we dont want starter distros to have complex dependencies if they don't need them.

but for now, this will do.

"pydantic>=2",
"requests",
"rich",
Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ click==8.1.8
colorama==0.4.6 ; sys_platform == 'win32'
distro==1.9.0
durationpy==0.9
ecdsa==0.19.1
exceptiongroup==1.2.2 ; python_full_version < '3.11'
filelock==3.17.0
fire==0.7.0
Expand Down Expand Up @@ -39,14 +40,15 @@ pandas==2.2.3
pillow==11.1.0
prompt-toolkit==3.0.50
pyaml==25.1.0
pyasn1==0.6.1
pyasn1-modules==0.4.2
pyasn1==0.4.8
pyasn1-modules==0.4.1
pycryptodomex==3.21.0
pydantic==2.10.6
pydantic-core==2.27.2
pygments==2.19.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-jose==3.4.0
pytz==2025.1
pyyaml==6.0.2
referencing==0.36.2
Expand Down
Loading
Loading