Skip to content

Commit c98fae0

Browse files
liangwen12yearWen Liang
authored and
Wen Liang
committed
feat(quota): add server‑side per‑client request quotas (requires auth)
Unrestricted usage can lead to runaway costs and fragmented client-side workarounds. This commit introduces a native quota mechanism to the server, giving operators a unified, centrally managed throttle for per-client requests—without needing extra proxies or custom client logic. This helps contain cloud-compute expenses, enables fine-grained usage control, and simplifies deployment and monitoring of Llama Stack services. Quotas are fully opt-in and have no effect unless explicitly configured. Notice that Quotas are fully opt-in and require authentication to be enabled. The 'sqlite' is the only supported quota `type` at this time, any other `type` will be rejected. And the only supported `period` is 'day'. Highlights: - Adds `QuotaMiddleware` to enforce per-client request quotas: - Uses `Authorization: Bearer <client_id>` (from AuthenticationMiddleware) - Tracks usage via a SQLite-based KV store - Returns 429 when the quota is exceeded - Extends `ServerConfig` with a `quota` section (type + config) - Enforces strict coupling: quotas require authentication or the server will fail to start Behavior changes: - Quotas are disabled by default unless explicitly configured - SQLite defaults to `./quotas.db` if no DB path is set - The server requires authentication when quotas are enabled To enable per-client request quotas in `run.yaml`, add: ``` server: port: 8321 auth: provider_type: "custom" config: endpoint: "https://auth.example.com/validate" quota: type: sqlite config: db_path: ./quotas.db limit: max_requests: 1000 period: day ``` Signed-off-by: Wen Liang <[email protected]>
1 parent 6371bb1 commit c98fae0

File tree

5 files changed

+276
-0
lines changed

5 files changed

+276
-0
lines changed

llama_stack/distribution/datatypes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,29 @@ class AuthenticationConfig(BaseModel):
234234
)
235235

236236

237+
class QuotaPeriod(str, Enum):
238+
DAY = "day"
239+
240+
241+
class QuotaLimit(BaseModel):
242+
max_requests: int = Field(default=1000, description="Maximum requests per period")
243+
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
244+
245+
246+
class QuotaType(str, Enum):
247+
SQLITE = "sqlite"
248+
249+
250+
class QuotaSqliteConfig(BaseModel):
251+
db_path: str = Field(default="./quotas.db", description="Path to the SQLite DB file")
252+
limit: QuotaLimit = Field(description="Quota limit configuration (requests + period)")
253+
254+
255+
class QuotaConfig(BaseModel):
256+
type: QuotaType = Field(description="Quota backend type. Only 'sqlite' is supported at this time")
257+
config: QuotaSqliteConfig
258+
259+
237260
class ServerConfig(BaseModel):
238261
port: int = Field(
239262
default=8321,
@@ -257,6 +280,10 @@ class ServerConfig(BaseModel):
257280
default=False,
258281
description="Disable IPv6 support",
259282
)
283+
quota: QuotaConfig | None = Field(
284+
default=None,
285+
description="Per client quota request configuration",
286+
)
260287

261288

262289
class StackRunConfig(BaseModel):

llama_stack/distribution/server/auth.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ async def __call__(self, scope, receive, send):
113113
"namespaces": [token],
114114
}
115115

