Skip to content

feat(quota): add server‑side per‑client request quotas (requires auth) #2096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions docs/source/distributions/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,80 @@ And must respond with:

If no access attributes are returned, the token is used as a namespace.

### Quota Configuration

The `quota` section allows you to enable server-side request throttling for both
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
fairness across tenants, and controlling infrastructure costs without requiring
client-side rate limiting or external proxies.

Quotas are disabled by default. When enabled, each client is tracked using either:

* Their authenticated `client_id` (derived from the Bearer token), or
* Their IP address (fallback for anonymous requests)

Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
within a configurable time window (currently only `day` is supported).

#### Example

```yaml
server:
quota:
kvstore:
type: sqlite
db_path: ./quotas.db
anonymous_max_requests: 100
authenticated_max_requests: 1000
period: day
```

#### Configuration Options

| Field | Description |
| ---------------------------- | -------------------------------------------------------------------------- |
| `kvstore` | Required. Backend storage config for tracking request counts. |
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
| `kvstore.db_path` | File path to the SQLite database. |
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
| `period` | Time window for quota enforcement. Only `"day"` is supported. |

> Note: if `authenticated_max_requests` is set but no authentication provider is
configured, the server will fall back to applying `anonymous_max_requests` to all
clients.

#### Example with Authentication Enabled

```yaml
server:
port: 8321
auth:
provider_type: custom
config:
endpoint: https://auth.example.com/validate
quota:
kvstore:
type: sqlite
db_path: ./quotas.db
anonymous_max_requests: 100
authenticated_max_requests: 1000
period: day
```

If a client exceeds their limit, the server responds with:

```http
HTTP/1.1 429 Too Many Requests
Content-Type: application/json

{
"error": {
"message": "Quota exceeded"
}
}
```

## Extending to handle Safety

Configuring Safety can be a little involved so it is instructive to go through an example.
Expand Down
19 changes: 18 additions & 1 deletion llama_stack/distribution/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig

LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
Expand Down Expand Up @@ -235,6 +235,19 @@ class AuthenticationConfig(BaseModel):
)


class QuotaPeriod(str, Enum):
DAY = "day"


class QuotaConfig(BaseModel):
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
authenticated_max_requests: int = Field(
default=1000, description="Max requests for authenticated clients per period"
)
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")


class ServerConfig(BaseModel):
port: int = Field(
default=8321,
Expand Down Expand Up @@ -262,6 +275,10 @@ class ServerConfig(BaseModel):
default=None,
description="The host the server should listen on",
)
quota: QuotaConfig | None = Field(
default=None,
description="Per client quota request configuration",
)


class StackRunConfig(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions llama_stack/distribution/server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ async def __call__(self, scope, receive, send):
"roles": [token],
}

# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token

# Store attributes in request scope
scope["user_attributes"] = user_attributes
scope["principal"] = validation_result.principal
Expand Down
110 changes: 110 additions & 0 deletions llama_stack/distribution/server/quota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import time
from datetime import datetime, timedelta, timezone

from starlette.types import ASGIApp, Receive, Scope, Send

from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl

logger = get_logger(name=__name__, category="quota")


class QuotaMiddleware:
"""
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
within a configurable time window.

- For authenticated requests, it reads the client ID from the
`Authorization: Bearer <client_id>` header.
- For anonymous requests, it falls back to the IP address of the client.
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
once a client exceeds its quota.
"""

def __init__(
self,
app: ASGIApp,
kv_config: KVStoreConfig,
anonymous_max_requests: int,
authenticated_max_requests: int,
window_seconds: int = 86400,
):
self.app = app
self.kv_config = kv_config
self.kv: KVStore | None = None
self.anonymous_max_requests = anonymous_max_requests
self.authenticated_max_requests = authenticated_max_requests
self.window_seconds = window_seconds

if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)

async def _get_kv(self) -> KVStore:
if self.kv is None:
self.kv = await kvstore_impl(self.kv_config)
return self.kv

async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "http":
# pick key & limit based on auth
auth_id = scope.get("authenticated_client_id")
if auth_id:
key_id = auth_id
limit = self.authenticated_max_requests
else:
# fallback to IP
client = scope.get("client")
key_id = client[0] if client else "anonymous"
limit = self.anonymous_max_requests

current_window = int(time.time() // self.window_seconds)
key = f"quota:{key_id}:{current_window}"

try:
kv = await self._get_kv()
prev = await kv.get(key) or "0"
count = int(prev) + 1

if int(prev) == 0:
# Set with expiration datetime when it is the first request in the window.
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
await kv.set(key, str(count), expiration=expiration)
else:
await kv.set(key, str(count))
except Exception:
logger.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error")

if count > limit:
logger.warning(
"Quota exceeded for client %s: %d/%d",
key_id,
count,
limit,
)
return await self._send_error(send, 429, "Quota exceeded")

return await self.app(scope, receive, send)

async def _send_error(self, send: Send, status: int, message: str):
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
body = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": body})
30 changes: 30 additions & 0 deletions llama_stack/distribution/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints
from .quota import QuotaMiddleware

REPO_ROOT = Path(__file__).parent.parent.parent.parent

Expand Down Expand Up @@ -434,6 +435,35 @@ def main(args: argparse.Namespace | None = None):
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
quota.anonymous_max_requests,
)

if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")

quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
# if auth is disabled, use the anonymous max requests
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests

kv_config = quota.kvstore
window_map = {"day": 86400}
window_seconds = window_map[quota.period.value]

app.add_middleware(
QuotaMiddleware,
kv_config=kv_config,
anonymous_max_requests=anonymous_max_requests,
authenticated_max_requests=authenticated_max_requests,
window_seconds=window_seconds,
)

try:
impls = asyncio.run(construct_stack(config))
Expand Down
Loading
Loading