Skip to content

Commit 572571a

Browse files
liangwen12yearWen Liang
authored and
Wen Liang
committed
feat(quota): add server‑side per‑client request quotas
Usage without limits can lead to runaway costs and fragmented client‑side workarounds. By building a native quota mechanism into the server, operators gain a single, centrally managed throttle for per‑client requests—no extra proxies or bespoke client logic required. This helps contain cloud‑compute expenses, provides fine‑grained control over usage, and simplifies deployment and monitoring of Llama Stack services. Quotas remain opt‑in and fully configurable, ensuring zero impact unless explicitly enabled. - Add `QuotaMiddleware` (llama_stack/distribution/server/quota.py) • Reads `Authorization: Bearer <client_id>` • Tracks daily counts in Redis • Enforces `quota_requests_per_day` over a `quota_window_seconds` window • Returns HTTP 429 when exceeded - Extend `ServerConfig` with three new fields: • quota_redis_url • quota_requests_per_day • quota_window_seconds - Wire middleware into server startup (`server.py`) and CLI entrypoint (`llama_stack/cli/stack/run.py`), gated on `quota_redis_url`. - Add CLI flags `--quota-redis-url`, `--quota-requests-per-day`, and `--quota-window-seconds` and ensure they override YAML config. - Leave quotas disabled by default when `quota_redis_url` is unset. To enable per‑client request quotas, add these three settings under the `server:` section of your `run.yaml`. Set the `quota_redis_url` to your Redis connection string to activate per‑client quotas; leave it blank or omit it to disable quotas. Use `quota_requests_per_day` to define the maximum number of requests each client may make in the window, and `quota_window_seconds` to specify the length of that window in seconds (for example, 86400 for 24 hours). ``` server: port: 8321 quota_redis_url: redis://localhost:6379/0 quota_requests_per_day: 1000 quota_window_seconds: 86400 ``` Signed-off-by: Wen Liang <[email protected]>
1 parent 9f27578 commit 572571a

File tree

10 files changed

+268
-2
lines changed

10 files changed

+268
-2
lines changed

docs/source/distributions/building_distro.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,19 @@ After this step is successful, you should be able to find the built container im
269269
### Running your Stack server
270270
Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack build` step.
271271

272-
```
272+
```bash
273273
llama stack run -h
274-
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
274+
usage: llama stack run [-h]
275+
[--port PORT]
276+
[--image-name IMAGE_NAME]
277+
[--disable-ipv6]
278+
[--env KEY=VALUE]
279+
[--tls-keyfile TLS_KEYFILE]
280+
[--tls-certfile TLS_CERTFILE]
275281
[--image-type {conda,container,venv}]
282+
[--quota-redis-url QUOTA_REDIS_URL]
283+
[--quota-requests-per-day QUOTA_REQUESTS_PER_DAY]
284+
[--quota-window-seconds QUOTA_WINDOW_SECONDS]
276285
config
277286

278287
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
@@ -293,6 +302,12 @@ options:
293302
Path to TLS certificate file for HTTPS (default: None)
294303
--image-type {conda,container,venv}
295304
Image Type used during the build. This can be either conda or container or venv. (default: conda)
305+
--quota-redis-url QUOTA_REDIS_URL
306+
Redis URL for quota tracking; omit to disable quotas.
307+
--quota-requests-per-day QUOTA_REQUESTS_PER_DAY
308+
Max requests each client may make per window (default: 1000).
309+
--quota-window-seconds QUOTA_WINDOW_SECONDS
310+
Quota window length in seconds (default: 86400 = 24 h).
296311

297312
```
298313

llama_stack/cli/stack/run.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,23 @@ def _add_arguments(self):
7575
help="Image Type used during the build. This can be either conda or container or venv.",
7676
choices=[e.value for e in ImageType],
7777
)
78+
self.parser.add_argument(
79+
"--quota-redis-url",
80+
type=str,
81+
help="Redis URL for quota tracking (enables quotas)",
82+
)
83+
self.parser.add_argument(
84+
"--quota-requests-per-day",
85+
type=int,
86+
default=None,
87+
help="Max requests per client per day",
88+
)
89+
self.parser.add_argument(
90+
"--quota-window-seconds",
91+
type=int,
92+
default=None,
93+
help="Time window for the daily quota, in seconds",
94+
)
7895

7996
# If neither image type nor image name is provided, but at the same time
8097
# the current environment has conda breadcrumbs, then assume what the user
@@ -144,6 +161,10 @@ def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
144161

