Skip to content

Commit 50a73bd

Browse files
authored
[TRTLLM-11878][feat] Gen-only sync transfer v2 and manager v2 (#12882)
1 parent a56a8d2 commit 50a73bd

8 files changed

Lines changed: 299 additions & 4 deletions

File tree

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,37 @@ def respond_and_send_async(self, req: LlmRequest):
310310
)
311311
self._send_reqs[rid] = req
312312

313+
@nvtx_range("KvCacheTransceiverV2.request_and_receive_sync")
313314
def request_and_receive_sync(self, req: LlmRequest):
314-
raise NotImplementedError("request_and_receive_sync is not implemented")
315+
rid = get_unique_rid(req)
316+
if rid in self._recv_sessions:
317+
logger.warning(
318+
f"request_and_receive_sync: rid={rid} already has a recv session, skipping"
319+
)
320+
return
321+
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS
322+
session = None
323+
try:
324+
session = self._transfer_worker.create_rx_session(req)
325+
self._recv_sessions[rid] = session
326+
self._recv_reqs[rid] = req
327+
session.receive(self._create_kv_slice(req))
328+
result = session.wait_complete(blocking=True)
329+
330+
if result == WaitResult.COMPLETED:
331+
if self._need_aux_transfer(req):
332+
self._apply_aux(session, req)
333+
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_COMPLETE
334+
else:
335+
req.state = LlmRequestState.DISAGG_TRANS_ERROR
336+
except Exception:
337+
req.state = LlmRequestState.DISAGG_TRANS_ERROR
338+
raise
339+
finally:
340+
if session is not None:
341+
session.close()
342+
self._recv_sessions.pop(rid, None)
343+
self._recv_reqs.pop(rid, None)
315344

316345
@nvtx_range("KvCacheTransceiverV2.request_and_receive_async")
317346
def request_and_receive_async(self, req: LlmRequest):

tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,19 @@ def _schedule_loop(self, active_requests, inflight_request_ids):
230230

231231
req_state_value = req.state_value
232232

