|
1 | 1 | #include <ATen/cuda/CUDAContext.h> |
| 2 | +#include <c10/cuda/CUDAGuard.h> |
2 | 3 | #include <cuda_runtime.h> |
3 | 4 | #include <torch/torch.h> |
4 | 5 | #include <torch/csrc/distributed/c10d/Backend.hpp> |
@@ -962,11 +963,10 @@ void MooncakeBackend::advanceSendOp(SendOpStateData& opData) { |
962 | 963 | if (isCpu_) { |
963 | 964 | std::memcpy(opData.sendBuf, opData.op.tensor.data_ptr(), opData.numBytes); |
964 | 965 | } else { |
965 | | - // Worker thread doesn't have CUDA context, so we need to set the device |
966 | | - // and use a stream from the pool |
| 966 | + // Worker thread needs to set CUDA device context |
967 | 967 | int deviceIndex = opData.op.tensor.device().index(); |
968 | | - at::cuda::CUDAGuard guard(deviceIndex); |
969 | | - auto stream = at::cuda::getStreamFromPool(false, deviceIndex); |
| 968 | + c10::cuda::CUDAGuard guard(deviceIndex); |
| 969 | + auto stream = at::cuda::getCurrentCUDAStream(deviceIndex); |
970 | 970 | auto err = cudaMemcpyAsync(opData.sendBuf, opData.op.tensor.data_ptr(), |
971 | 971 | opData.numBytes, cudaMemcpyDeviceToDevice, stream); |
972 | 972 | TORCH_CHECK(!err, "P2P send cudaMemcpyAsync failed: ", cudaGetErrorString(err)); |
@@ -1144,11 +1144,10 @@ void MooncakeBackend::advanceRecvOp(RecvOpStateData& opData) { |
1144 | 1144 | if (isCpu_) { |
1145 | 1145 | std::memcpy(opData.op.tensor.data_ptr(), opData.recvBuf, opData.numBytes); |
1146 | 1146 | } else { |
1147 | | - // Worker thread doesn't have CUDA context, so we need to set the device |
1148 | | - // and use a stream from the pool |
| 1147 | + // Worker thread needs to set CUDA device context |
1149 | 1148 | int deviceIndex = opData.op.tensor.device().index(); |
1150 | | - at::cuda::CUDAGuard guard(deviceIndex); |
1151 | | - auto stream = at::cuda::getStreamFromPool(false, deviceIndex); |
| 1149 | + c10::cuda::CUDAGuard guard(deviceIndex); |
| 1150 | + auto stream = at::cuda::getCurrentCUDAStream(deviceIndex); |
1152 | 1151 | auto err = cudaMemcpyAsync(opData.op.tensor.data_ptr(), opData.recvBuf, |
1153 | 1152 | opData.numBytes, cudaMemcpyDeviceToDevice, stream); |
1154 | 1153 | TORCH_CHECK(!err, "P2P recv cudaMemcpyAsync failed: ", cudaGetErrorString(err)); |
|
0 commit comments