@@ -201,11 +201,34 @@ async def write(
201
201
hasattr (f , "__cuda_array_interface__" ) for f in frames
202
202
)
203
203
sizes = tuple (nbytes (f ) for f in frames )
204
- send_frames = [
205
- each_frame
206
- for each_frame , each_size in zip (frames , sizes )
207
- if each_size
208
- ]
204
+ host_frames = host_array (
205
+ sum (
206
+ each_size
207
+ for is_cuda , each_size in zip (cuda_frames , sizes )
208
+ if not is_cuda
209
+ )
210
+ )
211
+ device_frames = device_array (
212
+ sum (
213
+ each_size
214
+ for is_cuda , each_size in zip (cuda_frames , sizes )
215
+ if is_cuda
216
+ )
217
+ )
218
+
219
+ # Pack frames
220
+ host_frames_view = memoryview (host_frames )
221
+ device_frames_view = as_device_array (device_frames )
222
+ for each_frame , is_cuda , each_size in zip (frames , cuda_frames , sizes ):
223
+ if each_size :
224
+ if is_cuda :
225
+ each_frame_view = as_device_array (each_frame )
226
+ device_frames_view [:each_size ] = each_frame_view [:]
227
+ device_frames_view = device_frames_view [each_size :]
228
+ else :
229
+ each_frame_view = memoryview (each_frame ).cast ("B" )
230
+ host_frames_view [:each_size ] = each_frame_view [:]
231
+ host_frames_view = host_frames_view [each_size :]
209
232
210
233
# Send meta data
211
234
await self .ep .send (struct .pack ("Q" , nframes ))
@@ -223,8 +246,10 @@ async def write(
223
246
if any (cuda_frames ):
224
247
synchronize_stream (0 )
225
248
226
- for each_frame in send_frames :
227
- await self .ep .send (each_frame )
249
+ if nbytes (host_frames ):
250
+ await self .ep .send (host_frames )
251
+ if nbytes (device_frames ):
252
+ await self .ep .send (device_frames )
228
253
return sum (sizes )
229
254
except (ucp .exceptions .UCXBaseException ):
230
255
self .abort ()
@@ -255,21 +280,48 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
255
280
raise CommClosedError ("While reading, the connection was closed" )
256
281
else :
257
282
# Recv frames
258
- frames = [
259
- device_array (each_size ) if is_cuda else host_array (each_size )
260
- for is_cuda , each_size in zip (cuda_frames , sizes )
261
- ]
262
- recv_frames = [
263
- each_frame for each_frame in frames if len (each_frame ) > 0
264
- ]
283
+ host_frames = host_array (
284
+ sum (
285
+ each_size
286
+ for is_cuda , each_size in zip (cuda_frames , sizes )
287
+ if not is_cuda
288
+ )
289
+ )
290
+ device_frames = device_array (
291
+ sum (
292
+ each_size
293
+ for is_cuda , each_size in zip (cuda_frames , sizes )
294
+ if is_cuda
295
+ )
296
+ )
265
297
266
298
# It is necessary to first populate `frames` with CUDA arrays and synchronize
267
299
# the default stream before starting receiving to ensure buffers have been allocated
268
300
if any (cuda_frames ):
269
301
synchronize_stream (0 )
270
302
271
- for each_frame in recv_frames :
272
- await self .ep .recv (each_frame )
303
+ if nbytes (host_frames ):
304
+ await self .ep .recv (host_frames )
305
+ if nbytes (device_frames ):
306
+ await self .ep .recv (device_frames )
307
+
308
+ frames = [
309
+ device_array (each_size ) if is_cuda else host_array (each_size )
310
+ for is_cuda , each_size in zip (cuda_frames , sizes )
311
+ ]
312
+ host_frames_view = memoryview (host_frames )
313
+ device_frames_view = as_device_array (device_frames )
314
+ for each_frame , is_cuda , each_size in zip (frames , cuda_frames , sizes ):
315
+ if each_size :
316
+ if is_cuda :
317
+ each_frame_view = as_device_array (each_frame )
318
+ each_frame_view [:] = device_frames_view [:each_size ]
319
+ device_frames_view = device_frames_view [each_size :]
320
+ else :
321
+ each_frame_view = memoryview (each_frame )
322
+ each_frame_view [:] = host_frames_view [:each_size ]
323
+ host_frames_view = host_frames_view [each_size :]
324
+
273
325
msg = await from_frames (
274
326
frames , deserialize = self .deserialize , deserializers = deserializers
275
327
)
0 commit comments