Skip to content

Commit 8095984

Browse files
committed
feat: Add legacy converter
1 parent d187109 commit 8095984

6 files changed

Lines changed: 438 additions & 23 deletions

File tree

changes/11330.enhance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Migrate kernel `live_stat` GraphQL resolver from Valkey to Prometheus while preserving the legacy wire shape

src/ai/backend/common/clients/prometheus/fixed_query_builder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ai.backend.common.clients.prometheus.preset import LabelMatcher, MetricPreset
1414
from ai.backend.common.clients.prometheus.querier import ContainerMetricQuerier
1515
from ai.backend.common.clients.prometheus.types import ValueType
16-
from ai.backend.common.exception import UnreachableError
1716
from ai.backend.common.metrics.types import (
1817
CONTAINER_UTILIZATION_METRIC_LABEL_NAME,
1918
CONTAINER_UTILIZATION_METRIC_NAME,
@@ -152,5 +151,3 @@ def _get_template(self, metric_type: MetricType) -> str:
152151
return _RATE_TEMPLATE
153152
case MetricType.DIFF:
154153
return _DIFF_TEMPLATE
155-
case _:
156-
raise UnreachableError(f"Unknown metric type: {metric_type}")

src/ai/backend/manager/api/gql_legacy/kernel.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
KernelId,
2525
SessionId,
2626
)
27+
from ai.backend.manager.api.gql_legacy.stat_converter import LegacyLiveStatConverter
2728
from ai.backend.manager.data.kernel.types import KernelStatus
2829
from ai.backend.manager.defs import DEFAULT_ROLE
2930
from ai.backend.manager.models.group import groups
@@ -42,6 +43,7 @@
4243
QueryFilterParser,
4344
)
4445
from ai.backend.manager.models.user import UserRole, users
46+
from ai.backend.manager.services.metric.actions.live_stat import ContainerLiveStatAction
4547

4648
from .base import (
4749
BigInt,
@@ -67,6 +69,19 @@
6769
)
6870

6971

72+
async def _batch_load_kernel_live_stat(
73+
ctx: GraphQueryContext,
74+
kernel_ids: Sequence[KernelId],
75+
) -> list[dict[str, Any] | None]:
76+
if not kernel_ids:
77+
return []
78+
action_result = await ctx.processors.metric.query_container_live_stat.wait_for_complete(
79+
ContainerLiveStatAction(kernel_ids=list(kernel_ids))
80+
)
81+
converted = LegacyLiveStatConverter.convert(action_result.stats)
82+
return [converted.get(kid) for kid in kernel_ids]
83+
84+
7085
class KernelNode(graphene.ObjectType): # type: ignore[misc]
7186
class Meta:
7287
interfaces = (AsyncNode,)
@@ -190,17 +205,10 @@ async def resolve_image(self, info: graphene.ResolveInfo) -> ImageNode | None:
190205
async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
191206
graph_ctx: GraphQueryContext = info.context
192207
loader = graph_ctx.dataloader_manager.get_loader_by_func(
193-
graph_ctx, self.batch_load_live_stat
208+
graph_ctx, _batch_load_kernel_live_stat
194209
)
195210
return cast(dict[str, Any] | None, await loader.load(self.row_id))
196211

197-
@classmethod
198-
async def batch_load_live_stat(
199-
cls, ctx: GraphQueryContext, kernel_ids: Sequence[KernelId]
200-
) -> list[dict[str, Any] | None]:
201-
kernel_ids_str = [str(kid) for kid in kernel_ids]
202-
return await ctx.valkey_stat.get_session_statistics_batch(kernel_ids_str)
203-
204212

205213
class KernelConnection(Connection):
206214
class Meta:
@@ -313,7 +321,9 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow | None) -> ComputeConta
313321
# we can leave last_stat value for legacy support, as an alias to last_stat
314322
async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Mapping[str, Any] | None:
315323
graph_ctx: GraphQueryContext = info.context
316-
loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, "KernelStatistics.by_kernel")
324+
loader = graph_ctx.dataloader_manager.get_loader_by_func(
325+
graph_ctx, _batch_load_kernel_live_stat
326+
)
317327
return cast(Mapping[str, Any] | None, await loader.load(self.id))
318328

