Skip to content

Commit e863b81

Browse files
liangwen12yearWen Liang
authored and
Wen Liang
committed
feat(quota): add server‑side per‑client request quotas (requires auth)
Unrestricted usage can lead to runaway costs and fragmented client-side workarounds. This commit introduces a native quota mechanism to the server, giving operators a unified, centrally managed throttle for per-client requests—without needing extra proxies or custom client logic. This helps contain cloud-compute expenses, enables fine-grained usage control, and simplifies deployment and monitoring of Llama Stack services. Quotas are fully opt-in and have no effect unless explicitly configured. Notice that Quotas are fully opt-in and require authentication to be enabled. The 'sqlite' is the only supported quota `type` at this time, any other `type` will be rejected. Highlights: - Adds `QuotaMiddleware` to enforce per-client request quotas: - Uses `Authorization: Bearer <client_id>` (from AuthenticationMiddleware) - Tracks usage via a SQLite-based KV store - Returns 429 when the quota is exceeded - Extends `ServerConfig` with a `quota` section (type + config) - Enforces strict coupling: quotas require authentication or the server will fail to start Behavior changes: - Quotas are disabled by default unless explicitly configured - SQLite defaults to `./quotas.db` if no DB path is set - The server requires authentication when quotas are enabled To enable per-client request quotas in `run.yaml`, add: ``` server: port: 8321 auth: provider_type: "custom" config: endpoint: "https://auth.example.com/validate" quota: type: sqlite config: db_path: ./quotas.db quota_requests_per_day: 1000 quota_window_seconds: 86400 ``` Signed-off-by: Wen Liang <[email protected]>
1 parent dd49ef3 commit e863b81

File tree

6 files changed

+2455
-2207
lines changed

6 files changed

+2455
-2207
lines changed

llama_stack/distribution/datatypes.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# the root directory of this source tree.
66

77
from enum import Enum
8-
from typing import Annotated, Any
8+
from typing import Annotated, Any, Literal
99

1010
from pydantic import BaseModel, Field
1111

@@ -234,6 +234,17 @@ class AuthenticationConfig(BaseModel):
234234
)
235235

236236

237+
class QuotaSqliteConfig(BaseModel):
238+
db_path: str = Field(default="./quotas.db", description="Path to the SQLite DB file")
239+
quota_requests_per_day: int = Field(default=1000, description="Maximum requests per client per window")
240+
quota_window_seconds: int = Field(default=86400, description="Quota window length in seconds")
241+
242+
243+
class QuotaConfig(BaseModel):
244+
type: Literal["sqlite"] = Field(description="Quota backend type: must be 'sqlite'")
245+
config: QuotaSqliteConfig
246+
247+
237248
class ServerConfig(BaseModel):
238249
port: int = Field(
239250
default=8321,
@@ -253,6 +264,10 @@ class ServerConfig(BaseModel):
253264
default=None,
254265
description="Authentication configuration for the server",
255266
)
267+
quota: QuotaConfig | None = Field(
268+
default=None,
269+
description=("Per client request quota configuration. If unset or null, quotas are disabled."),
270+
)
256271

257272

258273
class StackRunConfig(BaseModel):

llama_stack/distribution/server/auth.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ async def __call__(self, scope, receive, send):
113113
"namespaces": [token],
114114
}
115115

