Skip to content
Open
11 changes: 10 additions & 1 deletion apps/nccl/src/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,15 @@ ncclResult_t AllreduceNvlsWithCopy::allreduceKernelFunc(const std::shared_ptr<ms
const void* input, void* output, size_t count,
mscclpp::DataType dtype, cudaStream_t stream,
std::unordered_map<std::string, std::shared_ptr<void>>&) {
// NVLS kernels require that (size / nRanksPerNode) is aligned to 16 bytes
size_t elemSize = getDataTypeSize(dtype);
size_t nRanksPerNode = ctx->nRanksPerNode;
size_t alignmentBytes = 16;

// Calculate the minimum element count that ensures per-rank size is 16-byte aligned
size_t elementsPerRankAlign = (alignmentBytes * nRanksPerNode) / elemSize;
size_t alignedCount = ((count + elementsPerRankAlign - 1) / elementsPerRankAlign) * elementsPerRankAlign;

AllreduceFunc allreduce = dispatch<NvlsWithCopyAdapter>(ncclSum, dtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
Expand All @@ -454,7 +463,7 @@ ncclResult_t AllreduceNvlsWithCopy::allreduceKernelFunc(const std::shared_ptr<ms
cudaError_t error =
allreduce(input, this->scratchBuffer_.get(), output, this->memoryChannelsDeviceHandle_.get(), nullptr,
ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_, ctx->rank,
ctx->nRanksPerNode, ctx->workSize, count, stream, nullptr, nullptr, nullptr, 0);
ctx->nRanksPerNode, ctx->workSize, alignedCount, stream, nullptr, nullptr, nullptr, 0);
if (error != cudaSuccess) {
WARN("AllreduceNvlsWithCopy failed with error: %s", cudaGetErrorString(error));
return ncclUnhandledCudaError;
Expand Down
4 changes: 3 additions & 1 deletion test/torch/correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def _init_dist():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"]))
dist.init_process_group(backend=backend, rank=rank, world_size=world_size, device_id=local_rank)
dist.init_process_group(
backend=backend, rank=rank, world_size=world_size, device_id=torch.device(f"cuda:{local_rank}")
)
torch.cuda.set_device(local_rank)


Expand Down