Skip to content

Commit 21201ef

Browse files
authored
Add kv_canary PP self-test fixture and SWA divergence coverage (#27410)
1 parent b3e4c20 commit 21201ef

8 files changed

Lines changed: 208 additions & 21 deletions

File tree

python/sglang/srt/kv_canary/runner/swa_divergence.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ def find_last(cls, text: str) -> Optional[tuple["SwaDivergenceLog", str]]:
149149
return None
150150
return cls(**json.loads(last_match.group(1))), last_match.group(0)
151151

152+
@classmethod
153+
def find_all(cls, text: str) -> list[tuple["SwaDivergenceLog", str]]:
154+
return [
155+
(cls(**json.loads(match.group(1))), match.group(0))
156+
for match in _SWA_DIVERGENCE_LINE_RE.finditer(text)
157+
]
158+
152159

153160
def compute_swa_out_of_window_tokens(
154161
*,

python/sglang/test/kv_canary/e2e_base.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,17 @@ def send_parallel_requests(
126126
assert_all_success: bool = True,
127127
max_new_tokens: int = 2048,
128128
timeout: float = 240.0,
129+
ignore_eos: Optional[bool] = None,
129130
) -> list[dict]:
130131
"""Fan out n parallel /generate requests; return list of response dicts."""
132+
if ignore_eos is None:
133+
ignore_eos = self.model_mode == "swa"
131134
results = post_parallel_generate(
132135
url=self.base_url + "/generate",
133136
prompts=self.make_prompts(n),
134137
max_new_tokens=max_new_tokens,
135138
timeout=timeout,
139+
ignore_eos=ignore_eos,
136140
)
137141
if assert_all_success:
138142
for result in results:
@@ -155,44 +159,54 @@ def assert_swa_divergence_observed(
155159
"""Assert that the SWA path was genuinely exercised.
156160
157161
Three signals must all hold:
158-
- ``swa_out_of_window_tokens >= 1``: at least one prefix token has been clipped
159-
out of the sliding window (its SWA mapping is 0). Any prompt longer than the
160-
SWA window produces this — proves the SWA window slide actually ran.
162+
- ``swa_out_of_window_tokens >= 1``: at least one token has slid out of the
163+
sliding window (its SWA mapping is 0). This only appears once a request decodes
164+
past the window, so the window evicts — proves the SWA window slide actually ran.
161165
- ``swa_full_idx_divergence >= 1``: SWA pool has actually remapped at least one
162166
slot to a non-identity index (i.e. real slot reuse / eviction occurred). The
163167
workload must drive SWA pool pressure for this to fire — required because the
164168
"pool reuse" path is the one production hits under sustained long-context
165169
traffic, and we must keep it covered.
166170
- ``verify_swa < verify_full``: SWA verify kernel processed fewer tokens than
167171
FULL — proves both kernel groups ran and the window short-circuited SWA.
172+
173+
The first two signals are checked as the *peak* across all sampled forwards, not
174+
only the last sample. The divergence reporter snapshots one live forward batch per
175+
interval; under PP it snapshots a single micro-batch, which may hold only in-window
176+
requests even when another micro-batch diverged. "Was the SWA path ever exercised?"
177+
is a max-over-samples question, so a trailing in-window sample must not mask an
178+
earlier diverging one. ``verify_swa``/``verify_full`` are monotonic running totals,
179+
so the lag check reads the last sample.
168180
"""
169-
last_parsed = None
170-
last_line: str = ""
181+
samples: list[tuple[SwaDivergenceLog, str]] = []
171182
for _ in range(max_retries):
172183
time.sleep(flush_wait_seconds)
173-
log_text = self._captured_log_text()
174-
found = SwaDivergenceLog.find_last(log_text)
175-
if found is not None:
176-
last_parsed, last_line = found
184+
samples = SwaDivergenceLog.find_all(self._captured_log_text())
185+
if samples:
177186
break
178187

179-
if last_parsed is None:
188+
if not samples:
180189
raise AssertionError(
181190
"No kv_canary swa_divergence line found in server log after "
182191
f"{max_retries} retries (wait={flush_wait_seconds}s each). "
183192
f"Log tail:\n{self._captured_log_text()[-2000:]}"
184193
)
185194

186-
if last_parsed.swa_out_of_window_tokens < min_swa_out_of_window_tokens:
195+
peak_out_of_window = max(p.swa_out_of_window_tokens for p, _ in samples)
196+
peak_full_idx_divergence = max(p.swa_full_idx_divergence for p, _ in samples)
197+
last_parsed, last_line = samples[-1]
198+
199+
if peak_out_of_window < min_swa_out_of_window_tokens:
187200
raise AssertionError(
188-
f"SWA path not exercised: swa_out_of_window_tokens={last_parsed.swa_out_of_window_tokens} "
189-
f"< min={min_swa_out_of_window_tokens}. Line: {last_line}"
201+
f"SWA path not exercised: peak swa_out_of_window_tokens={peak_out_of_window} "
202+
f"< min={min_swa_out_of_window_tokens} across {len(samples)} samples. "
203+
f"Last line: {last_line}"
190204
)
191-
if last_parsed.swa_full_idx_divergence < min_swa_full_idx_divergence:
205+
if peak_full_idx_divergence < min_swa_full_idx_divergence:
192206
raise AssertionError(
193-
f"SWA pool reuse not exercised: swa_full_idx_divergence={last_parsed.swa_full_idx_divergence} "
194-
f"< min={min_swa_full_idx_divergence}. The workload did not drive enough SWA pool pressure "
195-
f"to force slot remap. Line: {last_line}"
207+
f"SWA pool reuse not exercised: peak swa_full_idx_divergence={peak_full_idx_divergence} "
208+
f"< min={min_swa_full_idx_divergence} across {len(samples)} samples. The workload "
209+
f"did not drive enough SWA pool pressure to force slot remap. Last line: {last_line}"
196210
)
197211
if require_verify_lag and not (
198212
last_parsed.verify_swa < last_parsed.verify_full
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from typing import ClassVar
4+
5+
from sglang.test.kv_canary.consts import SWA_POOL_SERVER_ARGS
6+
from sglang.test.kv_canary.e2e_base import CanaryE2EBase
7+
8+
PP_SIZE: int = 2
9+
10+
11+
class CanaryPPFixture(CanaryE2EBase):
12+
13+
model_mode: ClassVar[str] = "swa"
14+
workload_n_batches: ClassVar[int] = 2
15+
16+
@classmethod
17+
def setUpClass(cls) -> None:
18+
cls.extra_server_args = (
19+
"--pp-size",
20+
str(PP_SIZE),
21+
"--disable-cuda-graph",
22+
*SWA_POOL_SERVER_ARGS,
23+
*cls.extra_server_args,
24+
)
25+
super().setUpClass()

python/sglang/test/kv_canary/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def post_parallel_generate(
3333
prompts: list[str],
3434
max_new_tokens: int,
3535
timeout: float,
36+
ignore_eos: bool = False,
3637
) -> list[dict]:
3738
def _send(prompt: str) -> dict:
3839
try:
@@ -43,6 +44,7 @@ def _send(prompt: str) -> dict:
4344
"sampling_params": {
4445
"max_new_tokens": max_new_tokens,
4546
"temperature": 0.0,
47+
"ignore_eos": ignore_eos,
4648
},
4749
},
4850
timeout=timeout,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
from sglang.srt.kv_canary.config import CanaryMode
6+
from sglang.test.ci.ci_register import register_cuda_ci
7+
from sglang.test.kv_canary.pp_fixture import CanaryPPFixture
8+
9+
register_cuda_ci(est_time=220, stage="extra-a", runner_config="2-gpu-large")
10+
11+
12+
class TestPPBaselineSwa(CanaryPPFixture):
13+
14+
kv_canary_mode = CanaryMode.LOG
15+
16+
def test_no_violation(self) -> None:
17+
for _ in range(self.workload_n_batches):
18+
self.send_parallel_requests()
19+
self.assert_no_violation(wait_seconds=2.0)
20+
self.maybe_assert_swa_divergence_observed()
21+
22+
23+
if __name__ == "__main__":
24+
unittest.main()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
from typing import ClassVar
5+
6+
from sglang.srt.kv_canary.config import CanaryMode
7+
from sglang.srt.kv_canary.perturb.config import TargetGroupKind
8+
from sglang.test.ci.ci_register import register_cuda_ci
9+
from sglang.test.kv_canary.pp_fixture import CanaryPPFixture
10+
11+
register_cuda_ci(est_time=220, stage="extra-a", runner_config="2-gpu-large")
12+
13+
14+
class TestPPPerturbSwaSwa(CanaryPPFixture):
15+
16+
kv_canary_mode = CanaryMode.LOG
17+
target_group: ClassVar[TargetGroupKind] = TargetGroupKind.SWA
18+
extra_server_args = ("--kv-canary-real-data", "partial")
19+
20+
@classmethod
21+
def setUpClass(cls) -> None:
22+
cls.extra_env = {
23+
"SGLANG_KV_CANARY_PERTURB_REAL_KV_USED_PROB": "0.1",
24+
"SGLANG_KV_CANARY_PERTURB_TARGET_GROUP": str(cls.target_group),
25+
"SGLANG_KV_CANARY_PERTURB_WARMUP_STEPS": "0",
26+
}
27+
super().setUpClass()
28+
29+
def test_real_kv_used_perturbation_reports_real_kv_hash_violation(self) -> None:
30+
for _ in range(self.workload_n_batches):
31+
self.send_parallel_requests()
32+
self.assert_per_forward_violation_reported(
33+
fail_reason="verify_real_kv_hash",
34+
target_group=self.target_group,
35+
)
36+
self.maybe_assert_swa_divergence_observed()
37+
38+
39+
if __name__ == "__main__":
40+
unittest.main()

test/registered/kv_canary/test_self_unit_e2e_base.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,58 @@ def test_assert_swa_divergence_observed_passes_when_above_threshold(self) -> Non
6969
max_retries=1,
7070
)
7171

72-
def test_assert_swa_divergence_observed_uses_latest_line(self) -> None:
73-
log = _GOOD_LINE + "\n" + _LATER_LINE + "\n"
74-
harness, patcher = self._make_harness(log)
72+
def test_assert_swa_divergence_observed_uses_peak_out_of_window(self) -> None:
73+
diverged = SwaDivergenceLog(
74+
forward_ct=120,
75+
verify_full=10000,
76+
verify_swa=4200,
77+
swa_full_idx_divergence=512,
78+
swa_out_of_window_tokens=8192,
79+
).format()
80+
trailing_zero = SwaDivergenceLog(
81+
forward_ct=240,
82+
verify_full=20000,
83+
verify_swa=8400,
84+
swa_full_idx_divergence=0,
85+
swa_out_of_window_tokens=0,
86+
).format()
87+
harness, patcher = self._make_harness(diverged + "\n" + trailing_zero + "\n")
7588
with patcher:
7689
harness.assert_swa_divergence_observed(
77-
min_swa_full_idx_divergence=1000,
90+
min_swa_out_of_window_tokens=1,
91+
min_swa_full_idx_divergence=1,
7892
require_verify_lag=True,
7993
flush_wait_seconds=0.0,
8094
max_retries=1,
8195
)
8296

97+
def test_assert_swa_divergence_observed_checks_verify_lag_on_latest_line(
98+
self,
99+
) -> None:
100+
lagging = SwaDivergenceLog(
101+
forward_ct=120,
102+
verify_full=10000,
103+
verify_swa=4200,
104+
swa_full_idx_divergence=512,
105+
swa_out_of_window_tokens=8192,
106+
).format()
107+
no_lag = SwaDivergenceLog(
108+
forward_ct=240,
109+
verify_full=20000,
110+
verify_swa=20000,
111+
swa_full_idx_divergence=1024,
112+
swa_out_of_window_tokens=16384,
113+
).format()
114+
harness, patcher = self._make_harness(lagging + "\n" + no_lag + "\n")
115+
with patcher:
116+
with self.assertRaisesRegex(AssertionError, "verify_swa=20000"):
117+
harness.assert_swa_divergence_observed(
118+
min_swa_full_idx_divergence=1,
119+
require_verify_lag=True,
120+
flush_wait_seconds=0.0,
121+
max_retries=1,
122+
)
123+
83124
def test_assert_swa_divergence_observed_raises_when_below_threshold(self) -> None:
84125
zero_mapping_line = SwaDivergenceLog(
85126
forward_ct=100,

test/registered/kv_canary/test_self_unit_runner_swa_divergence.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,40 @@ def test_swa_divergence_report_emits_swa_full_idx_divergence_from_compute(
410410
self.assertEqual(parsed.verify_swa, 3)
411411

412412

413+
class TestSwaDivergenceLogFindAll(CustomTestCase):
414+
def test_find_all_returns_every_sample_in_order(self) -> None:
415+
text = "\n".join(
416+
SwaDivergenceLog(
417+
forward_ct=ct,
418+
verify_full=100 * ct,
419+
verify_swa=10 * ct,
420+
swa_full_idx_divergence=ct,
421+
swa_out_of_window_tokens=0,
422+
).format()
423+
for ct in (20, 40, 60)
424+
)
425+
parsed = SwaDivergenceLog.find_all(text)
426+
self.assertEqual([p.forward_ct for p, _ in parsed], [20, 40, 60])
427+
428+
def test_find_all_peak_survives_trailing_zero_sample(self) -> None:
429+
text = "\n".join(
430+
SwaDivergenceLog(
431+
forward_ct=ct,
432+
verify_full=1,
433+
verify_swa=0,
434+
swa_full_idx_divergence=1,
435+
swa_out_of_window_tokens=oow,
436+
).format()
437+
for ct, oow in ((20, 0), (40, 4080), (60, 0))
438+
)
439+
parsed = SwaDivergenceLog.find_all(text)
440+
self.assertEqual(max(p.swa_out_of_window_tokens for p, _ in parsed), 4080)
441+
self.assertEqual(parsed[-1][0].swa_out_of_window_tokens, 0)
442+
443+
def test_find_all_returns_empty_list_when_no_lines(self) -> None:
444+
self.assertEqual(SwaDivergenceLog.find_all("nothing here\n"), [])
445+
446+
413447
class TestCanaryManagerSwaDivergenceWiring(CanaryManagerTestCase):
414448
def test_swa_divergence_report_is_none_when_env_disabled(self) -> None:
415449
with envs.SGLANG_KV_CANARY_SWA_DIVERGENCE_STATS_INTERVAL.override(

0 commit comments

Comments
 (0)