116+
scope["authenticated_client_id"] = token
117+
116118
# Store attributes in request scope
117119
scope["user_attributes"] = user_attributes
118120
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import json
8+
from datetime import datetime, timezone
9+
10+
from starlette.types import ASGIApp, Receive, Scope, Send
11+
12+
from llama_stack.log import get_logger
13+
from llama_stack.providers.utils.kvstore.api import KVStore
14+
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
15+
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
16+
17+
logger = get_logger(name=__name__, category="quota")
18+
19+
20+
class QuotaMiddleware:
21+
"""
22+
ASGI middleware enforcing per client daily request quotas.
23+
24+
Expects Authorization: Bearer <client_id> header.
25+
Tracks counts in a KV store (SQLite); returns HTTP 429 when limit is exceeded.
26+
"""
27+
28+
def __init__(
29+
self,
30+
app: ASGIApp,
31+
kv_config: KVStoreConfig | None = None,
32+
default_requests_per_day: int = 1000,
33+
window_seconds: int = 86400,
34+
):
35+
self.app = app
36+
# if no config passed, default to on disk SQLite
37+
self._kv_config = kv_config or SqliteKVStoreConfig(db_path="./quotas.db")
38+
self._kv: KVStore | None = None
39+
self.default_limit = default_requests_per_day
40+
self.window = window_seconds
41+
42+
async def _get_kv(self) -> KVStore:
43+
if self._kv is None:
44+
self._kv = await kvstore_impl(self._kv_config)
45+
return self._kv
46+
47+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
48+
if scope["type"] == "http":
49+
client_id = scope.get("authenticated_client_id")
50+
if not client_id:
51+
logger.error(
52+
"QuotaMiddleware requires an authenticated client_id but none was found in the scope. "
53+
"This likely means AuthenticationMiddleware is not installed or failed."
54+
)
55+
return await self._send_error(
56+
send, 500, "Quota system misconfigured: missing authenticated client identity"
57+
)
58+
59+
key = f"quota:{client_id}:{datetime.now(timezone.utc).date().isoformat()}"
60+
61+
try:
62+
kv = await self._get_kv()
63+
prev = await kv.get(key) or "0"
64+
count = int(prev) + 1
65+
await kv.set(key, str(count))
66+
# Note: TTL/expire is only supported on backends that implement it;
67+
# for SQLite we ignore expire.
68+
except Exception:
69+
logger.exception("Error accessing KV store for quota")
70+
return await self._send_error(send, 500, "Quota service error")
71+
72+
if count > self.default_limit:
73+
logger.warning(
74+
"Quota exceeded for client %s: %d/%d",
75+
client_id,
76+
count,
77+
self.default_limit,
78+
)
79+
return await self._send_error(send, 429, "Quota exceeded")
80+
81+
# Pass through to downstream application
82+
return await self.app(scope, receive, send)
83+
84+
async def _send_error(self, send: Send, status: int, message: str):
85+
await send(
86+
{
87+
"type": "http.response.start",
88+
"status": status,
89+
"headers": [[b"content-type", b"application/json"]],
90+
}
91+
)
92+
body = json.dumps({"error": {"message": message}}).encode()
93+
await send({"type": "http.response.body", "body": body})

