Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions csrc/fused_moe/noAuxTcKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
int64_t const numGroup, int64_t const topkGroup,
int64_t const topk, int64_t const numExperts,
int64_t const numExpertsPerGroup,
double const routedScalingFactor) {
double const routedScalingFactor,
int16_t* routingReplayOut) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
Expand Down Expand Up @@ -213,6 +214,12 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
topkValues[laneIdx] = static_cast<OutputT>(finalScore);
topkIndices[laneIdx] = expertIdx;
}

// Routing replay: record all top-K selected expert IDs per token.
// Layout: [num_tokens, topk] β€” same indexing as topkIndices.
if (laneIdx < topk && routingReplayOut != nullptr) {
routingReplayOut[blockIdx.x * topk + laneIdx] = static_cast<int16_t>(expertIdx);
}
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
Expand All @@ -224,7 +231,8 @@ template <typename InputT, typename BiasT, typename OutputT, typename IdxT>
void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk_indices,
int64_t const num_tokens, int64_t const num_experts, int64_t const n_group,
int64_t const topk_group, int64_t const topk, double const routed_scaling_factor,
bool const launch_with_pdl, cudaStream_t const stream) {
bool const launch_with_pdl, cudaStream_t const stream,
int16_t* routing_replay_out) {
// Check if we can use the optimized deepseek_v3_topk_kernel
bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts);

Expand Down Expand Up @@ -262,7 +270,7 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk

cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
routed_scaling_factor);
routed_scaling_factor, routing_replay_out);
sync_check_cuda_error(stream);
} else {
// TODO: call the generic path (previous implementation) or signal unsupported config.
Expand All @@ -279,7 +287,7 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk
InputT * scores, BiasT * bias, OutputT * topk_values, IdxT * topk_indices, \
int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, \
int64_t const topk_group, int64_t const topk, double const routed_scaling_factor, \
bool const launch_with_pdl, cudaStream_t const stream);
bool const launch_with_pdl, cudaStream_t const stream, int16_t* routing_replay_out);

