Skip to content

Commit f83add6

Browse files
committed
Debug [skip ci]
1 parent 5464d62 commit f83add6

1 file changed

Lines changed: 7 additions & 8 deletions

File tree

mooncake-ep/src/mooncake_backend.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <ATen/cuda/CUDAContext.h>
2+
#include <c10/cuda/CUDAGuard.h>
23
#include <cuda_runtime.h>
34
#include <torch/torch.h>
45
#include <torch/csrc/distributed/c10d/Backend.hpp>
@@ -962,11 +963,10 @@ void MooncakeBackend::advanceSendOp(SendOpStateData& opData) {
962963
if (isCpu_) {
963964
std::memcpy(opData.sendBuf, opData.op.tensor.data_ptr(), opData.numBytes);
964965
} else {
965-
// Worker thread doesn't have CUDA context, so we need to set the device
966-
// and use a stream from the pool
966+
// Worker thread needs to set CUDA device context
967967
int deviceIndex = opData.op.tensor.device().index();
968-
at::cuda::CUDAGuard guard(deviceIndex);
969-
auto stream = at::cuda::getStreamFromPool(false, deviceIndex);
968+
c10::cuda::CUDAGuard guard(deviceIndex);
969+
auto stream = at::cuda::getCurrentCUDAStream(deviceIndex);
970970
auto err = cudaMemcpyAsync(opData.sendBuf, opData.op.tensor.data_ptr(),
971971
opData.numBytes, cudaMemcpyDeviceToDevice, stream);
972972
TORCH_CHECK(!err, "P2P send cudaMemcpyAsync failed: ", cudaGetErrorString(err));
@@ -1144,11 +1144,10 @@ void MooncakeBackend::advanceRecvOp(RecvOpStateData& opData) {
11441144
if (isCpu_) {
11451145
std::memcpy(opData.op.tensor.data_ptr(), opData.recvBuf, opData.numBytes);
11461146
} else {
1147-
// Worker thread doesn't have CUDA context, so we need to set the device
1148-
// and use a stream from the pool
1147+
// Worker thread needs to set CUDA device context
11491148
int deviceIndex = opData.op.tensor.device().index();
1150-
at::cuda::CUDAGuard guard(deviceIndex);
1151-
auto stream = at::cuda::getStreamFromPool(false, deviceIndex);
1149+
c10::cuda::CUDAGuard guard(deviceIndex);
1150+
auto stream = at::cuda::getCurrentCUDAStream(deviceIndex);
11521151
auto err = cudaMemcpyAsync(opData.op.tensor.data_ptr(), opData.recvBuf,
11531152
opData.numBytes, cudaMemcpyDeviceToDevice, stream);
11541153
TORCH_CHECK(!err, "P2P recv cudaMemcpyAsync failed: ", cudaGetErrorString(err));

0 commit comments

Comments
 (0)