Skip to content

Commit 1368717

Browse files
authored
Add more testing for chunked prefill (#27506)
1 parent 609f5f5 commit 1368717

42 files changed

Lines changed: 7928 additions & 21 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from __future__ import annotations
2+
3+
import time
4+
from types import SimpleNamespace
5+
from typing import ClassVar, List, Optional
6+
7+
from sglang.srt.utils import kill_process_tree
8+
from sglang.test.run_eval import run_eval
9+
from sglang.test.server_fixtures.disaggregation_fixture import (
10+
PDDisaggregationServerBase,
11+
)
12+
from sglang.test.test_utils import (
13+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
14+
DEFAULT_URL_FOR_TEST,
15+
CustomTestCase,
16+
popen_launch_server,
17+
try_cached_model,
18+
)
19+
20+
DEFAULT_MODEL: str = "Qwen/Qwen3-0.6B"
21+
22+
DEFAULT_CHUNKED_PREFILL_SIZE: int = 256
23+
DEFAULT_NUM_EXAMPLES: int = 100
24+
DEFAULT_NUM_SHOTS: int = 10
25+
LONG_PROMPT_NUM_SHOTS: int = 24
26+
DEFAULT_NUM_THREADS: int = 128
27+
DEFAULT_MAX_TOKENS: int = 512
28+
DEFAULT_SEED: int = 42
29+
30+
KV_CANARY_ARGS: List[str] = [
31+
"--kv-canary",
32+
"raise",
33+
"--kv-canary-real-data",
34+
"partial",
35+
"--kv-canary-sweep-interval",
36+
"100",
37+
"--disable-piecewise-cuda-graph",
38+
]
39+
40+
41+
class ChunkedGsm8kMixin:
42+
__test__ = False
43+
use_kv_canary: ClassVar[bool] = True
44+
model: ClassVar[str] = DEFAULT_MODEL
45+
feature_args: ClassVar[List[str]] = []
46+
47+
chunked_prefill_size: ClassVar[int] = DEFAULT_CHUNKED_PREFILL_SIZE
48+
num_shots: ClassVar[int] = DEFAULT_NUM_SHOTS
49+
num_examples: ClassVar[int] = DEFAULT_NUM_EXAMPLES
50+
num_threads: ClassVar[int] = DEFAULT_NUM_THREADS
51+
max_tokens: ClassVar[int] = DEFAULT_MAX_TOKENS
52+
gsm8k_threshold: ClassVar[float]
53+
54+
def build_prefill_side_args(self) -> List[str]:
55+
canary = list(KV_CANARY_ARGS) if self.use_kv_canary else []
56+
return (
57+
["--chunked-prefill-size", str(self.chunked_prefill_size)]
58+
+ list(self.feature_args)
59+
+ canary
60+
)
61+
62+
def test_mixed_prefix_gsm8k_chunked(self):
63+
fixture_name = type(self).__name__
64+
65+
args = SimpleNamespace(
66+
base_url=self.base_url,
67+
model=self.model,
68+
eval_name="mixed_prefix_gsm8k",
69+
api="chat_completion",
70+
max_tokens=self.max_tokens,
71+
num_examples=self.num_examples,
72+
num_threads=self.num_threads,
73+
num_shots=self.num_shots,
74+
mixed_prefix_gsm8k_secondary_pool_size=15,
75+
mixed_prefix_gsm8k_seed=DEFAULT_SEED,
76+
gsm8k_data_path=None,
77+
temperature=0.0,
78+
)
79+
tic = time.perf_counter()
80+
metrics = run_eval(args)
81+
metrics["elapsed_sec"] = time.perf_counter() - tic
82+
print(f"[{fixture_name}] {metrics} threshold={self.gsm8k_threshold:.4f}")
83+
84+
score = metrics.get("score")
85+
self.assertIsNotNone(score, "run_eval returned no score")
86+
self.assertGreaterEqual(score, self.gsm8k_threshold)
87+
88+
89+
class ChunkedTestBase(ChunkedGsm8kMixin, CustomTestCase):
90+
__test__ = False
91+
92+
base_url: ClassVar[str] = DEFAULT_URL_FOR_TEST
93+
launch_timeout: ClassVar[int] = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
94+
95+
process: ClassVar[Optional[object]] = None
96+
97+
@classmethod
98+
def setUpClass(cls):
99+
cls.process = popen_launch_server(
100+
cls.model,
101+
cls.base_url,
102+
timeout=cls.launch_timeout,
103+
other_args=cls("test_mixed_prefix_gsm8k_chunked").build_prefill_side_args(),
104+
)
105+
106+
@classmethod
107+
def tearDownClass(cls):
108+
if cls.process is not None:
109+
kill_process_tree(cls.process.pid)
110+
111+
112+
class ChunkedTestPDBase(ChunkedGsm8kMixin, PDDisaggregationServerBase):
113+
__test__ = False
114+
decode_feature_args: ClassVar[List[str]] = []
115+
116+
@classmethod
117+
def setUpClass(cls):
118+
cls.extra_prefill_args = cls(
119+
"test_mixed_prefix_gsm8k_chunked"
120+
).build_prefill_side_args()
121+
canary = list(KV_CANARY_ARGS) if cls.use_kv_canary else []
122+
cls.extra_decode_args = canary + list(cls.decode_feature_args)
123+
PDDisaggregationServerBase.setUpClass()
124+
cls.model = try_cached_model(cls.model)
125+
cls.launch_all()
126+
127+
@classmethod
128+
def tearDownClass(cls):
129+
PDDisaggregationServerBase.tearDownClass()

python/sglang/test/scripted_runtime/context/api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def start_req(
6464
return_logprob: bool = False,
6565
logprob_start_len: Optional[int] = None,
6666
top_logprobs_num: Optional[int] = None,
67+
stop_token_ids: Optional[List[int]] = None,
68+
temperature: Optional[float] = None,
6769
lora_path: Optional[str] = None,
6870
) -> "ScriptedReqHandle":
6971
return self._req_starter.start_req(
@@ -77,6 +79,8 @@ def start_req(
7779
return_logprob=return_logprob,
7880
logprob_start_len=logprob_start_len,
7981
top_logprobs_num=top_logprobs_num,
82+
stop_token_ids=stop_token_ids,
83+
temperature=temperature,
8084
lora_path=lora_path,
8185
)
8286

@@ -89,8 +93,8 @@ def continue_generation(self, *, torch_empty_cache: bool = False) -> None:
8993
def abort_all(self) -> None:
9094
return lifecycle.abort_all(self)
9195

92-
def abort(self, handle: "ScriptedReqHandle") -> None:
93-
return lifecycle.abort(self, rid=handle.rid)
96+
def abort(self, handle: "ScriptedReqHandle", *, await_arrival: bool = True) -> None:
97+
return lifecycle.abort(self, rid=handle.rid, await_arrival=await_arrival)
9498

9599
def flush_cache(self) -> None:
96100
return lifecycle.flush_cache(self)

python/sglang/test/scripted_runtime/context/http_post.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,34 @@ def _http_post_and_await_recv_msg(
1919
predicate: Callable[[Any], bool],
2020
description: str,
2121
timeout_s: float = RECV_MSG_ARRIVAL_TIMEOUT_S,
22+
) -> None:
23+
_submit_post(ctx, path=path, json=json)
24+
ctx._tokenizer_recv_proxy.wait_until_arrived(
25+
predicate,
26+
timeout_s=timeout_s,
27+
description=description,
28+
)
29+
30+
31+
def _http_post_fire_and_forget(
32+
ctx: "ScriptedContext",
33+
*,
34+
path: str,
35+
json: Optional[Dict[str, Any]],
36+
) -> None:
37+
_submit_post(ctx, path=path, json=json)
38+
39+
40+
def _submit_post(
41+
ctx: "ScriptedContext",
42+
*,
43+
path: str,
44+
json: Optional[Dict[str, Any]],
2245
) -> None:
2346
server_args = ctx.scheduler.server_args
2447
url = f"http://{server_args.host}:{server_args.port}{path}"
2548

2649
async def _post() -> None:
27-
try:
28-
await ctx._http_poster.post(url, json)
29-
except Exception: # noqa: BLE001 — fire-and-forget background POST
30-
logger.exception("scripted_runtime: POST %s failed", path)
50+
await ctx._http_poster.post(url, json)
3151

3252
ctx._http_poster.submit_coro(_post())
33-
ctx._tokenizer_recv_proxy.wait_until_arrived(
34-
predicate,
35-
timeout_s=timeout_s,
36-
description=description,
37-
)

python/sglang/test/scripted_runtime/context/lifecycle.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,24 @@
1010
)
1111
from sglang.test.scripted_runtime.context.http_post import (
1212
_http_post_and_await_recv_msg,
13+
_http_post_fire_and_forget,
1314
)
1415

1516
if TYPE_CHECKING:
1617
from sglang.test.scripted_runtime.context.api import ScriptedContext
1718

1819

1920
def _await_control(
20-
ctx: "ScriptedContext", *, path: str, json, expect_type: type
21+
ctx: "ScriptedContext",
22+
*,
23+
path: str,
24+
json,
25+
expect_type: type,
26+
await_arrival: bool = True,
2127
) -> None:
28+
if not await_arrival:
29+
_http_post_fire_and_forget(ctx, path=path, json=json)
30+
return
2231
_http_post_and_await_recv_msg(
2332
ctx,
2433
path=path,
@@ -57,12 +66,13 @@ def abort_all(ctx: "ScriptedContext") -> None:
5766
)
5867

5968

60-
def abort(ctx: "ScriptedContext", *, rid: str) -> None:
69+
def abort(ctx: "ScriptedContext", *, rid: str, await_arrival: bool = True) -> None:
6170
_await_control(
6271
ctx,
6372
path="/abort_request",
6473
json={"rid": rid, "abort_all": False},
6574
expect_type=AbortReq,
75+
await_arrival=await_arrival,
6676
)
6777

6878

python/sglang/test/scripted_runtime/context/queries.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,20 @@ def find_req_by_rid(ctx: "ScriptedContext", rid: str) -> Optional["Req"]:
8080

8181
def is_finished(ctx: "ScriptedContext", rid: str) -> bool:
8282
req = find_req_by_rid(ctx, rid)
83-
if req is None:
84-
return rid in ctx._seen_rids
85-
return req.finished()
83+
if req is not None:
84+
return req.finished()
85+
if rid in ctx._seen_rids:
86+
return True
87+
# Fallback: if the req ran in a forward batch (recorded in _batch_log) but
88+
# is now absent from all active scheduler sets, it must have finished.
89+
# This catches requests that completed without ever being observed via
90+
# find_req_by_rid (e.g. when Python short-circuit evaluation prevents the
91+
# query while another request is still running).
92+
log = ctx._scheduler_hook._batch_log
93+
if any(rid in record.rids for record in log):
94+
ctx._seen_rids.add(rid)
95+
return True
96+
return False
8697

8798

8899
def is_chunking(ctx: "ScriptedContext", rid: str) -> bool:

python/sglang/test/scripted_runtime/context/req_starter.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import uuid
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, List, Optional
55

66
from sglang.test.scripted_runtime.context.http_post import (
77
_http_post_and_await_recv_msg,
@@ -30,6 +30,8 @@ def start_req(
3030
return_logprob: bool = False,
3131
logprob_start_len: Optional[int] = None,
3232
top_logprobs_num: Optional[int] = None,
33+
stop_token_ids: Optional[List[int]] = None,
34+
temperature: Optional[float] = None,
3335
lora_path: Optional[str] = None,
3436
) -> ScriptedReqHandle:
3537
ctx = self._ctx
@@ -39,6 +41,10 @@ def start_req(
3941
self._req_counter += 1
4042

4143
sampling_params = {"max_new_tokens": max_new_tokens, "ignore_eos": ignore_eos}
44+
if stop_token_ids is not None:
45+
sampling_params["stop_token_ids"] = stop_token_ids
46+
if temperature is not None:
47+
sampling_params["temperature"] = temperature
4248
payload = {
4349
"input_ids": [prompt_token] * prompt_len,
4450
"sampling_params": sampling_params,

python/sglang/test/scripted_runtime/req_handle.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass
44
from typing import TYPE_CHECKING, Optional
55

6+
from sglang.test.scripted_runtime.context.radix import _node_lock_ref
7+
68
if TYPE_CHECKING:
79
from sglang.srt.managers.schedule_batch import Req
810
from sglang.test.scripted_runtime.context.api import ScriptedContext
@@ -47,5 +49,10 @@ def kv_pages(self) -> int:
4749

4850
@property
4951
def lock_refs(self) -> int:
50-
node = self.req.last_node
51-
return node.lock_ref if node is not None else 0
52+
req = self.req
53+
if req is None:
54+
return 0
55+
node = req.last_node
56+
if node is None:
57+
return 0
58+
return _node_lock_ref(node)

python/sglang/test/scripted_runtime/scheduler_hook.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def _drive_engine_through_warmup(ctx: ScriptedContext) -> Generator:
9595
def _reset_engine_state(ctx: ScriptedContext) -> Generator:
9696
scheduler = ctx.scheduler
9797

98+
if scheduler._engine_paused:
99+
ctx.continue_generation()
100+
98101
ctx._release_exhausted_pools()
99102
ctx.abort_all()
100103
for _ in range(RESET_DRAIN_MAX_STEPS):

python/sglang/test/scripted_runtime_chunked_helpers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@ def run_until_finished(handle, *, max_steps: int = DEFAULT_MAX_STEPS):
4141

4242

4343
def run_until_all_finished(handles: List[Any], *, max_steps: int = DEFAULT_MAX_STEPS):
44+
done = [False] * len(handles)
4445
for _ in range(max_steps):
45-
if all(h.finished for h in handles):
46+
for i, h in enumerate(handles):
47+
done[i] = done[i] or h.finished
48+
if all(done):
4649
return
4750
yield
4851
raise AssertionError(
4952
f"run_until_all_finished: not all reqs finished after {max_steps} "
50-
f"steps (finished={[h.finished for h in handles]})"
53+
f"steps (finished={done})"
5154
)
5255

5356

@@ -65,6 +68,12 @@ def warmup_radix(t, prompt_tokens: List[int], *, max_steps: int = DEFAULT_MAX_ST
6568

6669
BALLAST_MAX_NEW_TOKENS: int = 30000
6770

71+
SMALL_KV_POOL_MAX_TOTAL_TOKENS: int = 4096
72+
73+
SMALL_KV_POOL_BALLAST_MAX_NEW_TOKENS: int = 512
74+
75+
SMALL_KV_POOL_BALLAST_PROMPT_LEN: int = 1536
76+
6877

6978
def exhaust_row_pool(t, *, leave_rows: int, max_steps: int = DEFAULT_MAX_STEPS):
7079
target: int = t.scheduler.req_to_token_pool.available_size() - leave_rows

test/manual/chunked_prefill/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)