Skip to content

Commit 061f648

Browse files
committed
Update trtllm-gen batched GEMM artifact path & checksum, update csrc/trtllm_batched_gemm_runner.cu access to BatchedGemmOptions.mNumStages as it was split to A and B
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
1 parent 37c7620 commit 061f648

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ std::vector<int64_t> prioritizePredefinedConfigs(
6868
if (n /* out_dim */ == 0 && k /* in_dim */ == 0) {
6969
auto pred = [](BatchedGemmConfig const& config) {
7070
BatchedGemmOptions const& options = config.mOptions;
71-
return options.mNumStages == 4 && options.mNumStagesMma == 2 && options.mTileK == 256 &&
72-
options.mTileScheduler == TileScheduler::Persistent;
71+
return options.mNumStagesA == 4 && options.mNumStagesB == 4 && options.mNumStagesMma == 2 &&
72+
options.mTileK == 256 && options.mTileScheduler == TileScheduler::Persistent;
7373
};
7474
prioritizedIndices = bubbleUpConfig(sortedIndices, pred);
7575
}

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class ArtifactPath:
137137

138138
TRTLLM_GEN_FMHA: str = "e7afc4134bb53eaab63fb85163d5943fb190621c/fmha/trtllm-gen/"
139139
TRTLLM_GEN_BMM: str = (
140-
"b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/"
140+
"39a9d28268f43475a757d5700af135e1e58c9849/batched_gemm-5ee61af-2b9855b/"
141141
)
142142
TRTLLM_GEN_GEMM: str = (
143143
"b117d5a6b2dd2228aa966a938eac398cf336d8c0/gemm-b3c1646-1fddea2/"
@@ -158,7 +158,7 @@ class CheckSumHash:
158158
"5bd87798e560a63e883902fc5468146ffff0d3551bf337d2f81bd02893e9dc39"
159159
)
160160
TRTLLM_GEN_BMM: str = (
161-
"0af823880730c4f0b3832d2208fab035946694b83444410b9309db5613d60195"
161+
"db06db7f36a2a9395a2041ff6ac016fe664874074413a2ed90797f91ef17e0f6"
162162
)
163163
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
164164
TRTLLM_GEN_GEMM: str = (

0 commit comments

Comments
 (0)