Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llama_stack/distribution/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class AuthenticationConfig(BaseModel):
...,
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
)
config: dict[str, str] = Field(
config: dict[str, Any] = Field(
...,
description="Provider-specific configuration",
)
Expand Down
98 changes: 89 additions & 9 deletions llama_stack/distribution/server/auth_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
# the root directory of this source tree.

import json
import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from enum import Enum
from typing import Any
from urllib.parse import parse_qs

import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self

from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
Expand Down Expand Up @@ -85,7 +88,7 @@ class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers."""

provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: dict[str, str] = Field(..., description="Provider-specific configuration")
config: dict[str, Any] = Field(..., description="Provider-specific configuration")


class AuthProvider(ABC):
Expand Down Expand Up @@ -198,10 +201,21 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
return attributes


class OAuth2TokenAuthProviderConfig(BaseModel):
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
jwks_uri: str
uri: str
cache_ttl: int = 3600


class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
tls_cafile: str | None = None


class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack"
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
Expand All @@ -214,6 +228,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None
introspection: OAuth2IntrospectionConfig | None = None

@classmethod
@field_validator("claims_mapping")
Expand All @@ -225,6 +241,14 @@ def validate_claims_mapping(cls, v):
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v

@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self


class OAuth2TokenAuthProvider(AuthProvider):
"""
Expand All @@ -240,8 +264,17 @@ def __init__(self, config: OAuth2TokenAuthProviderConfig):
self._jwks_lock = Lock()

async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
if self.config.jwks:
return await self.validate_jwt_token(token, self.config.jwks, scope)
if self.config.introspection:
return await self.introspect_token(token, self.config.introspection, scope)
raise ValueError("One of jwks or introspection must be configured")

async def validate_jwt_token(
self, token: str, config: OAuth2JWKSConfig, scope: dict | None = None
) -> TokenValidationResult:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
await self._refresh_jwks(config)

try:
header = jwt.get_unverified_header(token)
Expand Down Expand Up @@ -269,14 +302,61 @@ async def validate_token(self, token: str, scope: dict | None = None) -> TokenVa
access_attributes=access_attributes,
)

async def introspect_token(
self, token: str, config: OAuth2IntrospectionConfig, scope: dict | None = None
) -> TokenValidationResult:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
}
if config.send_secret_in_body:
form["client_id"] = config.client_id
form["client_secret"] = config.client_secret
auth = None
else:
auth = (config.client_id, config.client_secret)
ssl_ctxt = None
if config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=config.tls_cafile)
try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post(
config.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")

fields = response.json()
if not fields["active"]:
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during token introspection")
raise ValueError("Token introspection error") from e

async def close(self):
"""Close the HTTP client."""
pass

async def _refresh_jwks(self) -> None:
async def _refresh_jwks(self, config: OAuth2JWKSConfig) -> None:
async with self._jwks_lock:
if time.time() - self._jwks_at > self.config.cache_ttl:
if time.time() - self._jwks_at > config.cache_ttl:
async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
res = await client.get(config.uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
Expand Down
162 changes: 160 additions & 2 deletions tests/unit/server/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,10 @@ def oauth2_app():
auth_config = AuthProviderConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks_uri": "http://mock-authz-service/token/introspect",
"cache_ttl": "3600",
"jwks": {
"uri": "http://mock-authz-service/token/introspect",
"cache_ttl": "3600",
},
"audience": "llama-stack",
},
)
Expand Down Expand Up @@ -517,3 +519,159 @@ def test_get_attributes_from_claims():


# TODO: add more tests for oauth2 token provider


# oauth token introspection tests
@pytest.fixture
def mock_introspection_endpoint():
return "http://mock-authz-service/token/introspect"


@pytest.fixture
def introspection_app(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthProviderConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
"introspection": {"url": mock_introspection_endpoint, "client_id": "myclient", "client_secret": "abcdefg"},
},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)

@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}

return app


@pytest.fixture
def introspection_app_with_custom_mapping(mock_introspection_endpoint):
app = FastAPI()
auth_config = AuthProviderConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks": None,
"introspection": {
"url": mock_introspection_endpoint,
"client_id": "myclient",
"client_secret": "abcdefg",
"send_secret_in_body": "true",
},
"claims_mapping": {
"sub": "roles",
"scope": "roles",
"groups": "teams",
"aud": "namespaces",
},
},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)

@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}

return app


@pytest.fixture
def introspection_client(introspection_app):
return TestClient(introspection_app)


@pytest.fixture
def introspection_client_with_custom_mapping(introspection_app_with_custom_mapping):
return TestClient(introspection_app_with_custom_mapping)


def test_missing_auth_header_introspection(introspection_client):
response = introspection_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]


def test_invalid_auth_header_format_introspection(introspection_client):
response = introspection_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]


async def mock_introspection_active(*args, **kwargs):
return MockResponse(
200,
{
"active": True,
"sub": "my-user",
"groups": ["group1", "group2"],
"scope": "foo bar",
"aud": ["set1", "set2"],
},
)


async def mock_introspection_inactive(*args, **kwargs):
return MockResponse(
200,
{
"active": False,
},
)


async def mock_introspection_invalid(*args, **kwargs):
class InvalidResponse:
def __init__(self, status_code):
self.status_code = status_code

def json(self):
raise ValueError("Not JSON")

return InvalidResponse(200)


async def mock_introspection_failed(*args, **kwargs):
return MockResponse(
500,
{},
)


@patch("httpx.AsyncClient.post", new=mock_introspection_active)
def test_valid_introspection_authentication(introspection_client, valid_api_key):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}


@patch("httpx.AsyncClient.post", new=mock_introspection_inactive)
def test_inactive_introspection_authentication(introspection_client, invalid_api_key):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Token not active" in response.json()["error"]["message"]


@patch("httpx.AsyncClient.post", new=mock_introspection_invalid)
def test_invalid_introspection_authentication(introspection_client, invalid_api_key):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Not JSON" in response.json()["error"]["message"]


@patch("httpx.AsyncClient.post", new=mock_introspection_failed)
def test_failed_introspection_authentication(introspection_client, invalid_api_key):
response = introspection_client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"})
assert response.status_code == 401
assert "Token introspection failed: 500" in response.json()["error"]["message"]


@patch("httpx.AsyncClient.post", new=mock_introspection_active)
def test_valid_introspection_with_custom_mapping_authentication(
introspection_client_with_custom_mapping, valid_api_key
):
response = introspection_client_with_custom_mapping.get(
"/test", headers={"Authorization": f"Bearer {valid_api_key}"}
)
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
Loading