77#include " allreduce/allreduce_allpair_packet.hpp"
88#include " allreduce/common.hpp"
99#include " collective_utils.hpp"
10- #include " debug.h "
10+ #include " logger.hpp "
1111
1212namespace mscclpp {
1313namespace collective {
@@ -27,22 +27,30 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
2727 size_t scratchBaseOffset = (flag % numScratchBuff) ? (scratchBufferSize / numScratchBuff) : 0 ;
2828 size_t channelScratchOffset = scratchBaseOffset;
2929
30- const int nBlocksPerPeer = gridDim .x / nPeers;
31- const int localBlockIdx = blockIdx .x % nBlocksPerPeer;
32- const int tid = threadIdx .x + localBlockIdx * blockDim .x ;
33- const int peerIdx = blockIdx .x / nBlocksPerPeer;
34- size_t srcOffset = channelDataOffset;
30+ const int tid = threadIdx .x + blockIdx .x * blockDim .x ;
3531 size_t scratchOffset = channelScratchOffset + rank * nelems * sizeof (LL8Packet);
3632 void * scratchBuff = (void *)((char *)scratch + channelScratchOffset);
3733 uint32_t * src = (uint32_t *)((char *)buff);
3834 uint32_t * dst = (uint32_t *)((char *)resultBuff);
3935
40- // step 1: write data to each peer's scratch buffer
41- memoryChannels[peerIdx].putPackets <LL8Packet>(scratchOffset, srcOffset, nelems * sizeof (uint32_t ), tid,
42- blockDim .x * nBlocksPerPeer, flag);
36+ const int warpId = threadIdx .x / WARP_SIZE;
37+ const int lane = threadIdx .x % WARP_SIZE;
38+ const int nWarpsPerBlock = blockDim .x / WARP_SIZE;
39+ // Assign one warp in every block to each peer. Each peer warp sends the
40+ // same block-owned stripe, so nBlocks only partitions data and no longer
41+ // needs to be grouped by nPeers.
42+ if (warpId < nPeers) {
43+ memoryChannels[warpId].putPackets <LL8Packet>(scratchOffset, channelDataOffset, nelems * sizeof (uint32_t ),
44+ lane + blockIdx .x * WARP_SIZE, gridDim .x * WARP_SIZE, flag);
45+ }
46+ // Safe for in-place allreduce: all peer warps must finish reading src for
47+ // this block's stripe before any warp writes reduced data back to dst/src.
48+ __syncthreads ();
4349
44- // step 2: Reduce Data
45- for (size_t idx = threadIdx .x + blockIdx .x * blockDim .x ; idx < nelems; idx += blockDim .x * gridDim .x ) {
50+ // Split the same sent stream across all warps for reduction. warpId selects
51+ // which strided subset to reduce while lane preserves coalesced packet reads.
52+ for (size_t idx = lane + blockIdx .x * WARP_SIZE + warpId * WARP_SIZE * gridDim .x ; idx < nelems;
53+ idx += nWarpsPerBlock * WARP_SIZE * gridDim .x ) {
4654 uint32_t data = src[idx];
4755 using AccRaw = std::conditional_t <std::is_same_v<T, AccumT>, uint32_t ,
4856 mscclpp::VectorType<AccumT, sizeof (uint32_t ) / sizeof (T)>>;
@@ -59,14 +67,14 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
5967 if (threadIdx .x == 0 ) {
6068 ((uint32_t *)flags)[blockIdx .x ] = flag + 1 ;
6169 }
62- if (blockIdx . x == 0 && threadIdx . x >= gridDim .x && threadIdx . x < flagSize / sizeof (uint32_t )) {
63- ((uint32_t *)flags)[threadIdx . x ] = flag + 1 ;
70+ if (tid >= gridDim .x && tid < flagSize / sizeof (uint32_t )) {
71+ ((uint32_t *)flags)[tid ] = flag + 1 ;
6472 }
6573}
6674
6775inline std::pair<int , int > getDefaultBlockNumAndThreadNum (size_t inputSize, int worldSize) {
6876 if (inputSize < worldSize * sizeof (int )) {
69- return {worldSize - 1 , 32 };
77+ return {worldSize - 1 , (worldSize - 1 ) * WARP_SIZE };
7078 }
7179 return {(worldSize - 1 ) * 4 , 512 };
7280}
@@ -80,11 +88,6 @@ struct AllpairAdapter {
8088 int nThreadsPerBlock = 0 ) {
8189 using ChannelType = DeviceHandle<MemoryChannel>;
8290 const size_t nelems = inputSize / sizeof (T);
83- // Round nBlocks to multiple of nPeers so every block maps to a valid peer.
84- const int nPeers = worldSize - 1 ;
85- if (nPeers > 0 ) {
86- nBlocks = (nBlocks / nPeers) * nPeers;
87- }
8891 allreduceAllPairs<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0 , stream>>> (
8992 (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
9093 nRanksPerNode, worldSize, nelems, numScratchBuff, flags, flagSize);
@@ -110,9 +113,17 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
110113 if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0 ) {
111114 blockAndThreadNum = getDefaultBlockNumAndThreadNum (inputSize, algoCtx->workSize );
112115 }
113- // nBlocks must be at least nPeers for allpair — each block maps to one peer.
116+ if (blockAndThreadNum.first > maxBlockNum_) {
117+ WARN (ALGO, " Requested block number " , blockAndThreadNum.first , " exceeds the maximum supported block number " ,
118+ maxBlockNum_, " ." );
119+ return CommResult::CommInvalidArgument;
120+ }
114121 const int nPeers = algoCtx->nRanksPerNode - 1 ;
115- if (nPeers > 0 && blockAndThreadNum.first < nPeers) {
122+ // The kernel maps peer sends by warpId, so every peer needs a full warp.
123+ if (blockAndThreadNum.second % WARP_SIZE != 0 || blockAndThreadNum.second / WARP_SIZE < nPeers) {
124+ WARN (ALGO,
125+ " Allpair packet requires at least one full warp per peer, but got nThreadsPerBlock=" , blockAndThreadNum.second ,
126+ " and nPeers=" , nPeers, " ." );
116127 return CommResult::CommInvalidArgument;
117128 }
118129 size_t sendBytes;
@@ -122,7 +133,8 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
122133
123134 AllreduceFunc allreduce = dispatch<AllpairAdapter>(op, dtype, accumDtype);
124135 if (!allreduce) {
125- WARN (" Unsupported operation or data type for allreduce: op=%d, dtype=%d" , op, static_cast <int >(dtype));
136+ WARN (ALGO, " Unsupported operation or data type for allreduce: op=" , static_cast <int >(op),
137+ " , dtype=" , static_cast <int >(dtype));
126138 return CommResult::CommInvalidArgument;
127139 }
128140 cudaError_t error =
@@ -131,7 +143,7 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
131143 algoCtx->workSize , inputSize, stream, (void *)flagBuffer_, (uint32_t )flagBufferSize_,
132144 this ->nSegmentsForScratchBuffer_ , blockAndThreadNum.first , blockAndThreadNum.second );
133145 if (error != cudaSuccess) {
134- WARN (" AllreducePacket failed with error: %s " , cudaGetErrorString (error));
146+ WARN (ALGO, " AllreducePacket failed with error: " , cudaGetErrorString (error));
135147 return CommResult::CommUnhandledCudaError;
136148 }
137149 return CommResult::CommSuccess;
@@ -189,4 +201,4 @@ std::shared_ptr<Algorithm> AllreduceAllpairPacket::build() {
189201 });
190202}
191203} // namespace collective
192- } // namespace mscclpp
204+ } // namespace mscclpp
0 commit comments