Skip to content

Commit 1b540a3

Browse files
committed
Send/recv host and device frames in a message each
1 parent 5bbd53e commit 1b540a3

File tree

1 file changed

+68
-16
lines changed

1 file changed

+68
-16
lines changed

distributed/comm/ucx.py

+68-16
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,34 @@ async def write(
194194
hasattr(f, "__cuda_array_interface__") for f in frames
195195
)
196196
sizes = tuple(nbytes(f) for f in frames)
197-
send_frames = [
198-
each_frame
199-
for each_frame, each_size in zip(frames, sizes)
200-
if each_size
201-
]
197+
host_frames = host_array(
198+
sum(
199+
each_size
200+
for is_cuda, each_size in zip(cuda_frames, sizes)
201+
if not is_cuda
202+
)
203+
)
204+
device_frames = device_array(
205+
sum(
206+
each_size
207+
for is_cuda, each_size in zip(cuda_frames, sizes)
208+
if is_cuda
209+
)
210+
)
211+
212+
# Pack frames
213+
host_frames_view = memoryview(host_frames)
214+
device_frames_view = as_numba_device_array(device_frames)
215+
for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes):
216+
if each_size:
217+
if is_cuda:
218+
each_frame_view = as_numba_device_array(each_frame)
219+
device_frames_view[:each_size] = each_frame_view[:]
220+
device_frames_view = device_frames_view[each_size:]
221+
else:
222+
each_frame_view = memoryview(each_frame).cast("B")
223+
host_frames_view[:each_size] = each_frame_view[:]
224+
host_frames_view = host_frames_view[each_size:]
202225

203226
# Send meta data
204227
await self.ep.send(struct.pack("Q", nframes))
@@ -216,8 +239,10 @@ async def write(
216239
if any(cuda_frames):
217240
synchronize_stream(0)
218241

219-
for each_frame in send_frames:
220-
await self.ep.send(each_frame)
242+
if nbytes(host_frames):
243+
await self.ep.send(host_frames)
244+
if nbytes(device_frames):
245+
await self.ep.send(device_frames)
221246
return sum(sizes)
222247
except (ucp.exceptions.UCXBaseException):
223248
self.abort()
@@ -248,21 +273,48 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
248273
raise CommClosedError("While reading, the connection was closed")
249274
else:
250275
# Recv frames
251-
frames = [
252-
device_array(each_size) if is_cuda else host_array(each_size)
253-
for is_cuda, each_size in zip(cuda_frames, sizes)
254-
]
255-
recv_frames = [
256-
each_frame for each_frame in frames if len(each_frame) > 0
257-
]
276+
host_frames = host_array(
277+
sum(
278+
each_size
279+
for is_cuda, each_size in zip(cuda_frames, sizes)
280+
if not is_cuda
281+
)
282+
)
283+
device_frames = device_array(
284+
sum(
285+
each_size
286+
for is_cuda, each_size in zip(cuda_frames, sizes)
287+
if is_cuda
288+
)
289+
)
258290

259291
# It is necessary to first populate `frames` with CUDA arrays and synchronize
260292
# the default stream before starting receiving to ensure buffers have been allocated
261293
if any(cuda_frames):
262294
synchronize_stream(0)
263295

264-
for each_frame in recv_frames:
265-
await self.ep.recv(each_frame)
296+
if nbytes(host_frames):
297+
await self.ep.recv(host_frames)
298+
if nbytes(device_frames):
299+
await self.ep.recv(device_frames)
300+
301+
frames = [
302+
device_array(each_size) if is_cuda else host_array(each_size)
303+
for is_cuda, each_size in zip(cuda_frames, sizes)
304+
]
305+
host_frames_view = memoryview(host_frames)
306+
device_frames_view = as_numba_device_array(device_frames)
307+
for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes):
308+
if each_size:
309+
if is_cuda:
310+
each_frame_view = as_numba_device_array(each_frame)
311+
each_frame_view[:] = device_frames_view[:each_size]
312+
device_frames_view = device_frames_view[each_size:]
313+
else:
314+
each_frame_view = memoryview(each_frame)
315+
each_frame_view[:] = host_frames_view[:each_size]
316+
host_frames_view = host_frames_view[each_size:]
317+
266318
msg = await from_frames(
267319
frames, deserialize=self.deserialize, deserializers=deserializers
268320
)

0 commit comments

Comments
 (0)