Skip to content

Commit 95a3534

Browse files
author
Yifu Wang
committed
[IntraNodeComm] fix an issue where input check fails when running all-reduce on sub groups
ghstack-source-id: 218c718 Pull Request resolved: pytorch#130492
1 parent 46c5266 commit 95a3534

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

torch/csrc/distributed/c10d/intra_node_comm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ bool IntraNodeComm::rendezvous() {
280280
return false;
281281
}
282282

283-
auto deviceIdx = at::cuda::current_device();
284-
c10::cuda::CUDAGuard guard(deviceIdx);
283+
deviceIdx_ = at::cuda::current_device();
284+
c10::cuda::CUDAGuard guard(deviceIdx_);
285285

286286
// First hand shake: exchange hostname and device bus ID
287287
struct DevInfo {
@@ -292,7 +292,7 @@ bool IntraNodeComm::rendezvous() {
292292
DevInfo devInfo{};
293293
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
294294
cudaDeviceProp prop{};
295-
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx));
295+
AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx_));
296296
snprintf(
297297
devInfo.busId,
298298
sizeof(devInfo.busId),
@@ -334,7 +334,7 @@ bool IntraNodeComm::rendezvous() {
334334
auto groupName = "IntraNodeComm" + std::to_string(intraNodeCommIdx++);
335335
set_group_info(groupName, rank_, worldSize_, store_);
336336
auto allocator = get_allocator(c10::DeviceType::CUDA);
337-
symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx, groupName);
337+
symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, groupName);
338338
symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_);
339339
TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize);
340340

torch/csrc/distributed/c10d/intra_node_comm.cu

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,13 +441,18 @@ static inline size_t alignUp(uint32_t a, uint32_t b) {
441441
return divUp(a, b) * b;
442442
}
443443

444-
static void checkInput(const at::Tensor& input, size_t rank) {
444+
static void checkInput(const at::Tensor& input, int deviceIdx) {
445445
TORCH_CHECK(
446446
input.dtype() == at::kBFloat16,
447447
"oneShotAllReduce only supports bf16 for now");
448448
TORCH_CHECK(input.is_non_overlapping_and_dense());
449449
TORCH_CHECK(input.device().is_cuda());
450-
TORCH_CHECK(static_cast<size_t>(input.get_device()) == rank);
450+
TORCH_CHECK(
451+
input.get_device() == deviceIdx,
452+
"IntraNodeComm: expect input to be on device ",
453+
deviceIdx,
454+
", got device ",
455+
input.get_device());
451456
}
452457

453458
static void getLaunchConfig(
@@ -507,7 +512,7 @@ void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) {
507512
at::Tensor IntraNodeComm::oneShotAllReduce(
508513
const at::Tensor& input,
509514
at::cuda::CUDAStream& stream) {
510-
checkInput(input, rank_);
515+
checkInput(input, deviceIdx_);
511516

512517
const size_t numelPerWarp =
513518
kBytesPerThread / input.element_size() * kWarpSize;

torch/csrc/distributed/c10d/intra_node_comm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
101101
* Members initialized after rendezvous
102102
*/
103103
bool isInitialized_ = false;
104+
int deviceIdx_;
104105
Topology topology_ = Topology::UNKNOWN;
105106
void* symmetricMemoryPtr_ = nullptr;
106107
c10::intrusive_ptr<SymmetricMemory> symmetricMemory_ = nullptr;

0 commit comments

Comments
 (0)