116+
scope["authenticated_client_id"] = token
117+
116118
# Store attributes in request scope
117119
scope["user_attributes"] = user_attributes
118120
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import json
8+
import time
9+
from datetime import datetime, timedelta, timezone
10+
11+
from starlette.types import ASGIApp, Receive, Scope, Send
12+
13+
from llama_stack.log import get_logger
14+
from llama_stack.providers.utils.kvstore.api import KVStore
15+
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
16+
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
17+
18+
logger = get_logger(name=__name__, category="quota")
19+
20+
21+
class QuotaMiddleware:
22+
"""
23+
ASGI middleware enforcing per-client request quotas over a defined period.
24+
25+
Expects Authorization: Bearer <client_id> header.
26+
Tracks counts in a KV store (SQLite); returns HTTP 429 when limit is exceeded.
27+
"""
28+
29+
def __init__(
30+
self,
31+
app: ASGIApp,
32+
kv_config: KVStoreConfig,
33+
max_requests: int = 1000,
34+
window_seconds: int = 86400,
35+
):
36+
self.app = app
37+
self.kv_config = kv_config
38+
self.kv: KVStore | None = None
39+
self.max_requests = max_requests
40+
self.window_seconds = window_seconds
41+
42+
if isinstance(self.kv_config, SqliteKVStoreConfig):
43+
logger.warning(
44+
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
45+
f"window_seconds={self.window_seconds}"
46+
)
47+
48+
async def _get_kv(self) -> KVStore:
49+
if self.kv is None:
50+
self.kv = await kvstore_impl(self.kv_config)
51+
return self.kv
52+
53+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
54+
if scope["type"] == "http":
55+
client_id = scope.get("authenticated_client_id")
56+
if not client_id:
57+
logger.error(
58+
"QuotaMiddleware requires an authenticated client_id but none was found in the scope. "
59+
"This likely means AuthenticationMiddleware is not installed or failed."
60+
)
61+
return await self._send_error(
62+
send, 500, "Quota system misconfigured: missing authenticated client identity"
63+
)
64+
65+
current_window = int(time.time() // self.window_seconds)
66+
key = f"quota:{client_id}:{current_window}"
67+
68+
try:
69+
kv = await self._get_kv()
70+
prev = await kv.get(key) or "0"
71+
count = int(prev) + 1
72+
73+
if int(prev) == 0:
74+
# Set with expiration datetime when it is the first request in the window.
75+
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
76+
await kv.set(key, str(count), expiration=expiration)
77+
else:
78+
await kv.set(key, str(count))
79+
except Exception:
80+
logger.exception("Failed to access KV store for quota")
81+
return await self._send_error(send, 500, "Quota service error")
82+
83+
if count > self.max_requests:
84+
logger.warning(
85+
"Quota exceeded for client %s: %d/%d",
86+
client_id,
87+
count,
88+
self.max_requests,
89+
)
90+
return await self._send_error(send, 429, "Quota exceeded")
91+
92+
return await self.app(scope, receive, send)
93+
94+
async def _send_error(self, send: Send, status: int, message: str):
95+
await send(
96+
{
97+
"type": "http.response.start",
98+
"status": status,
99+
"headers": [[b"content-type", b"application/json"]],
100+
}
101+
)
102+
body = json.dumps({"error": {"message": message}}).encode()
103+
await send({"type": "http.response.body", "body": body})

llama_stack/distribution/server/server.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
5050
TelemetryAdapter,
5151
)
52+
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
5253
from llama_stack.providers.utils.telemetry.tracing import (
5354
CURRENT_TRACE_CONTEXT,
5455
end_trace,
@@ -58,6 +59,7 @@
5859

5960
from .auth import AuthenticationMiddleware
6061
from .endpoints import get_all_api_endpoints
62+
from .quota import QuotaMiddleware
6163

6264
REPO_ROOT = Path(__file__).parent.parent.parent.parent
6365

@@ -411,6 +413,34 @@ def main(args: argparse.Namespace | None = None):
411413
if config.server.auth:
412414
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
413415
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
416+
else:
417+
if config.server.quota:
418+
logger.error(
419+
"Quota enforcement requires authentication to be enabled, but no auth config is present. "
420+
"Disable quotas or configure authentication."
421+
)
422+
raise RuntimeError("Quota middleware requires authentication middleware to be active.")
423+
424+
# Enforce per-client quota (only if configured and require authentication)
425+
if config.server.quota:
426+
logger.info("Enabling per-client quota middleware")
427+
428+
quota_conf = config.server.quota.config
429+
430+
kv_config = SqliteKVStoreConfig(db_path=quota_conf.db_path)
431+
432+
window_seconds_map = {
433+
"day": 86400,
434+
}
435+
436+
window_seconds = window_seconds_map[quota_conf.limit.period.value]
437+
438+
app.add_middleware(
439+
QuotaMiddleware,
440+
kv_config=kv_config,
441+
max_requests=quota_conf.limit.max_requests,
442+
window_seconds=window_seconds,
443+
)
414444

415445
try:
416446
impls = asyncio.run(construct_stack(config))

tests/unit/server/test_quota.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import os
8+
9+
import pytest
10+
from fastapi import FastAPI, Request
11+
from fastapi.testclient import TestClient
12+
from starlette.middleware.base import BaseHTTPMiddleware
13+
14+
from llama_stack.distribution.server.quota import QuotaMiddleware
15+
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
16+
17+
18+
@pytest.fixture(autouse=True)
19+
def clean_sqlite_db():
20+
"""
21+
Remove the quotas.db file before each test to ensure no leftover state on disk.
22+
"""
23+
db_path = "./quotas_test.db"
24+
if os.path.exists(db_path):
25+
os.remove(db_path)
26+
27+
28+
class InjectClientIDMiddleware(BaseHTTPMiddleware):
29+
"""
30+
Middleware that injects 'authenticated_client_id' to mimic AuthenticationMiddleware.
31+
"""
32+
33+
def __init__(self, app, client_id="client1"):
34+
super().__init__(app)
35+
self.client_id = client_id
36+
37+
async def dispatch(self, request: Request, call_next):
38+
request.scope["authenticated_client_id"] = self.client_id
39+
return await call_next(request)
40+
41+
42+
@pytest.fixture(scope="function")
43+
def app(request):
44+
"""
45+
Create a FastAPI app with both InjectClientIDMiddleware and QuotaMiddleware.
46+
Each test gets a unique client_id for safety.
47+
"""
48+
inner_app = FastAPI()
49+
50+
@inner_app.get("/test")
51+
async def test_endpoint():
52+
return {"message": "ok"}
53+
54+
# Use the test name to create a unique client_id per test
55+
client_id = f"client_{request.node.name}"
56+
57+
app = InjectClientIDMiddleware(
58+
QuotaMiddleware(
59+
inner_app,
60+
kv_config=SqliteKVStoreConfig(db_path="./quotas_test.db"),
61+
max_requests=2,
62+
window_seconds=60,
63+
),
64+
client_id=client_id,
65+
)
66+
67+
return app
68+
69+
70+
def test_quota_allows_up_to_limit(app):
71+
client = TestClient(app)
72+
73+
resp1 = client.get("/test")
74+
assert resp1.status_code == 200
75+
assert resp1.json() == {"message": "ok"}
76+
77+
resp2 = client.get("/test")
78+
assert resp2.status_code == 200
79+
assert resp2.json() == {"message": "ok"}
80+
81+
82+
def test_quota_blocks_after_limit(app):
83+
client = TestClient(app)
84+
85+
# Exceed limit: 3rd request should be throttled
86+
client.get("/test")
87+
client.get("/test")
88+
resp3 = client.get("/test")
89+
assert resp3.status_code == 429
90+
assert resp3.json()["error"]["message"] == "Quota exceeded"
91+
92+
93+
def test_missing_authenticated_client_id_returns_500():
94+
"""
95+
Confirm 500 error when QuotaMiddleware runs without authenticated_client_id.
96+
"""
97+
inner_app = FastAPI()
98+
99+
@inner_app.get("/test")
100+
async def test_endpoint():
101+
return {"message": "ok"}
102+
103+
test_app = QuotaMiddleware(
104+
inner_app,
105+
kv_config=SqliteKVStoreConfig(db_path="./quotas_test.db"),
106+
max_requests=2,
107+
window_seconds=60,
108+
)
109+
110+
client = TestClient(test_app)
111+
112+
resp = client.get("/test")
113+
assert resp.status_code == 500
114+
assert "Quota system misconfigured" in resp.json()["error"]["message"]

0 commit comments

Comments
 (0)