233-
# Disagg gen init bypasses normal state gating (same as C++ / V1 scheduler)
233+
# Disagg gen init bypasses normal state gating (same as C++ / V1 scheduler),
234+
# but the V2 scheduler owns inline KV allocation so we must allocate here.
235+
# V1 defers allocation to prepare_resources; V2 prepare_resources is a no-op
236+
# for the primary manager, so allocation must happen in the scheduling loop.
234237
if req_state_value == self._disagg_gen_init_state_value:
238+
if not self.kv_cache_manager.prepare_context(req):
239+
req_it += 1
240+
continue
241+
if not self.kv_cache_manager.resize_context(
242+
req, req.context_remaining_length + get_draft_token_length(req)
243+
):
244+
req_it += 1
245+
continue
235246
disagg_candidates.append(req)
236247
req_it += 1
237248
continue

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def launch_disaggregated_llm(
146146
max_workers: int = 16,
147147
enable_perf=False,
148148
extra_env: Optional[Dict[str, str]] = None,
149+
gen_extra_env: Optional[Dict[str, str]] = None,
149150
):
150151
temp_dir = tempfile.TemporaryDirectory()
151152
disaggregated_serving_config_path = os.path.join(
@@ -299,6 +300,8 @@ def _apply_perf_flags(cfg: Optional[Dict[str, Any]]):
299300

300301
for i, port in enumerate(gen_ports):
301302
env = base_env.copy()
303+
if gen_extra_env:
304+
env.update(gen_extra_env)
302305
env["TRTLLM_USE_UCX_KVCACHE"] = "1"
303306
# Need to set UCX_TLS to ^ib to avoid hangs on CI B200 cluster.
304307
env["UCX_TLS"] = "^ib"
@@ -633,6 +636,50 @@ def test_auto_dtype(self, ctx_disable_overlap_scheduler,
633636
self.MODEL_PATH) as llm:
634637
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
635638

639+
@skip_pre_hopper
640+
@pytest.mark.skip_less_device(2)
641+
def test_kv_cache_v2_nixl_python(self):
642+
"""Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON."""
643+
ctx_server_config = {
644+
"disable_overlap_scheduler": True,
645+
"kv_cache_config": {
646+
"enable_block_reuse": False,
647+
"use_kv_cache_manager_v2": True
648+
},
649+
"cache_transceiver_config": {
650+
"backend": "NIXL",
651+
"transceiver_runtime": "PYTHON"
652+
}
653+
}
654+
gen_server_config = {
655+
"disable_overlap_scheduler": False,
656+
"kv_cache_config": {
657+
"enable_block_reuse": False,
658+
"use_kv_cache_manager_v2": True
659+
},
660+
"cache_transceiver_config": {
661+
"backend": "NIXL",
662+
"transceiver_runtime": "PYTHON"
663+
}
664+
}
665+
disaggregated_server_config = {
666+
"hostname": "localhost",
667+
"port": 8000,
668+
"backend": "pytorch",
669+
"context_servers": {
670+
"num_instances": 1,
671+
"urls": ["localhost:8001"]
672+
},
673+
"generation_servers": {
674+
"num_instances": 1,
675+
"urls": ["localhost:8002"]
676+
}
677+
}
678+
with launch_disaggregated_llm(disaggregated_server_config,
679+
ctx_server_config, gen_server_config,
680+
self.MODEL_PATH) as llm:
681+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
682+
636683
@pytest.mark.skip_less_device(2)
637684
def test_ngram(self):
638685
speculative_decoding_config = {
@@ -952,6 +999,52 @@ def test_nixl_backend(self):
952999
self.MODEL_PATH) as llm:
9531000
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
9541001

1002+
@pytest.mark.skip_less_device(2)
1003+
@pytest.mark.skip_less_device_memory(60000)
1004+
@skip_no_hopper
1005+
def test_gen_only_sync(self):
1006+
"""Test gen-only synchronous KV transfer path with NIXL Python transceiver.
1007+
1008+
Sets TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1 so the gen worker calls
1009+
request_and_receive_sync instead of the async path. Accuracy must be
1010+
identical to the standard async path.
1011+
"""
1012+
ctx_server_config = {
1013+
"disable_overlap_scheduler": True,
1014+
"cache_transceiver_config": {
1015+
"backend": "NIXL",
1016+
"transceiver_runtime": "PYTHON",
1017+
"max_tokens_in_buffer": 4096,
1018+
},
1019+
}
1020+
gen_server_config = {
1021+
"disable_overlap_scheduler": True,
1022+
"cache_transceiver_config": {
1023+
"backend": "NIXL",
1024+
"transceiver_runtime": "PYTHON",
1025+
"max_tokens_in_buffer": 4096,
1026+
},
1027+
}
1028+
disaggregated_server_config = {
1029+
"hostname": "localhost",
1030+
"backend": "pytorch",
1031+
"context_servers": {
1032+
"num_instances": 1
1033+
},
1034+
"generation_servers": {
1035+
"num_instances": 1
1036+
},
1037+
}
1038+
with launch_disaggregated_llm(
1039+
disaggregated_server_config,
1040+
ctx_server_config,
1041+
gen_server_config,
1042+
self.MODEL_PATH,
1043+
# Apply to both servers: gen worker uses sync receive path.
1044+
extra_env={"TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP": "1"},
1045+
) as llm:
1046+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
1047+
9551048
@pytest.mark.skip_less_device(8)
9561049
@parametrize_with_ids("overlap_scheduler", [True, False])
9571050
@parametrize_with_ids("mtp_nextn", [0, 2])
@@ -1141,6 +1234,51 @@ def test_guided_decoding(self, backend: str, mtp_nextn: int, mocker):
11411234
self.MODEL_PATH) as llm:
11421235
run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"])
11431236

1237+
@pytest.mark.skip_less_device(2)
1238+
@pytest.mark.skip_less_device_memory(60000)
1239+
@skip_pre_hopper
1240+
def test_kv_cache_v2_nixl_python(self):
1241+
"""Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON."""
1242+
ctx_server_config = {
1243+
"disable_overlap_scheduler": True,
1244+
"kv_cache_config": {
1245+
"enable_block_reuse": False,
1246+
"use_kv_cache_manager_v2": True
1247+
},
1248+
"cache_transceiver_config": {
1249+
"backend": "NIXL",
1250+
"transceiver_runtime": "PYTHON"
1251+
}
1252+
}
1253+
gen_server_config = {
1254+
"disable_overlap_scheduler": True,
1255+
"kv_cache_config": {
1256+
"enable_block_reuse": False,
1257+
"use_kv_cache_manager_v2": True
1258+
},
1259+
"cache_transceiver_config": {
1260+
"backend": "NIXL",
1261+
"transceiver_runtime": "PYTHON"
1262+
}
1263+
}
1264+
disaggregated_server_config = {
1265+
"hostname": "localhost",
1266+
"port": 8000,
1267+
"backend": "pytorch",
1268+
"context_servers": {
1269+
"num_instances": 1,
1270+
"urls": ["localhost:8001"]
1271+
},
1272+
"generation_servers": {
1273+
"num_instances": 1,
1274+
"urls": ["localhost:8002"]
1275+
}
1276+
}
1277+
with launch_disaggregated_llm(disaggregated_server_config,
1278+
ctx_server_config, gen_server_config,
1279+
self.MODEL_PATH) as llm:
1280+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
1281+
11441282