145162
# Build the server args from the current args passed to the CLI
146163
server_args = argparse.Namespace()
164+
# Propagate quota flags into server_main
165+
server_args.quota_redis_url = args.quota_redis_url
166+
server_args.quota_requests_per_day = args.quota_requests_per_day
167+
server_args.quota_window_seconds = args.quota_window_seconds
147168
for arg in vars(args):
148169
# If this is a function, avoid passing it
149170
# "args" contains:

llama_stack/distribution/datatypes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,20 @@ class ServerConfig(BaseModel):
253253
default=None,
254254
description="Authentication configuration for the server",
255255
)
256+
quota_redis_url: str | None = Field(
257+
default=None,
258+
description="Redis URL for quota tracking (e.g. redis://localhost:6379/0). If unset, quotas are disabled.",
259+
)
260+
quota_requests_per_day: int = Field(
261+
default=1000,
262+
description="Default maximum number of requests allowed per client per day",
263+
ge=1,
264+
)
265+
quota_window_seconds: int = Field(
266+
default=86400,
267+
description="Time window in seconds for the daily quota (default: 24h)",
268+
ge=1,
269+
)
256270

257271

258272
class StackRunConfig(BaseModel):
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import json
8+
from datetime import datetime, timezone
9+
10+
import redis.asyncio as aioredis
11+
from starlette.types import ASGIApp, Receive, Scope, Send
12+
13+
from llama_stack.log import get_logger
14+
15+
logger = get_logger(name=__name__, category="quota")
16+
17+
18+
class QuotaMiddleware:
19+
"""
20+
ASGI middleware enforcing per-client daily request quotas.
21+
22+
Expects Authorization: Bearer <client_id> header.
23+
Tracks counts in Redis; returns HTTP 429 when limit is exceeded.
24+
"""
25+
26+
def __init__(
27+
self,
28+
app: ASGIApp,
29+
redis_url: str = "redis://localhost:6379/0",
30+
default_requests_per_day: int = 1000,
31+
window_seconds: int = 86400,
32+
):
33+
self.app = app
34+
self.redis = aioredis.from_url(redis_url, encoding="utf-8", decode_responses=True)
35+
self.default_limit = default_requests_per_day
36+
self.window = window_seconds
37+
38+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
39+
if scope["type"] == "http":
40+
# Extract API key from Authorization header
41+
headers = dict(scope.get("headers", []))
42+
auth = headers.get(b"authorization", b"").decode()
43+
if not auth or not auth.startswith("Bearer "):
44+
return await self._send_error(send, 401, "Missing or invalid API key")
45+
46+
client_id = auth.split("Bearer ", 1)[1].strip()
47+
key = f"quota:{client_id}:{datetime.now(timezone.utc).date().isoformat()}"
48+
49+
try:
50+
count = await self.redis.incr(key)
51+
if count == 1:
52+
await self.redis.expire(key, self.window)
53+
except Exception:
54+
logger.exception("Error accessing Redis for quota")
55+
return await self._send_error(send, 500, "Quota service error")
56+
57+
if count > self.default_limit:
58+
logger.warning(
59+
"Quota exceeded for client %s: %d/%d",
60+
client_id,
61+
count,
62+
self.default_limit,
63+
)
64+
return await self._send_error(send, 429, "Quota exceeded")
65+
66+
# Pass through to downstream app
67+
return await self.app(scope, receive, send)
68+
69+
async def _send_error(self, send: Send, status: int, message: str):
70+
await send(
71+
{
72+
"type": "http.response.start",
73+
"status": status,
74+
"headers": [[b"content-type", b"application/json"]],
75+
}
76+
)
77+
body = json.dumps({"error": {"message": message}}).encode()
78+
await send({"type": "http.response.body", "body": body})

llama_stack/distribution/server/server.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
)
5858

5959
from .auth import AuthenticationMiddleware
60+
from .quota import QuotaMiddleware
6061
from .endpoints import get_all_api_endpoints
6162

6263
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@@ -401,6 +402,13 @@ def main(args: argparse.Namespace | None = None):
401402
config = replace_env_vars(config_contents)
402403
config = StackRunConfig(**config)
403404

405+
if getattr(args, "quota_redis_url", None):
406+
config.server.quota_redis_url = args.quota_redis_url
407+
if getattr(args, "quota_requests_per_day", None) is not None:
408+
config.server.quota_requests_per_day = args.quota_requests_per_day
409+
if getattr(args, "quota_window_seconds", None) is not None:
410+
config.server.quota_window_seconds = args.quota_window_seconds
411+
404412
# now that the logger is initialized, print the line about which type of config we are using.
405413
logger.info(log_line)
406414

@@ -422,6 +430,18 @@ def main(args: argparse.Namespace | None = None):
422430
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
423431
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
424432

