Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/amrex/extensions/MultiFab.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,7 @@ def _get_field(self, mfi):
# The transpose is taken since the Python array interface to Array4 in
# self.array(mfi) is in C ordering.
# Note: transposing creates a view and not a copy.
import inspect

amr = inspect.getmodule(self)
if amr.Config.have_gpu:
device_arr = self.array(mfi).to_cupy(copy=False, order="F")
else:
device_arr = self.array(mfi).to_numpy(copy=False, order="F")
device_arr = self.array(mfi).to_xp(copy=False, order="F")
return device_arr


Expand Down Expand Up @@ -603,15 +597,26 @@ def __setitem__(self, index, value):
value : scalar or array
Input value to assign to the specified slice of the MultiFab
"""
# When filling an array that is a cupy array, the RHS must also be a cupy array.
# This checks if amr was built with GPU, and if so, it must convert the input
# value to a cupy array. Otherwise, it will use a numpy array.
import inspect

amr = inspect.getmodule(self)
if amr.Config.have_gpu:
import cupy as xp
else:
xp = np

index = _process_index(self, index)

if isinstance(value, np.ndarray):
if not np.isscalar(value):
# Expand the shape of the input array to match the shape of the global array
# (it needs to be 4-D).
# This converts value to an array if needed, and the [...] grabs a view so
# that the shape change below doesn't affect value.
value3d = np.array(value)[...]
global_shape = list(value3d.shape)
value4d = xp.array(value)[...]
global_shape = list(value4d.shape)
# The shape of 1 is added for the extra dimensions and when index is an integer
# (in which case the dimension was not in the input array).
if (index[0].stop - index[0].start) == 1:
Expand All @@ -622,15 +627,15 @@ def __setitem__(self, index, value):
global_shape[2:2] = [1]
if (index[3].stop - index[3].start) == 1 or len(global_shape) < 4:
global_shape[3:3] = [1]
value3d.shape = global_shape
value4d.shape = global_shape

for mfi in self:
block_slices, global_slices = _get_intersect_slice(self, mfi, index, True)
if global_slices is not None:
mf_arr = _get_field(self, mfi)
if isinstance(value, np.ndarray):
if not np.isscalar(value):
# The data is copied from host to device automatically if needed
mf_arr[block_slices] = value3d[global_slices]
mf_arr[block_slices] = value4d[global_slices]
else:
mf_arr[block_slices] = value

Expand Down
Loading