Skip to content

Commit 525e3fe

Browse files
authored
feat(monitoring): add llm oriented metrics to grafana dashboard (#768)
* feat(monitoring): add llm oriented metrics to grafana dashboard * chore(monitoring): move metrics definition to utils file * Update unit coverage badge * chore(monitoring): lint code --------- Co-authored-by: leoguillaume <leoguillaume@users.noreply.github.com>
1 parent a30a8ec commit 525e3fe

7 files changed

Lines changed: 235 additions & 11 deletions

File tree

.github/PULL_REQUEST_TEMPLATE/pull_request_template.md renamed to .github/PULL_REQUEST_TEMPLATE.md

File renamed without changes.

.github/badges/coverage.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"schemaVersion":1,"label":"coverage","message":"50.19%","color":"red"}
1+
{"schemaVersion":1,"label":"coverage","message":"49.58%","color":"red"}

api/clients/model/_basemodelprovider.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def _get_usage(self, request_content: RequestContent, response_data: dict | list
113113
tokenizer = getattr(global_context, "tokenizer", None)
114114
if tokenizer and request_content.endpoint in tokenizer.USAGE_ENDPOINTS:
115115
try:
116-
completion_tokens = 0
117116
prompt_tokens = tokenizer.get_prompt_tokens(endpoint=request_content.endpoint, body=request_content.json)
118117
completion_tokens = tokenizer.get_completion_tokens(endpoint=request_content.endpoint, response_data=response_data)
119118
total_tokens = prompt_tokens + completion_tokens
@@ -149,10 +148,10 @@ def _format_request(self, request_content: RequestContent) -> RequestContent:
149148
Format a request to a provider model. This method can be overridden by a subclass to add additional headers or parameters. This method format the requested endpoint thanks the ENDPOINT_TABLE attribute.
150149
151150
Args:
152-
content(RequestContent): The request content to format.
151+
request_content(RequestContent): The request content to format.
153152
154153
Returns:
155-
content(RequestContent): The formatted request content.
154+
request_content(RequestContent): The formatted request content.
156155
"""
157156
if "model" in request_content.json:
158157
request_content.json["model"] = self.model_name
@@ -221,7 +220,8 @@ def _format_response(self, request_content: RequestContent, response: httpx.Resp
221220

222221
return response
223222

224-
async def _ensure_timeseries_exists(self, redis_client: AsyncRedis, key: str) -> None:
223+
@staticmethod
224+
async def _ensure_timeseries_exists(redis_client: AsyncRedis, key: str) -> None:
225225
"""
226226
Ensure a time series exists with proper retention configuration.
227227

api/endpoints/monitoring.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,43 @@
33
from fastapi import Depends, FastAPI
44
import prometheus_client
55
from prometheus_client import CollectorRegistry, multiprocess
6-
from prometheus_fastapi_instrumentator import Instrumentator
6+
from prometheus_fastapi_instrumentator import Instrumentator, metrics
77
from starlette.responses import Response
88

99
from api.helpers._accesscontroller import AccessController
1010
from api.schemas.admin.roles import PermissionType
11+
from api.utils.monitoring import (
12+
inference_output_tokens_per_second,
13+
inference_requests_duration_seconds,
14+
inference_requests_total,
15+
inference_tokens_total,
16+
inference_ttft_milliseconds,
17+
)
1118
from api.utils.variables import RouterName
1219

1320

14-
def setup_prometheus(app: FastAPI, include_in_schema: bool = True) -> None:
15-
app.instrumentator = Instrumentator().instrument(app=app)
21+
def setup_prometheus(app: FastAPI, metric_namespace: str = "ogl", include_in_schema: bool = True) -> None:
22+
app.instrumentator = (
23+
Instrumentator()
24+
.instrument(app=app)
25+
.add(
26+
metrics.default(metric_namespace=metric_namespace),
27+
inference_output_tokens_per_second(metric_namespace=metric_namespace),
28+
inference_requests_total(metric_namespace=metric_namespace),
29+
inference_requests_duration_seconds(metric_namespace=metric_namespace),
30+
inference_ttft_milliseconds(metric_namespace=metric_namespace),
31+
inference_tokens_total(metric_namespace=metric_namespace),
32+
)
33+
.expose(app)
34+
)
1635

1736
@app.get(
1837
path="/metrics",
1938
tags=[RouterName.MONITORING.title()],
2039
dependencies=[Depends(dependency=AccessController(permissions=[PermissionType.READ_METRIC]))],
2140
include_in_schema=include_in_schema,
2241
)
23-
def metrics() -> Response:
42+
def get_metrics() -> Response:
2443
if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
2544
registry = CollectorRegistry()
2645
multiprocess.MultiProcessCollector(registry)

api/helpers/load_balancing/_leastbusyloadbalancingstrategy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def __init__(self, redis_client: AsyncRedis | Redis, load_balancing_metric: Metr
2020
Get a provider to handle the request based on the specified routing strategy.
2121
2222
Args:
23-
candidates (list[int]): The list of provider candidates (provider IDs) to choose from
2423
redis_client (AsyncRedis): Redis client instance, required for least busy strategy
2524
load_balancing_metric (Metric): The type of metric to use for performance evaluation
2625

api/utils/hooks_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ async def update_budget(usage: Usage):
189189
# Update the budget
190190
update_stmt = update(User).where(User.id == user_id).values(budget=new_budget, updated=func.now()).returning(User.budget)
191191

192-
result = await postgres_session.execute(update_stmt)
192+
await postgres_session.execute(update_stmt)
193193

194194
except Exception as e:
195195
logger.exception(f"Failed to update budget for user {user_id}: {e}")

api/utils/monitoring.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from collections.abc import Callable
2+
3+
from prometheus_client import Counter, Histogram
4+
from prometheus_fastapi_instrumentator.metrics import Info
5+
6+
from api.utils.context import request_context
7+
8+
9+
def _build_metric_name(namespace: str, name: str) -> str:
10+
return f"{namespace}_{name}" if namespace else name
11+
12+
13+
def inference_requests_total(metric_namespace: str = "") -> Callable[[Info], None]:
14+
metric_name = _build_metric_name(metric_namespace, "inference_requests_total")
15+
metric = Counter(
16+
metric_name,
17+
"Total number of LLM requests.",
18+
labelnames=("endpoint", "model", "status_code"),
19+
)
20+
21+
def instrumentation(info: Info) -> None:
22+
try:
23+
context = request_context.get()
24+
model = context.router_name
25+
endpoint = context.endpoint
26+
if model and endpoint:
27+
metric.labels(endpoint=endpoint, model=model, status_code=info.modified_status).inc()
28+
except Exception:
29+
pass
30+
31+
return instrumentation
32+
33+
34+
def inference_requests_duration_seconds(metric_namespace: str = "") -> Callable[[Info], None]:
35+
metric_name = _build_metric_name(metric_namespace, "inference_requests_duration_seconds")
36+
metric = Histogram(
37+
metric_name,
38+
"Duration of LLM requests in seconds.",
39+
labelnames=("endpoint", "model", "status_code"),
40+
buckets=(
41+
0.05,
42+
0.1,
43+
0.2,
44+
0.3,
45+
0.4,
46+
0.5,
47+
0.75,
48+
1,
49+
1.5,
50+
2,
51+
2.5,
52+
3,
53+
3.5,
54+
4,
55+
4.5,
56+
5,
57+
6,
58+
7,
59+
8,
60+
9,
61+
10,
62+
15,
63+
20,
64+
25,
65+
30,
66+
45,
67+
60,
68+
75,
69+
90,
70+
105,
71+
120,
72+
150,
73+
180,
74+
210,
75+
240,
76+
270,
77+
300,
78+
),
79+
)
80+
81+
def instrumentation(info: Info) -> None:
82+
try:
83+
context = request_context.get()
84+
model = context.router_name
85+
endpoint = context.endpoint
86+
latency = context.latency
87+
if model and endpoint and latency is not None:
88+
metric.labels(
89+
endpoint=endpoint,
90+
model=model,
91+
status_code=info.modified_status,
92+
).observe(latency / 1000)
93+
except Exception:
94+
pass
95+
96+
return instrumentation
97+
98+
99+
def inference_ttft_milliseconds(metric_namespace: str = "") -> Callable[[Info], None]:
100+
metric_name = _build_metric_name(metric_namespace, "inference_ttft_milliseconds")
101+
metric = Histogram(
102+
metric_name,
103+
"Time to first token for streaming LLM responses in milliseconds.",
104+
labelnames=("endpoint", "model", "status_code"),
105+
buckets=(
106+
5,
107+
10,
108+
20,
109+
30,
110+
50,
111+
75,
112+
100,
113+
150,
114+
200,
115+
300,
116+
500,
117+
750,
118+
1000,
119+
1500,
120+
2000,
121+
3000,
122+
5000,
123+
7500,
124+
10000,
125+
15000,
126+
20000,
127+
25000,
128+
30000,
129+
45000,
130+
60000,
131+
75000,
132+
90000,
133+
105000,
134+
120000,
135+
135000,
136+
150000,
137+
165000,
138+
180000,
139+
210000,
140+
240000,
141+
270000,
142+
300000,
143+
),
144+
)
145+
146+
def instrumentation(info: Info) -> None:
147+
try:
148+
context = request_context.get()
149+
model = context.router_name
150+
endpoint = context.endpoint
151+
ttft = context.ttft
152+
if model and endpoint and ttft is not None:
153+
metric.labels(endpoint=endpoint, model=model, status_code=info.modified_status).observe(ttft)
154+
except Exception:
155+
pass
156+
157+
return instrumentation
158+
159+
160+
def inference_output_tokens_per_second(metric_namespace: str = "") -> Callable[[Info], None]:
161+
metric_name = _build_metric_name(metric_namespace, "inference_output_tokens_per_second")
162+
metric = Histogram(
163+
metric_name,
164+
"Output generation speed in tokens per second (completion tokens / request duration, TTFT included).",
165+
labelnames=("endpoint", "model"),
166+
buckets=(5, 10, 20, 30, 50, 75, 85, 90, 95, 100, 105, 110, 115, 125, 150, 175, 200, 250, 300, 400, 500, 750, 1000),
167+
)
168+
169+
def instrumentation(info: Info) -> None:
170+
try:
171+
context = request_context.get()
172+
model = context.router_name
173+
endpoint = context.endpoint
174+
usage = context.usage
175+
latency = context.latency
176+
if model and endpoint and usage and latency and usage.completion_tokens:
177+
metric.labels(endpoint=endpoint, model=model).observe(usage.completion_tokens / (latency / 1000))
178+
except Exception:
179+
pass
180+
181+
return instrumentation
182+
183+
184+
def inference_tokens_total(metric_namespace: str = "") -> Callable[[Info], None]:
185+
metric_name = _build_metric_name(metric_namespace, "inference_tokens_total")
186+
metric = Counter(
187+
metric_name,
188+
"Total number of tokens consumed (prompt and completion).",
189+
labelnames=("endpoint", "model", "type"),
190+
)
191+
192+
def instrumentation(info: Info) -> None:
193+
try:
194+
context = request_context.get()
195+
model = context.router_name
196+
endpoint = context.endpoint
197+
usage = context.usage
198+
if model and endpoint and usage is not None:
199+
if usage.prompt_tokens:
200+
metric.labels(endpoint=endpoint, model=model, type="prompt").inc(usage.prompt_tokens)
201+
if usage.completion_tokens:
202+
metric.labels(endpoint=endpoint, model=model, type="completion").inc(usage.completion_tokens)
203+
except Exception:
204+
pass
205+
206+
return instrumentation

0 commit comments

Comments
 (0)