Skip to content

Commit 4328ede

Browse files
Fixes in MultiFab setitem (#440)
Previously, `setitem` was only checking for numpy arrays, which was causing errors on a GPU when a cupy array was passed in. This fix makes it more flexible, so now lists can also be used for example. Also, when using GPU, the input value must be converted to a cupy array before the copy can be done. A minor cleanup of `_get_field` is also done, simplifying the code. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 77db770 commit 4328ede

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

src/amrex/extensions/MultiFab.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,7 @@ def _get_field(self, mfi):
395395
# The transpose is taken since the Python array interface to Array4 in
396396
# self.array(mfi) is in C ordering.
397397
# Note: transposing creates a view and not a copy.
398-
import inspect
399-
400-
amr = inspect.getmodule(self)
401-
if amr.Config.have_gpu:
402-
device_arr = self.array(mfi).to_cupy(copy=False, order="F")
403-
else:
404-
device_arr = self.array(mfi).to_numpy(copy=False, order="F")
398+
device_arr = self.array(mfi).to_xp(copy=False, order="F")
405399
return device_arr
406400

407401

@@ -603,15 +597,26 @@ def __setitem__(self, index, value):
603597
value : scalar or array
604598
Input value to assign to the specified slice of the MultiFab
605599
"""
600+
# When filling an array that is a cupy array, the RHS must also be a cupy array.
601+
# This checks if amr was built with GPU, and if so, it must convert the input
602+
# value to a cupy array. Otherwise, it will use a numpy array.
603+
import inspect
604+
605+
amr = inspect.getmodule(self)
606+
if amr.Config.have_gpu:
607+
import cupy as xp
608+
else:
609+
xp = np
610+
606611
index = _process_index(self, index)
607612

608-
if isinstance(value, np.ndarray):
613+
if not np.isscalar(value):
609614
# Expand the shape of the input array to match the shape of the global array
610615
# (it needs to be 4-D).
611616
# This converts value to an array if needed, and the [...] grabs a view so
612617
# that the shape change below doesn't affect value.
613-
value3d = np.array(value)[...]
614-
global_shape = list(value3d.shape)
618+
value4d = xp.array(value)[...]
619+
global_shape = list(value4d.shape)
615620
# The shape of 1 is added for the extra dimensions and when index is an integer
616621
# (in which case the dimension was not in the input array).
617622
if (index[0].stop - index[0].start) == 1:
@@ -622,15 +627,15 @@ def __setitem__(self, index, value):
622627
global_shape[2:2] = [1]
623628
if (index[3].stop - index[3].start) == 1 or len(global_shape) < 4:
624629
global_shape[3:3] = [1]
625-
value3d.shape = global_shape
630+
value4d.shape = global_shape
626631

627632
for mfi in self:
628633
block_slices, global_slices = _get_intersect_slice(self, mfi, index, True)
629634
if global_slices is not None:
630635
mf_arr = _get_field(self, mfi)
631-
if isinstance(value, np.ndarray):
636+
if not np.isscalar(value):
632637
# The data is copied from host to device automatically if needed
633-
mf_arr[block_slices] = value3d[global_slices]
638+
mf_arr[block_slices] = value4d[global_slices]
634639
else:
635640
mf_arr[block_slices] = value
636641

0 commit comments

Comments
 (0)