llama_stack/distribution/server/server.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
5050
TelemetryAdapter,
5151
)
52+
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
5253
from llama_stack.providers.utils.telemetry.tracing import (
5354
CURRENT_TRACE_CONTEXT,
5455
end_trace,
@@ -58,6 +59,7 @@
5859

5960
from .auth import AuthenticationMiddleware
6061
from .endpoints import get_all_api_endpoints
62+
from .quota import QuotaMiddleware
6163

6264
REPO_ROOT = Path(__file__).parent.parent.parent.parent
6365

@@ -421,6 +423,28 @@ def main(args: argparse.Namespace | None = None):
421423
if config.server.auth:
422424
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
423425
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
426+
else:
427+
if config.server.quota:
428+
logger.error(
429+
"Quota enforcement requires authentication to be enabled, but no auth config is present. "
430+
"Disable quotas or configure authentication."
431+
)
432+
raise RuntimeError("Quota middleware requires authentication middleware to be active.")
433+
434+
# Enforce per-client quota (only if configured and require authentication)
435+
if config.server.quota:
436+
logger.info("Enabling per-client quota middleware")
437+
438+
quota_conf = config.server.quota.config
439+
440+
kv_config = SqliteKVStoreConfig(db_path=quota_conf.db_path)
441+
442+
app.add_middleware(
443+
QuotaMiddleware,
444+
kv_config=kv_config,
445+
default_requests_per_day=quota_conf.quota_requests_per_day,
446+
window_seconds=quota_conf.quota_window_seconds,
447+
)
424448

425449
try:
426450
impls = asyncio.run(construct_stack(config))

tests/unit/server/test_quota.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import os
8+
9+
import pytest
10+
from fastapi import FastAPI, Request
11+
from fastapi.testclient import TestClient
12+
from starlette.middleware.base import BaseHTTPMiddleware
13+
14+
from llama_stack.distribution.server.quota import QuotaMiddleware
15+
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
16+
17+
18+
@pytest.fixture(autouse=True)
19+
def clean_sqlite_db():
20+
"""
21+
Remove the quotas.db file before each test to ensure no leftover state on disk.
22+
"""
23+
db_path = "./quotas_test.db"
24+
if os.path.exists(db_path):
25+
os.remove(db_path)
26+
27+
28+
class InjectClientIDMiddleware(BaseHTTPMiddleware):
29+
"""
30+
Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware.
31+
"""
32+
33+
def __init__(self, app, client_id="client1"):
34+
super().__init__(app)
35+
self.client_id = client_id
36+
37+
async def dispatch(self, request: Request, call_next):
38+
request.scope["authenticated_client_id"] = self.client_id
39+
return await call_next(request)
40+
41+
42+
@pytest.fixture(scope="function")
43+
def app(request):
44+
"""
45+
Create a FastAPI app with both InjectClientIDMiddleware and QuotaMiddleware.
46+
Each test gets a unique client_id for safety.
47+
"""
48+
inner_app = FastAPI()
49+
50+
@inner_app.get("/test")
51+
async def test_endpoint():
52+
return {"message": "ok"}
53+
54+
# Use the test name to create a unique client_id per test
55+
client_id = f"client_{request.node.name}"
56+
57+
app = InjectClientIDMiddleware(
58+
QuotaMiddleware(
59+
inner_app,
60+
kv_config=SqliteKVStoreConfig(db_path="./quotas_test.db"),
61+
default_requests_per_day=2,
62+
window_seconds=60,
63+
),
64+
client_id=client_id,
65+
)
66+
67+
return app
68+
69+
70+
def test_quota_allows_up_to_limit(app):
71+
client = TestClient(app)
72+
73+
resp1 = client.get("/test")
74+
assert resp1.status_code == 200
75+
assert resp1.json() == {"message": "ok"}
76+
77+
resp2 = client.get("/test")
78+
assert resp2.status_code == 200
79+
assert resp2.json() == {"message": "ok"}
80+
81+
82+
def test_quota_blocks_after_limit(app):
83+
client = TestClient(app)
84+
85+
# Exceed limit: 3rd request should be throttled
86+
client.get("/test")
87+
client.get("/test")
88+
resp3 = client.get("/test")
89+
assert resp3.status_code == 429
90+
assert resp3.json()["error"]["message"] == "Quota exceeded"
91+
92+
93+
def test_missing_authenticated_client_id_returns_500():
94+
"""
95+
Confirm 500 error when QuotaMiddleware runs without authenticated_client_id.
96+
"""
97+
inner_app = FastAPI()
98+
99+
@inner_app.get("/test")
100+
async def test_endpoint():
101+
return {"message": "ok"}
102+
103+
app = QuotaMiddleware(
104+
inner_app,
105+
kv_config=SqliteKVStoreConfig(db_path="./quotas_test.db"),
106+
default_requests_per_day=2,
107+
window_seconds=60,
108+
)
109+
110+
client = TestClient(app)
111+
112+
resp = client.get("/test")
113+
assert resp.status_code == 500
114+
assert "Quota system misconfigured" in resp.json()["error"]["message"]

0 commit comments

Comments
 (0)