Skip to content

Commit 6cb3974

Browse files
authored
optimize custom allreduce kernel (#2904)
1 parent f65c13b commit 6cb3974

File tree

9 files changed

+244
-80
lines changed

9 files changed

+244
-80
lines changed

sgl-kernel/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "sgl-kernel"
7-
version = "0.0.2.post12"
7+
version = "0.0.2.post13"
88
description = "Kernel Library for SGLang"
99
readme = "README.md"
1010
requires-python = ">=3.8"

sgl-kernel/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def update_wheel_platform_tag():
4040
"-U__CUDA_NO_HALF2_OPERATORS__",
4141
]
4242
cxx_flags = ["-O3"]
43-
libraries = ["c10", "torch", "torch_python"]
43+
libraries = ["c10", "torch", "torch_python", "cuda"]
4444
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
4545
ext_modules = [
4646
CUDAExtension(

sgl-kernel/src/sgl-kernel/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from sgl_kernel.ops import (
22
custom_dispose,
33
custom_reduce,
4+
get_graph_buffer_ipc_meta,
45
init_custom_reduce,
56
int8_scaled_mm,
67
moe_align_block_size,
8+
register_graph_buffers,
79
sampling_scaling_penalties,
810
)
911

@@ -14,4 +16,6 @@
1416
"custom_reduce",
1517
"int8_scaled_mm",
1618
"sampling_scaling_penalties",
19+
"get_graph_buffer_ipc_meta",
20+
"register_graph_buffers",
1721
]

sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
// trt_reduce
44
using fptr_t = int64_t;
5-
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
6-
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
5+
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
6+
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
7+
const std::vector<fptr_t>& barrier_out);
78
void dispose(fptr_t _fa);
89
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
10+
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
11+
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
12+
const std::vector<std::vector<int64_t>>& offsets);
913

1014
// moe_align_block_size
1115
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
@@ -25,6 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2529
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
2630
m.def("dispose", &dispose, "dispose custom allreduce meta");
2731
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
32+
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta");
33+
m.def("register_graph_buffers", &register_graph_buffers, "custom all reduce register graph buffers");
2834
// moe_align_block_size
2935
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
3036
// sampling_scaling_penalties

sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu

Lines changed: 83 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
126126
__syncthreads();
127127
}
128128

129+
template <bool start, bool need_fence = false>
129130
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
130-
size_t const world_size, int const tidx, int const bidx, int const grid_size,
131-
bool start = true, bool need_fence = false) {
132-
if (!start) {
131+
size_t const world_size, int const tidx, int const bidx, int const grid_size) {
132+
if constexpr (!start) {
133133
__syncthreads();
134134
}
135135
// After this function, the block of id == bidx of each GPU has reached the barrier
@@ -141,22 +141,16 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
141141
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
142142
uint32_t flag_block_offset = world_size + bidx * world_size;
143143

144-
if (flag % 2 == 1) {
145-
flag_block_offset += (grid_size + 1) * world_size;
146-
}
144+
flag_block_offset += (grid_size + 1) * world_size * (flag % 2);
147145

148-
if (need_fence) {
149-
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
150-
} else {
151-
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
152-
}
153-
// Blocks check that corresponding blocks on other GPUs have also set the flag
154146
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
155-
156-
if (need_fence) {
147+
// Blocks check that corresponding blocks on other GPUs have also set the flag
148+
if constexpr (need_fence) {
149+
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
157150
while (ld_flag_acquire(peer_barrier_d) != flag) {
158151
}
159152
} else {
153+
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
160154
while (ld_flag_volatile(peer_barrier_d) != flag) {
161155
}
162156
}
@@ -165,7 +159,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
165159
__syncthreads();
166160
}
167161

168-
template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
162+
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
169163
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
170164
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
171165
// The message is partitioned into chunks as detailed below:
@@ -193,6 +187,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
193187

194188
int const bidx = blockIdx.x;
195189
int const tidx = threadIdx.x;
190+
int const grid_size = gridDim.x;
196191

197192
// The number of elements packed into one for comms
198193
static constexpr int NUM_ELTS = 16 / sizeof(T);
@@ -201,26 +196,31 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
201196
using PackedStruct = typename PackedOn16Bytes<T>::Type;
202197

203198
// The source pointers. Distributed round-robin for the different warps.
204-
T const* buffers[RANKS_PER_NODE];
205-
199+
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
200+
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
206201
// Start and end offsets of the thread
207202
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
208203
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
209-
#pragma unroll
210-
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
211-
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
212-
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
213-
}
214204

215-
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
205+
if constexpr (COPY_INPUT) {
206+
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
207+
// Copy from local buffer to shareable buffer
208+
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
209+
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset]) =
210+
*reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
211+
}
212+
}
213+
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
214+
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
215+
grid_size);
216216

217217
// Each block accumulates the values from the different GPUs on the same node.
218218
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
219219
// Iterate over the different ranks/devices on the node to load the values.
220220
PackedStruct vals[RANKS_PER_NODE];
221221
#pragma unroll
222222
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
223-
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][iter_offset]);
223+
vals[ii].packed = *reinterpret_cast<int4 const*>(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]);
224224
}
225225

226226
// Sum the values from the different ranks.
@@ -229,16 +229,15 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
229229
#pragma unroll
230230
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
231231
// Always reduce from rank 0 to ensure stable reduce order.
232-
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
233-
sums.packed = add128b(sums, vals[ii]);
232+
sums.packed = add128b(sums, vals[rank]);
234233
}
235234

236235
// Store to the destination buffer.
237236
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
238237
}
239238
}
240239