433+
434+
# Add per-client quota enforcement
435+
# per‑client quota enforcement (only if configured)
436+
if config.server.quota_redis_url:
437+
logger.info("Enabling per-client quota middleware")
438+
app.add_middleware(
439+
QuotaMiddleware,
440+
redis_url=config.server.quota_redis_url,
441+
default_requests_per_day=config.server.quota_requests_per_day,
442+
window_seconds=config.server.quota_window_seconds,
443+
)
444+
425445
try:
426446
impls = asyncio.run(construct_stack(config))
427447
except InvalidProviderError as e:

llama_stack/templates/dev/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,6 @@ tool_groups:
431431
provider_id: rag-runtime
432432
server:
433433
port: 8321
434+
quota_redis_url: ""
435+
quota_requests_per_day: 1000
436+
quota_window_seconds: 86400

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dependencies = [
4040
"pillow",
4141
"h11>=0.16.0",
4242
"kubernetes",
43+
"redis>=4.4.0",
4344
]
4445

4546
[project.optional-dependencies]

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# uv export --frozen --no-hashes --no-emit-project --output-file=requirements.txt
33
annotated-types==0.7.0
44
anyio==4.8.0
5+
async-timeout==5.0.1 ; python_full_version < '3.11.3'
56
attrs==25.1.0
67
blobfile==3.0.0
78
cachetools==5.5.2
@@ -49,6 +50,7 @@ python-dateutil==2.9.0.post0
4950
python-dotenv==1.0.1
5051
pytz==2025.1
5152
pyyaml==6.0.2
53+
redis==6.0.0
5254
referencing==0.36.2
5355
regex==2024.11.6
5456
requests==2.32.3

tests/unit/server/test_quota.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
# tests/unit/server/test_quota.py
8+
9+
import pytest
10+
import redis.asyncio as aioredis
11+
from fastapi import FastAPI
12+
from fastapi.testclient import TestClient
13+
14+
from llama_stack.distribution.server.quota import QuotaMiddleware
15+
16+
17+
@pytest.fixture(autouse=True)
18+
def fake_redis(monkeypatch):
19+
"""
20+
Replace aioredis.from_url with a fake in-memory Redis for quota tests.
21+
"""
22+
23+
class FakeRedis:
24+
def __init__(self):
25+
self._store = {}
26+
27+
async def incr(self, key):
28+
v = self._store.get(key, 0) + 1
29+
self._store[key] = v
30+
return v
31+
32+
async def expire(self, key, seconds):
33+
# no-op TTL for tests
34+
return True
35+
36+
def fake_from_url(url, encoding="utf-8", decode_responses=True):
37+
# Return our FakeRedis instance synchronously
38+
return FakeRedis()
39+
40+
monkeypatch.setattr(aioredis, "from_url", fake_from_url)
41+
42+
43+
@pytest.fixture
44+
def app():
45+
"""
46+
Create a FastAPI app with QuotaMiddleware mounted.
47+
Use a small limit (2 requests) and short window (60s) for testing.
48+
"""
49+
app = FastAPI()
50+
app.add_middleware(
51+
QuotaMiddleware,
52+
redis_url="redis://localhost:6379/0",
53+
default_requests_per_day=2,
54+
window_seconds=60,
55+
)
56+
57+
@app.get("/test")
58+
def test_endpoint():
59+
return {"message": "ok"}
60+
61+
return app
62+
63+
64+
def test_quota_allows_up_to_limit(app):
65+
client = TestClient(app)
66+
headers = {"Authorization": "Bearer client1"}
67+
68+
# First two requests should pass
69+
resp1 = client.get("/test", headers=headers)
70+
assert resp1.status_code == 200
71+
assert resp1.json() == {"message": "ok"}
72+
73+
resp2 = client.get("/test", headers=headers)
74+
assert resp2.status_code == 200
75+
assert resp2.json() == {"message": "ok"}
76+
77+
78+
def test_quota_blocks_after_limit(app):
79+
client = TestClient(app)
80+
headers = {"Authorization": "Bearer client1"}
81+
82+
# Exceed the limit: 3rd request should be throttled
83+
client.get("/test", headers=headers)
84+
client.get("/test", headers=headers)
85+
resp3 = client.get("/test", headers=headers)
86+
assert resp3.status_code == 429
87+
assert resp3.json()["error"]["message"] == "Quota exceeded"
88+
89+
90+
def test_missing_auth_header_returns_401(app):
91+
client = TestClient(app)
92+
93+
# No Authorization header → 401
94+
resp = client.get("/test")
95+
assert resp.status_code == 401
96+
assert "Missing or invalid API key" in resp.json()["error"]["message"]

uv.lock

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)