Skip to content

Commit 08ee18b

Browse files
authored
Add check to filter invalid nblock/nthread candidates (#811)
Add check for invalid nblock/nthread candidate
1 parent 9e177b3 commit 08ee18b

3 files changed

Lines changed: 46 additions & 14 deletions

File tree

examples/torch-integration/customized_comm_with_tuning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ class CustomizedComm:
7070
_TUNE_N_WARMUP = 5
7171
_TUNE_N_GRAPH_LAUNCHES = 10
7272
_TUNE_N_OPS_PER_GRAPH = 100
73-
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 128]
73+
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 56, 64, 128]
7474
_CANDIDATE_NTHREADS = [512, 768, 1024]
7575
_NBLOCKS_LIMIT = {
7676
"default_allreduce_nvls_packet": 16,
7777
"default_allreduce_packet": 56,
78-
"default_allreduce_allpair_packet": 56,
78+
"default_allreduce_allpair_packet": 64,
7979
"default_allreduce_fullmesh": 64,
8080
"default_allgather_fullmesh2": 32,
8181
}

src/ext/collectives/allgather/allgather_fullmesh.cu

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
namespace mscclpp {
99
namespace collective {
1010

11+
namespace {
12+
constexpr int kMaxBlocks = 56;
13+
constexpr int kMaxThreadsPerBlock = 1024;
14+
} // namespace
15+
1116
template <bool IsOutOfPlace>
1217
__global__ void __launch_bounds__(1024, 1)
1318
allgatherFullmesh(void* buff, void* scratch, void* resultBuff, DeviceHandle<MemoryChannel>* memoryChannels,
@@ -116,12 +121,19 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr<void> ct
116121
int rank = ctx->rank;
117122
const size_t nElem = inputSize / sizeof(int);
118123
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
119-
if (numBlocksAndThreads.first > 56) {
120-
WARN("AllgatherFullmesh: number of blocks exceeds maximum supported blocks, which is 56");
121-
return mscclpp::CommResult::CommInvalidArgument;
122-
}
123124
if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) {
124-
numBlocksAndThreads = {56, 1024};
125+
numBlocksAndThreads = {kMaxBlocks, kMaxThreadsPerBlock};
126+
}
127+
if (numBlocksAndThreads.first > kMaxBlocks || numBlocksAndThreads.second > kMaxThreadsPerBlock) {
128+
WARN(
129+
"AllgatherFullmesh: number of blocks must be no more than %d and threads per block must be no more than %d; "
130+
"got nBlocks=%d, nThreadsPerBlock=%d",
131+
kMaxBlocks, kMaxThreadsPerBlock, numBlocksAndThreads.first, numBlocksAndThreads.second);
132+
return CommResult::CommInvalidArgument;
133+
}
134+
if (numBlocksAndThreads.second % WARP_SIZE != 0) {
135+
WARN("AllgatherFullmesh: threads per block must be a multiple of warp size %d", WARP_SIZE);
136+
return CommResult::CommInvalidArgument;
125137
}
126138
if ((char*)input == (char*)output + rank * inputSize) {
127139
allgatherFullmesh<false><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
@@ -142,15 +154,13 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr<void> ct
142154

143155
std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Communicator> comm, const void* input,
144156
void*, size_t inputSize, DataType) {
145-
constexpr int nChannelsPerConnection = 56;
146-
147157
auto ctx = std::make_shared<AlgorithmCtx>();
148158
ctx->rank = comm->bootstrap()->getRank();
149159
ctx->workSize = comm->bootstrap()->getNranks();
150160
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
151161

152162
// setup semaphores
153-
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection);
163+
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, kMaxBlocks);
154164

155165
// register the memory for the broadcast operation
156166
RegisteredMemory localMemory = comm->registerMemory((void*)input, inputSize, Transport::CudaIpc);
@@ -159,7 +169,7 @@ std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Co
159169

160170
// setup channels
161171
ctx->memoryChannels =
162-
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection);
172+
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, kMaxBlocks);
163173
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
164174

165175
// keep registered memories reference
@@ -196,4 +206,4 @@ std::shared_ptr<Algorithm> AllgatherFullmesh::build() {
196206
});
197207
}
198208
} // namespace collective
199-
} // namespace mscclpp
209+
} // namespace mscclpp

src/ext/collectives/allgather/allgather_fullmesh_2.cu

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ __global__ void __launch_bounds__(1024, 1)
1818
const size_t lid = tid % WARP_SIZE;
1919
const size_t wid = tid / WARP_SIZE;
2020

21-
const size_t nThread = blockDim.x * gridDim.x;
21+
// Round down to multiple of warp size
22+
const size_t nThread = (blockDim.x * gridDim.x) / WARP_SIZE * WARP_SIZE;
23+
if (tid >= nThread) {
24+
return;
25+
}
2226
const size_t nWarp = nThread / WARP_SIZE;
2327
const size_t nPeer = nRanksPerNode - 1;
2428
const size_t chanOffset = nPeer * blockIdx.x;
@@ -135,6 +139,24 @@ CommResult AllgatherFullmesh2::allgatherKernelFunc(const std::shared_ptr<void> c
135139
numBlocksAndThreads.first = 35;
136140
}
137141
}
142+
const int nPeer = ctx->nRanksPerNode - 1;
143+
const int nWarp = numBlocksAndThreads.first * numBlocksAndThreads.second / WARP_SIZE;
144+
if (numBlocksAndThreads.first > nChannelsPerConnection_ || numBlocksAndThreads.first <= 0 ||
145+
numBlocksAndThreads.second <= 0) {
146+
WARN(
147+
"AllgatherFullmesh2: number of blocks must be a positive multiple of peer count and no more than %d, threads "
148+
"per block must be positive; got nBlocks=%d, nThreadsPerBlock=%d, nPeers=%d",
149+
nChannelsPerConnection_, numBlocksAndThreads.first, numBlocksAndThreads.second, nPeer);
150+
return CommResult::CommInvalidArgument;
151+
}
152+
if (nWarp < nPeer) {
153+
WARN(
154+
"AllgatherFullmesh2: total number of warps must be no less than peer count; got nBlocks=%d, "
155+
"nThreadsPerBlock=%d, "
156+
"nPeers=%d",
157+
numBlocksAndThreads.first, numBlocksAndThreads.second, nPeer);
158+
return CommResult::CommInvalidArgument;
159+
}
138160

139161
size_t channelOutOffset = *static_cast<size_t*>(ctx->extras["channel_out_offset"].get());
140162
if ((char*)input == (char*)output + rank * inputSize) {
@@ -226,4 +248,4 @@ std::shared_ptr<Algorithm> AllgatherFullmesh2::build() {
226248
}
227249

228250
} // namespace collective
229-
} // namespace mscclpp
251+
} // namespace mscclpp

0 commit comments

Comments
 (0)