-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add more testing for chunked prefill #27506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a9d4d99
5053d8a
0216b40
3cde89b
4f2b4c1
187f44f
378abc7
6aed448
cf82cc2
06b0716
5b04f87
452dbd5
afc6643
bb91eae
097d69e
b04c139
26839be
b3df1f2
3703d72
efd2797
c70113a
2fe37e9
22ad99c
62cbabc
691c637
7ced1c5
286c048
2561d57
fd20c46
f8e6bad
c492526
2ff58c0
4448e42
460105f
40e631d
9d6ddbc
9193494
20c79fd
a5dcdc3
ceba6f8
f5938ca
1e0dca5
cd2c137
79f0600
3a6b52e
aa11ba3
d53662d
ea5201d
e220c0f
b2de384
bc26ba4
08c7f4c
770e8b1
3a76105
bb0b8c8
6a1b5c8
5c83123
9d6632e
a497908
3c7823d
be38f82
7ba758a
fdcc461
3f3caf5
c2a0db1
ad13e0f
5e2257d
b2453b1
b5fe41b
5776a41
addee53
1554087
5e0014b
c1a3778
527e770
3a7f61e
6499e8a
81b3f1c
1e4b829
237f0c5
46f3ddc
587fc97
d77f931
af12dc3
beb4bf2
27e0eee
f380209
685f8bd
852ad69
aadd330
178ad5f
138f8f6
bafedaa
8dd2b30
6769347
a2e6710
1d3bcdc
3591ca4
a27b96e
8eb018c
526a82e
ae103fe
ca07bdf
e683412
cc00a75
43fb008
60980ff
74637e1
736928c
dc09e5b
2a6b82b
7c93fcb
d32afce
3972571
7016c11
0eb7ce0
259f5a2
30d0dfd
3d39694
f180498
a51f87f
867913f
6b36d3f
85b75b7
372df05
7776bd2
815eaee
ff1f8bd
3a3b9dc
1f0f9ba
6927845
5e674e3
87bd5a0
389d427
d5cc957
91b57fa
f4bc61e
3532003
b772f04
51046cd
4e11770
349b497
29bd8eb
2399517
95f3aa1
06f4bca
cbfc6cf
7f8290f
59ad8e6
3b15001
0e4c5df
6ade010
10a0f19
6308d60
ec49263
46c3db4
b381a79
97cfc2e
141c6e2
e4b9e04
953bd4b
7a9b214
a24924c
c864d67
ed020ee
7601b1f
e485564
5c518b7
3f9e255
04304bd
e8eb94d
cd81386
38c6da7
0b28725
61eed00
237b232
aa8a849
90dec59
8926b68
78be392
2b255d7
1a6ee32
7d57cae
cfb0e74
9bad8d3
8ead099
ee0b68f
43f0b63
1cfadc2
3501f01
e00be15
11bc498
c7bdeb8
57513a3
0b9c47b
16a2b65
9f368d5
574c03c
84e654b
7b90b00
f2875c4
4664fa0
57c872d
308fce8
2265d52
2edb7cf
6e062e0
3cc45ef
c2548be
b7763c0
3404574
bcaed43
f551ab2
f632d14
98ef65c
bc3c10b
ac9d8e8
5cb8847
da8fe11
b42ebb6
5f45cb1
c19fb6c
a80e474
b67412e
cfa3f15
2244491
fb917d0
a81e0b6
14de978
e5fb69c
31c3964
d3090fb
8d2af96
27fbff8
9ebf28e
e6b8b18
cd5573b
015200a
19e2353
a59441f
722fd4b
7c3b7b3
a3996e1
c1d8ff1
4eb5bc6
8c8f08d
f240a51
9b01721
17f139b
6e4b4dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import time | ||
| from types import SimpleNamespace | ||
| from typing import ClassVar, List, Optional | ||
|
|
||
| from sglang.srt.utils import kill_process_tree | ||
| from sglang.test.run_eval import run_eval | ||
| from sglang.test.server_fixtures.disaggregation_fixture import ( | ||
| PDDisaggregationServerBase, | ||
| ) | ||
| from sglang.test.test_utils import ( | ||
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
| DEFAULT_URL_FOR_TEST, | ||
| CustomTestCase, | ||
| popen_launch_server, | ||
| try_cached_model, | ||
| ) | ||
|
|
||
| DEFAULT_MODEL: str = "Qwen/Qwen3-0.6B" | ||
|
|
||
| DEFAULT_CHUNKED_PREFILL_SIZE: int = 256 | ||
| DEFAULT_NUM_EXAMPLES: int = 100 | ||
| DEFAULT_NUM_SHOTS: int = 10 | ||
| LONG_PROMPT_NUM_SHOTS: int = 24 | ||
| DEFAULT_NUM_THREADS: int = 128 | ||
| DEFAULT_MAX_TOKENS: int = 512 | ||
| DEFAULT_SEED: int = 42 | ||
|
|
||
| KV_CANARY_ARGS: List[str] = [ | ||
| "--kv-canary", | ||
| "raise", | ||
| "--kv-canary-real-data", | ||
| "partial", | ||
| "--kv-canary-sweep-interval", | ||
| "100", | ||
| "--disable-piecewise-cuda-graph", | ||
| ] | ||
|
|
||
|
|
||
| class ChunkedGsm8kMixin: | ||
| __test__ = False | ||
| use_kv_canary: ClassVar[bool] = True | ||
| model: ClassVar[str] = DEFAULT_MODEL | ||
| feature_args: ClassVar[List[str]] = [] | ||
|
|
||
| chunked_prefill_size: ClassVar[int] = DEFAULT_CHUNKED_PREFILL_SIZE | ||
| num_shots: ClassVar[int] = DEFAULT_NUM_SHOTS | ||
| num_examples: ClassVar[int] = DEFAULT_NUM_EXAMPLES | ||
| num_threads: ClassVar[int] = DEFAULT_NUM_THREADS | ||
| max_tokens: ClassVar[int] = DEFAULT_MAX_TOKENS | ||
| gsm8k_threshold: ClassVar[float] | ||
|
|
||
| def build_prefill_side_args(self) -> List[str]: | ||
| canary = list(KV_CANARY_ARGS) if self.use_kv_canary else [] | ||
| return ( | ||
| ["--chunked-prefill-size", str(self.chunked_prefill_size)] | ||
| + list(self.feature_args) | ||
| + canary | ||
| ) | ||
|
|
||
| def test_mixed_prefix_gsm8k_chunked(self): | ||
| fixture_name = type(self).__name__ | ||
|
|
||
| args = SimpleNamespace( | ||
| base_url=self.base_url, | ||
| model=self.model, | ||
| eval_name="mixed_prefix_gsm8k", | ||
| api="chat_completion", | ||
| max_tokens=self.max_tokens, | ||
| num_examples=self.num_examples, | ||
| num_threads=self.num_threads, | ||
| num_shots=self.num_shots, | ||
| mixed_prefix_gsm8k_secondary_pool_size=15, | ||
| mixed_prefix_gsm8k_seed=DEFAULT_SEED, | ||
| gsm8k_data_path=None, | ||
| temperature=0.0, | ||
| ) | ||
| tic = time.perf_counter() | ||
| metrics = run_eval(args) | ||
| metrics["elapsed_sec"] = time.perf_counter() - tic | ||
| print(f"[{fixture_name}] {metrics} threshold={self.gsm8k_threshold:.4f}") | ||
|
|
||
| score = metrics.get("score") | ||
| self.assertIsNotNone(score, "run_eval returned no score") | ||
| self.assertGreaterEqual(score, self.gsm8k_threshold) | ||
|
|
||
|
|
||
| class ChunkedTestBase(ChunkedGsm8kMixin, CustomTestCase): | ||
| __test__ = False | ||
|
|
||
| base_url: ClassVar[str] = DEFAULT_URL_FOR_TEST | ||
| launch_timeout: ClassVar[int] = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH | ||
|
|
||
| process: ClassVar[Optional[object]] = None | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| cls.process = popen_launch_server( | ||
| cls.model, | ||
| cls.base_url, | ||
| timeout=cls.launch_timeout, | ||
| other_args=cls("test_mixed_prefix_gsm8k_chunked").build_prefill_side_args(), | ||
| ) | ||
|
Comment on lines
+98
to
+104
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| if cls.process is not None: | ||
| kill_process_tree(cls.process.pid) | ||
|
|
||
|
|
||
| class ChunkedTestPDBase(ChunkedGsm8kMixin, PDDisaggregationServerBase): | ||
| __test__ = False | ||
| decode_feature_args: ClassVar[List[str]] = [] | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| cls.extra_prefill_args = cls( | ||
| "test_mixed_prefix_gsm8k_chunked" | ||
| ).build_prefill_side_args() | ||
|
Comment on lines
+117
to
+120
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| canary = list(KV_CANARY_ARGS) if cls.use_kv_canary else [] | ||
| cls.extra_decode_args = canary + list(cls.decode_feature_args) | ||
| PDDisaggregationServerBase.setUpClass() | ||
| cls.model = try_cached_model(cls.model) | ||
| cls.launch_all() | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| PDDisaggregationServerBase.tearDownClass() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,19 +19,34 @@ def _http_post_and_await_recv_msg( | |
| predicate: Callable[[Any], bool], | ||
| description: str, | ||
| timeout_s: float = RECV_MSG_ARRIVAL_TIMEOUT_S, | ||
| ) -> None: | ||
| _submit_post(ctx, path=path, json=json) | ||
| ctx._tokenizer_recv_proxy.wait_until_arrived( | ||
| predicate, | ||
| timeout_s=timeout_s, | ||
| description=description, | ||
| ) | ||
|
|
||
|
|
||
| def _http_post_fire_and_forget( | ||
| ctx: "ScriptedContext", | ||
| *, | ||
| path: str, | ||
| json: Optional[Dict[str, Any]], | ||
| ) -> None: | ||
| _submit_post(ctx, path=path, json=json) | ||
|
|
||
|
|
||
| def _submit_post( | ||
| ctx: "ScriptedContext", | ||
| *, | ||
| path: str, | ||
| json: Optional[Dict[str, Any]], | ||
| ) -> None: | ||
| server_args = ctx.scheduler.server_args | ||
| url = f"http://{server_args.host}:{server_args.port}{path}" | ||
|
|
||
| async def _post() -> None: | ||
| try: | ||
| await ctx._http_poster.post(url, json) | ||
| except Exception: # noqa: BLE001 — fire-and-forget background POST | ||
| logger.exception("scripted_runtime: POST %s failed", path) | ||
| await ctx._http_poster.post(url, json) | ||
|
Comment on lines
49
to
+50
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing the async def _post() -> None:\n try:\n await ctx._http_poster.post(url, json)\n except Exception:\n logger.exception(\"scripted_runtime: POST %s failed\", path) |
||
|
|
||
| ctx._http_poster.submit_coro(_post()) | ||
| ctx._tokenizer_recv_proxy.wait_until_arrived( | ||
| predicate, | ||
| timeout_s=timeout_s, | ||
| description=description, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| import unittest | ||
|
|
||
| from sglang.test.chunked_prefill_test_utils import ChunkedTestPDBase | ||
|
|
||
|
|
||
| class TestChunkedFeatureDisagg(ChunkedTestPDBase): | ||
| __test__ = True | ||
| use_kv_canary = False | ||
| gsm8k_threshold = 0.50 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import unittest | ||
|
|
||
| from sglang.test.chunked_prefill_test_utils import ChunkedTestBase | ||
| from sglang.test.test_utils import DEFAULT_MLA_MODEL_NAME_FOR_TEST | ||
|
|
||
|
|
||
| class TestChunkedFeatureDPAttention(ChunkedTestBase): | ||
| __test__ = True | ||
| use_kv_canary = False | ||
| model = DEFAULT_MLA_MODEL_NAME_FOR_TEST | ||
| gsm8k_threshold = 0.50 | ||
| feature_args = [ | ||
| "--trust-remote-code", | ||
| "--tp", | ||
| "2", | ||
| "--enable-dp-attention", | ||
| "--dp", | ||
| "2", | ||
| ] | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instantiating the test case class manually inside
setUpClassjust to callbuild_prefill_side_argsis an anti-pattern. Sincebuild_prefill_side_argsonly accesses class-level variables, it should be defined as a@classmethod.