319329
async def resolve_last_stat(self, info: graphene.ResolveInfo) -> Mapping[str, Any] | None:
@@ -606,7 +616,9 @@ class Meta:
606616
# we can leave last_stat value for legacy support, as an alias to last_stat
607617
async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Mapping[str, Any] | None:
608618
graph_ctx: GraphQueryContext = info.context
609-
loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, "KernelStatistics.by_kernel")
619+
loader = graph_ctx.dataloader_manager.get_loader_by_func(
620+
graph_ctx, _batch_load_kernel_live_stat
621+
)
610622
return cast(Mapping[str, Any] | None, await loader.load(self.id))
611623

612624
async def resolve_last_stat(self, info: graphene.ResolveInfo) -> Mapping[str, Any] | None:
@@ -632,7 +644,9 @@ async def _resolve_legacy_metric(
632644
if value is None:
633645
return convert_type(0)
634646
return convert_type(value)
635-
loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, "KernelStatistics.by_kernel")
647+
loader = graph_ctx.dataloader_manager.get_loader_by_func(
648+
graph_ctx, _batch_load_kernel_live_stat
649+
)
636650
kstat = await loader.load(self.id)
637651
if kstat is None:
638652
return convert_type(0)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from collections.abc import Iterable
2+
from typing import Final
3+
4+
from ai.backend.common.clients.prometheus.metric_types import KernelLiveStatBatchResult
5+
from ai.backend.common.clients.prometheus.types import MetricValue as PrometheusMetricValue
6+
from ai.backend.common.clients.prometheus.types import ValueType
7+
from ai.backend.common.metrics.types import UTILIZATION_METRIC_INTERVAL
8+
from ai.backend.common.types import KernelId, MetricValue
9+
10+
# Metric-name classification used only while adapting Prometheus samples back
11+
# into the legacy live_stat dict that Graphene/WebUI still expects.
12+
_RATE_STAT_METRICS: Final[frozenset[str]] = frozenset({"net_rx", "net_tx"})
13+
_DIFF_STAT_METRICS: Final[frozenset[str]] = frozenset({"cpu_util"})
14+
15+
# Per-metric unit hint emitted by the agent (source of truth:
16+
# src/ai/backend/agent/docker/intrinsic.py).
17+
_METRIC_UNIT_HINTS: Final[dict[str, str]] = {
18+
"cpu_used": "msec",
19+
"cpu_util": "percent",
20+
"mem": "bytes",
21+
"net_rx": "bps",
22+
"net_tx": "bps",
23+
"io_read": "bytes",
24+
"io_write": "bytes",
25+
"io_scratch_size": "bytes",
26+
}
27+
28+
29+
def _make_default_metric_value(unit_hint: str) -> MetricValue:
30+
return MetricValue({
31+
"current": "0",
32+
"capacity": "0",
33+
"pct": "0",
34+
"unit_hint": unit_hint,
35+
"stats.min": "0",
36+
"stats.max": "0",
37+
"stats.sum": "0",
38+
"stats.avg": "0",
39+
"stats.diff": "0",
40+
"stats.rate": "0",
41+
"stats.version": None,
42+
})
43+
44+
45+
def _resolve_unit_hint(metric_name: str) -> str:
46+
if metric_name in _METRIC_UNIT_HINTS:
47+
return _METRIC_UNIT_HINTS[metric_name]
48+
if metric_name.endswith("_util"):
49+
return "percent"
50+
if metric_name == "mem" or metric_name.endswith("_mem"):
51+
return "bytes"
52+
if metric_name.startswith("io_"):
53+
return "bytes"
54+
if metric_name.startswith("net_"):
55+
return "bps"
56+
return metric_name
57+
58+
59+
class LegacyLiveStatConverter:
60+
"""Adapt `KernelLiveStatBatchResult` into the legacy
61+
`dict[metric_name, MetricValue]` shape consumed by GQL/WebUI.
62+
63+
Merge order from upstream is gauge -> diff -> rate, so for
64+
RATE/DIFF metrics the same `(name, CURRENT)` tuple appears twice;
65+
`currents[0]` is the raw gauge sample, `currents[-1]` is the
66+
rate/diff query result.
67+
68+
`stats.max` / `stats.avg` are not populated
69+
"""
70+
71+
@classmethod
72+
def convert(
73+
cls, result: KernelLiveStatBatchResult
74+
) -> dict[KernelId, dict[str, MetricValue] | None]:
75+
out: dict[KernelId, dict[str, MetricValue] | None] = {}
76+
for kernel_id, entry in result.entries.items():
77+
if not entry.values:
78+
out[kernel_id] = None
79+
continue
80+
out[kernel_id] = cls._convert_one_kernel(entry.values)
81+
return out
82+
83+
@classmethod
84+
def _convert_one_kernel(cls, values: Iterable[PrometheusMetricValue]) -> dict[str, MetricValue]:
85+
grouped: dict[str, list[PrometheusMetricValue]] = {}
86+
for v in values:
87+
grouped.setdefault(v.metric_name, []).append(v)
88+
89+
per_metric: dict[str, MetricValue] = {}
90+
for name, samples in grouped.items():
91+
per_metric[name] = cls._convert_metric_samples(name, samples)
92+
return per_metric
93+
94+
@staticmethod
95+
def _convert_metric_samples(
96+
metric_name: str, samples: list[PrometheusMetricValue]
97+
) -> MetricValue:
98+
# `_resolve_unit_hint` falls back to naming conventions and finally
99+
# the metric_name itself for unregistered plugin metrics.
100+
unit_hint = _resolve_unit_hint(metric_name)
101+
out = _make_default_metric_value(unit_hint=unit_hint)
102+
103+
currents = [s.value for s in samples if s.value_type is ValueType.CURRENT]
104+
capacities = [s.value for s in samples if s.value_type is ValueType.CAPACITY]
105+
pcts = [s.value for s in samples if s.value_type is ValueType.PCT]
106+
107+
is_rate_metric = metric_name in _RATE_STAT_METRICS
108+
is_diff_metric = metric_name in _DIFF_STAT_METRICS
109+
110+
if currents:
111+
# RATE/DIFF: prefer the rate/diff query result over the raw gauge,
112+
# mirroring the legacy `current_hook=stats.rate|diff` behavior.
113+
if (is_rate_metric or is_diff_metric) and len(currents) > 1:
114+
out["current"] = currents[-1]
115+
else:
116+
out["current"] = currents[0]
117+
if capacities:
118+
out["capacity"] = capacities[-1]
119+
120+
if is_rate_metric and currents:
121+
# RATE template applies `/ UTILIZATION_METRIC_INTERVAL`; undo it
122+
# here to recover the per-second magnitude legacy `stats.rate` had.
123+
# TODO: separate the rate query from the gauge query so this
124+
# hack-multiply isn't needed.
125+
try:
126+
rate_value = float(currents[-1]) * UTILIZATION_METRIC_INTERVAL
127+
out["stats.rate"] = f"{rate_value:.6f}"
128+
except ValueError:
129+
out["stats.rate"] = currents[-1]
130+
if is_diff_metric and currents:
131+
# Per-second rate, not the legacy per-5s delta — GQL consumers
132+
# only read `cpu_util.pct`, so magnitude mismatch is acceptable.
133+
out["stats.diff"] = currents[-1]
134+
135+
# Derive pct from current/capacity when no PCT sample was emitted.
136+
if pcts:
137+
out["pct"] = pcts[-1]
138+
else:
139+
try:
140+
current_value = float(out["current"])
141+
capacity = out["capacity"]
142+
if capacity is None:
143+
return out
144+
capacity_value = float(capacity)
145+
if capacity_value > 0:
146+
out["pct"] = f"{current_value / capacity_value * 100:.2f}"
147+
except ValueError:
148+
pass
149+
150+
return out

src/ai/backend/manager/api/gql_legacy/statistics.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,6 @@ async def batch_load_by_kernel_impl(
2828
session_ids_str = [str(sess_id) for sess_id in session_ids]
2929
return await valkey_stat_client.get_session_statistics_batch(session_ids_str)
3030

31-
@classmethod
32-
async def batch_load_by_kernel(
33-
cls,
34-
ctx: GraphQueryContext,
35-
session_ids: Sequence[SessionId],
36-
) -> Sequence[Mapping[str, Any] | None]:
37-
"""wrapper of `KernelStatistics.batch_load_by_kernel_impl()` for aiodataloader"""
38-
return await cls.batch_load_by_kernel_impl(ctx.valkey_stat, session_ids)
39-
4031
@classmethod
4132
async def batch_load_inference_metrics_by_kernel(
4233
cls,

0 commit comments

Comments
 (0)