241-
template <typename T, int RANKS_PER_NODE>
240+
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
242241
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
243242
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
244243
// The message is partitioned into chunks as detailed below:
@@ -286,20 +285,24 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
286285
static constexpr int PACKED_ELTS = 16 / sizeof(T);
287286
using PackedType = typename PackedOn16Bytes<T>::Type;
288287

289-
T* local_shared_buffer = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[params.local_rank]);
288+
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
289+
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
290+
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
290291
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
291292

292293
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
293294
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);
294295

295296
T* buffers[RANKS_PER_NODE];
297+
T* buffers_unorder[RANKS_PER_NODE];
296298
int ranks[RANKS_PER_NODE];
297299
#pragma unroll
298300
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
299301
// A mapping of the ranks to scatter reads as much as possible
300302
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
301303
ranks[ii] = rank;
302-
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
304+
buffers[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[rank]);
305+
buffers_unorder[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[ii]);
303306
}
304307

305308
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
@@ -308,8 +311,22 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
308311
#endif
309312
#endif
310313

311-
block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
312-
grid_size);
314+
if constexpr (COPY_INPUT) {
315+
// Copy all blocks from local buffer to shareable buffer
316+
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
317+
#pragma unroll
318+
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
319+
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
320+
if (offset_rank >= params.elts_total) {
321+
continue;
322+
}
323+
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank]) =
324+
*reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
325+
}
326+
}
327+
}
328+
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
329+
grid_size);
313330

314331
// Each block accumulates the values from the different GPUs on the same node.
315332
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
@@ -319,7 +336,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
319336
PackedType vals[RANKS_PER_NODE];
320337
#pragma unroll
321338
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
322-
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][responsible_block_offset]);
339+
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers_unorder[ii][responsible_block_offset]);
323340
}
324341

325342
// Sum the values from the different ranks.
@@ -328,16 +345,19 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
328345
#pragma unroll
329346
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
330347
// Always reduce from rank 0 to ensure stable reduce order.
331-
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
332-
sums.packed = add128b(sums, vals[ii]);
348+
sums.packed = add128b(sums, vals[rank]);
333349
}
334350

335-
// Store to the local buffer.
336-
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
351+
// Store to the local buffer or tmp buffer
352+
if constexpr (COPY_INPUT) {
353+
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
354+
} else {
355+
*reinterpret_cast<int4*>(&params.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed;
356+
}
337357
}
338358

339-
block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
340-
grid_size, false, true);
359+
block_barrier<false, true>(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx,
360+
bidx, grid_size);
341361

342362
// Gather all needed elts from other intra-node ranks
343363
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
@@ -348,8 +368,13 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
348368
if (offset_rank >= params.elts_total) {
349369
continue;
350370
}
351-
352-
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) = *reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
371+
if constexpr (COPY_INPUT) {
372+
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
373+
*reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
374+
} else {
375+
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
376+
*reinterpret_cast<int4*>(&params.tmp_result_buffers[ranks[ii]][offset_rank]);
377+
}
353378
}
354379
}
355380
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
@@ -417,48 +442,50 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
417442

418443
////////////////////////////////////////////////////////////////////////////////////////////////////
419444

420-
template <typename T, int RANKS_PER_NODE>
445+
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
421446
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
422447
cudaStream_t stream) {
423448
switch (algo) {
424449
case AllReduceStrategyType::ONESHOT: {
425-
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
450+
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
426451
break;
427452
}
428453
case AllReduceStrategyType::TWOSHOT: {
429-
twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
454+
twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
430455
break;
431456
}
432457
}
433458
}
434459

435-
template <typename T>
436-
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
437-
void* buffer = reinterpret_cast<void*>(param.peer_comm_buffer_ptrs[param.rank]);
438-
void* local_inp_buffer = param.local_input_buffer_ptr;
439-
CHECK_CUDA_SUCCESS(
440-
cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream));
441-
442-
CHECK_CUDA_SUCCESS(cudaGetLastError());
443-
460+
template <typename T, bool COPY_INPUT>
461+
void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
444462
size_t elts_per_thread = 16 / sizeof(T);
445463
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
446464
switch (param.ranks_per_node) {
447465
case 2:
448-
dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream);
466+
dispatchARKernels<T, 2, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
449467
break;
450468
case 4:
451-
dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream);
469+
dispatchARKernels<T, 4, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
452470
break;
453471
case 6:
454-
dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream);
472+
dispatchARKernels<T, 6, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
455473
break;
456474
case 8:
457-
dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream);
475+
dispatchARKernels<T, 8, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
458476
break;
459477
default:
460478
break;
461479
}
480+
}
481+
482+
template <typename T>
483+
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
484+
if (param.is_capturing) {
485+
dispatchARKernelsCopyInput<T, false>(strat, param, stream);
486+
} else {
487+
dispatchARKernelsCopyInput<T, true>(strat, param, stream);
488+
}
462489
CHECK_CUDA_SUCCESS(cudaGetLastError());
463490
}
464491

sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ enum class AllReduceStrategyType : int8_t {
3636
AUTO = 3,
3737
};
3838

39+
struct RankData {
40+
void* ptrs[MAX_RANKS_PER_NODE];
41+
};
42+
3943
struct AllReduceParams {
4044
size_t elts_size;
4145
size_t elts_total;
@@ -46,9 +50,11 @@ struct AllReduceParams {
4650
uint32_t barrier_flag;
4751
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
4852
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
49-
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
53+
uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE];
54+
RankData* peer_comm_buffer_ptrs;
5055
void* local_input_buffer_ptr;
5156
void* local_output_buffer_ptr;
57+
bool is_capturing;
5258
};
5359

5460
inline size_t GetMaxRequiredWorkspaceSize(int world_size) {

0 commit comments

Comments
 (0)