Skip to content

Commit 086e485

Browse files
committed
commit
1 parent 212a3f6 commit 086e485

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -658,13 +658,13 @@ void dispatch(void* packed_recv_x,
658658
cudaStream_t stream,
659659
int phases,
660660
int num_per_channel) {
661-
constexpr int kNumMaxTopK = 9;
661+
constexpr int kNumMaxTopK = 16;
662662
constexpr int NUM_WARPS = 32;
663663
const int dev_id = 0;
664664
int sm_count;
665665
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
666666
const int num_warp_groups = cell_div(num_experts, sm_count);
667-
const auto num_sms = max(sm_count, cell_div(num_experts, num_warp_groups));
667+
668668
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
669669

670670
// Workspace checks
@@ -681,6 +681,7 @@ void dispatch(void* packed_recv_x,
681681
kNumPerChannels,
682682
{DISPATCH_NUM_WARP_GROUPS(num_warp_groups, kNumWarpGroups, { // 1
683683
constexpr int kNumWarpsPerGroup = NUM_WARPS / kNumWarpGroups; // 32
684+
// because of many `warp_id < num_topk` in kernel.
684685
EP_STATIC_ASSERT(
685686
kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup,
686687
"Too many top-k selections");
@@ -695,7 +696,7 @@ void dispatch(void* packed_recv_x,
695696
kHidden,
696697
kNumPerChannels>;
697698
SETUP_LAUNCH_CONFIG(
698-
num_sms, kNumWarpGroups * kNumWarpsPerGroup * 32, stream);
699+
sm_count, kNumWarpGroups * kNumWarpsPerGroup * 32, stream);
699700
LAUNCH_KERNEL(&cfg,
700701
dispatch_func,
701702
packed_recv_x,
@@ -993,14 +994,13 @@ void combine(void* combined_x,
993994
cudaStream_t stream,
994995
int phases,
995996
bool zero_copy) {
996-
constexpr int kNumMaxTopk = 9;
997+
constexpr int kNumMaxTopk = 16;
997998
constexpr int NUM_WARPS = 32;
998999

9991000
const int dev_id = 0;
10001001
int sm_count;
10011002
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
10021003
const int num_warp_groups = cell_div(num_experts, sm_count);
1003-
const auto num_sms = max(sm_count, cell_div(num_experts, num_warp_groups));
10041004

10051005
// Check workspace
10061006
auto atomic_clean_flag = reinterpret_cast<int*>(workspace);
@@ -1015,7 +1015,7 @@ void combine(void* combined_x,
10151015
auto combine_func =
10161016
combine<kNumWarpGroups, kNumWarpsPerGroup, kHidden, kNumMaxTopk>;
10171017
SETUP_LAUNCH_CONFIG(
1018-
num_sms, kNumWarpGroups * kNumWarpsPerGroup * 32, stream);
1018+
sm_count, kNumWarpGroups * kNumWarpsPerGroup * 32, stream);
10191019
LAUNCH_KERNEL(&cfg,
10201020
combine_func,
10211021
combined_x,

0 commit comments

Comments
 (0)