Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11330.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Migrate kernel `live_stat` GraphQL resolver from Valkey to Prometheus while preserving the legacy wire shape
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ai.backend.common.clients.prometheus.preset import LabelMatcher, MetricPreset
from ai.backend.common.clients.prometheus.querier import ContainerMetricQuerier
from ai.backend.common.clients.prometheus.types import ValueType
from ai.backend.common.exception import UnreachableError
from ai.backend.common.metrics.types import (
CONTAINER_UTILIZATION_METRIC_LABEL_NAME,
CONTAINER_UTILIZATION_METRIC_NAME,
Expand Down Expand Up @@ -152,5 +151,3 @@ def _get_template(self, metric_type: MetricType) -> str:
return _RATE_TEMPLATE
case MetricType.DIFF:
return _DIFF_TEMPLATE
case _:
raise UnreachableError(f"Unknown metric type: {metric_type}")
36 changes: 25 additions & 11 deletions src/ai/backend/manager/api/gql_legacy/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
KernelId,
SessionId,
)
from ai.backend.manager.api.gql_legacy.stat_converter import LegacyLiveStatConverter
from ai.backend.manager.data.kernel.types import KernelStatus
from ai.backend.manager.defs import DEFAULT_ROLE
from ai.backend.manager.models.group import groups
Expand All @@ -42,6 +43,7 @@
QueryFilterParser,
)
from ai.backend.manager.models.user import UserRole, users
from ai.backend.manager.services.metric.actions.live_stat import ContainerLiveStatAction

from .base import (
BigInt,
Expand All @@ -67,6 +69,19 @@
)


async def _batch_load_kernel_live_stat(
Comment thread
jopemachine marked this conversation as resolved.
ctx: GraphQueryContext,
kernel_ids: Sequence[KernelId],
) -> list[dict[str, Any] | None]:
if not kernel_ids:
return []
action_result = await ctx.processors.metric.query_container_live_stat.wait_for_complete(
ContainerLiveStatAction(kernel_ids=list(kernel_ids))
)
converted = LegacyLiveStatConverter.convert(action_result.stats)
return [converted.get(kid) for kid in kernel_ids]


class KernelNode(graphene.ObjectType): # type: ignore[misc]
class Meta:
interfaces = (AsyncNode,)
Expand Down Expand Up @@ -190,17 +205,10 @@ async def resolve_image(self, info: graphene.ResolveInfo) -> ImageNode | None:
async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, self.batch_load_live_stat
graph_ctx, _batch_load_kernel_live_stat
)
return cast(dict[str, Any] | None, await loader.load(self.row_id))

@classmethod
async def batch_load_live_stat(
cls, ctx: GraphQueryContext, kernel_ids: Sequence[KernelId]
) -> list[dict[str, Any] | None]:
kernel_ids_str = [str(kid) for kid in kernel_ids]
return await ctx.valkey_stat.get_session_statistics_batch(kernel_ids_str)

Comment on lines -193 to -203
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like code that doesn't really need to be changed is being modified, and this makes it difficult to read the code.


class KernelConnection(Connection):
class Meta:
Expand Down Expand Up @@ -313,7 +321,9 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow | None) -> ComputeConta
# we can leave last_stat value for legacy support, as an alias to last_stat
async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Mapping[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, "KernelStatistics.by_kernel")
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, _batch_load_kernel_live_stat
)
return cast(Mapping[str, Any] | None, await loader.load(self.id))

async def resolve_last_stat(self, info: graphene.ResolveInfo) -> Mapping[str, Any] | None:
Expand Down Expand Up @@ -606,7 +616,9 @@ class Meta:
# we can leave last_stat value for legacy support, as an alias to last_stat
async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Mapping[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, "KernelStatistics.by_kernel")
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, _batch_load_kernel_live_stat
)
return cast(Mapping[str, Any] | None, await loader.load(self.id))

async def resolve_last_stat(self, info: graphene.ResolveInfo) -> Mapping[str, Any] | None:
Expand All @@ -632,7 +644,9 @@ async def _resolve_legacy_metric(
if value is None:
return convert_type(0)
return convert_type(value)
loader = graph_ctx.dataloader_manager.get_loader(graph_ctx, "KernelStatistics.by_kernel")
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, _batch_load_kernel_live_stat
)
kstat = await loader.load(self.id)
if kstat is None:
return convert_type(0)
Expand Down
150 changes: 150 additions & 0 deletions src/ai/backend/manager/api/gql_legacy/stat_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from collections.abc import Iterable
from typing import Final

from ai.backend.common.clients.prometheus.metric_types import KernelLiveStatBatchResult
from ai.backend.common.clients.prometheus.types import MetricValue as PrometheusMetricValue
from ai.backend.common.clients.prometheus.types import ValueType
from ai.backend.common.metrics.types import UTILIZATION_METRIC_INTERVAL
from ai.backend.common.types import KernelId, MetricValue

# Metric-name classification used only while adapting Prometheus samples back
# into the legacy live_stat dict that Graphene/WebUI still expects.
_RATE_STAT_METRICS: Final[frozenset[str]] = frozenset({"net_rx", "net_tx"})
_DIFF_STAT_METRICS: Final[frozenset[str]] = frozenset({"cpu_util"})

# Per-metric unit hint emitted by the agent (source of truth:
# src/ai/backend/agent/docker/intrinsic.py).
_METRIC_UNIT_HINTS: Final[dict[str, str]] = {
"cpu_used": "msec",
"cpu_util": "percent",
"mem": "bytes",
"net_rx": "bps",
"net_tx": "bps",
"io_read": "bytes",
"io_write": "bytes",
"io_scratch_size": "bytes",
}


