Skip to content

Commit 379d0e5

Browse files
authored
Fix for allgather_fullmesh algo (#813)
1 parent 08ee18b commit 379d0e5

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/ext/collectives/allgather/allgather_fullmesh_2.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ __global__ void __launch_bounds__(1024, 1)
1717
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
1818
const size_t lid = tid % WARP_SIZE;
1919
const size_t wid = tid / WARP_SIZE;
20+
const size_t nPeer = nRanksPerNode - 1;
2021

21-
// Round down to multiple of warp size
22-
const size_t nThread = (blockDim.x * gridDim.x) / WARP_SIZE * WARP_SIZE;
22+
// Round down to multiple of peer count.
23+
const size_t nThread = (blockDim.x * gridDim.x) / WARP_SIZE / nPeer * nPeer * WARP_SIZE;
2324
if (tid >= nThread) {
2425
return;
2526
}
2627
const size_t nWarp = nThread / WARP_SIZE;
27-
const size_t nPeer = nRanksPerNode - 1;
2828
const size_t chanOffset = nPeer * blockIdx.x;
2929
auto memChans = memoryChannels + chanOffset;
3030

0 commit comments

Comments
 (0)