Skip to content

Commit bee6f0b

Browse files
committed
Send/recv host and device frames in a message each
To cutdown on the number of send/recv operations and also to transmit larger amounts of data at a time, this condenses all frames into a host buffer and a device buffer, which are sent as two separate transmissions.
1 parent 070a19f commit bee6f0b

File tree

1 file changed

+68
-16
lines changed

1 file changed

+68
-16
lines changed

distributed/comm/ucx.py

+68-16
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,34 @@ async def write(
201201
hasattr(f, "__cuda_array_interface__") for f in frames
202202
)
203203
sizes = tuple(nbytes(f) for f in frames)
204-
send_frames = [
205-
each_frame
206-
for each_frame, each_size in zip(frames, sizes)
207-
if each_size
208-
]
204+
host_frames = host_array(
205+
sum(
206+
each_size
207+
for is_cuda, each_size in zip(cuda_frames, sizes)
208+
if not is_cuda
209+
)
210+
)
211+
device_frames = device_array(
212+
sum(
213+
each_size
214+
for is_cuda, each_size in zip(cuda_frames, sizes)
215+
if is_cuda
216+
)
217+
)
218+
219+
# Pack frames
220+
host_frames_view = memoryview(host_frames)
221+
device_frames_view = as_device_array(device_frames)
222+
for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes):
223+
if each_size:
224+
if is_cuda:
225+
each_frame_view = as_device_array(each_frame)
226+
device_frames_view[:each_size] = each_frame_view[:]
227+
device_frames_view = device_frames_view[each_size:]
228+
else:
229+
each_frame_view = memoryview(each_frame).cast("B")
230+
host_frames_view[:each_size] = each_frame_view[:]
231+
host_frames_view = host_frames_view[each_size:]
209232

210233
# Send meta data
211234

@@ -227,8 +250,10 @@ async def write(
227250
if any(cuda_frames):
228251
synchronize_stream(0)
229252

230-
for each_frame in send_frames:
231-
await self.ep.send(each_frame)
253+
if nbytes(host_frames):
254+
await self.ep.send(host_frames)
255+
if nbytes(device_frames):
256+
await self.ep.send(device_frames)
232257
return sum(sizes)
233258
except (ucp.exceptions.UCXBaseException):
234259
self.abort()
@@ -263,21 +288,48 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
263288
raise CommClosedError("While reading, the connection was closed")
264289
else:
265290
# Recv frames
266-
frames = [
267-
device_array(each_size) if is_cuda else host_array(each_size)
268-
for is_cuda, each_size in zip(cuda_frames, sizes)
269-
]
270-
recv_frames = [
271-
each_frame for each_frame in frames if len(each_frame) > 0
272-
]
291+
host_frames = host_array(
292+
sum(
293+
each_size
294+
for is_cuda, each_size in zip(cuda_frames, sizes)
295+
if not is_cuda
296+
)
297+
)
298+
device_frames = device_array(
299+
sum(
300+
each_size
301+
for is_cuda, each_size in zip(cuda_frames, sizes)
302+
if is_cuda
303+
)
304+
)
273305

274306
# It is necessary to first populate `frames` with CUDA arrays and synchronize
275307
# the default stream before starting receiving to ensure buffers have been allocated
276308
if any(cuda_frames):
277309
synchronize_stream(0)
278310

279-
for each_frame in recv_frames:
280-
await self.ep.recv(each_frame)
311+
if nbytes(host_frames):
312+
await self.ep.recv(host_frames)
313+
if nbytes(device_frames):
314+
await self.ep.recv(device_frames)
315+
316+
frames = [
317+
device_array(each_size) if is_cuda else host_array(each_size)
318+
for is_cuda, each_size in zip(cuda_frames, sizes)
319+
]
320+
host_frames_view = memoryview(host_frames)
321+
device_frames_view = as_device_array(device_frames)
322+
for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes):
323+
if each_size:
324+
if is_cuda:
325+
each_frame_view = as_device_array(each_frame)
326+
each_frame_view[:] = device_frames_view[:each_size]
327+
device_frames_view = device_frames_view[each_size:]
328+
else:
329+
each_frame_view = memoryview(each_frame)
330+
each_frame_view[:] = host_frames_view[:each_size]
331+
host_frames_view = host_frames_view[each_size:]
332+
281333
msg = await from_frames(
282334
frames, deserialize=self.deserialize, deserializers=deserializers
283335
)

0 commit comments

Comments
 (0)