@@ -290,9 +290,9 @@ __device__ void vectorized_dispatch(uint8_t const* src_ptr, int bytes_per_token,
290290}
291291
292292__global__ void moeA2APrepareDispatchKernel (int * send_counters, int * local_token_counter,
293- int ep_size, uint32_t * flag_val_ptr) {
293+ int ep_size, uint32_t * flag_val_ptr, bool enable_pdl ) {
294294#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
295- cudaGridDependencySynchronize ();
295+ if (enable_pdl) cudaGridDependencySynchronize ();
296296 cudaTriggerProgrammaticLaunchCompletion ();
297297#endif
298298 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -322,7 +322,7 @@ __global__ void moeA2ADispatchKernel(
322322 const DispatchKernelPointers ptrs, // Struct containing all kernel pointers
323323 int num_payloads, // Number of payloads
324324 int max_tokens_per_rank, // Maximum tokens per rank
325- int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank) {
325+ int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank, bool enable_pdl ) {
326326 int thread_idx = ThreadingPolicy::offset ();
327327 int local_token_idx = ThreadingPolicy::token_idx ();
328328
@@ -332,14 +332,14 @@ __global__ void moeA2ADispatchKernel(
332332 // synchronization. Other threads should return.
333333 if (local_token_idx > 0 ) return ;
334334#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
335- cudaGridDependencySynchronize ();
335+ if (enable_pdl) cudaGridDependencySynchronize ();
336336#endif
337337 } else {
338338 // Threads that do not have a token to process should return.
339339 if (local_token_idx >= local_num_tokens) return ;
340340
341341#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
342- cudaGridDependencySynchronize ();
342+ if (enable_pdl) cudaGridDependencySynchronize ();
343343#endif
344344
345345 // Prepare per-policy shared-memory tiles for this token
@@ -491,9 +491,9 @@ __global__ void moeA2ADispatchKernel(
491491}
492492
493493void moe_a2a_prepare_dispatch_launch (MoeA2ADispatchParams const & params) {
494- launchWithPdlWhenEnabled (" moeA2APrepareDispatchKernel" , moeA2APrepareDispatchKernel, 1 ,
495- params.ep_size , 0 , params.stream , params.send_counters , params. local_token_counter ,
496- params.ep_size , params.flag_val );
494+ launchWithPdlWhenEnabled (" moeA2APrepareDispatchKernel" , params. enable_pdl ,
495+ moeA2APrepareDispatchKernel, 1 , params.ep_size , 0 , params.stream , params.send_counters ,
496+ params.local_token_counter , params. ep_size , params.flag_val , params. enable_pdl );
497497}
498498
499499// ============================================================================
@@ -552,10 +552,10 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) {
552552 int shared_bytes = 2 * params.top_k * (int )sizeof (int );
553553 SWITCH_TOP_K (params.top_k , TOP_K, {
554554 auto kernel_fn = moeA2ADispatchKernel<BlockPolicy, TOP_K>;
555- launchWithPdlWhenEnabled (" moeA2ADispatchKernel" , kernel_fn, grid_size, kBlockSize ,
556- shared_bytes, params.stream , params.token_selected_experts , kernel_ptrs,
555+ launchWithPdlWhenEnabled (" moeA2ADispatchKernel" , params. enable_pdl , kernel_fn, grid_size ,
556+ kBlockSize , shared_bytes, params.stream , params.token_selected_experts , kernel_ptrs,
557557 params.num_payloads , params.max_tokens_per_rank , params.local_num_tokens , params.ep_rank ,
558- params.ep_size , params.num_experts_per_rank );
558+ params.ep_size , params.num_experts_per_rank , params. enable_pdl );
559559 })
560560 } else {
561561 int grid_size = ceilDiv (params.local_num_tokens , kWarpsPerBlock );
@@ -567,10 +567,10 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) {
567567 int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int )sizeof (int );
568568 SWITCH_TOP_K (params.top_k , TOP_K, {
569569 auto kernel_fn = moeA2ADispatchKernel<WarpPolicy, TOP_K>;
570- launchWithPdlWhenEnabled (" moeA2ADispatchKernel" , kernel_fn, grid_size, kBlockSize ,
571- shared_bytes, params.stream , params.token_selected_experts , kernel_ptrs,
570+ launchWithPdlWhenEnabled (" moeA2ADispatchKernel" , params. enable_pdl , kernel_fn, grid_size ,
571+ kBlockSize , shared_bytes, params.stream , params.token_selected_experts , kernel_ptrs,
572572 params.num_payloads , params.max_tokens_per_rank , params.local_num_tokens , params.ep_rank ,
573- params.ep_size , params.num_experts_per_rank );
573+ params.ep_size , params.num_experts_per_rank , params. enable_pdl );
574574 })
575575 }
576576}
@@ -919,9 +919,10 @@ template <typename ThreadingPolicy, bool LOW_PRECISION, typename SrcT>
919919__global__ void moeA2APrepareCombineKernel (uint8_t * recv_buffer_bytes, void const * payload,
920920 int elements_per_token, int ep_size,
921921 int max_tokens_per_rank, uint32_t * flag_val_ptr,
922- int const * recv_counters, int stride_per_token) {
922+ int const * recv_counters, int stride_per_token,
923+ bool enable_pdl) {
923924#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
924- cudaGridDependencySynchronize ();
925+ if (enable_pdl) cudaGridDependencySynchronize ();
925926 cudaTriggerProgrammaticLaunchCompletion ();
926927#endif
927928
@@ -977,7 +978,7 @@ template <typename T, typename ThreadingPolicy, int TOP_K>
977978__global__ void moeA2ACombineKernel (
978979 const CombineKernelPointers ptrs, // Combine-specific struct, src_data_ptrs[0] is output
979980 int max_tokens_per_rank, int elements_per_token, int stride_per_token, int local_num_tokens,
980- int rank_id, int ep_size) {
981+ int rank_id, int ep_size, bool enable_pdl ) {
981982 int local_token_idx = ThreadingPolicy::token_idx ();
982983 int const size_per_token = elements_per_token * static_cast <int >(sizeof (T));
983984
@@ -992,7 +993,7 @@ __global__ void moeA2ACombineKernel(
992993 }
993994
994995#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
995- cudaGridDependencySynchronize ();
996+ if (enable_pdl) cudaGridDependencySynchronize ();
996997#endif
997998
998999#if !DISABLE_SYNC_FOR_PROFILING
@@ -1108,9 +1109,10 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) {
11081109 params.one_block_per_token
11091110 ? moeA2APrepareCombineKernel<BlockPolicy, LOW_PRECISION, SrcT>
11101111 : moeA2APrepareCombineKernel<WarpPolicy, LOW_PRECISION, SrcT>;
1111- launchWithPdlWhenEnabled (" moeA2APrepareCombineKernel" , kernel_fn, grid, kBlockSize , 0 ,
1112- params.stream , recv_buffer_bytes, payload, params.elements_per_token , params.ep_size ,
1113- params.max_tokens_per_rank , params.flag_val , params.recv_counters , stride_per_token);
1112+ launchWithPdlWhenEnabled (" moeA2APrepareCombineKernel" , params.enable_pdl , kernel_fn, grid,
1113+ kBlockSize , 0 , params.stream , recv_buffer_bytes, payload, params.elements_per_token ,
1114+ params.ep_size , params.max_tokens_per_rank , params.flag_val , params.recv_counters ,
1115+ stride_per_token, params.enable_pdl );
11141116 });
11151117 });
11161118}
@@ -1184,9 +1186,10 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) {
11841186 SWITCH_POLICY (params.one_block_per_token , Policy, {
11851187 SWITCH_TOP_K (params.top_k , TOP_K, {
11861188 auto kernel_fn = moeA2ACombineKernel<TKernelType, Policy, TOP_K>;
1187- launchWithPdlWhenEnabled (" moeA2ACombineKernel" , kernel_fn, grid, kBlockSize , 0 ,
1188- params.stream , kernel_ptrs, params.max_tokens_per_rank , params.elements_per_token ,
1189- stride_per_token, params.local_num_tokens , params.ep_rank , params.ep_size );
1189+ launchWithPdlWhenEnabled (" moeA2ACombineKernel" , params.enable_pdl , kernel_fn, grid,
1190+ kBlockSize , 0 , params.stream , kernel_ptrs, params.max_tokens_per_rank ,
1191+ params.elements_per_token , stride_per_token, params.local_num_tokens , params.ep_rank ,
1192+ params.ep_size , params.enable_pdl );
11901193 });
11911194 });
11921195 });
@@ -1196,9 +1199,9 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) {
11961199__global__ void moeA2ASanitizeExpertIdsKernel (int32_t * expert_ids_ptr,
11971200 int32_t const * recv_counters_ptr, int ep_size,
11981201 int max_tokens_per_rank, int top_k,
1199- int32_t invalid_id) {
1202+ int32_t invalid_id, bool enable_pdl ) {
12001203#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1201- cudaGridDependencySynchronize ();
1204+ if (enable_pdl) cudaGridDependencySynchronize ();
12021205 cudaTriggerProgrammaticLaunchCompletion ();
12031206#endif
12041207 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -1218,13 +1221,13 @@ __global__ void moeA2ASanitizeExpertIdsKernel(int32_t* expert_ids_ptr,
12181221
12191222void moe_a2a_sanitize_expert_ids_launch (int32_t * expert_ids, int32_t const * recv_counters,
12201223 int32_t invalid_id, int ep_size, int max_tokens_per_rank,
1221- int top_k, cudaStream_t stream) {
1224+ int top_k, cudaStream_t stream, bool enable_pdl ) {
12221225 constexpr int kBlockSize = 256 ;
12231226 int total_tokens = ep_size * max_tokens_per_rank;
12241227 int grid = ceilDiv (total_tokens, kBlockSize );
1225- launchWithPdlWhenEnabled (" moeA2ASanitizeExpertIdsKernel" , moeA2ASanitizeExpertIdsKernel, grid ,
1226- kBlockSize , 0 , stream, expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k ,
1227- invalid_id);
1228+ launchWithPdlWhenEnabled (" moeA2ASanitizeExpertIdsKernel" , enable_pdl ,
1229+ moeA2ASanitizeExpertIdsKernel, grid, kBlockSize , 0 , stream, expert_ids, recv_counters,
1230+ ep_size, max_tokens_per_rank, top_k, invalid_id, enable_pdl );
12281231}
12291232
12301233} // namespace tensorrt_llm::kernels::moe_alltoall
0 commit comments