diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 4e6ca8116c8..7761afef7a1 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -6,10 +6,10 @@ .. _UCX: https://github.com/openucx/ucx """ import logging +import struct import weakref import dask -import numpy as np from .addressing import parse_host_port, unparse_host_port from .core import Comm, Connector, Listener, CommClosedError @@ -33,7 +33,8 @@ # required to ensure Dask configuration gets propagated to UCX, which needs # variables to be set before being imported. ucp = None -cuda_array = None +host_array = None +device_array = None def synchronize_stream(stream=0): @@ -46,7 +47,7 @@ def synchronize_stream(stream=0): def init_once(): - global ucp, cuda_array + global ucp, host_array, device_array if ucp is not None: return @@ -59,34 +60,42 @@ def init_once(): ucp.init(options=ucx_config, env_takes_precedence=True) + # Find the function, `host_array()`, to use when allocating new host arrays + try: + import numpy + + host_array = lambda n: numpy.empty((n,), dtype="u1") + except ImportError: + host_array = lambda n: bytearray(n) + # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: import rmm if hasattr(rmm, "DeviceBuffer"): - cuda_array = lambda n: rmm.DeviceBuffer(size=n) + device_array = lambda n: rmm.DeviceBuffer(size=n) else: # pre-0.11.0 import numba.cuda - def rmm_cuda_array(n): - a = rmm.device_array(n, dtype=np.uint8) + def rmm_device_array(n): + a = rmm.device_array(n, dtype="u1") weakref.finalize(a, numba.cuda.current_context) return a - cuda_array = rmm_cuda_array + device_array = rmm_device_array except ImportError: try: import numba.cuda - def numba_cuda_array(n): - a = numba.cuda.device_array((n,), dtype=np.uint8) + def numba_device_array(n): + a = numba.cuda.device_array((n,), dtype="u1") weakref.finalize(a, numba.cuda.current_context) return a - cuda_array = numba_cuda_array + device_array = numba_device_array except ImportError: - def cuda_array(n): + def device_array(n): raise RuntimeError( "In order to send/recv CUDA arrays, Numba or RMM is required" ) @@ -169,19 +178,25 @@ async def write( frames = await to_frames( msg, serializers=serializers, on_error=on_error ) + nframes = len(frames) + cuda_frames = tuple( + hasattr(f, "__cuda_array_interface__") for f in frames + ) + sizes = tuple(nbytes(f) for f in frames) send_frames = [ - each_frame for each_frame in frames if len(each_frame) > 0 + each_frame + for each_frame, each_size in zip(frames, sizes) + if each_size ] # Send meta data - cuda_frames = np.array( - [hasattr(f, "__cuda_array_interface__") for f in frames], - dtype=np.bool, - ) - await self.ep.send(np.array([len(frames)], dtype=np.uint64)) - await self.ep.send(cuda_frames) + + # Send # of frames (uint64) + await self.ep.send(struct.pack("Q", nframes)) + # Send which frames are CUDA (bool) and + # how large each frame is (uint64) await self.ep.send( - np.array([nbytes(f) for f in frames], dtype=np.uint64) + struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) ) # Send frames @@ -191,12 +206,12 @@ async def write( # syncing the default stream will wait for other non-blocking CUDA streams. # Note this is only sufficient if the memory being sent is not currently in use on # non-blocking CUDA streams. - if cuda_frames.any(): + if any(cuda_frames): synchronize_stream(0) for each_frame in send_frames: await self.ep.send(each_frame) - return sum(map(nbytes, send_frames)) + return sum(sizes) except (ucp.exceptions.UCXBaseException): self.abort() raise CommClosedError("While writing, the connection was closed") @@ -211,22 +226,28 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): try: # Recv meta data - nframes = np.empty(1, dtype=np.uint64) + + # Recv # of frames (uint64) + nframes_fmt = "Q" + nframes = host_array(struct.calcsize(nframes_fmt)) await self.ep.recv(nframes) - is_cudas = np.empty(nframes[0], dtype=np.bool) - await self.ep.recv(is_cudas) - sizes = np.empty(nframes[0], dtype=np.uint64) - await self.ep.recv(sizes) + (nframes,) = struct.unpack(nframes_fmt, nframes) + + # Recv which frames are CUDA (bool) and + # how large each frame is (uint64) + header_fmt = nframes * "?" + nframes * "Q" + header = host_array(struct.calcsize(header_fmt)) + await self.ep.recv(header) + header = struct.unpack(header_fmt, header) + cuda_frames, sizes = header[:nframes], header[nframes:] except (ucp.exceptions.UCXBaseException, CancelledError): self.abort() raise CommClosedError("While reading, the connection was closed") else: # Recv frames frames = [ - cuda_array(each_size) - if is_cuda - else np.empty(each_size, dtype=np.uint8) - for is_cuda, each_size in zip(is_cudas.tolist(), sizes.tolist()) + device_array(each_size) if is_cuda else host_array(each_size) + for is_cuda, each_size in zip(cuda_frames, sizes) ] recv_frames = [ each_frame for each_frame in frames if len(each_frame) > 0 @@ -234,7 +255,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated - if is_cudas.any(): + if any(cuda_frames): synchronize_stream(0) for each_frame in recv_frames: