@@ -194,11 +194,34 @@ async def write(
194
194
hasattr (f , "__cuda_array_interface__" ) for f in frames
195
195
)
196
196
sizes = tuple (nbytes (f ) for f in frames )
197
- send_frames = [
198
- each_frame
199
- for each_frame , each_size in zip (frames , sizes )
200
- if each_size
201
- ]
197
+ host_frames = host_array (
198
+ sum (
199
+ each_size
200
+ for is_cuda , each_size in zip (cuda_frames , sizes )
201
+ if not is_cuda
202
+ )
203
+ )
204
+ device_frames = device_array (
205
+ sum (
206
+ each_size
207
+ for is_cuda , each_size in zip (cuda_frames , sizes )
208
+ if is_cuda
209
+ )
210
+ )
211
+
212
+ # Pack frames
213
+ host_frames_view = memoryview (host_frames )
214
+ device_frames_view = as_numba_device_array (device_frames )
215
+ for each_frame , is_cuda , each_size in zip (frames , cuda_frames , sizes ):
216
+ if each_size :
217
+ if is_cuda :
218
+ each_frame_view = as_numba_device_array (each_frame )
219
+ device_frames_view [:each_size ] = each_frame_view [:]
220
+ device_frames_view = device_frames_view [each_size :]
221
+ else :
222
+ each_frame_view = memoryview (each_frame ).cast ("B" )
223
+ host_frames_view [:each_size ] = each_frame_view [:]
224
+ host_frames_view = host_frames_view [each_size :]
202
225
203
226
# Send meta data
204
227
await self .ep .send (struct .pack ("Q" , nframes ))
@@ -216,8 +239,10 @@ async def write(
216
239
if any (cuda_frames ):
217
240
synchronize_stream (0 )
218
241
219
- for each_frame in send_frames :
220
- await self .ep .send (each_frame )
242
+ if nbytes (host_frames ):
243
+ await self .ep .send (host_frames )
244
+ if nbytes (device_frames ):
245
+ await self .ep .send (device_frames )
221
246
return sum (sizes )
222
247
except (ucp .exceptions .UCXBaseException ):
223
248
self .abort ()
@@ -248,21 +273,48 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
248
273
raise CommClosedError ("While reading, the connection was closed" )
249
274
else :
250
275
# Recv frames
251
- frames = [
252
- device_array (each_size ) if is_cuda else host_array (each_size )
253
- for is_cuda , each_size in zip (cuda_frames , sizes )
254
- ]
255
- recv_frames = [
256
- each_frame for each_frame in frames if len (each_frame ) > 0
257
- ]
276
+ host_frames = host_array (
277
+ sum (
278
+ each_size
279
+ for is_cuda , each_size in zip (cuda_frames , sizes )
280
+ if not is_cuda
281
+ )
282
+ )
283
+ device_frames = device_array (
284
+ sum (
285
+ each_size
286
+ for is_cuda , each_size in zip (cuda_frames , sizes )
287
+ if is_cuda
288
+ )
289
+ )
258
290
259
291
# It is necessary to first populate `frames` with CUDA arrays and synchronize
260
292
# the default stream before starting receiving to ensure buffers have been allocated
261
293
if any (cuda_frames ):
262
294
synchronize_stream (0 )
263
295
264
- for each_frame in recv_frames :
265
- await self .ep .recv (each_frame )
296
+ if nbytes (host_frames ):
297
+ await self .ep .recv (host_frames )
298
+ if nbytes (device_frames ):
299
+ await self .ep .recv (device_frames )
300
+
301
+ frames = [
302
+ device_array (each_size ) if is_cuda else host_array (each_size )
303
+ for is_cuda , each_size in zip (cuda_frames , sizes )
304
+ ]
305
+ host_frames_view = memoryview (host_frames )
306
+ device_frames_view = as_numba_device_array (device_frames )
307
+ for each_frame , is_cuda , each_size in zip (frames , cuda_frames , sizes ):
308
+ if each_size :
309
+ if is_cuda :
310
+ each_frame_view = as_numba_device_array (each_frame )
311
+ each_frame_view [:] = device_frames_view [:each_size ]
312
+ device_frames_view = device_frames_view [each_size :]
313
+ else :
314
+ each_frame_view = memoryview (each_frame )
315
+ each_frame_view [:] = host_frames_view [:each_size ]
316
+ host_frames_view = host_frames_view [each_size :]
317
+
266
318
msg = await from_frames (
267
319
frames , deserialize = self .deserialize , deserializers = deserializers
268
320
)
0 commit comments