diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index e3c0a4ed3153..f47b82f44ac1 100755 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -4423,6 +4423,89 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById) EXPECT_EQ(freeAfterUnpin, totalBlocks); } +// Regression test for NVBug 6018647: storeBlocks(pin=true) on a zero-ref block +// that sits in the eviction free queue must call claimBlock() before incRefCount(). +// Without the fix, unpinBlocksById inserts the block into the free queue a second +// time, creating a ghost entry that inflates the free count and can cause hangs. +TEST_F(KVCacheManagerTest, StoreBlocksForReuseWithPinDoesNotCreateGhostFreeBlocks) +{ + using namespace tensorrt_llm::batch_manager::kv_cache_manager; + auto constexpr numLayers = 2; + auto constexpr numKvHeads = 2; + auto constexpr sizePerHead = 16; + auto constexpr tokensPerBlock = 4; + auto constexpr blocksInPrimaryPool = 6; + auto constexpr blocksInSecondaryPool = 0; + auto constexpr maxNumSequences = 8; + auto const stream = std::make_shared(); + auto constexpr onboardBlocks = true; + auto constexpr beamWidth = 1; + auto const maxAttentionWindow = tokensPerBlock * blocksInPrimaryPool; + + BlocksPerWindow const blocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; + + KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, + beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, + 0, stream, maxAttentionWindow, true /* enableBlockReuse */, onboardBlocks); + kvCacheManager.allocatePools(false); + + auto const totalBlocks = kvCacheManager.getMaxNumBlocks(); + + // 8 tokens = 2 blocks (tokensPerBlock=4). + auto inputTokens = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7}); + tr::SamplingConfig const samplingConfig{beamWidth}; + bool constexpr isStreaming{false}; + + // Step 1: Add seq A (requestId=0). Tree is empty, no reuse. + LlmRequest::RequestIdType requestIdA{0}; + auto llmRequestA = std::make_shared(requestIdA, 0, inputTokens, samplingConfig, isStreaming); + kvCacheManager.addSequence(requestIdA, static_cast(inputTokens->size()), beamWidth, llmRequestA); + + // Step 2: Add seq B (requestId=1) with same tokens. Tree still empty, allocates different blocks. + LlmRequest::RequestIdType requestIdB{1}; + auto llmRequestB = std::make_shared(requestIdB, 0, inputTokens, samplingConfig, isStreaming); + kvCacheManager.addSequence(requestIdB, static_cast(inputTokens->size()), beamWidth, llmRequestB); + + // Both sequences allocated, 4 blocks consumed. + auto const freeAfterBothAlloc = kvCacheManager.getNumFreeBlocks(); + EXPECT_EQ(freeAfterBothAlloc, totalBlocks - 4); + + // Step 3-4: Simulate prefill completion for both. + tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*llmRequestA); + tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*llmRequestB); + + // Step 5: Store A's blocks in the radix tree. + kvCacheManager.storeContextBlocks(*llmRequestA); + + // Step 6: Remove seq A. Its blocks are stored in tree, refCount -> 0, released to free queue. + (void) kvCacheManager.removeSequence(requestIdA, llmRequestA); + auto const freeAfterRemoveA = kvCacheManager.getNumFreeBlocks(); + // A's 2 blocks + the 2 that were already free = totalBlocks - 2 (B's blocks). + EXPECT_EQ(freeAfterRemoveA, totalBlocks - 2); + + // Step 7: storeBlocksForReuse with pin=true on seq B. + // storeBlocks finds A's tree blocks (refCount=0, in free queue) as matches and pins them. + // Without the fix: incRefCount alone, block stays in free queue -> ghost on unpin. + // With the fix: claimBlock first, block removed from free queue -> correct lifecycle. + auto pinnedBlockIds = kvCacheManager.storeBlocksForReuse(requestIdB, llmRequestB, /*pinBlocks=*/true); + EXPECT_FALSE(pinnedBlockIds.empty()); + + // Step 8: Unpin the blocks. + kvCacheManager.unpinBlocksById(pinnedBlockIds); + auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks(); + // A's blocks should be in the free queue exactly once. B's 2 blocks still allocated. + // With the bug, ghost entries would inflate this beyond (totalBlocks - 2). + EXPECT_EQ(freeAfterUnpin, totalBlocks - 2); + EXPECT_LE(freeAfterUnpin, totalBlocks); + + // Step 9: Remove seq B. All blocks should now be free. + (void) kvCacheManager.removeSequence(requestIdB, llmRequestB); + auto const freeAfterAll = kvCacheManager.getNumFreeBlocks(); + EXPECT_EQ(freeAfterAll, totalBlocks); + // Ghost entries would make free count exceed total blocks. + EXPECT_LE(freeAfterAll, totalBlocks); +} + TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamBlocking) { auto constexpr numLayers = 12; diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index 15b096d97d22..41bfe694abc7 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -3374,7 +3374,7 @@ def launchTestJobs(pipeline, testFilter) "GB200-12_GPUs-3_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE2-GPU8-Post-Merge", "auto:gb200-flex", "l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8", - 3, + 8, 12, 3 ) @@ -3387,6 +3387,15 @@ def launchTestJobs(pipeline, testFilter) 16, 4 ) + // 5 Nodes + multiNodesSBSAConfigs += buildStageConfigs( + "GB200-20_GPUs-5_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE4-GPU16-Post-Merge", + "auto:gb200-flex", + "l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16", + 1, + 20, + 5 + ) // 6 Nodes multiNodesSBSAConfigs += buildStageConfigs( "GB200-24_GPUs-6_Nodes-PyTorch-Disagg-PerfSanity-CTX2-NODE1-GPU4-GEN1-NODE4-GPU16-Post-Merge", @@ -3396,6 +3405,32 @@ def launchTestJobs(pipeline, testFilter) 24, 6 ) + multiNodesSBSAConfigs += buildStageConfigs( + "GB200-24_GPUs-6_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE2-GPU8-GEN1-NODE4-GPU16-Post-Merge", + "auto:gb200-flex", + "l0_gb200_multi_nodes_perf_sanity_ctx1_node2_gpu8_gen1_node4_gpu16", + 1, + 24, + 6 + ) + // 9 Nodes + multiNodesSBSAConfigs += buildStageConfigs( + "GB200-36_GPUs-9_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE8-GPU32-Post-Merge", + "auto:gb200-flex", + "l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32", + 7, + 36, + 9 + ) + // 10 Nodes + multiNodesSBSAConfigs += buildStageConfigs( + "GB200-40_GPUs-10_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE2-GPU8-GEN1-NODE8-GPU32-Post-Merge", + "auto:gb200-flex", + "l0_gb200_multi_nodes_perf_sanity_ctx1_node2_gpu8_gen1_node8_gpu32", + 1, + 40, + 10 + ) fullSet += multiNodesSBSAConfigs.keySet() if (env.targetArch == AARCH64_TRIPLE) { diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index e9737023bc77..a99d62a48b23 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -261,6 +261,10 @@ def post_config(self): self.config = self.llm.config self.model_config.pretrained_config = self.llm.config + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_hyperclovax.py b/tensorrt_llm/_torch/models/modeling_hyperclovax.py index 26a103676ae1..0825212300bc 100644 --- a/tensorrt_llm/_torch/models/modeling_hyperclovax.py +++ b/tensorrt_llm/_torch/models/modeling_hyperclovax.py @@ -1067,6 +1067,10 @@ def load_weights(self, weights): if not DISAGG: self.mm_encoder.load_weights(weights) + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 81b248bd94ea..ccf0ef6592cf 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -664,6 +664,10 @@ def post_config(self): self.config = self.llm.config self.model_config.pretrained_config = self.llm.config + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index d07a82d7db08..573318b41130 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -666,6 +666,10 @@ def draft_model(self): def load_draft_weights(self): return self.llm.load_draft_weights + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_nano.py b/tensorrt_llm/_torch/models/modeling_nemotron_nano.py index 2737c6d2c0a2..86c1089987c7 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_nano.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_nano.py @@ -1285,6 +1285,10 @@ def load_weights(self, weights): weight_mapper.init_model_and_config(self.llm, self.model_config) self.llm.load_weights(filtered_weights, weight_mapper=weight_mapper) + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index 268ef6ce5f5a..10749350172e 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -1030,6 +1030,10 @@ def load_weights(self, weights): [_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID], device=self.device) + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 453f28cac417..a48340ab872b 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -897,6 +897,10 @@ def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]): def load_weights(self, weights, weight_mapper: BaseWeightMapper): pass + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_qwen3vl.py b/tensorrt_llm/_torch/models/modeling_qwen3vl.py index 5a418af67299..702930e68458 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3vl.py @@ -952,6 +952,10 @@ def post_config(self): self.model_config.pretrained_config = self.llm.config self.config = self.model_config.pretrained_config + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tensorrt_llm/_torch/models/modeling_vila.py b/tensorrt_llm/_torch/models/modeling_vila.py index 8b634229237d..accd5a98998a 100644 --- a/tensorrt_llm/_torch/models/modeling_vila.py +++ b/tensorrt_llm/_torch/models/modeling_vila.py @@ -1243,6 +1243,10 @@ def load_weights(self, weights): _resize_token_embeddings(self.llm, len(self.tokenizer)) self.vocab_size = len(self.tokenizer) + @property + def vocab_size_padded(self) -> int: + return self.llm.vocab_size_padded + def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml index d52b57bcdfef..962889097d3d 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8.yml @@ -17,11 +17,11 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8: - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # failed - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # failed - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # failed - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_1k1k_con4096_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) # failed - # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con4_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # failed + - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) + - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) + - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) + - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_1k1k_con4096_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) + - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con4_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml index 22949684ea05..c5e15ec1e3f1 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16.yml @@ -15,6 +15,6 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16: backend: pytorch tests: - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp1_ccb-UCX] TIMEOUT (120) - - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp1_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120) diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml index b834a05e70ba..f28216f8c761 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32.yml @@ -21,8 +21,8 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32: - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con256_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb256_mtp3_ccb-UCX] TIMEOUT (120) - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con4096_ctx1_dep4_gen1_dep32_eplb256_mtp0_ccb-UCX] TIMEOUT (120) - - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX] TIMEOUT (120) - - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX] TIMEOUT (120) + # - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-v32-fp4_1k1k_con1024_ctx1_dep4_gen1_dep32_eplb256_mtp3_ccb-UCX] TIMEOUT (120) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 0cb74ed5e700..201de58dc314 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -329,6 +329,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_laten accuracy/test_llm_api_pytorch.py::TestPhi4::test_auto_dtype SKIP (https://nvbugs/6040098) unittest/bindings/test_transfer_agent_bindings.py::TestMooncakeFunctionalTransfer::test_mooncake_wait_in_progress_on_zero_timeout SKIP (https://nvbugs/6043312) perf/test_perf.py::test_perf[deepseek_r1_distill_qwen_32b-bench-_autodeploy-float16-input_output_len:1024,1024-reqs:512] SKIP (https://nvbugs/6044213) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto] SKIP (https://nvbugs/6026678) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-triton-auto] SKIP (https://nvbugs/6026678) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-fp8] SKIP (https://nvbugs/6050481) examples/test_visual_gen.py::test_vbench_dimension_score_wan22_a14b_nvfp4 SKIP (https://nvbugs/6050483) visual_gen/test_visual_gen_benchmark.py::test_online_benchmark[openai-videos] SKIP (https://nvbugs/6050483) @@ -339,3 +341,10 @@ visual_gen/test_visual_gen_benchmark.py::test_offline_benchmark SKIP (https://nv accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/6050487) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/6050489) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/6050489) +disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp1-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6057459) +disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp4-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6057460) +perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5844149) +perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con2048_ctx1_dep4_gen1_dep32_eplb288_mtp1_ccb-UCX] SKIP (https://nvbugs/5844149) +perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con256_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5844149) +perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_128k8k_con128_ctx1_pp8_gen1_dep16_eplb0_mtp2_ccb-UCX] SKIP (https://nvbugs/6060119) +perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_128k8k_con64_ctx1_pp8_gen1_dep32_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/6060119) diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py b/tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py new file mode 100644 index 000000000000..b0ecd47b7a6b --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Regression test for: VLM wrapper classes missing vocab_size_padded, causing +# AttributeError at server startup when guided decoding is configured. +# https://github.com/NVIDIA/TensorRT-LLM/pull/12284 + +import json +import os +import sys +import tempfile + +import jsonschema +import openai +import pytest +import yaml +from utils.llm_data import llm_models_root + +from .openai_server import RemoteOpenAIServer + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from test_llm import get_model_path + +pytestmark = pytest.mark.threadleak(enabled=False) + +_MODEL_NAME = "Qwen3/Qwen3-VL-8B-Instruct" +_IMAGE_URL = str(llm_models_root() / "multimodals" / "test_data" / "seashore.png") + +_SCHEMA = { + "type": "object", + "properties": { + "subject": {"type": "string", "description": "The main subject visible in the image."}, + "setting": {"type": "string", "description": "The setting or environment of the image."}, + }, + "required": ["subject", "setting"], + "additionalProperties": False, +} + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "extra_llm_api_options_vlm_guided.yaml") + try: + extra_llm_api_options_dict = { + "guided_decoding_backend": "xgrammar", + "kv_cache_config": { + "free_gpu_memory_fraction": 0.8, + }, + "max_num_tokens": 4096, + } + with open(temp_file_path, "w") as f: + yaml.dump(extra_llm_api_options_dict, f) + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(temp_extra_llm_api_options_file: str): + model_path = get_model_path(_MODEL_NAME) + args = ["--extra_llm_api_options", temp_extra_llm_api_options_file] + with RemoteOpenAIServer(model_path, cli_args=args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_client() + + +def test_vlm_guided_decoding_json_schema(client: openai.OpenAI): + response = client.chat.completions.create( + model=_MODEL_NAME, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": _IMAGE_URL}}, + {"type": "text", "text": "Describe the main subject of this image."}, + ], + } + ], + max_completion_tokens=256, + temperature=0.0, + response_format={ + "type": "json_schema", + "json_schema": { + "name": "image_description", + "strict": True, + "schema": _SCHEMA, + }, + }, + ) + + content = response.choices[0].message.content + assert content is not None + + parsed = json.loads(content) + jsonschema.validate(parsed, _SCHEMA)