Skip to content

Commit

Permalink
optimize custom allreduce kernel (#2904)
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 authored Jan 15, 2025
1 parent f65c13b commit 6cb3974
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 80 deletions.
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post12"
version = "0.0.2.post13"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def update_wheel_platform_tag():
"-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python"]
libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
CUDAExtension(
Expand Down
4 changes: 4 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from sgl_kernel.ops import (
custom_dispose,
custom_reduce,
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
moe_align_block_size,
register_graph_buffers,
sampling_scaling_penalties,
)

Expand All @@ -14,4 +16,6 @@
"custom_reduce",
"int8_scaled_mm",
"sampling_scaling_penalties",
"get_graph_buffer_ipc_meta",
"register_graph_buffers",
]
10 changes: 8 additions & 2 deletions sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

// trt_reduce
using fptr_t = int64_t;
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& barrier_in, const std::vector<fptr_t>& barrier_out);
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
const std::vector<fptr_t>& barrier_out);
void dispose(fptr_t _fa);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);

// moe_align_block_size
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
Expand All @@ -25,6 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
m.def("dispose", &dispose, "dispose custom allreduce meta");
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta");
m.def("register_graph_buffers", &register_graph_buffers, "custom all reduce register graph buffers");
// moe_align_block_size
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
// sampling_scaling_penalties
Expand Down
139 changes: 83 additions & 56 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
__syncthreads();
}

template <bool start, bool need_fence = false>
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx, int const grid_size,
bool start = true, bool need_fence = false) {
if (!start) {
size_t const world_size, int const tidx, int const bidx, int const grid_size) {
if constexpr (!start) {
__syncthreads();
}
// After this function, the block of id == bidx of each GPU has reached the barrier
Expand All @@ -141,22 +141,16 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t flag_block_offset = world_size + bidx * world_size;

if (flag % 2 == 1) {
flag_block_offset += (grid_size + 1) * world_size;
}
flag_block_offset += (grid_size + 1) * world_size * (flag % 2);

if (need_fence) {
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
} else {
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
}
// Blocks check that corresponding blocks on other GPUs have also set the flag
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;

if (need_fence) {
// Blocks check that corresponding blocks on other GPUs have also set the flag
if constexpr (need_fence) {
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
} else {
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
while (ld_flag_volatile(peer_barrier_d) != flag) {
}
}
Expand All @@ -165,7 +159,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
__syncthreads();
}

template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
Expand Down Expand Up @@ -193,6 +187,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {

int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
int const grid_size = gridDim.x;

// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
Expand All @@ -201,26 +196,31 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
using PackedStruct = typename PackedOn16Bytes<T>::Type;

// The source pointers. Distributed round-robin for the different warps.
T const* buffers[RANKS_PER_NODE];

auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
// Start and end offsets of the thread
size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS;
size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}

multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
if constexpr (COPY_INPUT) {
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
// Copy from local buffer to shareable buffer
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset]) =
*reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);

// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][iter_offset]);
vals[ii].packed = *reinterpret_cast<int4 const*>(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]);
}

// Sum the values from the different ranks.
Expand All @@ -229,16 +229,15 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
sums.packed = add128b(sums, vals[ii]);
sums.packed = add128b(sums, vals[rank]);
}

// Store to the destination buffer.
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
}
}

template <typename T, int RANKS_PER_NODE>
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true>
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
Expand Down Expand Up @@ -286,20 +285,24 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
static constexpr int PACKED_ELTS = 16 / sizeof(T);
using PackedType = typename PackedOn16Bytes<T>::Type;

T* local_shared_buffer = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[params.local_rank]);
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs;
T* local_shared_buffer = reinterpret_cast<T*>(peer_comm_buffer_ptrs[params.local_rank]);
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);

size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);

T* buffers[RANKS_PER_NODE];
T* buffers_unorder[RANKS_PER_NODE];
int ranks[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// A mapping of the ranks to scatter reads as much as possible
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
ranks[ii] = rank;
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
buffers[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[rank]);
buffers_unorder[ii] = reinterpret_cast<T*>(peer_comm_buffer_ptrs[ii]);
}

#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
Expand All @@ -308,8 +311,22 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
#endif
#endif

block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);
if constexpr (COPY_INPUT) {
// Copy all blocks from local buffer to shareable buffer
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total) {
continue;
}
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank]) =
*reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
}
}
}
block_barrier<true>(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);

// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
Expand All @@ -319,7 +336,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][responsible_block_offset]);
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers_unorder[ii][responsible_block_offset]);
}

// Sum the values from the different ranks.
Expand All @@ -328,16 +345,19 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
sums.packed = add128b(sums, vals[ii]);
sums.packed = add128b(sums, vals[rank]);
}

// Store to the local buffer.
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
// Store to the local buffer or tmp buffer
if constexpr (COPY_INPUT) {
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
} else {
*reinterpret_cast<int4*>(&params.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed;
}
}

block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size, false, true);
block_barrier<false, true>(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx,
bidx, grid_size);

// Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
Expand All @@ -348,8 +368,13 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc
if (offset_rank >= params.elts_total) {
continue;
}

*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) = *reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
if constexpr (COPY_INPUT) {
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
*reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
} else {
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) =
*reinterpret_cast<int4*>(&params.tmp_result_buffers[ranks[ii]][offset_rank]);
}
}
}
#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12))
Expand Down Expand Up @@ -417,48 +442,50 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, int RANKS_PER_NODE>
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT>
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
cudaStream_t stream) {
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
oneShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
case AllReduceStrategyType::TWOSHOT: {
twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
twoShotAllReduceKernel<T, RANKS_PER_NODE, COPY_INPUT><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
}
}

template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
void* buffer = reinterpret_cast<void*>(param.peer_comm_buffer_ptrs[param.rank]);
void* local_inp_buffer = param.local_input_buffer_ptr;
CHECK_CUDA_SUCCESS(
cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream));

CHECK_CUDA_SUCCESS(cudaGetLastError());

template <typename T, bool COPY_INPUT>
void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) {
size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node) {
case 2:
dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 2, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 4:
dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 4, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 6:
dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 6, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
case 8:
dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream);
dispatchARKernels<T, 8, COPY_INPUT>(strat, param, blocks_per_grid, threads_per_block, stream);
break;
default:
break;
}
}

template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) {
if (param.is_capturing) {
dispatchARKernelsCopyInput<T, false>(strat, param, stream);
} else {
dispatchARKernelsCopyInput<T, true>(strat, param, stream);
}
CHECK_CUDA_SUCCESS(cudaGetLastError());
}

Expand Down
8 changes: 7 additions & 1 deletion sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ enum class AllReduceStrategyType : int8_t {
AUTO = 3,
};

struct RankData {
void* ptrs[MAX_RANKS_PER_NODE];
};

struct AllReduceParams {
size_t elts_size;
size_t elts_total;
Expand All @@ -46,9 +50,11 @@ struct AllReduceParams {
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE];
RankData* peer_comm_buffer_ptrs;
void* local_input_buffer_ptr;
void* local_output_buffer_ptr;
bool is_capturing;
};

inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
Expand Down
Loading

0 comments on commit 6cb3974

Please sign in to comment.