def _make_default_metric_value(unit_hint: str) -> MetricValue:
return MetricValue({
"current": "0",
"capacity": "0",
"pct": "0",
"unit_hint": unit_hint,
"stats.min": "0",
"stats.max": "0",
"stats.sum": "0",
"stats.avg": "0",
"stats.diff": "0",
"stats.rate": "0",
"stats.version": None,
})


def _resolve_unit_hint(metric_name: str) -> str:
if metric_name in _METRIC_UNIT_HINTS:
return _METRIC_UNIT_HINTS[metric_name]
if metric_name.endswith("_util"):
return "percent"
if metric_name == "mem" or metric_name.endswith("_mem"):
return "bytes"
if metric_name.startswith("io_"):
return "bytes"
if metric_name.startswith("net_"):
return "bps"
return metric_name


class LegacyLiveStatConverter:
Comment thread
seedspirit marked this conversation as resolved.
"""Adapt `KernelLiveStatBatchResult` into the legacy
`dict[metric_name, MetricValue]` shape consumed by GQL/WebUI.

Merge order from upstream is gauge -> diff -> rate, so for
RATE/DIFF metrics the same `(name, CURRENT)` tuple appears twice;
`currents[0]` is the raw gauge sample, `currents[-1]` is the
rate/diff query result.
Comment on lines +63 to +66
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR, I plan to refactor the code so that it can be converted from an index-based to a type-based approach. Since this would require modifying the response merge logic, I did not include it in this PR.


`stats.max` / `stats.avg` are not populated
"""

@classmethod
def convert(
cls, result: KernelLiveStatBatchResult
) -> dict[KernelId, dict[str, MetricValue] | None]:
out: dict[KernelId, dict[str, MetricValue] | None] = {}
for kernel_id, entry in result.entries.items():
if not entry.values:
out[kernel_id] = None
continue
out[kernel_id] = cls._convert_one_kernel(entry.values)
return out

@classmethod
def _convert_one_kernel(cls, values: Iterable[PrometheusMetricValue]) -> dict[str, MetricValue]:
grouped: dict[str, list[PrometheusMetricValue]] = {}
for v in values:
grouped.setdefault(v.metric_name, []).append(v)

per_metric: dict[str, MetricValue] = {}
for name, samples in grouped.items():
per_metric[name] = cls._convert_metric_samples(name, samples)
return per_metric

@staticmethod
def _convert_metric_samples(
metric_name: str, samples: list[PrometheusMetricValue]
) -> MetricValue:
# `_resolve_unit_hint` falls back to naming conventions and finally
# the metric_name itself for unregistered plugin metrics.
unit_hint = _resolve_unit_hint(metric_name)
out = _make_default_metric_value(unit_hint=unit_hint)

currents = [s.value for s in samples if s.value_type is ValueType.CURRENT]
capacities = [s.value for s in samples if s.value_type is ValueType.CAPACITY]
pcts = [s.value for s in samples if s.value_type is ValueType.PCT]

is_rate_metric = metric_name in _RATE_STAT_METRICS
is_diff_metric = metric_name in _DIFF_STAT_METRICS

if currents:
# RATE/DIFF: prefer the rate/diff query result over the raw gauge,
# mirroring the legacy `current_hook=stats.rate|diff` behavior.
if (is_rate_metric or is_diff_metric) and len(currents) > 1:
out["current"] = currents[-1]
else:
out["current"] = currents[0]
if capacities:
out["capacity"] = capacities[-1]

if is_rate_metric and currents:
# RATE template applies `/ UTILIZATION_METRIC_INTERVAL`; undo it
# here to recover the per-second magnitude legacy `stats.rate` had.
# TODO: separate the rate query from the gauge query so this
# hack-multiply isn't needed.
try:
rate_value = float(currents[-1]) * UTILIZATION_METRIC_INTERVAL
out["stats.rate"] = f"{rate_value:.6f}"
except ValueError:
out["stats.rate"] = currents[-1]
if is_diff_metric and currents:
# Per-second rate, not the legacy per-5s delta — GQL consumers
# only read `cpu_util.pct`, so magnitude mismatch is acceptable.
out["stats.diff"] = currents[-1]

# Derive pct from current/capacity when no PCT sample was emitted.
if pcts:
out["pct"] = pcts[-1]
else:
try:
current_value = float(out["current"])
capacity = out["capacity"]
if capacity is None:
return out
capacity_value = float(capacity)
if capacity_value > 0:
out["pct"] = f"{current_value / capacity_value * 100:.2f}"
except ValueError:
pass

return out
9 changes: 0 additions & 9 deletions src/ai/backend/manager/api/gql_legacy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@ async def batch_load_by_kernel_impl(
session_ids_str = [str(sess_id) for sess_id in session_ids]
return await valkey_stat_client.get_session_statistics_batch(session_ids_str)

@classmethod
async def batch_load_by_kernel(
cls,
ctx: GraphQueryContext,
session_ids: Sequence[SessionId],
) -> Sequence[Mapping[str, Any] | None]:
"""wrapper of `KernelStatistics.batch_load_by_kernel_impl()` for aiodataloader"""
return await cls.batch_load_by_kernel_impl(ctx.valkey_stat, session_ids)

@classmethod
async def batch_load_inference_metrics_by_kernel(
cls,
Expand Down
Loading
Loading