11451283
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
11461284
class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
@@ -1193,6 +1331,52 @@ def test_auto_dtype(self, block_reuse):
11931331
self.MODEL_PATH) as llm:
11941332
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
11951333

1334+
@pytest.mark.skip_less_device(2)
1335+
@skip_pre_hopper
1336+
def test_kv_cache_v2_nixl_python(self):
1337+
"""Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON."""
1338+
ctx_server_config = {
1339+
"disable_overlap_scheduler": True,
1340+
"cuda_graph_config": None,
1341+
"kv_cache_config": {
1342+
"enable_block_reuse": False,
1343+
"use_kv_cache_manager_v2": True
1344+
},
1345+
"cache_transceiver_config": {
1346+
"backend": "NIXL",
1347+
"transceiver_runtime": "PYTHON"
1348+
}
1349+
}
1350+
gen_server_config = {
1351+
"disable_overlap_scheduler": True,
1352+
"cuda_graph_config": None,
1353+
"kv_cache_config": {
1354+
"enable_block_reuse": False,
1355+
"use_kv_cache_manager_v2": True
1356+
},
1357+
"cache_transceiver_config": {
1358+
"backend": "NIXL",
1359+
"transceiver_runtime": "PYTHON"
1360+
}
1361+
}
1362+
disaggregated_server_config = {
1363+
"hostname": "localhost",
1364+
"port": 8000,
1365+
"backend": "pytorch",
1366+
"context_servers": {
1367+
"num_instances": 1,
1368+
"urls": ["localhost:8001"]
1369+
},
1370+
"generation_servers": {
1371+
"num_instances": 1,
1372+
"urls": ["localhost:8002"]
1373+
}
1374+
}
1375+
with launch_disaggregated_llm(disaggregated_server_config,
1376+
ctx_server_config, gen_server_config,
1377+
self.MODEL_PATH) as llm:
1378+
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
1379+
11961380

11971381
@skip_pre_blackwell
11981382
@pytest.mark.skip_less_device_memory(80000)

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,14 +390,18 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with
390390
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2]
391391
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
392392
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
393+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_sync
394+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python
393395
accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy
394396
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
395397
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
398+
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python
396399
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
397400
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
398401
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
399402
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]
400403
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
404+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python
401405
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
402406
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
403407
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]

tests/integration/test_lists/test-db/l0_dgx_b300.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ l0_dgx_b300:
9292
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8]
9393
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
9494
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
95+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python
96+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python
97+
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python
9598
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4] TIMEOUT (180)
9699
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] TIMEOUT (180)
97100
- condition:

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@ l0_dgx_h100:
3131
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill
3232
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
3333
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
34+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_sync
35+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python
3436
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
37+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python
3538
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
3639
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
40+
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python
3741
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-False]
3842
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
3943
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True-False]

tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,18 @@ def make_encoder_request(request_id, encoder_output_len, lora_task_id=None):
9999
return req
100100

101101

102-
def make_disagg_request(request_id):
102+
def make_disagg_request(request_id, context_remaining_length=1, num_draft_tokens=0):
103103
req = Mock()
104104
req.request_id = request_id
105105
req.py_request_id = request_id
106106
req.state_value = DISAGG_GEN_INIT
107107
req.is_context_init_state = False
108108
req.is_generation_in_progress_state = False
109109
req.is_first_context_chunk = True
110+
req.context_remaining_length = context_remaining_length
111+
req.num_draft_tokens = num_draft_tokens
112+
req.has_draft_tokens = num_draft_tokens > 0
113+
req.py_draft_tokens = [0] * num_draft_tokens if num_draft_tokens > 0 else []
110114
return req
111115

112116

0 commit comments

Comments
 (0)