-
Notifications
You must be signed in to change notification settings - Fork 937
feat: support routing replay in trtllm_fp8_block_scale_moe and fused_topk_deepseek #2685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -134,6 +134,9 @@ class FusedMoeLauncher { | |
|
|
||
| int64_t intermediate_size_factor{2}; | ||
|
|
||
| // Optional routing replay output: [num_tokens, top_k] int16 tensor | ||
| Optional<TensorView> routing_replay_out; | ||
|
|
||
| public: | ||
| // Constructor that initializes all TensorView members | ||
| FusedMoeLauncher(const Optional<TensorView>& routing_logits, | ||
|
|
@@ -160,6 +163,10 @@ class FusedMoeLauncher { | |
| activation_type{ActivationType::Swiglu}, | ||
| intermediate_size_factor{2} {} | ||
|
|
||
| void set_routing_replay_out(const Optional<TensorView>& replay_out) { | ||
| routing_replay_out = replay_out; | ||
| } | ||
|
|
||
| protected: | ||
| // Initialize common data necessary for later. | ||
| // May throw exception from TVM_FFI_ICHECK. | ||
|
|
@@ -375,6 +382,11 @@ class FusedMoeLauncher { | |
| tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); | ||
| cudaStream_t routing_stream = get_stream(hidden_states.device()); | ||
|
|
||
| int16_t* replay_ptr = nullptr; | ||
| if (routing_replay_out.has_value()) { | ||
| replay_ptr = reinterpret_cast<int16_t*>(routing_replay_out.value().data_ptr()); | ||
| } | ||
|
|
||
| routing_runner.run( | ||
| args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, | ||
| args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, | ||
|
|
@@ -389,7 +401,7 @@ class FusedMoeLauncher { | |
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||
| use_routing_scales_on_input, use_deep_seek_fp8, | ||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream, replay_ptr); | ||
|
|
||
| check_moe(); | ||
| prepare_moe(moe_tactic); | ||
|
|
@@ -1076,6 +1088,11 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { | |
| cudaStream_t routing_stream = get_stream(hidden_states.device()); | ||
| tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); | ||
|
|
||
| int16_t* replay_ptr = nullptr; | ||
| if (routing_replay_out.has_value()) { | ||
| replay_ptr = reinterpret_cast<int16_t*>(routing_replay_out.value().data_ptr()); | ||
| } | ||
|
|
||
| // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr | ||
| bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; | ||
| // When using pre-computed routing, pass nullptr as routing_logits to tell the | ||
|
|
@@ -1094,7 +1111,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { | |
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||
| use_routing_scales_on_input, use_deep_seek_fp8, | ||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream, replay_ptr); | ||
|
|
||
| check_moe(); | ||
| prepare_moe(moe_tactic); | ||
|
|
@@ -1543,6 +1560,11 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |
| tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); | ||
| cudaStream_t routing_stream = get_stream(hidden_states.device()); | ||
|
|
||
| int16_t* replay_ptr = nullptr; | ||
| if (routing_replay_out.has_value()) { | ||
| replay_ptr = reinterpret_cast<int16_t*>(routing_replay_out.value().data_ptr()); | ||
| } | ||
|
|
||
| routing_runner.run( | ||
| args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, | ||
| args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, | ||
|
|
@@ -1557,7 +1579,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||
| use_routing_scales_on_input, use_deep_seek_fp8, | ||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream, replay_ptr); | ||
|
|
||
| check_moe(); | ||
| prepare_moe(moe_tactic); | ||
|
|
@@ -1782,7 +1804,8 @@ Array<Tensor> trtllm_fp8_block_scale_moe( | |
| Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size, | ||
| int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor, | ||
| int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, | ||
| bool enable_pdl, Array<int64_t> config_index, Fp8QuantizationType quantization_type) { | ||
| bool enable_pdl, Array<int64_t> config_index, Fp8QuantizationType quantization_type, | ||
| Optional<TensorView> routing_replay_out) { | ||
|
TomerBN-Nvidia marked this conversation as resolved.
|
||
| // Basic type validation | ||
| auto dtype = hidden_states.dtype(); | ||
|
|
||
|
|
@@ -1843,6 +1866,19 @@ Array<Tensor> trtllm_fp8_block_scale_moe( | |
| auto const num_tokens = hidden_states.size(0); | ||
| auto const hidden_size = hidden_states.size(1); | ||
|
|
||
| if (routing_replay_out.has_value()) { | ||
| auto replay = routing_replay_out.value(); | ||
| TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA) | ||
| << "routing_replay_out must be a CUDA tensor"; | ||
| TVM_FFI_ICHECK(replay.device().device_id == hidden_states.device().device_id) | ||
| << "routing_replay_out must be on the same device as hidden_states"; | ||
| TVM_FFI_ICHECK(replay.ndim() == 2) << "routing_replay_out must be 2D [num_tokens, top_k]"; | ||
| TVM_FFI_ICHECK(replay.size(0) == num_tokens) << "routing_replay_out dim0 must equal num_tokens"; | ||
| TVM_FFI_ICHECK(replay.size(1) == top_k) << "routing_replay_out dim1 must equal top_k"; | ||
| TVM_FFI_ICHECK(encode_dlpack_dtype(replay.dtype()) == int16_code) | ||
| << "routing_replay_out must be int16 dtype"; | ||
| } | ||
|
Comment on lines
+1869
to
+1880
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fail fast when replay is requested for unsupported routing methods.
π§ Suggested guard if (routing_replay_out.has_value()) {
+ TVM_FFI_ICHECK_EQ(static_cast<RoutingMethodType>(routing_method_type),
+ RoutingMethodType::DeepSeekV3)
+ << "routing_replay_out is currently supported only for DeepSeekV3 routing";
auto replay = routing_replay_out.value();
TVM_FFI_ICHECK(replay.device().device_type == kDLCUDA)
<< "routing_replay_out must be a CUDA tensor";Based on learnings: In π€ Prompt for AI Agents |
||
|
|
||
| auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type); | ||
| std::set<int32_t> selected_tile_nums = | ||
| computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts); | ||
|
|
@@ -1875,6 +1911,7 @@ Array<Tensor> trtllm_fp8_block_scale_moe( | |
| quantization_type); | ||
| launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, | ||
| weight_layout); | ||
| launcher->set_routing_replay_out(routing_replay_out); | ||
|
|
||
| launchers_map[curr_tile_N] = std::move(launcher); | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.