@@ -658,13 +658,13 @@ void dispatch(void* packed_recv_x,
658658 cudaStream_t stream,
659659 int phases,
660660 int num_per_channel) {
661- constexpr int kNumMaxTopK = 9 ;
661+ constexpr int kNumMaxTopK = 16 ;
662662 constexpr int NUM_WARPS = 32 ;
663663 const int dev_id = 0 ;
664664 int sm_count;
665665 cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
666666 const int num_warp_groups = cell_div (num_experts, sm_count);
667- const auto num_sms = max (sm_count, cell_div (num_experts, num_warp_groups));
667+
668668 EP_HOST_ASSERT (num_topk <= kNumMaxTopK );
669669
670670 // Workspace checks
@@ -681,6 +681,7 @@ void dispatch(void* packed_recv_x,
681681 kNumPerChannels ,
682682 {DISPATCH_NUM_WARP_GROUPS (num_warp_groups, kNumWarpGroups , { // 1
683683 constexpr int kNumWarpsPerGroup = NUM_WARPS / kNumWarpGroups ; // 32
684+ // because of many `warp_id < num_topk` in kernel.
684685 EP_STATIC_ASSERT (
685686 kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup ,
686687 " Too many top-k selections" );
@@ -695,7 +696,7 @@ void dispatch(void* packed_recv_x,
695696 kHidden ,
696697 kNumPerChannels >;
697698 SETUP_LAUNCH_CONFIG (
698- num_sms , kNumWarpGroups * kNumWarpsPerGroup * 32 , stream);
699+ sm_count , kNumWarpGroups * kNumWarpsPerGroup * 32 , stream);
699700 LAUNCH_KERNEL (&cfg,
700701 dispatch_func,
701702 packed_recv_x,
@@ -993,14 +994,13 @@ void combine(void* combined_x,
993994 cudaStream_t stream,
994995 int phases,
995996 bool zero_copy) {
996- constexpr int kNumMaxTopk = 9 ;
997+ constexpr int kNumMaxTopk = 16 ;
997998 constexpr int NUM_WARPS = 32 ;
998999
9991000 const int dev_id = 0 ;
10001001 int sm_count;
10011002 cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
10021003 const int num_warp_groups = cell_div (num_experts, sm_count);
1003- const auto num_sms = max (sm_count, cell_div (num_experts, num_warp_groups));
10041004
10051005 // Check workspace
10061006 auto atomic_clean_flag = reinterpret_cast <int *>(workspace);
@@ -1015,7 +1015,7 @@ void combine(void* combined_x,
10151015 auto combine_func =
10161016 combine<kNumWarpGroups , kNumWarpsPerGroup , kHidden , kNumMaxTopk >;
10171017 SETUP_LAUNCH_CONFIG (
1018- num_sms , kNumWarpGroups * kNumWarpsPerGroup * 32 , stream);
1018+ sm_count , kNumWarpGroups * kNumWarpsPerGroup * 32 , stream);
10191019 LAUNCH_KERNEL (&cfg,
10201020 combine_func,
10211021 combined_x,
0 commit comments