INSTANTIATE_NOAUX_TC(float, float, float, int32_t);
INSTANTIATE_NOAUX_TC(float, half, float, int32_t);
Expand All @@ -305,7 +313,7 @@ namespace flashinfer::trtllm_dsv3_fused_routing {

void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_group, int64_t topk,
double routed_scaling_factor, TensorView topk_values, TensorView topk_indices,
bool launch_with_pdl) {
bool launch_with_pdl, Optional<TensorView> routing_replay_out) {
auto data_type = scores.dtype();
auto bias_type = bias.dtype();

Expand Down Expand Up @@ -342,6 +350,23 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g
TVM_FFI_ICHECK(encode_dlpack_dtype(topk_indices.dtype()) == int32_code)
<< "topk_indices must have the same dtype as scores";

int16_t* replay_ptr = nullptr;
if (routing_replay_out.has_value()) {
auto replay = routing_replay_out.value();
TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
<< "routing_replay_out must be a CUDA tensor";
TVM_FFI_ICHECK(replay.device().device_id == scores.device().device_id)
<< "routing_replay_out must be on the same device as scores";
TVM_FFI_ICHECK(replay.dim() == 2)
<< "routing_replay_out must be a 2D Tensor [num_tokens, topk]";
TVM_FFI_ICHECK(replay.sizes()[0] == num_tokens)
<< "routing_replay_out dim0 must equal num_tokens";
TVM_FFI_ICHECK(replay.sizes()[1] == topk) << "routing_replay_out dim1 must equal topk";
TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code)
<< "routing_replay_out must be int16 dtype";
replay_ptr = reinterpret_cast<int16_t*>(replay.data_ptr());
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

auto stream = get_stream(scores.device());
using namespace tensorrt_llm::kernels;
switch (encode_dlpack_dtype(data_type)) {
Expand All @@ -353,22 +378,22 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g
reinterpret_cast<half*>(scores.data_ptr()), reinterpret_cast<half*>(bias.data_ptr()),
reinterpret_cast<half*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case float32_code:
invokeNoAuxTc<half, float, half, int32_t>(
reinterpret_cast<half*>(scores.data_ptr()), reinterpret_cast<float*>(bias.data_ptr()),
reinterpret_cast<half*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case bfloat16_code:
invokeNoAuxTc<half, __nv_bfloat16, half, int32_t>(
reinterpret_cast<half*>(scores.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()),
reinterpret_cast<half*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
default:
throw std::invalid_argument(
Expand All @@ -384,22 +409,22 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g
reinterpret_cast<float*>(bias.data_ptr()),
reinterpret_cast<float*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case float16_code:
invokeNoAuxTc<float, half, float, int32_t>(
reinterpret_cast<float*>(scores.data_ptr()), reinterpret_cast<half*>(bias.data_ptr()),
reinterpret_cast<float*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case bfloat16_code:
invokeNoAuxTc<float, __nv_bfloat16, float, int32_t>(
reinterpret_cast<float*>(scores.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()),
reinterpret_cast<float*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
default:
throw std::invalid_argument(
Expand All @@ -416,23 +441,23 @@ void NoAuxTc(TensorView scores, TensorView bias, int64_t n_group, int64_t topk_g
reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case float16_code:
invokeNoAuxTc<__nv_bfloat16, half, __nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()),
reinterpret_cast<half*>(bias.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
case float32_code:
invokeNoAuxTc<__nv_bfloat16, float, __nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.data_ptr()),
reinterpret_cast<float*>(bias.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.data_ptr()), num_tokens, num_experts, n_group,
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream);
topk_group, topk, routed_scaling_factor, launch_with_pdl, stream, replay_ptr);
break;
default:
throw std::invalid_argument(
Expand Down
45 changes: 41 additions & 4 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class FusedMoeLauncher {

int64_t intermediate_size_factor{2};

// Optional routing replay output: [num_tokens, top_k] int16 tensor
Optional<TensorView> routing_replay_out;

public:
// Constructor that initializes all TensorView members
FusedMoeLauncher(const Optional<TensorView>& routing_logits,
Expand All @@ -160,6 +163,10 @@ class FusedMoeLauncher {
activation_type{ActivationType::Swiglu},
intermediate_size_factor{2} {}

void set_routing_replay_out(const Optional<TensorView>& replay_out) {
routing_replay_out = replay_out;
}

protected:
// Initialize common data necessary for later.
// May throw exception from TVM_FFI_ICHECK.
Expand Down Expand Up @@ -375,6 +382,11 @@ class FusedMoeLauncher {
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t routing_stream = get_stream(hidden_states.device());

int16_t* replay_ptr = nullptr;
if (routing_replay_out.has_value()) {
replay_ptr = reinterpret_cast<int16_t*>(routing_replay_out.value().data_ptr());
}

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
Expand All @@ -389,7 +401,7 @@ class FusedMoeLauncher {
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);
static_cast<RoutingMethodType>(routing_method_type), routing_stream, replay_ptr);

check_moe();
prepare_moe(moe_tactic);
Expand Down Expand Up @@ -1076,6 +1088,11 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
cudaStream_t routing_stream = get_stream(hidden_states.device());
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);

int16_t* replay_ptr = nullptr;
if (routing_replay_out.has_value()) {
replay_ptr = reinterpret_cast<int16_t*>(routing_replay_out.value().data_ptr());
}

// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
// When using pre-computed routing, pass nullptr as routing_logits to tell the
Expand All @@ -1094,7 +1111,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);
static_cast<RoutingMethodType>(routing_method_type), routing_stream, replay_ptr);

check_moe();
prepare_moe(moe_tactic);
Expand Down Expand Up @@ -1543,6 +1560,11 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t routing_stream = get_stream(hidden_states.device());

int16_t* replay_ptr = nullptr;
if (routing_replay_out.has_value()) {
replay_ptr = reinterpret_cast<int16_t*>(routing_replay_out.value().data_ptr());
}

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
Expand All @@ -1557,7 +1579,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);
static_cast<RoutingMethodType>(routing_method_type), routing_stream, replay_ptr);

check_moe();
prepare_moe(moe_tactic);
Expand Down Expand Up @@ -1782,7 +1804,8 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
bool enable_pdl, Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
bool enable_pdl, Array<int64_t> config_index, Fp8QuantizationType quantization_type,
Optional<TensorView> routing_replay_out) {
Comment thread
TomerBN-Nvidia marked this conversation as resolved.
// Basic type validation
auto dtype = hidden_states.dtype();

Expand Down Expand Up @@ -1843,6 +1866,19 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
auto const num_tokens = hidden_states.size(0);
auto const hidden_size = hidden_states.size(1);

if (routing_replay_out.has_value()) {
auto replay = routing_replay_out.value();
TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
<< "routing_replay_out must be a CUDA tensor";
TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id)
<< "routing_replay_out must be on the same device as hidden_states";
TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]";
TVM_FFI_ICHECK(replay.size(0) == num_tokens) << "routing_replay_out dim0 must equal num_tokens";
TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k";
TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code)
<< "routing_replay_out must be int16 dtype";
}
Comment on lines +1869 to +1880
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast when replay is requested for unsupported routing methods.

routing_replay_out is validated for shape/dtype/device, but there is no guard that it is only used with RoutingMethodType::DeepSeekV3. Today this can silently accept replay buffers for unsupported routing methods and produce no replay writes.

πŸ”§ Suggested guard
   if (routing_replay_out.has_value()) {
+    TVM_FFI_ICHECK_EQ(static_cast<RoutingMethodType>(routing_method_type),
+                      RoutingMethodType::DeepSeekV3)
+        << "routing_replay_out is currently supported only for DeepSeekV3 routing";
     auto replay = routing_replay_out.value();
     TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
         << "routing_replay_out must be a CUDA tensor";

Based on learnings: In csrc/trtllm_fused_moe_runner.cu, routingReplayOut is intentionally wired only into the DeepSeekV3 routing path.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1869 - 1880, The
validation for routing_replay_out currently checks shape/dtype/device but does
not ensure it's only used with the DeepSeekV3 routing implementation; add a
fail-fast guard that checks the active routing method (compare against
RoutingMethodType::DeepSeekV3) before accepting routing_replay_out and raise a
clear error (TVM_FFI_ICHECK or equivalent) if routing_replay_out.has_value()
while the routing method is not DeepSeekV3 so we don't silently accept
unsupported replay buffers; place this check adjacent to the existing
routing_replay_out block (same scope) referencing routing_replay_out and
RoutingMethodType::DeepSeekV3.


auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type);
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts);
Expand Down Expand Up @@ -1875,6 +1911,7 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
quantization_type);
launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight,
weight_layout);
launcher->set_routing_replay_out(routing_replay_out);

