Skip to content

Commit 72621e7

Browse files
Binyang2014Copilot
andauthored
add nBlocks check for allreduce_allpair_packet algo (#807)
- Fix the correctness issue for allreduce_allpair_packet algo. Make sure no overwrite for input buffer. Use same tb for send/reduce/write-back. - Check if nBlocks/nthreads validate for packet algorithm. - Add more logs - Modify flag update logic, make it work for the case: nthreadPerNBlock < nflags --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent c107131 commit 72621e7

3 files changed

Lines changed: 42 additions & 25 deletions

File tree

src/ext/collectives/allreduce/allreduce_allpair_packet.cu

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "allreduce/allreduce_allpair_packet.hpp"
88
#include "allreduce/common.hpp"
99
#include "collective_utils.hpp"
10-
#include "debug.h"
10+
#include "logger.hpp"
1111

1212
namespace mscclpp {
1313
namespace collective {
@@ -27,22 +27,30 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
2727
size_t scratchBaseOffset = (flag % numScratchBuff) ? (scratchBufferSize / numScratchBuff) : 0;
2828
size_t channelScratchOffset = scratchBaseOffset;
2929

30-
const int nBlocksPerPeer = gridDim.x / nPeers;
31-
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
32-
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
33-
const int peerIdx = blockIdx.x / nBlocksPerPeer;
34-
size_t srcOffset = channelDataOffset;
30+
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
3531
size_t scratchOffset = channelScratchOffset + rank * nelems * sizeof(LL8Packet);
3632
void* scratchBuff = (void*)((char*)scratch + channelScratchOffset);
3733
uint32_t* src = (uint32_t*)((char*)buff);
3834
uint32_t* dst = (uint32_t*)((char*)resultBuff);
3935

40-
// step 1: write data to each peer's scratch buffer
41-
memoryChannels[peerIdx].putPackets<LL8Packet>(scratchOffset, srcOffset, nelems * sizeof(uint32_t), tid,
42-
blockDim.x * nBlocksPerPeer, flag);
36+
const int warpId = threadIdx.x / WARP_SIZE;
37+
const int lane = threadIdx.x % WARP_SIZE;
38+
const int nWarpsPerBlock = blockDim.x / WARP_SIZE;
39+
// Assign one warp in every block to each peer. Each peer warp sends the
40+
// same block-owned stripe, so nBlocks only partitions data and no longer
41+
// needs to be grouped by nPeers.
42+
if (warpId < nPeers) {
43+
memoryChannels[warpId].putPackets<LL8Packet>(scratchOffset, channelDataOffset, nelems * sizeof(uint32_t),
44+
lane + blockIdx.x * WARP_SIZE, gridDim.x * WARP_SIZE, flag);
45+
}
46+
// Safe for in-place allreduce: all peer warps must finish reading src for
47+
// this block's stripe before any warp writes reduced data back to dst/src.
48+
__syncthreads();
4349

44-
// step 2: Reduce Data
45-
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) {
50+
// Split the same sent stream across all warps for reduction. warpId selects
51+
// which strided subset to reduce while lane preserves coalesced packet reads.
52+
for (size_t idx = lane + blockIdx.x * WARP_SIZE + warpId * WARP_SIZE * gridDim.x; idx < nelems;
53+
idx += nWarpsPerBlock * WARP_SIZE * gridDim.x) {
4654
uint32_t data = src[idx];
4755
using AccRaw = std::conditional_t<std::is_same_v<T, AccumT>, uint32_t,
4856
mscclpp::VectorType<AccumT, sizeof(uint32_t) / sizeof(T)>>;
@@ -59,14 +67,14 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
5967
if (threadIdx.x == 0) {
6068
((uint32_t*)flags)[blockIdx.x] = flag + 1;
6169
}
62-
if (blockIdx.x == 0 && threadIdx.x >= gridDim.x && threadIdx.x < flagSize / sizeof(uint32_t)) {
63-
((uint32_t*)flags)[threadIdx.x] = flag + 1;
70+
if (tid >= gridDim.x && tid < flagSize / sizeof(uint32_t)) {
71+
((uint32_t*)flags)[tid] = flag + 1;
6472
}
6573
}
6674

6775
inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int worldSize) {
6876
if (inputSize < worldSize * sizeof(int)) {
69-
return {worldSize - 1, 32};
77+
return {worldSize - 1, (worldSize - 1) * WARP_SIZE};
7078
}
7179
return {(worldSize - 1) * 4, 512};
7280
}
@@ -80,11 +88,6 @@ struct AllpairAdapter {
8088
int nThreadsPerBlock = 0) {
8189
using ChannelType = DeviceHandle<MemoryChannel>;
8290
const size_t nelems = inputSize / sizeof(T);
83-
// Round nBlocks to multiple of nPeers so every block maps to a valid peer.
84-
const int nPeers = worldSize - 1;
85-
if (nPeers > 0) {
86-
nBlocks = (nBlocks / nPeers) * nPeers;
87-
}
8891
allreduceAllPairs<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
8992
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
9093
nRanksPerNode, worldSize, nelems, numScratchBuff, flags, flagSize);
@@ -110,9 +113,17 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
110113
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
111114
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->workSize);
112115
}
113-
// nBlocks must be at least nPeers for allpair — each block maps to one peer.
116+
if (blockAndThreadNum.first > maxBlockNum_) {
117+
WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ",
118+
maxBlockNum_, ".");
119+
return CommResult::CommInvalidArgument;
120+
}
114121
const int nPeers = algoCtx->nRanksPerNode - 1;
115-
if (nPeers > 0 && blockAndThreadNum.first < nPeers) {
122+
// The kernel maps peer sends by warpId, so every peer needs a full warp.
123+
if (blockAndThreadNum.second % WARP_SIZE != 0 || blockAndThreadNum.second / WARP_SIZE < nPeers) {
124+
WARN(ALGO,
125+
"Allpair packet requires at least one full warp per peer, but got nThreadsPerBlock=", blockAndThreadNum.second,
126+
" and nPeers=", nPeers, ".");
116127
return CommResult::CommInvalidArgument;
117128
}
118129
size_t sendBytes;
@@ -122,7 +133,8 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
122133

123134
AllreduceFunc allreduce = dispatch<AllpairAdapter>(op, dtype, accumDtype);
124135
if (!allreduce) {
125-
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast<int>(dtype));
136+
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
137+
", dtype=", static_cast<int>(dtype));
126138
return CommResult::CommInvalidArgument;
127139
}
128140
cudaError_t error =
@@ -131,7 +143,7 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
131143
algoCtx->workSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_,
132144
this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second);
133145
if (error != cudaSuccess) {
134-
WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error));
146+
WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error));
135147
return CommResult::CommUnhandledCudaError;
136148
}
137149
return CommResult::CommSuccess;
@@ -189,4 +201,4 @@ std::shared_ptr<Algorithm> AllreduceAllpairPacket::build() {
189201
});
190202
}
191203
} // namespace collective
192-
} // namespace mscclpp
204+
} // namespace mscclpp

src/ext/collectives/allreduce/allreduce_packet.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
235235
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
236236
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->workSize, ctx->nRanksPerNode, dtype);
237237
}
238+
if (blockAndThreadNum.first > maxBlockNum_) {
239+
WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ",
240+
maxBlockNum_, ".");
241+
return CommResult::CommInvalidArgument;
242+
}
238243

239244
size_t sendBytes;
240245
CUdeviceptr sendBasePtr;

src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder {
2929
void* scratchBuffer_;
3030
size_t scratchBufferSize_;
3131
const int nSegmentsForScratchBuffer_ = 2;
32-
const int maxBlockNum_ = 28;
32+
const int maxBlockNum_ = 64;
3333
std::vector<Connection> conns_;
3434
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores_;
3535
std::vector<RegisteredMemory> registeredMemories_;

0 commit comments

Comments
 (0)