88namespace mscclpp {
99namespace collective {
1010
11+ namespace {
12+ constexpr int kMaxBlocks = 56 ;
13+ constexpr int kMaxThreadsPerBlock = 1024 ;
14+ } // namespace
15+
1116template <bool IsOutOfPlace>
1217__global__ void __launch_bounds__ (1024 , 1 )
1318 allgatherFullmesh(void * buff, void * scratch, void * resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
@@ -116,12 +121,19 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr<void> ct
116121 int rank = ctx->rank ;
117122 const size_t nElem = inputSize / sizeof (int );
118123 std::pair<int , int > numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
119- if (numBlocksAndThreads.first > 56 ) {
120- WARN (" AllgatherFullmesh: number of blocks exceeds maximum supported blocks, which is 56" );
121- return mscclpp::CommResult::CommInvalidArgument;
122- }
123124 if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0 ) {
124- numBlocksAndThreads = {56 , 1024 };
125+ numBlocksAndThreads = {kMaxBlocks , kMaxThreadsPerBlock };
126+ }
127+ if (numBlocksAndThreads.first > kMaxBlocks || numBlocksAndThreads.second > kMaxThreadsPerBlock ) {
128+ WARN (
129+ " AllgatherFullmesh: number of blocks must be no more than %d and threads per block must be no more than %d; "
130+ " got nBlocks=%d, nThreadsPerBlock=%d" ,
131+ kMaxBlocks , kMaxThreadsPerBlock , numBlocksAndThreads.first , numBlocksAndThreads.second );
132+ return CommResult::CommInvalidArgument;
133+ }
134+ if (numBlocksAndThreads.second % WARP_SIZE != 0 ) {
135+ WARN (" AllgatherFullmesh: threads per block must be a multiple of warp size %d" , WARP_SIZE);
136+ return CommResult::CommInvalidArgument;
125137 }
126138 if ((char *)input == (char *)output + rank * inputSize) {
127139 allgatherFullmesh<false ><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0 , stream>>> (
@@ -142,15 +154,13 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr<void> ct
142154
143155std::shared_ptr<void > AllgatherFullmesh::initAllgatherContext (std::shared_ptr<Communicator> comm, const void * input,
144156 void *, size_t inputSize, DataType) {
145- constexpr int nChannelsPerConnection = 56 ;
146-
147157 auto ctx = std::make_shared<AlgorithmCtx>();
148158 ctx->rank = comm->bootstrap ()->getRank ();
149159 ctx->workSize = comm->bootstrap ()->getNranks ();
150160 ctx->nRanksPerNode = comm->bootstrap ()->getNranksPerNode ();
151161
152162 // setup semaphores
153- ctx->memorySemaphores = setupMemorySemaphores (comm, this ->conns_ , nChannelsPerConnection );
163+ ctx->memorySemaphores = setupMemorySemaphores (comm, this ->conns_ , kMaxBlocks );
154164
155165 // register the memory for the broadcast operation
156166 RegisteredMemory localMemory = comm->registerMemory ((void *)input, inputSize, Transport::CudaIpc);
@@ -159,7 +169,7 @@ std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Co
159169
160170 // setup channels
161171 ctx->memoryChannels =
162- setupMemoryChannels (this ->conns_ , ctx->memorySemaphores , remoteMemories, localMemory, nChannelsPerConnection );
172+ setupMemoryChannels (this ->conns_ , ctx->memorySemaphores , remoteMemories, localMemory, kMaxBlocks );
163173 ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles (ctx->memoryChannels );
164174
165175 // keep registered memories reference
@@ -196,4 +206,4 @@ std::shared_ptr<Algorithm> AllgatherFullmesh::build() {
196206 });
197207}
198208} // namespace collective
199- } // namespace mscclpp
209+ } // namespace mscclpp
0 commit comments