@@ -201,11 +201,34 @@ async def write(
201
201
hasattr (f , "__cuda_array_interface__" ) for f in frames
202
202
)
203
203
sizes = tuple (nbytes (f ) for f in frames )
204
- send_frames = [
205
- each_frame
206
- for each_frame , each_size in zip (frames , sizes )
207
- if each_size
208
- ]
204
+ host_frames = host_array (
205
+ sum (
206
+ each_size
207
+ for is_cuda , each_size in zip (cuda_frames , sizes )
208
+ if not is_cuda
209
+ )
210
+ )
211
+ device_frames = device_array (
212
+ sum (
213
+ each_size
214
+ for is_cuda , each_size in zip (cuda_frames , sizes )
215
+ if is_cuda
216
+ )
217
+ )
218
+
219
+ # Pack frames
220
+ host_frames_view = memoryview (host_frames )
221
+ device_frames_view = as_device_array (device_frames )
222
+ for each_frame , is_cuda , each_size in zip (frames , cuda_frames , sizes ):
223
+ if each_size :
224
+ if is_cuda :
225
+ each_frame_view = as_device_array (each_frame )
226
+ device_frames_view [:each_size ] = each_frame_view [:]
227
+ device_frames_view = device_frames_view [each_size :]
228
+ else :
229
+ each_frame_view = memoryview (each_frame ).cast ("B" )
230
+ host_frames_view [:each_size ] = each_frame_view [:]
231
+ host_frames_view = host_frames_view [each_size :]
209
232
210
233
# Send meta data
211
234
@@ -227,8 +250,10 @@ async def write(
227
250
if any (cuda_frames ):
228
251
synchronize_stream (0 )
229
252
230
- for each_frame in send_frames :
231
- await self .ep .send (each_frame )
253
+ if nbytes (host_frames ):
254
+ await self .ep .send (host_frames )
255
+ if nbytes (device_frames ):
256
+ await self .ep .send (device_frames )
232
257
return sum (sizes )
233
258
except (ucp .exceptions .UCXBaseException ):
234
259
self .abort ()
@@ -263,21 +288,48 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
263
288
raise CommClosedError ("While reading, the connection was closed" )
264
289
else :
265
290
# Recv frames
266
- frames = [
267
- device_array (each_size ) if is_cuda else host_array (each_size )
268
- for is_cuda , each_size in zip (cuda_frames , sizes )
269
- ]
270
- recv_frames = [
271
- each_frame for each_frame in frames if len (each_frame ) > 0
272
- ]
291
+ host_frames = host_array (
292
+ sum (
293
+ each_size
294
+ for is_cuda , each_size in zip (cuda_frames , sizes )
295
+ if not is_cuda
296
+ )
297
+ )
298
+ device_frames = device_array (
299
+ sum (
300
+ each_size
301
+ for is_cuda , each_size in zip (cuda_frames , sizes )
302
+ if is_cuda
303
+ )
304
+ )
273
305
274
306
# It is necessary to first populate `frames` with CUDA arrays and synchronize
275
307
# the default stream before starting receiving to ensure buffers have been allocated
276
308
if any (cuda_frames ):
277
309
synchronize_stream (0 )
278
310
279
- for each_frame in recv_frames :
280
- await self .ep .recv (each_frame )
311
+ if nbytes (host_frames ):
312
+ await self .ep .recv (host_frames )
313
+ if nbytes (device_frames ):
314
+ await self .ep .recv (device_frames )
315
+
316
+ frames = [
317
+ device_array (each_size ) if is_cuda else host_array (each_size )
318
+ for is_cuda , each_size in zip (cuda_frames , sizes )
319
+ ]
320
+ host_frames_view = memoryview (host_frames )
321
+ device_frames_view = as_device_array (device_frames )
322
+ for each_frame , is_cuda , each_size in zip (frames , cuda_frames , sizes ):
323
+ if each_size :
324
+ if is_cuda :
325
+ each_frame_view = as_device_array (each_frame )
326
+ each_frame_view [:] = device_frames_view [:each_size ]
327
+ device_frames_view = device_frames_view [each_size :]
328
+ else :
329
+ each_frame_view = memoryview (each_frame )
330
+ each_frame_view [:] = host_frames_view [:each_size ]
331
+ host_frames_view = host_frames_view [each_size :]
332
+
281
333
msg = await from_frames (
282
334
frames , deserialize = self .deserialize , deserializers = deserializers
283
335
)
0 commit comments