Skip to content

Commit e5a842e

Browse files
committed
Define as_cuda_array
Provides a function to let us coerce our underlying `__cuda_array_interface__` objects into something that behaves more like an array. Prefers CuPy if possible, but will fallback to Numba if its not available.
1 parent 98d82dd commit e5a842e

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

distributed/comm/ucx.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ucp = None
3636
host_array = None
3737
device_array = None
38+
as_device_array = None
3839

3940

4041
def synchronize_stream(stream=0):
@@ -47,7 +48,7 @@ def synchronize_stream(stream=0):
4748

4849

4950
def init_once():
50-
global ucp, host_array, device_array
51+
global ucp, host_array, device_array, as_device_array
5152
if ucp is not None:
5253
return
5354

@@ -100,6 +101,23 @@ def device_array(n):
100101
"In order to send/recv CUDA arrays, Numba or RMM is required"
101102
)
102103

104+
# Find the function, `as_device_array()`
105+
try:
106+
import cupy
107+
108+
as_device_array = lambda a: cupy.asarray(a)
109+
except ImportError:
110+
try:
111+
import numba.cuda
112+
113+
as_device_array = lambda a: numba.cuda.as_cuda_array(a)
114+
except ImportError:
115+
116+
def as_device_array(n):
117+
raise RuntimeError(
118+
"In order to send/recv CUDA arrays, CuPy or Numba is required"
119+
)
120+
103121
pool_size_str = dask.config.get("rmm.pool-size")
104122
if pool_size_str is not None:
105123
pool_size = parse_bytes(pool_size_str)

0 commit comments

Comments
 (0)