@@ -126,10 +126,10 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
126
126
__syncthreads ();
127
127
}
128
128
129
+ template <bool start, bool need_fence = false >
129
130
__inline__ __device__ void block_barrier (uint32_t ** signals, uint32_t const flag, size_t const local_rank,
130
- size_t const world_size, int const tidx, int const bidx, int const grid_size,
131
- bool start = true , bool need_fence = false ) {
132
- if (!start) {
131
+ size_t const world_size, int const tidx, int const bidx, int const grid_size) {
132
+ if constexpr (!start) {
133
133
__syncthreads ();
134
134
}
135
135
// After this function, the block of id == bidx of each GPU has reached the barrier
@@ -141,22 +141,16 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
141
141
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
142
142
uint32_t flag_block_offset = world_size + bidx * world_size;
143
143
144
- if (flag % 2 == 1 ) {
145
- flag_block_offset += (grid_size + 1 ) * world_size;
146
- }
144
+ flag_block_offset += (grid_size + 1 ) * world_size * (flag % 2 );
147
145
148
- if (need_fence) {
149
- st_flag_release (flag, signals[tidx] + flag_block_offset + local_rank);
150
- } else {
151
- st_flag_volatile (flag, signals[tidx] + flag_block_offset + local_rank);
152
- }
153
- // Blocks check that corresponding blocks on other GPUs have also set the flag
154
146
uint32_t * peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
155
-
156
- if (need_fence) {
147
+ // Blocks check that corresponding blocks on other GPUs have also set the flag
148
+ if constexpr (need_fence) {
149
+ st_flag_release (flag, signals[tidx] + flag_block_offset + local_rank);
157
150
while (ld_flag_acquire (peer_barrier_d) != flag) {
158
151
}
159
152
} else {
153
+ st_flag_volatile (flag, signals[tidx] + flag_block_offset + local_rank);
160
154
while (ld_flag_volatile (peer_barrier_d) != flag) {
161
155
}
162
156
}
@@ -165,7 +159,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
165
159
__syncthreads ();
166
160
}
167
161
168
- template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
162
+ template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true >
169
163
static __global__ void oneShotAllReduceKernel (AllReduceParams params) {
170
164
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
171
165
// The message is partitioned into chunks as detailed below:
@@ -193,6 +187,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
193
187
194
188
int const bidx = blockIdx .x ;
195
189
int const tidx = threadIdx .x ;
190
+ int const grid_size = gridDim .x ;
196
191
197
192
// The number of elements packed into one for comms
198
193
static constexpr int NUM_ELTS = 16 / sizeof (T);
@@ -201,26 +196,31 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
201
196
using PackedStruct = typename PackedOn16Bytes<T>::Type;
202
197
203
198
// The source pointers. Distributed round-robin for the different warps.
204
- T const * buffers[RANKS_PER_NODE] ;
205
-
199
+ auto peer_comm_buffer_ptrs = params. peer_comm_buffer_ptrs -> ptrs ;
200
+ T* local_shared_buffer = reinterpret_cast <T*>(peer_comm_buffer_ptrs[params. local_rank ]);
206
201
// Start and end offsets of the thread
207
202
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
208
203
size_t chunk_end = std::min ((bidx + 1 ) * params.elts_per_block , params.elts_per_rank );
209
- #pragma unroll
210
- for (int ii = 0 ; ii < RANKS_PER_NODE; ++ii) {
211
- int rank = (params.local_rank + ii) % RANKS_PER_NODE;
212
- buffers[ii] = reinterpret_cast <T*>(params.peer_comm_buffer_ptrs [rank]);
213
- }
214
204
215
- multi_gpu_barrier (params.peer_barrier_ptrs_in , params.barrier_flag , params.local_rank , RANKS_PER_NODE, tidx, bidx);
205
+ if constexpr (COPY_INPUT) {
206
+ T const * local_input_buffer = reinterpret_cast <T const *>(params.local_input_buffer_ptr );
207
+ // Copy from local buffer to shareable buffer
208
+ for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim .x * NUM_ELTS) {
209
+ *reinterpret_cast <int4 *>(&local_shared_buffer[iter_offset]) =
210
+ *reinterpret_cast <int4 const *>(&local_input_buffer[iter_offset]);
211
+ }
212
+ }
213
+ // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
214
+ block_barrier<true >(params.peer_barrier_ptrs_in , params.barrier_flag , params.local_rank , RANKS_PER_NODE, tidx, bidx,
215
+ grid_size);
216
216
217
217
// Each block accumulates the values from the different GPUs on the same node.
218
218
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim .x * NUM_ELTS) {
219
219
// Iterate over the different ranks/devices on the node to load the values.
220
220
PackedStruct vals[RANKS_PER_NODE];
221
221
#pragma unroll
222
222
for (int ii = 0 ; ii < RANKS_PER_NODE; ++ii) {
223
- vals[ii].packed = *reinterpret_cast <int4 const *>(&buffers [ii][iter_offset]);
223
+ vals[ii].packed = *reinterpret_cast <int4 const *>(&((T*)peer_comm_buffer_ptrs [ii]) [iter_offset]);
224
224
}
225
225
226
226
// Sum the values from the different ranks.
@@ -229,16 +229,15 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
229
229
#pragma unroll
230
230
for (int rank = 0 ; rank < RANKS_PER_NODE; ++rank) {
231
231
// Always reduce from rank 0 to ensure stable reduce order.
232
- int ii = (rank + RANKS_PER_NODE - params.local_rank ) % RANKS_PER_NODE;
233
- sums.packed = add128b (sums, vals[ii]);
232
+ sums.packed = add128b (sums, vals[rank]);
234
233
}
235
234
236
235
// Store to the destination buffer.
237
236
*reinterpret_cast <int4 *>(&reinterpret_cast <T*>(params.local_output_buffer_ptr )[iter_offset]) = sums.packed ;
238
237
}
239
238
}
240
239
241
- template <typename T, int RANKS_PER_NODE>
240
+ template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true >
242
241
static __global__ void __launch_bounds__ (512 , 1 ) twoShotAllReduceKernel(AllReduceParams params) {
243
242
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
244
243
// The message is partitioned into chunks as detailed below:
@@ -286,20 +285,24 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
286
285
static constexpr int PACKED_ELTS = 16 / sizeof (T);
287
286
using PackedType = typename PackedOn16Bytes<T>::Type;
288
287
289
- T* local_shared_buffer = reinterpret_cast <T*>(params.peer_comm_buffer_ptrs [params.local_rank ]);
288
+ T const * local_input_buffer = reinterpret_cast <T const *>(params.local_input_buffer_ptr );
289
+ auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs ->ptrs ;
290
+ T* local_shared_buffer = reinterpret_cast <T*>(peer_comm_buffer_ptrs[params.local_rank ]);
290
291
T* local_output_buffer = reinterpret_cast <T*>(params.local_output_buffer_ptr );
291
292
292
293
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
293
294
size_t const chunk_end = min (chunk_start + params.elts_per_block , params.elts_per_rank );
294
295
295
296
T* buffers[RANKS_PER_NODE];
297
+ T* buffers_unorder[RANKS_PER_NODE];
296
298
int ranks[RANKS_PER_NODE];
297
299
#pragma unroll
298
300
for (int ii = 0 ; ii < RANKS_PER_NODE; ++ii) {
299
301
// A mapping of the ranks to scatter reads as much as possible
300
302
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
301
303
ranks[ii] = rank;
302
- buffers[ii] = reinterpret_cast <T*>(params.peer_comm_buffer_ptrs [rank]);
304
+ buffers[ii] = reinterpret_cast <T*>(peer_comm_buffer_ptrs[rank]);
305
+ buffers_unorder[ii] = reinterpret_cast <T*>(peer_comm_buffer_ptrs[ii]);
303
306
}
304
307
305
308
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
@@ -308,8 +311,22 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
308
311
#endif
309
312
#endif
310
313
311
- block_barrier (params.peer_barrier_ptrs_in , params.barrier_flag , params.local_rank , RANKS_PER_NODE, tidx, bidx,
312
- grid_size);
314
+ if constexpr (COPY_INPUT) {
315
+ // Copy all blocks from local buffer to shareable buffer
316
+ for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim .x * PACKED_ELTS) {
317
+ #pragma unroll
318
+ for (int ii = 0 ; ii < RANKS_PER_NODE; ++ii) {
319
+ size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
320
+ if (offset_rank >= params.elts_total ) {
321
+ continue ;
322
+ }
323
+ *reinterpret_cast <int4 *>(&local_shared_buffer[offset_rank]) =
324
+ *reinterpret_cast <int4 const *>(&local_input_buffer[offset_rank]);
325
+ }
326
+ }
327
+ }
328
+ block_barrier<true >(params.peer_barrier_ptrs_in , params.barrier_flag , params.local_rank , RANKS_PER_NODE, tidx, bidx,
329
+ grid_size);
313
330
314
331
// Each block accumulates the values from the different GPUs on the same node.
315
332
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim .x * PACKED_ELTS) {
@@ -319,7 +336,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
319
336
PackedType vals[RANKS_PER_NODE];
320
337
#pragma unroll
321
338
for (int ii = 0 ; ii < RANKS_PER_NODE; ++ii) {
322
- vals[ii].packed = *reinterpret_cast <int4 const *>(&buffers [ii][responsible_block_offset]);
339
+ vals[ii].packed = *reinterpret_cast <int4 const *>(&buffers_unorder [ii][responsible_block_offset]);
323
340
}
324
341
325
342
// Sum the values from the different ranks.
@@ -328,16 +345,19 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
328
345
#pragma unroll
329
346
for (int rank = 0 ; rank < RANKS_PER_NODE; ++rank) {
330
347
// Always reduce from rank 0 to ensure stable reduce order.
331
- int ii = (rank + RANKS_PER_NODE - params.local_rank ) % RANKS_PER_NODE;
332
- sums.packed = add128b (sums, vals[ii]);
348
+ sums.packed = add128b (sums, vals[rank]);
333
349
}
334
350
335
- // Store to the local buffer.
336
- *reinterpret_cast <int4 *>(&local_shared_buffer[responsible_block_offset]) = sums.packed ;
351
+ // Store to the local buffer or tmp buffer
352
+ if constexpr (COPY_INPUT) {
353
+ *reinterpret_cast <int4 *>(&local_shared_buffer[responsible_block_offset]) = sums.packed ;
354
+ } else {
355
+ *reinterpret_cast <int4 *>(¶ms.tmp_result_buffers [params.local_rank ][responsible_block_offset]) = sums.packed ;
356
+ }
337
357
}
338
358
339
- block_barrier (params.peer_barrier_ptrs_out , params.barrier_flag , params.local_rank , RANKS_PER_NODE, tidx, bidx ,
340
- grid_size, false , true );
359
+ block_barrier< false , true > (params.peer_barrier_ptrs_out , params.barrier_flag , params.local_rank , RANKS_PER_NODE, tidx,
360
+ bidx, grid_size );
341
361
342
362
// Gather all needed elts from other intra-node ranks
343
363
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim .x * PACKED_ELTS) {
@@ -348,8 +368,13 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
348
368
if (offset_rank >= params.elts_total ) {
349
369
continue ;
350
370
}
351
-
352
- *reinterpret_cast <int4 *>(&local_output_buffer[offset_rank]) = *reinterpret_cast <int4 *>(&buffers[ii][offset_rank]);
371
+ if constexpr (COPY_INPUT) {
372
+ *reinterpret_cast <int4 *>(&local_output_buffer[offset_rank]) =
373
+ *reinterpret_cast <int4 *>(&buffers[ii][offset_rank]);
374
+ } else {
375
+ *reinterpret_cast <int4 *>(&local_output_buffer[offset_rank]) =
376
+ *reinterpret_cast <int4 *>(¶ms.tmp_result_buffers [ranks[ii]][offset_rank]);
377
+ }
353
378
}
354
379
}
355
380
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
@@ -417,48 +442,50 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
417
442
418
443
// //////////////////////////////////////////////////////////////////////////////////////////////////
419
444
420
- template <typename T, int RANKS_PER_NODE>
445
+ template <typename T, int RANKS_PER_NODE, bool COPY_INPUT >
421
446
void dispatchARKernels (AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
422
447
cudaStream_t stream) {
423
448
switch (algo) {
424
449
case AllReduceStrategyType::ONESHOT: {
425
- oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0 , stream>>> (param);
450
+ oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT ><<<blocks_per_grid, threads_per_block, 0 , stream>>> (param);
426
451
break ;
427
452
}
428
453
case AllReduceStrategyType::TWOSHOT: {
429
- twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0 , stream>>> (param);
454
+ twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT ><<<blocks_per_grid, threads_per_block, 0 , stream>>> (param);
430
455
break ;
431
456
}
432
457
}
433
458
}
434
459
435
- template <typename T>
436
- void invokeOneOrTwoShotAllReduceKernel (AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
437
- void * buffer = reinterpret_cast <void *>(param.peer_comm_buffer_ptrs [param.rank ]);
438
- void * local_inp_buffer = param.local_input_buffer_ptr ;
439
- CHECK_CUDA_SUCCESS (
440
- cudaMemcpyAsync (buffer, local_inp_buffer, param.elts_total * param.elts_size , cudaMemcpyDeviceToDevice, stream));
441
-
442
- CHECK_CUDA_SUCCESS (cudaGetLastError ());
443
-
460
+ template <typename T, bool COPY_INPUT>
461
+ void dispatchARKernelsCopyInput (AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
444
462
size_t elts_per_thread = 16 / sizeof (T);
445
463
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig (strat, param, elts_per_thread);
446
464
switch (param.ranks_per_node ) {
447
465
case 2 :
448
- dispatchARKernels<T, 2 >(strat, param, blocks_per_grid, threads_per_block, stream);
466
+ dispatchARKernels<T, 2 , COPY_INPUT >(strat, param, blocks_per_grid, threads_per_block, stream);
449
467
break ;
450
468
case 4 :
451
- dispatchARKernels<T, 4 >(strat, param, blocks_per_grid, threads_per_block, stream);
469
+ dispatchARKernels<T, 4 , COPY_INPUT >(strat, param, blocks_per_grid, threads_per_block, stream);
452
470
break ;
453
471
case 6 :
454
- dispatchARKernels<T, 6 >(strat, param, blocks_per_grid, threads_per_block, stream);
472
+ dispatchARKernels<T, 6 , COPY_INPUT >(strat, param, blocks_per_grid, threads_per_block, stream);
455
473
break ;
456
474
case 8 :
457
- dispatchARKernels<T, 8 >(strat, param, blocks_per_grid, threads_per_block, stream);
475
+ dispatchARKernels<T, 8 , COPY_INPUT >(strat, param, blocks_per_grid, threads_per_block, stream);
458
476
break ;
459
477
default :
460
478
break ;
461
479
}
480
+ }
481
+
482
+ template <typename T>
483
+ void invokeOneOrTwoShotAllReduceKernel (AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
484
+ if (param.is_capturing ) {
485
+ dispatchARKernelsCopyInput<T, false >(strat, param, stream);
486
+ } else {
487
+ dispatchARKernelsCopyInput<T, true >(strat, param, stream);
488
+ }
462
489
CHECK_CUDA_SUCCESS (cudaGetLastError ());
463
490
}
464
491
0 commit comments