Skip to content

Commit f4696b7

Browse files
authored
update (#436)
1 parent c095c62 commit f4696b7

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

csrc/cuda/scatter_cuda.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
6363
CHECK_CUDA(index);
6464
if (optional_out.has_value())
6565
CHECK_CUDA(optional_out.value());
66-
cudaSetDevice(src.get_device());
66+
c10::cuda::MaybeSetDevice(src.get_device());
6767

6868
CHECK_INPUT(src.dim() == index.dim());
6969
for (auto i = 0; i < index.dim() - 1; i++)

csrc/cuda/segment_coo_cuda.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
157157
CHECK_CUDA(index);
158158
if (optional_out.has_value())
159159
CHECK_CUDA(optional_out.value());
160-
cudaSetDevice(src.get_device());
160+
c10::cuda::MaybeSetDevice(src.get_device());
161161

162162
CHECK_INPUT(src.dim() >= index.dim());
163163

@@ -330,7 +330,7 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
330330
CHECK_CUDA(index);
331331
if (optional_out.has_value())
332332
CHECK_CUDA(optional_out.value());
333-
cudaSetDevice(src.get_device());
333+
c10::cuda::MaybeSetDevice(src.get_device());
334334
335335
CHECK_INPUT(src.dim() >= index.dim());
336336

csrc/cuda/segment_csr_cuda.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
102102
CHECK_CUDA(indptr);
103103
if (optional_out.has_value())
104104
CHECK_CUDA(optional_out.value());
105-
cudaSetDevice(src.get_device());
105+
c10::cuda::MaybeSetDevice(src.get_device());
106106

107107
CHECK_INPUT(src.dim() >= indptr.dim());
108108

@@ -222,7 +222,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
222222
CHECK_CUDA(indptr);
223223
if (optional_out.has_value())
224224
CHECK_CUDA(optional_out.value());
225-
cudaSetDevice(src.get_device());
225+
c10::cuda::MaybeSetDevice(src.get_device());
226226

227227
CHECK_INPUT(src.dim() >= indptr.dim());
228228

0 commit comments

Comments
 (0)