Skip to content

Commit 828d272

Browse files
committed
update missing args
1 parent 5591429 commit 828d272

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,9 @@ void TrtllmGenBatchedGemmRunner::run(
258258
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
259259
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
260260

261-
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream),
262-
multiProcessorCount, enable_pdl, globalTrtllmGenBatchedGemmModuleCache);
261+
auto const err =
262+
bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
263+
enable_pdl, nullptr, globalTrtllmGenBatchedGemmModuleCache);
263264

264265
FLASHINFER_CHECK(err == 0,
265266
"Error occurred when running GEMM!"

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,8 +1443,8 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher {
14431443
static_cast<int*>(num_tokens_per_expert.data_ptr()),
14441444
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
14451445
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
1446-
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
1447-
use_routing_scales_on_input, use_deep_seek_fp8,
1446+
static_cast<int*>(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt,
1447+
mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8,
14481448
static_cast<RoutingMethodType>(routing_method_type), routing_stream);
14491449

14501450
check_moe();

0 commit comments

Comments
 (0)