Skip to content

Refactor Array4: Compute strides on the fly #4420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions Src/AmrCore/AMReX_MFInterp_C.H
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Real mf_compute_slopes_x (int i, int j, int k, Array4<Real const> const& u, int
Real dc = Real(0.5) * (u(i+1,j,k,nu) - u(i-1,j,k,nu));
if (i == domain.smallEnd(0) && (bc.lo(0) == BCType::ext_dir ||
bc.lo(0) == BCType::hoextrap)) {
if (i+2 < u.end.x) {
if (i+2 < (u.begin.x+u.len.x)) {
dc = -Real(16./15.)*u(i-1,j,k,nu) + Real(0.5)*u(i,j,k,nu)
+ Real(2./3.)*u(i+1,j,k,nu) - Real(0.1)*u(i+2,j,k,nu);
} else {
Expand All @@ -40,7 +40,7 @@ Real mf_compute_slopes_y (int i, int j, int k, Array4<Real const> const& u, int
Real dc = Real(0.5) * (u(i,j+1,k,nu) - u(i,j-1,k,nu));
if (j == domain.smallEnd(1) && (bc.lo(1) == BCType::ext_dir ||
bc.lo(1) == BCType::hoextrap)) {
if (j+2 < u.end.y) {
if (j+2 < (u.begin.y+u.len.y)) {
dc = -Real(16./15.)*u(i,j-1,k,nu) + Real(0.5)*u(i,j,k,nu)
+ Real(2./3.)*u(i,j+1,k,nu) - Real(0.1)*u(i,j+2,k,nu);
} else {
Expand All @@ -66,7 +66,7 @@ Real mf_compute_slopes_z (int i, int j, int k, Array4<Real const> const& u, int
Real dc = Real(0.5) * (u(i,j,k+1,nu) - u(i,j,k-1,nu));
if (k == domain.smallEnd(2) && (bc.lo(2) == BCType::ext_dir ||
bc.lo(2) == BCType::hoextrap)) {
if (k+2 < u.end.z) {
if (k+2 < (u.begin.z+u.len.z)) {
dc = -Real(16./15.)*u(i,j,k-1,nu) + Real(0.5)*u(i,j,k,nu)
+ Real(2./3.)*u(i,j,k+1,nu) - Real(0.1)*u(i,j,k+2,nu);
} else {
Expand All @@ -93,7 +93,7 @@ Real mf_cell_quadratic_compute_slopes_xx (int i, int j, int k,
Real xx = u(i-1,j,k,nu) - 2.0_rt * u(i,j,k,nu) + u(i+1,j,k,nu);
if (i == domain.smallEnd(0) && (bc.lo(0) == BCType::ext_dir ||
bc.lo(0) == BCType::hoextrap)) {
if (i+2 < u.end.x) {
if (i+2 < (u.begin.x+u.len.x)) {
xx = 0._rt;
}
}
Expand All @@ -114,7 +114,7 @@ Real mf_cell_quadratic_compute_slopes_yy (int i, int j, int k,
Real yy = u(i,j-1,k,nu) - 2.0_rt * u(i,j,k,nu) + u(i,j+1,k,nu);
if (j == domain.smallEnd(1) && (bc.lo(1) == BCType::ext_dir ||
bc.lo(1) == BCType::hoextrap)) {
if (j+2 < u.end.y) {
if (j+2 < (u.begin.y+u.len.y)) {
yy = 0._rt;
}
}
Expand All @@ -135,7 +135,7 @@ Real mf_cell_quadratic_compute_slopes_zz (int i, int j, int k,
Real zz = u(i,j,k-1,nu) - 2.0_rt * u(i,j,k,nu) + u(i,j,k+1,nu);
if (k == domain.smallEnd(2) && (bc.lo(2) == BCType::ext_dir ||
bc.lo(2) == BCType::hoextrap)) {
if (k+2 < u.end.z) {
if (k+2 < (u.begin.z+u.len.z)) {
zz = 0._rt;
}
}
Expand All @@ -157,7 +157,7 @@ Real mf_cell_quadratic_compute_slopes_xy (int i, int j, int k,
- u(i-1,j+1,k,nu) + u(i+1,j+1,k,nu) );
if (i == domain.smallEnd(0) && (bc.lo(0) == BCType::ext_dir ||
bc.lo(0) == BCType::hoextrap)) {
if (i+2 < u.end.x) {
if (i+2 < (u.begin.x+u.len.x)) {
xy = 0._rt;
}
}
Expand All @@ -169,7 +169,7 @@ Real mf_cell_quadratic_compute_slopes_xy (int i, int j, int k,
}
if (j == domain.smallEnd(1) && (bc.lo(1) == BCType::ext_dir ||
bc.lo(1) == BCType::hoextrap)) {
if (j+2 < u.end.y) {
if (j+2 < (u.begin.y+u.len.y)) {
xy = 0._rt;
}
}
Expand All @@ -191,7 +191,7 @@ Real mf_cell_quadratic_compute_slopes_xz (int i, int j, int k,
- u(i-1,j,k+1,nu) + u(i+1,j,k+1,nu) );
if (i == domain.smallEnd(0) && (bc.lo(0) == BCType::ext_dir ||
bc.lo(0) == BCType::hoextrap)) {
if (i+2 < u.end.x) {
if (i+2 < (u.begin.x+u.len.x)) {
xz = 0._rt;
}
}
Expand All @@ -203,7 +203,7 @@ Real mf_cell_quadratic_compute_slopes_xz (int i, int j, int k,
}
if (k == domain.smallEnd(2) && (bc.lo(2) == BCType::ext_dir ||
bc.lo(2) == BCType::hoextrap)) {
if (k+2 < u.end.z) {
if (k+2 < (u.begin.z+u.len.z)) {
xz = 0._rt;
}
}
Expand All @@ -225,7 +225,7 @@ Real mf_cell_quadratic_compute_slopes_yz (int i, int j, int k,
- u(i,j+1,k-1,nu) + u(i,j+1,k+1,nu) );
if (j == domain.smallEnd(1) && (bc.lo(1) == BCType::ext_dir ||
bc.lo(1) == BCType::hoextrap)) {
if (j+2 < u.end.y) {
if (j+2 < (u.begin.y+u.len.y)) {
yz = 0._rt;
}
}
Expand All @@ -237,7 +237,7 @@ Real mf_cell_quadratic_compute_slopes_yz (int i, int j, int k,
}
if (k == domain.smallEnd(2) && (bc.lo(2) == BCType::ext_dir ||
bc.lo(2) == BCType::hoextrap)) {
if (k+2 < u.end.z) {
if (k+2 < (u.begin.z+u.len.z)) {
yz = 0._rt;
}
}
Expand Down
94 changes: 53 additions & 41 deletions Src/Base/AMReX_Array4.H
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,8 @@ namespace amrex {
struct Array4
{
T* AMREX_RESTRICT p;
Long jstride = 0;
Long kstride = 0;
Long nstride = 0;
Dim3 begin{1,1,1};
Dim3 end{0,0,0}; // end is hi + 1
Dim3 len{0,0,0};
int ncomp=0;

AMREX_GPU_HOST_DEVICE
Expand All @@ -74,22 +71,16 @@ namespace amrex {
AMREX_GPU_HOST_DEVICE
constexpr Array4 (Array4<std::remove_const_t<T>> const& rhs) noexcept
: p(rhs.p),
jstride(rhs.jstride),
kstride(rhs.kstride),
nstride(rhs.nstride),
begin(rhs.begin),
end(rhs.end),
len(rhs.len),
ncomp(rhs.ncomp)
{}

AMREX_GPU_HOST_DEVICE
constexpr Array4 (T* a_p, Dim3 const& a_begin, Dim3 const& a_end, int a_ncomp) noexcept
: p(a_p),
jstride(a_end.x-a_begin.x),
kstride(jstride*(a_end.y-a_begin.y)),
nstride(kstride*(a_end.z-a_begin.z)),
begin(a_begin),
end(a_end),
len{a_end.x-a_begin.x, a_end.y-a_begin.y, a_end.z-a_begin.z},
ncomp(a_ncomp)
{}

Expand All @@ -99,12 +90,9 @@ namespace amrex {
std::remove_const_t<U>>,int> = 0>
AMREX_GPU_HOST_DEVICE
constexpr Array4 (Array4<U> const& rhs, int start_comp) noexcept
: p((T*)(rhs.p+start_comp*rhs.nstride)),
jstride(rhs.jstride),
kstride(rhs.kstride),
nstride(rhs.nstride),
: p((T*)(rhs.p+start_comp*rhs.nstride())),
begin(rhs.begin),
end(rhs.end),
len(rhs.len),
ncomp(rhs.ncomp-start_comp)
{}

Expand All @@ -114,25 +102,31 @@ namespace amrex {
std::remove_const_t<U>>,int> = 0>
AMREX_GPU_HOST_DEVICE
constexpr Array4 (Array4<U> const& rhs, int start_comp, int num_comps) noexcept
: p((T*)(rhs.p+start_comp*rhs.nstride)),
jstride(rhs.jstride),
kstride(rhs.kstride),
nstride(rhs.nstride),
: p((T*)(rhs.p+start_comp*rhs.nstride())),
begin(rhs.begin),
end(rhs.end),
len(rhs.len),
ncomp(num_comps)
{}

AMREX_GPU_HOST_DEVICE
explicit operator bool() const noexcept { return p != nullptr; }

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Long jstride () const noexcept { return Long(len.x); }

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Long kstride () const noexcept { return Long(len.x)*Long(len.y); }

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Long nstride () const noexcept { return Long(len.x)*Long(len.y)*Long(len.z); }

template <class U=T, std::enable_if_t<!std::is_void_v<U>,int> = 0>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
U& operator() (int i, int j, int k) const noexcept {
#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
index_assert(i,j,k,0);
#endif
return p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride];
return p[(i-begin.x)+Long(len.x)*((j-begin.y)+Long(len.y)*(k-begin.z))];
}

template <class U=T, std::enable_if_t<!std::is_void_v<U>,int> = 0>
Expand All @@ -141,7 +135,7 @@ namespace amrex {
#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
index_assert(i,j,k,n);
#endif
return p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride];
return p[(i-begin.x)+Long(len.x)*((j-begin.y)+Long(len.y)*((k-begin.z)+Long(len.z)*n))];
}

template <class U=T, std::enable_if_t<!std::is_void_v<U>,int> = 0>
Expand All @@ -150,7 +144,7 @@ namespace amrex {
#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
index_assert(i,j,k,0);
#endif
return p + ((i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride);
return p + ((i-begin.x)+Long(len.x)*((j-begin.y)+Long(len.y)*(k-begin.z)));
}

template <class U=T, std::enable_if_t<!std::is_void_v<U>,int> = 0>
Expand All @@ -159,7 +153,7 @@ namespace amrex {
#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
index_assert(i,j,k,n);
#endif
return p + ((i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride);
return p + ((i-begin.x)+Long(len.x)*((j-begin.y)+Long(len.y)*((k-begin.z)+Long(len.z)*n)));
}

template <class U=T, std::enable_if_t<!std::is_void_v<U>,int> = 0>
Expand Down Expand Up @@ -241,22 +235,22 @@ namespace amrex {

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
std::size_t size () const noexcept {
return this->nstride * this->ncomp;
return this->nstride() * this->ncomp;
}

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
int nComp () const noexcept { return ncomp; }

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
bool contains (int i, int j, int k) const noexcept {
return (i>=begin.x && i<end.x && j>=begin.y && j<end.y && k>=begin.z && k<end.z);
return (i>=begin.x && i<(begin.x+len.x) && j>=begin.y && j<(begin.y+len.y) && k>=begin.z && k<(begin.z+len.z));
}

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
bool contains (IntVect const& iv) const noexcept {
return AMREX_D_TERM( iv[0]>=begin.x && iv[0]<end.x,
&& iv[1]>=begin.y && iv[1]<end.y,
&& iv[2]>=begin.z && iv[2]<end.z);
return AMREX_D_TERM( iv[0]>=begin.x && iv[0]<(begin.x+len.x),
&& iv[1]>=begin.y && iv[1]<(begin.y+len.y),
&& iv[2]>=begin.z && iv[2]<(begin.z+len.z));
}

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Expand All @@ -272,21 +266,21 @@ namespace amrex {
#endif
void index_assert (int i, int j, int k, int n) const
{
if (i<begin.x || i>=end.x || j<begin.y || j>=end.y || k<begin.z || k>=end.z
if (i<begin.x || i>=(begin.x+len.x) || j<begin.y || j>=(begin.y+len.y) || k<begin.z || k>=(begin.z+len.z)
|| n < 0 || n >= ncomp) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%d,%d,%d,%d) is out of bound (%d:%d,%d:%d,%d:%d,0:%d)\n",
i, j, k, n, begin.x, end.x-1, begin.y, end.y-1,
begin.z, end.z-1, ncomp-1);
i, j, k, n, begin.x, (begin.x+len.x)-1, begin.y, (begin.y+len.y)-1,
begin.z, (begin.z+len.z)-1, ncomp-1);
amrex::Abort();
))
AMREX_IF_ON_HOST((
std::stringstream ss;
ss << " (" << i << "," << j << "," << k << "," << n
<< ") is out of bound ("
<< begin.x << ":" << end.x-1 << ","
<< begin.y << ":" << end.y-1 << ","
<< begin.z << ":" << end.z-1 << ","
<< begin.x << ":" << (begin.x+len.x)-1 << ","
<< begin.y << ":" << (begin.y+len.y)-1 << ","
<< begin.z << ":" << (begin.z+len.z)-1 << ","
<< "0:" << ncomp-1 << ")";
amrex::Abort(ss.str());
))
Expand All @@ -296,15 +290,19 @@ namespace amrex {

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
CellData<T> cellData (int i, int j, int k) const noexcept {
return CellData<T>{this->ptr(i,j,k), nstride, ncomp};
return CellData<T>{this->ptr(i,j,k), nstride(), ncomp};
}
};

template <class Tto, class Tfrom>
[[nodiscard]] AMREX_GPU_HOST_DEVICE
Array4<Tto> ToArray4 (Array4<Tfrom> const& a_in) noexcept
{
return Array4<Tto>((Tto*)(a_in.p), a_in.begin, a_in.end, a_in.ncomp);
return Array4<Tto>((Tto*)(a_in.p), a_in.begin,
Dim3{a_in.begin.x + a_in.len.x,
a_in.begin.y + a_in.len.y,
a_in.begin.z + a_in.len.z},
a_in.ncomp);
}

template <class T>
Expand All @@ -318,14 +316,28 @@ namespace amrex {
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Dim3 ubound (Array4<T> const& a) noexcept
{
return Dim3{a.end.x-1,a.end.y-1,a.end.z-1};
return Dim3{a.begin.x+a.len.x-1,a.begin.y+a.len.y-1,a.begin.z+a.len.z-1};
}

template <class T>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Dim3 begin (Array4<T> const& a) noexcept
{
return a.begin;
}

template <class T>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Dim3 end (Array4<T> const& a) noexcept
{
return Dim3{a.begin.x+a.len.x,a.begin.y+a.len.y,a.begin.z+a.len.z};
}

template <class T>
[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Dim3 length (Array4<T> const& a) noexcept
{
return Dim3{a.end.x-a.begin.x,a.end.y-a.begin.y,a.end.z-a.begin.z};
return a.len;
}

template <typename T>
Expand Down
16 changes: 8 additions & 8 deletions Src/Base/AMReX_BaseFab.H
Original file line number Diff line number Diff line change
Expand Up @@ -2013,32 +2013,32 @@ template<class T>
BaseFab<T>::BaseFab (Array4<T> const& a) noexcept
: dptr(a.p),
domain(IntVect(AMREX_D_DECL(a.begin.x,a.begin.y,a.begin.z)),
IntVect(AMREX_D_DECL(a.end.x-1,a.end.y-1,a.end.z-1))),
nvar(a.ncomp), truesize(a.ncomp*a.nstride)
IntVect(AMREX_D_DECL(a.begin.x+a.len.x-1,a.begin.y+a.len.y-1,a.begin.z+a.len.z-1))),
nvar(a.ncomp), truesize(a.ncomp*a.nstride())
{}

template<class T>
BaseFab<T>::BaseFab (Array4<T> const& a, IndexType t) noexcept
: dptr(a.p),
domain(IntVect(AMREX_D_DECL(a.begin.x,a.begin.y,a.begin.z)),
IntVect(AMREX_D_DECL(a.end.x-1,a.end.y-1,a.end.z-1)), t),
nvar(a.ncomp), truesize(a.ncomp*a.nstride)
IntVect(AMREX_D_DECL(a.begin.x+a.len.x-1,a.begin.y+a.len.y-1,a.begin.z+a.len.z-1)), t),
nvar(a.ncomp), truesize(a.ncomp*a.nstride())
{}

template<class T>
BaseFab<T>::BaseFab (Array4<T const> const& a) noexcept
: dptr(const_cast<T*>(a.p)),
domain(IntVect(AMREX_D_DECL(a.begin.x,a.begin.y,a.begin.z)),
IntVect(AMREX_D_DECL(a.end.x-1,a.end.y-1,a.end.z-1))),
nvar(a.ncomp), truesize(a.ncomp*a.nstride)
IntVect(AMREX_D_DECL(a.begin.x+a.len.x-1,a.begin.y+a.len.y-1,a.begin.z+a.len.z-1))),
nvar(a.ncomp), truesize(a.ncomp*a.nstride())
{}

template<class T>
BaseFab<T>::BaseFab (Array4<T const> const& a, IndexType t) noexcept
: dptr(const_cast<T*>(a.p)),
domain(IntVect(AMREX_D_DECL(a.begin.x,a.begin.y,a.begin.z)),
IntVect(AMREX_D_DECL(a.end.x-1,a.end.y-1,a.end.z-1)), t),
nvar(a.ncomp), truesize(a.ncomp*a.nstride)
IntVect(AMREX_D_DECL(a.begin.x+a.len.x-1,a.begin.y+a.len.y-1,a.begin.z+a.len.z-1)), t),
nvar(a.ncomp), truesize(a.ncomp*a.nstride())
{}

template <class T>
Expand Down
2 changes: 1 addition & 1 deletion Src/Base/AMReX_Box.H
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public:
AMREX_GPU_HOST_DEVICE
explicit BoxND (Array4<T> const& a) noexcept
: smallend(a.begin),
bigend(IntVectND<dim>(a.end) - 1)
bigend(IntVectND<dim>(ubound(a)))
{}

// dtor, copy-ctor, copy-op=, move-ctor, and move-op= are compiler generated.
Expand Down
Loading
Loading