Skip to content

Commit 9a21d89

Browse files
authored
Tiny add metrics for prefill delayer (sgl-project#16603)
1 parent 4c9ac85 commit 9a21d89

4 files changed

Lines changed: 93 additions & 9 deletions

File tree

python/sglang/srt/managers/prefill_delayer.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,47 @@
11
import logging
2-
from typing import Optional
2+
import time
3+
from dataclasses import dataclass, field
4+
from typing import TYPE_CHECKING, Optional
35

46
import torch
57

68
from sglang.srt.environ import envs
79
from sglang.srt.utils import get_bool_env_var
810

11+
if TYPE_CHECKING:
12+
from sglang.srt.metrics.collector import SchedulerMetricsCollector
13+
914
_DEBUG_LOG = get_bool_env_var("SGLANG_PREFILL_DELAYER_DEBUG_LOG")
1015

1116
logger = logging.getLogger(__name__)
1217

1318

19+
@dataclass
20+
class _DelayInfo:
21+
delayed_count: int = 0
22+
start_time: float = field(default_factory=time.perf_counter)
23+
24+
1425
class PrefillDelayer:
15-
def __init__(self, dp_size, attn_tp_size, tp_worker, server_args):
26+
def __init__(
27+
self,
28+
dp_size,
29+
attn_tp_size,
30+
tp_worker,
31+
server_args,
32+
metrics_collector: Optional["SchedulerMetricsCollector"] = None,
33+
):
1634
self.global_info = torch.empty(
1735
(dp_size, attn_tp_size, 1),
1836
dtype=torch.int64,
1937
device="cpu",
2038
)
2139
self.cpu_group = tp_worker.get_tp_group().cpu_group
2240

23-
self.curr_delayed_count = 0
2441
self.max_delay_passes = envs.SGLANG_PREFILL_DELAYER_MAX_DELAY_PASSES.get()
42+
self._metrics_collector = metrics_collector
43+
44+
self._curr_delay_info: Optional[_DelayInfo] = None
2545

2646
assert (
2747
server_args.enable_dp_attention
@@ -43,18 +63,31 @@ def _negotiate_should_allow_prefill(self, local_prefillable: bool) -> bool:
4363
)
4464

4565
if global_mixed_prefillable:
46-
self.curr_delayed_count += 1
47-
if self.curr_delayed_count < self.max_delay_passes:
66+
if self._curr_delay_info is None:
67+
self._curr_delay_info = _DelayInfo()
68+
self._curr_delay_info.delayed_count += 1
69+
if self._curr_delay_info.delayed_count < self.max_delay_passes:
4870
return False
4971

50-
if _DEBUG_LOG and global_mixed_prefillable:
72+
is_timeout = global_mixed_prefillable
73+
if _DEBUG_LOG and is_timeout:
5174
logger.info(
5275
f"PrefillDelayer timeout thus not forbid prefill (prefillable: {global_prefillable.sum()})"
5376
)
5477

55-
self.curr_delayed_count = 0
78+
self._record_metrics(is_timeout=is_timeout)
79+
self._curr_delay_info = None
5680
return True
5781

82+
def _record_metrics(self, is_timeout: bool) -> None:
83+
if self._curr_delay_info is not None and self._metrics_collector is not None:
84+
wait_seconds = time.perf_counter() - self._curr_delay_info.start_time
85+
self._metrics_collector.observe_prefill_delayer_wait(
86+
forward_passes=self._curr_delay_info.delayed_count,
87+
wait_seconds=wait_seconds,
88+
is_timeout=is_timeout,
89+
)
90+
5891
def _gather_info(self, local_prefillable: bool):
5992
local_info = torch.tensor(
6093
[int(local_prefillable)],

python/sglang/srt/managers/scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,9 @@ def init_schedule_policy(self):
806806
attn_tp_size=self.attn_tp_size,
807807
tp_worker=self.tp_worker,
808808
server_args=self.server_args,
809+
metrics_collector=(
810+
self.metrics_collector if self.enable_metrics else None
811+
),
809812
)
810813
# Enable preemption for priority scheduling.
811814
self.try_preemption = self.enable_priority_scheduling

python/sglang/srt/metrics/collector.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,25 @@ def __init__(
761761
labelnames=list(labels.keys()) + ["category", "num_prefill_ranks"],
762762
)
763763

764+
max_delay_passes = envs.SGLANG_PREFILL_DELAYER_MAX_DELAY_PASSES.get()
765+
self.prefill_delayer_wait_forward_passes = Histogram(
766+
name="sglang:prefill_delayer_wait_forward_passes",
767+
documentation="Histogram of forward passes waited by prefill delayer.",
768+
labelnames=labels.keys(),
769+
buckets=[5, 20, max_delay_passes - 1],
770+
)
771+
self.prefill_delayer_wait_seconds = Histogram(
772+
name="sglang:prefill_delayer_wait_seconds",
773+
documentation="Histogram of wait time in seconds by prefill delayer.",
774+
labelnames=labels.keys(),
775+
buckets=[5, 20, 100, 500],
776+
)
777+
self.prefill_delayer_timeouts_total = Counter(
778+
name="sglang:prefill_delayer_timeouts_total",
779+
documentation="Total number of prefill delayer timeouts.",
780+
labelnames=labels.keys(),
781+
)
782+
764783
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
765784
# Convenience function for logging to gauge.
766785
gauge.labels(**self.labels).set(data)
@@ -781,6 +800,14 @@ def observe_per_stage_req_latency(self, stage: str, latency: float) -> None:
781800
def observe_queue_time(self, latency: float) -> None:
782801
self._log_histogram(self.queue_time, latency)
783802

803+
def observe_prefill_delayer_wait(
804+
self, forward_passes: int, wait_seconds: float, is_timeout: bool
805+
) -> None:
806+
self._log_histogram(self.prefill_delayer_wait_forward_passes, forward_passes)
807+
self._log_histogram(self.prefill_delayer_wait_seconds, wait_seconds)
808+
if is_timeout:
809+
self.prefill_delayer_timeouts_total.labels(**self.labels).inc(1)
810+
784811
def increment_retracted_reqs(
785812
self,
786813
num_retracted_reqs: int,

test/srt/test_prefill_delayer.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import unittest
33
from types import SimpleNamespace
44

5+
import requests
6+
57
from sglang.bench_serving import run_benchmark
68
from sglang.srt.environ import envs
79
from sglang.srt.utils import kill_process_tree
@@ -88,6 +90,7 @@ def _run_throughput_test(
8890
**other_benchmark_args,
8991
)
9092
res = run_benchmark(args)
93+
_print_prefill_delayer_metrics(base_url, expect_metrics=prefill_delayer)
9194
finally:
9295
kill_process_tree(process.pid)
9396

@@ -137,6 +140,7 @@ def _run_accuracy_test(self, prefill_delayer: bool):
137140

138141
def _launch_server(*, model, base_url, prefill_delayer: bool, other_args):
139142
os.environ["SGLANG_PREFILL_DELAYER_DEBUG_LOG"] = "1"
143+
world_size = os.environ.get("SGLANG_TEST_WORLD_SIZE", "8")
140144

141145
with envs.SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE.override(
142146
prefill_delayer
@@ -148,18 +152,35 @@ def _launch_server(*, model, base_url, prefill_delayer: bool, other_args):
148152
other_args=[
149153
"--trust-remote-code",
150154
"--tp",
151-
"8",
155+
world_size,
152156
"--enable-dp-attention",
153157
"--dp",
154-
"8",
158+
world_size,
155159
"--chunked-prefill-size",
156160
"131072",
157161
"--mem-fraction-static",
158162
"0.6",
163+
"--enable-metrics",
159164
*(other_args or []),
160165
],
161166
)
162167

163168

169+
def _print_prefill_delayer_metrics(base_url: str, expect_metrics: bool):
170+
metrics_response = requests.get(f"{base_url}/metrics")
171+
assert metrics_response.status_code == 200
172+
metrics_text = metrics_response.text
173+
prefill_delayer_metrics = [
174+
line for line in metrics_text.split("\n") if "prefill_delayer" in line
175+
]
176+
print("=== PrefillDelayer Metrics ===")
177+
for line in prefill_delayer_metrics:
178+
print(line)
179+
if expect_metrics:
180+
assert "sglang:prefill_delayer_wait_forward_passes" in metrics_text
181+
assert "sglang:prefill_delayer_wait_seconds" in metrics_text
182+
assert "sglang:prefill_delayer_timeouts_total" in metrics_text
183+
184+
164185
if __name__ == "__main__":
165186
unittest.main()

0 commit comments

Comments
 (0)