Skip to content

Commit 1fb28f2

Browse files
Add a copy flag to the np.array <-> context array conversion functions
1 parent 4649fe4 commit 1fb28f2

File tree

4 files changed

+50
-24
lines changed

4 files changed

+50
-24
lines changed

xobjects/context.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,12 +375,30 @@ def get_installed_c_source_paths(self) -> List[str]:
375375
return sources
376376

377377
@abstractmethod
378-
def nparray_to_context_array(self, arr):
379-
return arr
378+
def nparray_to_context_array(self, arr, copy=False):
379+
"""Obtain an array on the context, given a numpy array.
380+
381+
Args:
382+
arr: numpy array
383+
copy: if True, always create a copy in the context. If False,
384+
try to avoid copy if possible (not guaranteed).
385+
386+
Returns:
387+
array on the context
388+
"""
380389

381390
@abstractmethod
382-
def nparray_from_context_array(self, dev_arr):
383-
return dev_arr
391+
def nparray_from_context_array(self, dev_arr, copy=False):
392+
"""Obtain a numpy array, given an array on the context.
393+
394+
Args:
395+
arr: array on the context
396+
copy: if True, always create a copy in the context. If False,
397+
try to avoid copy if possible (not guaranteed).
398+
399+
Returns:
400+
Numpy array
401+
"""
384402

385403
@property
386404
@abstractmethod

xobjects/context_cpu.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -543,33 +543,38 @@ def cffi_module_for_c_types(c_types, containing_dir="."):
543543

544544
return None
545545

546-
def nparray_to_context_array(self, arr):
546+
def nparray_to_context_array(self, arr, copy=False):
547547
"""
548548
Moves a numpy array to the device memory. No action is performed by
549549
this function in the CPU context. The method is provided
550550
so that the CPU context has an identical API to the GPU ones.
551551
552552
Args:
553553
arr (numpy.ndarray): Array to be transferred
554+
copy (bool): If True, a copy of the array is made.
554555
555556
Returns:
556-
numpy.ndarray: The same array (no copy!).
557-
557+
numpy.ndarray: Numpy array with the same data, original or a copy.
558558
"""
559+
if copy:
560+
arr = np.copy(arr)
559561
return arr
560562

561-
def nparray_from_context_array(self, dev_arr):
563+
def nparray_from_context_array(self, dev_arr, copy=False):
562564
"""
563565
Moves an array to the device to a numpy array. No action is performed by
564566
this function in the CPU context. The method is provided so that the CPU
565567
context has an identical API to the GPU ones.
566568
567569
Args:
568570
dev_arr (numpy.ndarray): Array to be transferred
569-
Returns:
570-
numpy.ndarray: The same array (no copy!)
571+
copy (bool): If True, a copy of the array is made.
571572
573+
Returns:
574+
numpy.ndarray: Numpy array with the same data, original or a copy.
572575
"""
576+
if copy:
577+
dev_arr = np.copy(dev_arr)
573578
return dev_arr
574579

575580
@property
@@ -579,7 +584,6 @@ def nplike_lib(self):
579584
through ``nplike_lib`` to keep compatibility with the other contexts.
580585
581586
"""
582-
583587
return np
584588

585589
@property
@@ -589,7 +593,6 @@ def splike_lib(self):
589593
through ``splike_lib`` to keep compatibility with the other contexts.
590594
591595
"""
592-
593596
return sp
594597

595598
def synchronize(self):

xobjects/context_cupy.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,29 +487,32 @@ def build_kernels(
487487
def __str__(self):
488488
return f"{type(self).__name__}:{cupy.cuda.get_device_id()}"
489489

490-
def nparray_to_context_array(self, arr):
490+
def nparray_to_context_array(self, arr, copy=False):
491491
"""
492492
Copies a numpy array to the device memory.
493493
494494
Args:
495495
arr (numpy.ndarray): Array to be transferred
496+
copy (bool): This parameter is ignored for CUDA, as the data lives
497+
on a different device.
496498
497499
Returns:
498500
cupy.ndarray:The same array copied to the device.
499-
500501
"""
501502
dev_arr = cupy.array(arr)
502503
return dev_arr
503504

504-
def nparray_from_context_array(self, dev_arr):
505+
def nparray_from_context_array(self, dev_arr, copy=False):
505506
"""
506507
Copies an array to the device to a numpy array.
507508
508509
Args:
509510
dev_arr (cupy.ndarray): Array to be transferred.
511+
copy (bool): This parameter is ignored for CUDA, as the data lives
512+
on a different device.
513+
510514
Returns:
511515
numpy.ndarray: The same data copied to a numpy array.
512-
513516
"""
514517
return dev_arr.get()
515518

xobjects/context_pyopencl.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,28 +253,30 @@ def __str__(self):
253253
device_id = self.platform.get_devices().index(self.device)
254254
return f"{type(self).__name__}:{platform_id}.{device_id}"
255255

256-
def nparray_to_context_array(self, arr):
257-
"""
258-
Copies a numpy array to the device memory.
256+
def nparray_to_context_array(self, arr, copy=False):
257+
"""Copies a numpy array to the device memory.
258+
259259
Args:
260260
arr (numpy.ndarray): Array to be transferred
261+
copy (bool): This parameter is ignored for OpenCL, as the data lives
262+
on a different device.
261263
262264
Returns:
263265
pyopencl.array.Array:The same array copied to the device.
264-
265266
"""
266267
dev_arr = cla.to_device(self.queue, arr)
267268
return dev_arr
268269

269-
def nparray_from_context_array(self, dev_arr):
270-
"""
271-
Copies an array to the device to a numpy array.
270+
def nparray_from_context_array(self, dev_arr, copy=False):
271+
"""Copies an array to the device to a numpy array.
272272
273273
Args:
274274
dev_arr (pyopencl.array.Array): Array to be transferred.
275+
copy (bool): This parameter is ignored for OpenCL, as the data lives
276+
on a different device.
277+
275278
Returns:
276279
numpy.ndarray: The same data copied to a numpy array.
277-
278280
"""
279281
return dev_arr.get()
280282

0 commit comments

Comments
 (0)