Skip to content

Commit 628cf81

Browse files
seedspiritclaude
andcommitted
refactor(BA-6039): rework container live-stat with explicit per-query pipeline
Replace the single-query/positional-merge live-stat pipeline with six explicit PromQL queries (instant / rate_current / max / rate_max / avg / rate_avg) and a partitioned KernelLiveStatBatchResult. Newly populates stats.max/stats.avg in the legacy converter, and drops the stats.rate * UTILIZATION_METRIC_INTERVAL undo-hack. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 54e5cf4 commit 628cf81

13 files changed

Lines changed: 517 additions & 748 deletions

File tree

src/ai/backend/common/clients/prometheus/client.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
ContainerMetricOptionalLabel,
1414
ContainerMetricResponseInfo,
1515
ContainerMetricResult,
16-
KernelLiveStatValues,
17-
KernelMetricValuesByKernel,
16+
KernelLiveStatBatchResult,
1817
MetricResultValue,
1918
)
2019
from ai.backend.common.clients.prometheus.preset import LabelMatcher
2120
from ai.backend.common.dto.clients.prometheus.request import QueryTimeRange
2221
from ai.backend.common.dto.clients.prometheus.response import (
2322
LabelValueResponse,
23+
PrometheusQueryData,
2424
PrometheusResponse,
2525
)
2626
from ai.backend.common.exception import (
@@ -82,15 +82,40 @@ async def fetch_container_metric(
8282
async def fetch_container_live_stats(
8383
self,
8484
kernel_ids: Sequence[KernelId],
85-
) -> KernelLiveStatValues:
85+
) -> KernelLiveStatBatchResult:
8686
queries = self._fixed_query_builder.get_container_live_stat_queries(kernel_ids)
87-
merged = KernelMetricValuesByKernel(values_by_kernel={})
88-
for preset in queries.to_list():
89-
response = await self._query_instant(preset)
90-
merged = merged.merged_with(
91-
KernelMetricValuesByKernel.from_prometheus_response(response)
92-
)
93-
return KernelLiveStatValues.with_capacity_sentinels(merged.values_by_kernel)
87+
88+
instant_res = await self._query_instant(queries.instant)
89+
rate_current_res = await self._query_instant(queries.rate_current)
90+
# rate_max/rate_avg wrap rate() first because cpu_util/net_rx/net_tx are cumulative counters
91+
# aggregating their raw values would just track the running total.
92+
max_res = await self._query_instant(queries.max)
93+
rate_max_res = await self._query_instant(queries.rate_max)
94+
avg_res = await self._query_instant(queries.avg)
95+
rate_avg_res = await self._query_instant(queries.rate_avg)
96+
97+
# The max/rate_max and avg/rate_avg queries read the same "current"
98+
# series, so we merge each pair to cover all data points regardless of
99+
# individual query result types.
100+
return KernelLiveStatBatchResult.from_responses(
101+
instant=instant_res,
102+
rate_current=rate_current_res,
103+
max=self._merge_prometheus_responses(
104+
max_res, rate_max_res, final_result_type=max_res.data.result_type
105+
),
106+
avg=self._merge_prometheus_responses(
107+
avg_res, rate_avg_res, final_result_type=avg_res.data.result_type
108+
),
109+
)
110+
111+
def _merge_prometheus_responses(
112+
self, first: PrometheusResponse, second: PrometheusResponse, *, final_result_type: str
113+
) -> PrometheusResponse:
114+
data = PrometheusQueryData(
115+
result_type=final_result_type,
116+
result=[*first.data.result, *second.data.result],
117+
)
118+
return first.model_copy(update={"data": data})
94119

95120
async def execute_preset(
96121
self,

src/ai/backend/common/clients/prometheus/fixed_query_builder.py

Lines changed: 64 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,25 @@
66
from ai.backend.common.clients.prometheus.metric_types import (
77
DIFF_METRICS,
88
RATE_METRICS,
9-
STATS_AVG_GAUGE_METRIC_PATTERNS,
10-
STATS_AVG_GAUGE_METRICS,
11-
STATS_AVG_OVER_RATE_METRICS,
12-
STATS_MAX_GAUGE_METRIC_PATTERNS,
13-
STATS_MAX_GAUGE_METRICS,
14-
STATS_MAX_OVER_RATE_METRICS,
15-
STATS_RATE_COUNTER_METRICS,
16-
STATS_RATE_GAUGE_METRICS,
179
ContainerLiveStatQueries,
1810
ContainerMetricOptionalLabel,
1911
MetricType,
2012
)
21-
from ai.backend.common.clients.prometheus.preset import LabelMatcher, MetricPreset
13+
from ai.backend.common.clients.prometheus.preset import MetricPreset
2214
from ai.backend.common.clients.prometheus.querier import ContainerMetricQuerier
2315
from ai.backend.common.clients.prometheus.types import ValueType
2416
from ai.backend.common.metrics.types import (
2517
CONTAINER_UTILIZATION_METRIC_LABEL_NAME,
2618
CONTAINER_UTILIZATION_METRIC_NAME,
27-
UTILIZATION_METRIC_INTERVAL,
2819
)
2920
from ai.backend.common.types import KernelId
3021

31-
_LIVE_STAT_GROUP_BY: Final[frozenset[str]] = frozenset({
32-
"kernel_id",
33-
"container_metric_name",
34-
"value_type",
35-
})
36-
3722
_GAUGE_TEMPLATE: Final[str] = (
3823
f"sum by ({{group_by}})({CONTAINER_UTILIZATION_METRIC_NAME}{{{{{{labels}}}}}})"
3924
)
4025
_RATE_TEMPLATE: Final[str] = (
4126
"sum by ({group_by})(rate("
4227
f"{CONTAINER_UTILIZATION_METRIC_NAME}{{{{{{labels}}}}}}[{{window}}]))"
43-
f" / {UTILIZATION_METRIC_INTERVAL}"
4428
)
4529
_DIFF_TEMPLATE: Final[str] = (
4630
"sum by ({group_by})(rate("
@@ -54,60 +38,43 @@ class LabelValuesQuery:
5438
metric_match: str
5539

5640

57-
@dataclass(frozen=True)
58-
class _LiveStatQuerySpec:
59-
template: str
60-
metric_name_filter: frozenset[str] | None = None
61-
value_type_filter: ValueType | None = None
62-
63-
64-
@dataclass(frozen=True)
65-
class _StatsBucket:
66-
"""Window-stats bucket spec (gauge metrics + rate metrics for a single stat)."""
67-
68-
value_type: ValueType
69-
gauge_metrics: frozenset[str]
70-
rate_metrics: frozenset[str]
71-
gauge_metric_patterns: frozenset[str] = frozenset()
72-
73-
7441
def _regex_union(values: Sequence[str]) -> str:
7542
return "|".join(re.escape(value).replace(r"\-", "-") for value in values)
7643

7744

78-
def _metric_name_regex(
79-
metric_names: frozenset[str],
80-
metric_patterns: frozenset[str] = frozenset(),
81-
) -> str:
82-
exact_parts = [re.escape(value) for value in sorted(metric_names)]
83-
return "|".join([*exact_parts, *sorted(metric_patterns)])
45+
def _value_type_regex(value_types: Sequence[ValueType]) -> str:
46+
return _regex_union([value_type.value for value_type in value_types])
8447

8548

86-
_GAUGE_LIVE_STAT_SPEC: Final[_LiveStatQuerySpec] = _LiveStatQuerySpec(
87-
template=_GAUGE_TEMPLATE,
49+
_LIVE_STAT_INSTANT_TEMPLATE: Final[str] = (
50+
f"sum by (kernel_id,container_metric_name,value_type)({CONTAINER_UTILIZATION_METRIC_NAME}"
51+
'{{kernel_id=~"{kernel_ids}",value_type=~"{value_types}"}})'
8852
)
89-
_DIFF_LIVE_STAT_SPEC: Final[_LiveStatQuerySpec] = _LiveStatQuerySpec(
90-
template=_DIFF_TEMPLATE,
91-
metric_name_filter=DIFF_METRICS,
92-
value_type_filter=ValueType.CURRENT,
53+
_LIVE_STAT_RATE_METRICS: Final[frozenset[str]] = RATE_METRICS | DIFF_METRICS
54+
55+
_LIVE_STAT_RATE_CURRENT_TEMPLATE: Final[str] = (
56+
f"sum by (kernel_id,container_metric_name)(rate("
57+
f"{CONTAINER_UTILIZATION_METRIC_NAME}"
58+
'{{kernel_id=~"{kernel_ids}",container_metric_name=~"{metric_names}",value_type="{value_type}"}}'
59+
"[{window}]))"
9360
)
94-
_RATE_LIVE_STAT_SPEC: Final[_LiveStatQuerySpec] = _LiveStatQuerySpec(
95-
template=_RATE_TEMPLATE,
96-
metric_name_filter=RATE_METRICS,
97-
value_type_filter=ValueType.CURRENT,
61+
_LIVE_STAT_MAX_TEMPLATE: Final[str] = (
62+
"max_over_time(("
63+
f"sum by (kernel_id,container_metric_name)({CONTAINER_UTILIZATION_METRIC_NAME}"
64+
'{{kernel_id=~"{kernel_ids}",value_type="{value_type}"}}'
65+
"))[{window}:])"
9866
)
99-
100-
_MAX_STATS_BUCKET: Final[_StatsBucket] = _StatsBucket(
101-
value_type=ValueType.MAX,
102-
gauge_metrics=STATS_MAX_GAUGE_METRICS,
103-
rate_metrics=STATS_MAX_OVER_RATE_METRICS,
104-
gauge_metric_patterns=STATS_MAX_GAUGE_METRIC_PATTERNS,
67+
_LIVE_STAT_RATE_MAX_TEMPLATE: Final[str] = (
68+
f"max_over_time(({_LIVE_STAT_RATE_CURRENT_TEMPLATE})[{{window}}:])"
69+
)
70+
_LIVE_STAT_AVG_TEMPLATE: Final[str] = (
71+
"avg_over_time(("
72+
f"sum by (kernel_id,container_metric_name)({CONTAINER_UTILIZATION_METRIC_NAME}"
73+
'{{kernel_id=~"{kernel_ids}",value_type="{value_type}"}}'
74+
"))[{window}:])"
10575
)
106-
_AVG_STATS_BUCKET: Final[_StatsBucket] = _StatsBucket(
107-
value_type=ValueType.AVG,
108-
gauge_metrics=STATS_AVG_GAUGE_METRICS,
109-
rate_metrics=STATS_AVG_OVER_RATE_METRICS,
110-
gauge_metric_patterns=STATS_AVG_GAUGE_METRIC_PATTERNS,
76+
_LIVE_STAT_RATE_AVG_TEMPLATE: Final[str] = (
77+
f"avg_over_time(({_LIVE_STAT_RATE_CURRENT_TEMPLATE})[{{window}}:])"
11178
)
11279

11380

@@ -160,126 +127,51 @@ def get_container_live_stat_queries(
160127
self,
161128
kernel_ids: Sequence[KernelId],
162129
) -> ContainerLiveStatQueries:
163-
return ContainerLiveStatQueries(
164-
gauge=self._build_filtered_preset(kernel_ids, _GAUGE_LIVE_STAT_SPEC),
165-
diff=self._build_filtered_preset(kernel_ids, _DIFF_LIVE_STAT_SPEC),
166-
rate=self._build_filtered_preset(kernel_ids, _RATE_LIVE_STAT_SPEC),
167-
max=self._build_window_stats_preset(kernel_ids, _MAX_STATS_BUCKET),
168-
avg=self._build_window_stats_preset(kernel_ids, _AVG_STATS_BUCKET),
169-
rate_stats=self._build_rate_stats_preset(kernel_ids),
170-
)
171-
172-
def _build_rate_stats_preset(
173-
self,
174-
kernel_ids: Sequence[KernelId],
175-
) -> MetricPreset:
176130
kernel_id_regex = _regex_union([str(kid) for kid in kernel_ids])
177-
group_by = ",".join(sorted(_LIVE_STAT_GROUP_BY))
178-
parts: list[str] = []
179-
if STATS_RATE_GAUGE_METRICS:
180-
gauge_regex = _regex_union(sorted(STATS_RATE_GAUGE_METRICS))
181-
selector = self._utilization_selector(kernel_id_regex, gauge_regex)
182-
parts.append(self._labelled_sum(selector, group_by, ValueType.RATE))
183-
if STATS_RATE_COUNTER_METRICS:
184-
counter_regex = _regex_union(sorted(STATS_RATE_COUNTER_METRICS))
185-
base = self._utilization_selector(kernel_id_regex, counter_regex)
186-
selector = f"rate({base}[{self._timewindow}])"
187-
parts.append(self._labelled_sum(selector, group_by, ValueType.RATE))
188-
return MetricPreset(template=" or ".join(parts))
189131

190-
def _labelled_sum(self, selector: str, group_by: str, stat_label: ValueType) -> str:
191-
return (
192-
f"label_replace(sum by ({group_by})({selector}),"
193-
f'"value_type","{stat_label}","value_type",".*")'
132+
instant_query = _LIVE_STAT_INSTANT_TEMPLATE.format(
133+
kernel_ids=kernel_id_regex,
134+
value_types=_value_type_regex([
135+
ValueType.CURRENT,
136+
ValueType.CAPACITY,
137+
]),
194138
)
195-
196-
def _build_window_stats_preset(
197-
self,
198-
kernel_ids: Sequence[KernelId],
199-
bucket: _StatsBucket,
200-
) -> MetricPreset:
201-
kernel_id_regex = _regex_union([str(kid) for kid in kernel_ids])
202-
group_by = ",".join(sorted(_LIVE_STAT_GROUP_BY))
203-
return MetricPreset(
204-
template=self._render_stats_query(
205-
bucket,
206-
kernel_id_regex=kernel_id_regex,
207-
group_by=group_by,
208-
)
139+
rate_current_query = _LIVE_STAT_RATE_CURRENT_TEMPLATE.format(
140+
kernel_ids=kernel_id_regex,
141+
metric_names=_regex_union(sorted(_LIVE_STAT_RATE_METRICS)),
142+
value_type=ValueType.CURRENT.value,
143+
window=self._timewindow,
209144
)
210-
211-
def _build_filtered_preset(
212-
self,
213-
kernel_ids: Sequence[KernelId],
214-
spec: _LiveStatQuerySpec,
215-
) -> MetricPreset:
216-
labels: dict[str, LabelMatcher] = {
217-
"kernel_id": LabelMatcher.regex(_regex_union([str(kid) for kid in kernel_ids]))
218-
}
219-
if spec.metric_name_filter is not None:
220-
labels["container_metric_name"] = LabelMatcher.regex(
221-
_regex_union(sorted(spec.metric_name_filter))
222-
)
223-
if spec.value_type_filter is not None:
224-
labels["value_type"] = LabelMatcher.exact(spec.value_type_filter.value)
225-
226-
return MetricPreset(
227-
template=spec.template,
228-
group_by=_LIVE_STAT_GROUP_BY,
229-
labels=labels,
145+
max_query = _LIVE_STAT_MAX_TEMPLATE.format(
146+
kernel_ids=kernel_id_regex,
147+
value_type=ValueType.CURRENT.value,
230148
window=self._timewindow,
231149
)
232-
233-
def _render_stats_query(
234-
self,
235-
bucket: _StatsBucket,
236-
*,
237-
kernel_id_regex: str,
238-
group_by: str,
239-
) -> str:
240-
stat_fn = f"{bucket.value_type}_over_time"
241-
parts: list[str] = []
242-
if bucket.gauge_metrics or bucket.gauge_metric_patterns:
243-
gauge_regex = _metric_name_regex(bucket.gauge_metrics, bucket.gauge_metric_patterns)
244-
selector = self._utilization_selector(kernel_id_regex, gauge_regex)
245-
parts.append(self._window_stat_subquery(stat_fn, selector, group_by, bucket.value_type))
246-
if bucket.rate_metrics:
247-
rate_regex = _regex_union(sorted(bucket.rate_metrics))
248-
base = self._utilization_selector(kernel_id_regex, rate_regex)
249-
selector = f"rate({base}[{self._timewindow}])"
250-
parts.append(self._window_stat_subquery(stat_fn, selector, group_by, bucket.value_type))
251-
return " or ".join(parts)
252-
253-
def _utilization_selector(self, kernel_id_regex: str, metric_name_regex: str) -> str:
254-
labels = self._live_stat_current_labels(
255-
kernel_id_regex=kernel_id_regex,
256-
metric_name_regex=metric_name_regex,
150+
rate_max_query = _LIVE_STAT_RATE_MAX_TEMPLATE.format(
151+
kernel_ids=kernel_id_regex,
152+
metric_names=_regex_union(sorted(_LIVE_STAT_RATE_METRICS)),
153+
value_type=ValueType.CURRENT.value,
154+
window=self._timewindow,
257155
)
258-
return f"{CONTAINER_UTILIZATION_METRIC_NAME}{{{labels}}}"
259-
260-
def _window_stat_subquery(
261-
self,
262-
stat_fn: str,
263-
selector: str,
264-
group_by: str,
265-
stat_label: ValueType,
266-
) -> str:
267-
return (
268-
f"label_replace("
269-
f"{stat_fn}((sum by ({group_by})({selector}))[{self._timewindow}:]),"
270-
f'"value_type","{stat_label}","value_type",".*")'
156+
avg_query = _LIVE_STAT_AVG_TEMPLATE.format(
157+
kernel_ids=kernel_id_regex,
158+
value_type=ValueType.CURRENT.value,
159+
window=self._timewindow,
160+
)
161+
rate_avg_query = _LIVE_STAT_RATE_AVG_TEMPLATE.format(
162+
kernel_ids=kernel_id_regex,
163+
metric_names=_regex_union(sorted(_LIVE_STAT_RATE_METRICS)),
164+
value_type=ValueType.CURRENT.value,
165+
window=self._timewindow,
271166
)
272167

273-
def _live_stat_current_labels(
274-
self,
275-
*,
276-
kernel_id_regex: str,
277-
metric_name_regex: str,
278-
) -> str:
279-
return (
280-
f'kernel_id=~"{kernel_id_regex}"'
281-
f',container_metric_name=~"{metric_name_regex}"'
282-
f',value_type="{ValueType.CURRENT}"'
168+
return ContainerLiveStatQueries(
169+
instant=MetricPreset(template=instant_query),
170+
rate_current=MetricPreset(template=rate_current_query),
171+
max=MetricPreset(template=max_query),
172+
rate_max=MetricPreset(template=rate_max_query),
173+
avg=MetricPreset(template=avg_query),
174+
rate_avg=MetricPreset(template=rate_avg_query),
283175
)
284176

285177
def _get_template(self, metric_type: MetricType) -> str:

0 commit comments

Comments
 (0)