Skip to content

Commit c758ae6

Browse files
committed
Send/recv host and device frames in a message each
To cutdown on the number of send/recv operations and also to transmit larger amounts of data at a time, this condenses all frames into a host buffer and a device buffer, which are sent as two separate transmissions.
1 parent e5a842e commit c758ae6

File tree

1 file changed

+68
-16
lines changed

1 file changed

+68
-16
lines changed

distributed/comm/ucx.py

+68-16
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,34 @@ async def write(
201201
hasattr(f, "__cuda_array_interface__") for f in frames
202202
)
203203
sizes = tuple(nbytes(f) for f in frames)
204-
send_frames = [
205-
each_frame
206-
for each_frame, each_size in zip(frames, sizes)
207-
if each_size
208-
]
204+
host_frames = host_array(
205+
sum(
206+
each_size
207+
for is_cuda, each_size in zip(cuda_frames, sizes)
208+
if not is_cuda
209+
)
210+
)
211+
device_frames = device_array(
212+
sum(
213+
each_size
214+
for is_cuda, each_size in zip(cuda_frames, sizes)
215+
if is_cuda
216+
)
217+
)
218+
219+
# Pack frames
220+
host_frames_view = memoryview(host_frames)
221+
device_frames_view = as_device_array(device_frames)
222+
for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes):
223+
if each_size:
224+
if is_cuda:
225+
each_frame_view = as_device_array(each_frame)
226+
device_frames_view[:each_size] = each_frame_view[:]
227+
device_frames_view = device_frames_view[each_size:]
228+
else:
229+
each_frame_view = memoryview(each_frame).cast("B")
230+
host_frames_view[:each_size] = each_frame_view[:]
231+
host_frames_view = host_frames_view[each_size:]
209232

210233
# Send meta data
211234
await self.ep.send(struct.pack("Q", nframes))
@@ -223,8 +246,10 @@ async def write(
223246
if any(cuda_frames):
224247
synchronize_stream(0)
225248

226-
for each_frame in send_frames:
227-
await self.ep.send(each_frame)
249+
if nbytes(host_frames):
250+
await self.ep.send(host_frames)
251+
if nbytes(device_frames):
252+
await self.ep.send(device_frames)
228253
return sum(sizes)
229254
except (ucp.exceptions.UCXBaseException):
230255
self.abort()
@@ -255,21 +280,48 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
255280
raise CommClosedError("While reading, the connection was closed")
256281
else:
257282
# Recv frames
258-
frames = [
259-
device_array(each_size) if is_cuda else host_array(each_size)
260-
for is_cuda, each_size in zip(cuda_frames, sizes)
261-
]
262-
recv_frames = [
263-
each_frame for each_frame in frames if len(each_frame) > 0
264-
]
283+
host_frames = host_array(
284+
sum(
285+
each_size
286+
for is_cuda, each_size in zip(cuda_frames, sizes)
287+
if not is_cuda
288+
)
289+
)
290+
device_frames = device_array(
291+
sum(
292+
each_size
293+
for is_cuda, each_size in zip(cuda_frames, sizes)
294+
if is_cuda
295+
)
296+
)
265297

266298
# It is necessary to first populate `frames` with CUDA arrays and synchronize
267299
# the default stream before starting receiving to ensure buffers have been allocated
268300
if any(cuda_frames):
269301
synchronize_stream(0)
270302

271-
for each_frame in recv_frames:
272-
await self.ep.recv(each_frame)
303+
if nbytes(host_frames):
304+
await self.ep.recv(host_frames)
305+
if nbytes(device_frames):
306+
await self.ep.recv(device_frames)
307+
308+
frames = [
309+
device_array(each_size) if is_cuda else host_array(each_size)
310+
for is_cuda, each_size in zip(cuda_frames, sizes)
311+
]
312+
host_frames_view = memoryview(host_frames)
313+
device_frames_view = as_device_array(device_frames)
314+
for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes):
315+
if each_size:
316+
if is_cuda:
317+
each_frame_view = as_device_array(each_frame)
318+
each_frame_view[:] = device_frames_view[:each_size]
319+
device_frames_view = device_frames_view[each_size:]
320+
else:
321+
each_frame_view = memoryview(each_frame)
322+
each_frame_view[:] = host_frames_view[:each_size]
323+
host_frames_view = host_frames_view[each_size:]
324+
273325
msg = await from_frames(
274326
frames, deserialize=self.deserialize, deserializers=deserializers
275327
)

0 commit comments

Comments
 (0)