launchers_map[curr_tile_N] = std::move(launcher);
}
Expand Down
6 changes: 6 additions & 0 deletions csrc/trtllm_fused_moe_routing_deepseek.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ __global__ void routingMainKernel(KernelParams params) {
params.mPtrTopKIds == nullptr) {
params.mPtrTopKWeights[idxTopK] = finalScore;
}

// Routing replay: record all top-K selected expert IDs per token.
// Layout: [num_tokens, topK] β€” same indexing as mPtrTopKPacked.
if (laneIdx < params.mTopK && params.mPtrRoutingReplayOut != nullptr) {
params.mPtrRoutingReplayOut[idxTopK] = static_cast<int16_t>(expertIdx);
}
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8,
RoutingMethodType routingMethodType, cudaStream_t stream) {
RoutingMethodType routingMethodType, cudaStream_t stream,
int16_t* routingReplayOut) {
if (routingMethodType == RoutingMethodType::DeepSeekV3) {
Comment thread
TomerBN-Nvidia marked this conversation as resolved.
FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22");
FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4");
Expand Down Expand Up @@ -98,6 +99,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mNumLocalExperts = localNumExperts;
routingData.mRouteScale = routedScalingFactor;
routingData.mUseRoutingSoftmax = false;
routingData.mPtrRoutingReplayOut = routingReplayOut;
moe::dev::routing::routingDeepSeek::run(routingData, stream);
} else if (routingMethodType == RoutingMethodType::Llama4) {
FLASHINFER_CHECK(topK == 1, "For Llama routing method, must have topK == 1");
Expand Down
1 change: 1 addition & 0 deletions csrc/tvm_ffi_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ constexpr int64_t float16_code = encode_dlpack_dtype(dl_float16);
constexpr int64_t bfloat16_code = encode_dlpack_dtype(dl_bfloat16);
constexpr int64_t float32_code = encode_dlpack_dtype(dl_float32);
constexpr int64_t uint8_code = encode_dlpack_dtype(dl_uint8);
constexpr int64_t int16_code = encode_dlpack_dtype(dl_int16);
constexpr int64_t int32_code = encode_dlpack_dtype(dl_int32);
constexpr int64_t int64_code = encode_dlpack_dtype(dl_int64);
constexpr int64_t float8_e4m3fn_code = encode_dlpack_dtype(dl_float8_e4m3fn);
Expand Down
Loading