Skip to content

Commit 7d1ad28

Browse files
HyeockJinKimclaude
andauthored
feat(BA-5065): expose /metrics endpoint for Custom inference runtime variant (#9984)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 9efee39 commit 7d1ad28

3 files changed

Lines changed: 235 additions & 0 deletions

File tree

changes/9984.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add metrics collection support for Custom inference runtime variant via `/metrics` endpoint.

src/ai/backend/appproxy/worker/metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ async def gather_inference_measures(
288288
)
289289
)
290290
return measures
291+
case RuntimeVariant.CUSTOM:
292+
return await gather_prometheus_inference_measures(client_pool, circuit.route_info)
291293
case _:
292294
return None
293295

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""
2+
Tests for Custom inference runtime variant metrics collection.
3+
4+
Verifies that RuntimeVariant.CUSTOM:
5+
- Returns all metrics without prefix filtering (same as HUGGINGFACE_TGI)
6+
- Handles multiple replicas correctly
7+
- Returns a list (not None) so metrics are collected
8+
9+
Contrast with NIM/CMD which fall through to the catch-all case and return None.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
from datetime import UTC, datetime
15+
from decimal import Decimal
16+
from typing import Any
17+
from uuid import UUID, uuid4
18+
19+
import pytest
20+
21+
from ai.backend.appproxy.common.types import (
22+
AppMode,
23+
FrontendMode,
24+
ProxyProtocol,
25+
RouteInfo,
26+
)
27+
from ai.backend.appproxy.worker.metrics import gather_inference_measures
28+
from ai.backend.appproxy.worker.types import (
29+
Circuit,
30+
InferenceAppInfo,
31+
Measurement,
32+
PortFrontendInfo,
33+
)
34+
from ai.backend.common.types import ModelServiceStatus, RuntimeVariant
35+
36+
# Sample Prometheus output that intentionally mixes multiple prefixes
37+
# to verify that CUSTOM variant returns ALL metrics without any filtering.
38+
SAMPLE_CUSTOM_METRICS = """\
39+
# HELP custom_metric_a A custom metric
40+
# TYPE custom_metric_a gauge
41+
custom_metric_a 42
42+
# HELP vllm:num_requests_running vLLM metric included in custom endpoint
43+
# TYPE vllm:num_requests_running gauge
44+
vllm:num_requests_running 5
45+
# HELP app_requests_total Total requests processed
46+
# TYPE app_requests_total gauge
47+
app_requests_total 100
48+
"""
49+
50+
51+
def _make_route(
52+
kernel_host: str = "10.0.0.1",
53+
kernel_port: int = 8080,
54+
health_status: ModelServiceStatus | None = ModelServiceStatus.HEALTHY,
55+
) -> RouteInfo:
56+
return RouteInfo(
57+
route_id=uuid4(),
58+
session_id=uuid4(),
59+
session_name=None,
60+
kernel_host=kernel_host,
61+
kernel_port=kernel_port,
62+
protocol=ProxyProtocol.HTTP,
63+
traffic_ratio=1.0,
64+
health_status=health_status,
65+
last_health_check=None,
66+
consecutive_failures=0,
67+
)
68+
69+
70+
def _make_custom_circuit(routes: list[RouteInfo], port: int = 10300) -> Circuit:
71+
return Circuit(
72+
id=uuid4(),
73+
app="custom-model",
74+
protocol=ProxyProtocol.HTTP,
75+
worker=UUID("00000000-0000-0000-0000-000000000000"),
76+
app_mode=AppMode.INFERENCE,
77+
frontend_mode=FrontendMode.PORT,
78+
frontend=PortFrontendInfo(port),
79+
port=port,
80+
app_info=InferenceAppInfo(
81+
endpoint_id=uuid4(),
82+
runtime_variant=RuntimeVariant.CUSTOM,
83+
),
84+
subdomain=None,
85+
runtime_variant=RuntimeVariant.CUSTOM,
86+
envs={},
87+
arguments=None,
88+
open_to_public=False,
89+
allowed_client_ips=None,
90+
user_id=uuid4(),
91+
access_key="TESTKEY",
92+
endpoint_id=None,
93+
route_info=routes,
94+
session_ids=[r.session_id for r in routes],
95+
created_at=datetime(2024, 7, 16, 5, 45, 45, tzinfo=UTC),
96+
updated_at=datetime(2024, 7, 16, 5, 45, 45, tzinfo=UTC),
97+
)
98+
99+
100+
def _make_unsupported_circuit(
101+
variant: RuntimeVariant, routes: list[RouteInfo], port: int = 10301
102+
) -> Circuit:
103+
return Circuit(
104+
id=uuid4(),
105+
app="model",
106+
protocol=ProxyProtocol.HTTP,
107+
worker=UUID("00000000-0000-0000-0000-000000000000"),
108+
app_mode=AppMode.INFERENCE,
109+
frontend_mode=FrontendMode.PORT,
110+
frontend=PortFrontendInfo(port),
111+
port=port,
112+
app_info=InferenceAppInfo(
113+
endpoint_id=uuid4(),
114+
runtime_variant=variant,
115+
),
116+
subdomain=None,
117+
runtime_variant=variant,
118+
envs={},
119+
arguments=None,
120+
open_to_public=False,
121+
allowed_client_ips=None,
122+
user_id=uuid4(),
123+
access_key="TESTKEY",
124+
endpoint_id=None,
125+
route_info=routes,
126+
session_ids=[r.session_id for r in routes],
127+
created_at=datetime(2024, 7, 16, 5, 45, 45, tzinfo=UTC),
128+
updated_at=datetime(2024, 7, 16, 5, 45, 45, tzinfo=UTC),
129+
)
130+
131+
132+
class TestCustomRuntimeVariantMetrics:
133+
"""Tests for RuntimeVariant.CUSTOM metrics collection."""
134+
135+
async def test_custom_variant_collects_all_metrics_without_filtering(
136+
self, mock_metrics_client_pool: Any
137+
) -> None:
138+
"""CUSTOM variant must return all metrics regardless of prefix."""
139+
route = _make_route(kernel_host="10.0.0.1", kernel_port=8080)
140+
circuit = _make_custom_circuit([route])
141+
142+
responses = {
143+
f"http://{route.current_kernel_host}:{route.kernel_port}": SAMPLE_CUSTOM_METRICS
144+
}
145+
146+
async with mock_metrics_client_pool(responses) as (client_pool, _):
147+
measures = await gather_inference_measures(client_pool, circuit)
148+
149+
assert measures is not None, "CUSTOM variant must return a list, not None"
150+
assert len(measures) > 0, "CUSTOM variant must return at least one measure"
151+
152+
metric_keys = {m.key for m in measures}
153+
# All three metrics from SAMPLE_CUSTOM_METRICS should be present — no prefix filtering
154+
assert "custom_metric_a" in metric_keys
155+
assert "vllm:num_requests_running" in metric_keys
156+
assert "app_requests_total" in metric_keys
157+
158+
async def test_custom_variant_handles_multiple_replicas(
159+
self, mock_metrics_client_pool: Any
160+
) -> None:
161+
"""CUSTOM variant must aggregate metrics across multiple replicas."""
162+
route1 = _make_route(kernel_host="10.0.0.1", kernel_port=8080)
163+
route2 = _make_route(kernel_host="10.0.0.2", kernel_port=8081)
164+
circuit = _make_custom_circuit([route1, route2])
165+
166+
REPLICA_METRICS = """\
167+
# HELP custom_metric_a A custom metric
168+
# TYPE custom_metric_a gauge
169+
custom_metric_a 10
170+
"""
171+
responses = {
172+
f"http://{route1.current_kernel_host}:{route1.kernel_port}": REPLICA_METRICS,
173+
f"http://{route2.current_kernel_host}:{route2.kernel_port}": REPLICA_METRICS,
174+
}
175+
176+
async with mock_metrics_client_pool(responses) as (client_pool, _):
177+
measures = await gather_inference_measures(client_pool, circuit)
178+
179+
assert measures is not None
180+
custom_a_measures = [m for m in measures if m.key == "custom_metric_a"]
181+
assert len(custom_a_measures) == 1
182+
183+
measure = custom_a_measures[0]
184+
# per_app should aggregate both replicas: 10 + 10 = 20
185+
assert isinstance(measure.per_app, Measurement)
186+
assert measure.per_app.value == Decimal(20)
187+
# per_replica should have one entry per route
188+
assert len(measure.per_replica) == 2
189+
190+
async def test_custom_variant_returns_empty_list_when_all_routes_fail(
191+
self, mock_metrics_client_pool: Any
192+
) -> None:
193+
"""CUSTOM variant returns empty list (not None) when all routes are unreachable."""
194+
route = _make_route(kernel_host="10.0.0.1", kernel_port=8080)
195+
circuit = _make_custom_circuit([route])
196+
197+
responses = {
198+
f"http://{route.current_kernel_host}:{route.kernel_port}": ConnectionError(
199+
"Connection refused"
200+
)
201+
}
202+
203+
async with mock_metrics_client_pool(responses) as (client_pool, _):
204+
measures = await gather_inference_measures(client_pool, circuit)
205+
206+
# gather_inference_measures returns the result of gather_prometheus_inference_measures,
207+
# which returns an empty list on failure — not None.
208+
assert measures is not None
209+
assert measures == []
210+
211+
@pytest.mark.parametrize(
212+
"variant",
213+
[
214+
pytest.param(RuntimeVariant.NIM, id="nim"),
215+
pytest.param(RuntimeVariant.CMD, id="cmd"),
216+
],
217+
)
218+
async def test_unsupported_variants_return_none(
219+
self, variant: RuntimeVariant, mock_metrics_client_pool: Any
220+
) -> None:
221+
"""NIM and CMD variants must still return None (unchanged behavior)."""
222+
route = _make_route()
223+
circuit = _make_unsupported_circuit(variant, [route])
224+
225+
responses = {
226+
f"http://{route.current_kernel_host}:{route.kernel_port}": SAMPLE_CUSTOM_METRICS
227+
}
228+
229+
async with mock_metrics_client_pool(responses) as (client_pool, _):
230+
measures = await gather_inference_measures(client_pool, circuit)
231+
232+
assert measures is None, f"{variant} must return None (unsupported variant)"

0 commit comments

Comments
 (0)