Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4423,6 +4423,89 @@ TEST_F(KVCacheManagerTest, PinAndUnpinBlocksById)
EXPECT_EQ(freeAfterUnpin, totalBlocks);
}

// Regression test for NVBug 6018647: storeBlocks(pin=true) on a zero-ref block
// that sits in the eviction free queue must call claimBlock() before incRefCount().
// Without the fix, unpinBlocksById inserts the block into the free queue a second
// time, creating a ghost entry that inflates the free count and can cause hangs.
TEST_F(KVCacheManagerTest, StoreBlocksForReuseWithPinDoesNotCreateGhostFreeBlocks)
{
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
auto constexpr numLayers = 2;
auto constexpr numKvHeads = 2;
auto constexpr sizePerHead = 16;
auto constexpr tokensPerBlock = 4;
auto constexpr blocksInPrimaryPool = 6;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr maxNumSequences = 8;
auto const stream = std::make_shared<tr::CudaStream>();
auto constexpr onboardBlocks = true;
auto constexpr beamWidth = 1;
auto const maxAttentionWindow = tokensPerBlock * blocksInPrimaryPool;

BlocksPerWindow const blocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};

KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
beamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF,
0, stream, maxAttentionWindow, true /* enableBlockReuse */, onboardBlocks);
kvCacheManager.allocatePools(false);

auto const totalBlocks = kvCacheManager.getMaxNumBlocks();

// 8 tokens = 2 blocks (tokensPerBlock=4).
auto inputTokens = std::make_shared<VecTokens>(VecTokens{0, 1, 2, 3, 4, 5, 6, 7});
tr::SamplingConfig const samplingConfig{beamWidth};
bool constexpr isStreaming{false};

// Step 1: Add seq A (requestId=0). Tree is empty, no reuse.
LlmRequest::RequestIdType requestIdA{0};
auto llmRequestA = std::make_shared<LlmRequest>(requestIdA, 0, inputTokens, samplingConfig, isStreaming);
kvCacheManager.addSequence(requestIdA, static_cast<SizeType32>(inputTokens->size()), beamWidth, llmRequestA);

// Step 2: Add seq B (requestId=1) with same tokens. Tree still empty, allocates different blocks.
LlmRequest::RequestIdType requestIdB{1};
auto llmRequestB = std::make_shared<LlmRequest>(requestIdB, 0, inputTokens, samplingConfig, isStreaming);
kvCacheManager.addSequence(requestIdB, static_cast<SizeType32>(inputTokens->size()), beamWidth, llmRequestB);

// Both sequences allocated, 4 blocks consumed.
auto const freeAfterBothAlloc = kvCacheManager.getNumFreeBlocks();
EXPECT_EQ(freeAfterBothAlloc, totalBlocks - 4);

// Step 3-4: Simulate prefill completion for both.
tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*llmRequestA);
tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*llmRequestB);

// Step 5: Store A's blocks in the radix tree.
kvCacheManager.storeContextBlocks(*llmRequestA);

// Step 6: Remove seq A. Its blocks are stored in tree, refCount -> 0, released to free queue.
(void) kvCacheManager.removeSequence(requestIdA, llmRequestA);
auto const freeAfterRemoveA = kvCacheManager.getNumFreeBlocks();
// A's 2 blocks + the 2 that were already free = totalBlocks - 2 (B's blocks).
EXPECT_EQ(freeAfterRemoveA, totalBlocks - 2);

// Step 7: storeBlocksForReuse with pin=true on seq B.
// storeBlocks finds A's tree blocks (refCount=0, in free queue) as matches and pins them.
// Without the fix: incRefCount alone, block stays in free queue -> ghost on unpin.
// With the fix: claimBlock first, block removed from free queue -> correct lifecycle.
auto pinnedBlockIds = kvCacheManager.storeBlocksForReuse(requestIdB, llmRequestB, /*pinBlocks=*/true);
EXPECT_FALSE(pinnedBlockIds.empty());

// Step 8: Unpin the blocks.
kvCacheManager.unpinBlocksById(pinnedBlockIds);
auto const freeAfterUnpin = kvCacheManager.getNumFreeBlocks();
// A's blocks should be in the free queue exactly once. B's 2 blocks still allocated.
// With the bug, ghost entries would inflate this beyond (totalBlocks - 2).
EXPECT_EQ(freeAfterUnpin, totalBlocks - 2);
EXPECT_LE(freeAfterUnpin, totalBlocks);

// Step 9: Remove seq B. All blocks should now be free.
(void) kvCacheManager.removeSequence(requestIdB, llmRequestB);
auto const freeAfterAll = kvCacheManager.getNumFreeBlocks();
EXPECT_EQ(freeAfterAll, totalBlocks);
// Ghost entries would make free count exceed total blocks.
EXPECT_LE(freeAfterAll, totalBlocks);
}

TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamBlocking)
{
auto constexpr numLayers = 12;
Expand Down
37 changes: 36 additions & 1 deletion jenkins/L0_Test.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -3374,7 +3374,7 @@ def launchTestJobs(pipeline, testFilter)
"GB200-12_GPUs-3_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE2-GPU8-Post-Merge",
"auto:gb200-flex",
"l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8",
3,
8,
12,
3
)
Expand All @@ -3387,6 +3387,15 @@ def launchTestJobs(pipeline, testFilter)
16,
4
)
// 5 Nodes
multiNodesSBSAConfigs += buildStageConfigs(
"GB200-20_GPUs-5_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE4-GPU16-Post-Merge",
"auto:gb200-flex",
"l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16",
1,
20,
5
)
// 6 Nodes
multiNodesSBSAConfigs += buildStageConfigs(
"GB200-24_GPUs-6_Nodes-PyTorch-Disagg-PerfSanity-CTX2-NODE1-GPU4-GEN1-NODE4-GPU16-Post-Merge",
Expand All @@ -3396,6 +3405,32 @@ def launchTestJobs(pipeline, testFilter)
24,
6
)
multiNodesSBSAConfigs += buildStageConfigs(
"GB200-24_GPUs-6_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE2-GPU8-GEN1-NODE4-GPU16-Post-Merge",
"auto:gb200-flex",
"l0_gb200_multi_nodes_perf_sanity_ctx1_node2_gpu8_gen1_node4_gpu16",
1,
24,
6
)
// 9 Nodes
multiNodesSBSAConfigs += buildStageConfigs(
"GB200-36_GPUs-9_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE1-GPU4-GEN1-NODE8-GPU32-Post-Merge",
"auto:gb200-flex",
"l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32",
7,
36,
9
)
// 10 Nodes
multiNodesSBSAConfigs += buildStageConfigs(
"GB200-40_GPUs-10_Nodes-PyTorch-Disagg-PerfSanity-CTX1-NODE2-GPU8-GEN1-NODE8-GPU32-Post-Merge",
"auto:gb200-flex",
"l0_gb200_multi_nodes_perf_sanity_ctx1_node2_gpu8_gen1_node8_gpu32",
1,
40,
10
)
fullSet += multiNodesSBSAConfigs.keySet()

if (env.targetArch == AARCH64_TRIPLE) {
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_gemma3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ def post_config(self):
self.config = self.llm.config
self.model_config.pretrained_config = self.llm.config

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_hyperclovax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,10 @@ def load_weights(self, weights):
if not DISAGG:
self.mm_encoder.load_weights(weights)

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,10 @@ def post_config(self):
self.config = self.llm.config
self.model_config.pretrained_config = self.llm.config

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,10 @@ def draft_model(self):
def load_draft_weights(self):
return self.llm.load_draft_weights

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_nemotron_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,10 @@ def load_weights(self, weights):
weight_mapper.init_model_and_config(self.llm, self.model_config)
self.llm.load_weights(filtered_weights, weight_mapper=weight_mapper)

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,10 @@ def load_weights(self, weights):
[_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID],
device=self.device)

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,10 @@ def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
def load_weights(self, weights, weight_mapper: BaseWeightMapper):
pass

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_qwen3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,10 @@ def post_config(self):
self.model_config.pretrained_config = self.llm.config
self.config = self.model_config.pretrained_config

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/models/modeling_vila.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,10 @@ def load_weights(self, weights):
_resize_token_embeddings(self.llm, len(self.tokenizer))
self.vocab_size = len(self.tokenizer)

@property
def vocab_size_padded(self) -> int:
return self.llm.vocab_size_padded

def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node2_gpu8:
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # failed
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # failed
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # failed
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_1k1k_con4096_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120) # failed
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con4_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120) # failed
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_1k1k_con4096_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con4_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep8_eplb0_mtp0_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_8k1k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node4_gpu16:
backend: pytorch
tests:
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp1_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb0_mtp1_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_kimi-k2-thinking-fp4_8k1k_con4096_ctx1_dep4_gen1_dep16_eplb384_mtp0_ccb-UCX] TIMEOUT (120)
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ l0_gb200_multi_nodes_perf_sanity_ctx1_node1_gpu4_gen1_node8_gpu32:
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con256_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb256_mtp3_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_8k1k_con4096_ctx1_dep4_gen1_dep32_eplb256_mtp0_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_1k1k_con2048_ctx1_dep4_gen1_dep32_eplb384_mtp0_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_kimi-k2-thinking-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb416_mtp3_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_1k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-r1-fp4_8k1k_con1024_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] TIMEOUT (120)
# - perf/test_perf_sanity.py::test_e2e[disagg_upload-e2e-gb200_deepseek-v32-fp4_1k1k_con1024_ctx1_dep4_gen1_dep32_eplb256_mtp3_ccb-UCX] TIMEOUT (120)
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_laten
accuracy/test_llm_api_pytorch.py::TestPhi4::test_auto_dtype SKIP (https://nvbugs/6040098)
unittest/bindings/test_transfer_agent_bindings.py::TestMooncakeFunctionalTransfer::test_mooncake_wait_in_progress_on_zero_timeout SKIP (https://nvbugs/6043312)
perf/test_perf.py::test_perf[deepseek_r1_distill_qwen_32b-bench-_autodeploy-float16-input_output_len:1024,1024-reqs:512] SKIP (https://nvbugs/6044213)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto] SKIP (https://nvbugs/6026678)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-triton-auto] SKIP (https://nvbugs/6026678)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-fp8] SKIP (https://nvbugs/6050481)
examples/test_visual_gen.py::test_vbench_dimension_score_wan22_a14b_nvfp4 SKIP (https://nvbugs/6050483)
visual_gen/test_visual_gen_benchmark.py::test_online_benchmark[openai-videos] SKIP (https://nvbugs/6050483)
Expand All @@ -339,3 +341,10 @@ visual_gen/test_visual_gen_benchmark.py::test_offline_benchmark SKIP (https://nv
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/6050487)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/6050489)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/6050489)
disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp1-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6057459)
disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp4-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6057460)
perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con1_ctx1_dep4_gen1_tep8_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5844149)
perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con2048_ctx1_dep4_gen1_dep32_eplb288_mtp1_ccb-UCX] SKIP (https://nvbugs/5844149)
perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con256_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5844149)
perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_128k8k_con128_ctx1_pp8_gen1_dep16_eplb0_mtp2_ccb-UCX] SKIP (https://nvbugs/6060119)
perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_128k8k_con64_ctx1_pp8_gen1_dep32_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/6060119)
Loading
Loading