Skip to content

Commit b3f6738

Browse files
authored
Make FFT safe for slabs (#4268)
Support FFT on domains that have one cell in some dimensions. It also supports Poisson solves on slab domains. However, for FFT::PoissonHybrid that treats the z-direction in a special way, the z-direction must have more than one cell.
1 parent bdb4be3 commit b3f6738

File tree

7 files changed

+926
-155
lines changed

7 files changed

+926
-155
lines changed

Src/Base/AMReX_Periodicity.H

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ public:
3232
//! Cell-centered domain Box "infinitely" long in non-periodic directions.
3333
[[nodiscard]] Box Domain () const noexcept;
3434

35+
[[nodiscard]] IntVect const& intVect () const { return period; }
36+
3537
[[nodiscard]] std::vector<IntVect> shiftIntVect (IntVect const& nghost = IntVect(0)) const;
3638

3739
static const Periodicity& NonPeriodic () noexcept;

Src/FFT/AMReX_FFT.cpp

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,235 @@ void hip_execute (rocfft_plan plan, void **in, void **out)
118118
}
119119
#endif
120120

121+
SubHelper::SubHelper (Box const& domain)
122+
{
123+
#if (AMREX_SPACEDIM == 1)
124+
amrex::ignore_unused(domain);
125+
#elif (AMREX_SPACEDIM == 2)
126+
if (domain.length(0) == 1) {
127+
m_case = case_1n;
128+
}
129+
#else
130+
if (domain.length(0) == 1 && domain.length(1) == 1) {
131+
m_case = case_11n;
132+
} else if (domain.length(0) == 1 && domain.length(2) == 1) {
133+
m_case = case_1n1;
134+
} else if (domain.length(0) == 1) {
135+
m_case = case_1nn;
136+
} else if (domain.length(1) == 1) {
137+
m_case = case_n1n;
138+
}
139+
#endif
140+
}
141+
142+
Box SubHelper::make_box (Box const& box) const
143+
{
144+
return Box(make_iv(box.smallEnd()), make_iv(box.bigEnd()), box.ixType());
145+
}
146+
147+
Periodicity SubHelper::make_periodicity (Periodicity const& period) const
148+
{
149+
return Periodicity(make_iv(period.intVect()));
150+
}
151+
152+
bool SubHelper::ghost_safe (IntVect const& ng) const
153+
{
154+
#if (AMREX_SPACEDIM == 1)
155+
amrex::ignore_unused(ng,this);
156+
return true;
157+
#elif (AMREX_SPACEDIM == 2)
158+
if (m_case == case_1n) {
159+
return (ng[0] == 0);
160+
} else {
161+
return true;
162+
}
163+
#else
164+
if (m_case == case_11n) {
165+
return (ng[0] == 0) && (ng[1] == 0);
166+
} else if (m_case == case_1n1) {
167+
return (ng[0] == 0);
168+
} else if (m_case == case_1nn) {
169+
return (ng[0] == 0);
170+
} else if (m_case == case_n1n) {
171+
return (ng[1] == 0);
172+
} else {
173+
return true;
174+
}
175+
#endif
176+
}
177+
178+
IntVect SubHelper::make_iv (IntVect const& iv) const
179+
{
180+
return this->make_array(iv);
181+
}
182+
183+
IntVect SubHelper::make_safe_ghost (IntVect const& ng) const
184+
{
185+
#if (AMREX_SPACEDIM == 1)
186+
amrex::ignore_unused(this);
187+
return ng;
188+
#elif (AMREX_SPACEDIM == 2)
189+
if (m_case == case_1n) {
190+
return IntVect{0,ng[1]};
191+
} else {
192+
return ng;
193+
}
194+
#else
195+
if (m_case == case_11n) {
196+
return IntVect{0,0,ng[2]};
197+
} else if (m_case == case_1n1) {
198+
return IntVect{0,ng[1],ng[2]};
199+
} else if (m_case == case_1nn) {
200+
return IntVect{0,ng[1],ng[2]};
201+
} else if (m_case == case_n1n) {
202+
return IntVect{ng[0],0,ng[2]};
203+
} else {
204+
return ng;
205+
}
206+
#endif
207+
}
208+
209+
BoxArray SubHelper::inverse_boxarray (BoxArray const& ba) const
210+
{ // sub domain order -> original domain order
211+
#if (AMREX_SPACEDIM == 1)
212+
amrex::ignore_unused(this);
213+
return ba;
214+
#elif (AMREX_SPACEDIM == 2)
215+
AMREX_ALWAYS_ASSERT(m_case == case_1n);
216+
BoxList bl = ba.boxList();
217+
// sub domain order: y, x
218+
for (auto& b : bl) {
219+
auto const& lo = b.smallEnd();
220+
auto const& hi = b.bigEnd();
221+
b.setSmall(IntVect(lo[1],lo[0]));
222+
b.setBig (IntVect(hi[1],hi[0]));
223+
}
224+
return BoxArray(std::move(bl));
225+
#else
226+
BoxList bl = ba.boxList();
227+
if (m_case == case_11n) {
228+
// sub domain order: z, x, y
229+
for (auto& b : bl) {
230+
auto const& lo = b.smallEnd();
231+
auto const& hi = b.bigEnd();
232+
b.setSmall(IntVect(lo[1],lo[2],lo[0]));
233+
b.setBig (IntVect(hi[1],hi[2],hi[0]));
234+
}
235+
} else if (m_case == case_1n1) {
236+
// sub domain order: y, x, z
237+
for (auto& b : bl) {
238+
auto const& lo = b.smallEnd();
239+
auto const& hi = b.bigEnd();
240+
b.setSmall(IntVect(lo[1],lo[0],lo[2]));
241+
b.setBig (IntVect(hi[1],hi[0],hi[2]));
242+
}
243+
} else if (m_case == case_1nn) {
244+
// sub domain order: y, z, x
245+
for (auto& b : bl) {
246+
auto const& lo = b.smallEnd();
247+
auto const& hi = b.bigEnd();
248+
b.setSmall(IntVect(lo[2],lo[0],lo[1]));
249+
b.setBig (IntVect(hi[2],hi[0],hi[1]));
250+
}
251+
} else if (m_case == case_n1n) {
252+
// sub domain order: x, z, y
253+
for (auto& b : bl) {
254+
auto const& lo = b.smallEnd();
255+
auto const& hi = b.bigEnd();
256+
b.setSmall(IntVect(lo[0],lo[2],lo[1]));
257+
b.setBig (IntVect(hi[0],hi[2],hi[1]));
258+
}
259+
} else {
260+
amrex::Abort("SubHelper::inverse_boxarray: how did this happen?");
261+
}
262+
return BoxArray(std::move(bl));
263+
#endif
264+
}
265+
266+
IntVect SubHelper::inverse_order (IntVect const& order) const
267+
{
268+
#if (AMREX_SPACEDIM == 1)
269+
amrex::ignore_unused(this);
270+
return order;
271+
#elif (AMREX_SPACEDIM == 2)
272+
amrex::ignore_unused(this);
273+
return IntVect(order[1],order[0]);
274+
#else
275+
auto translate = [&] (int index) -> int
276+
{
277+
int r = index;
278+
if (m_case == case_11n) {
279+
// sub domain order: z, x, y
280+
if (index == 0) {
281+
r = 2;
282+
} else if (index == 1) {
283+
r = 0;
284+
} else {
285+
r = 1;
286+
}
287+
} else if (m_case == case_1n1) {
288+
// sub domain order: y, x, z
289+
if (index == 0) {
290+
r = 1;
291+
} else if (index == 1) {
292+
r = 0;
293+
} else {
294+
r = 2;
295+
}
296+
} else if (m_case == case_1nn) {
297+
// sub domain order: y, z, x
298+
if (index == 0) {
299+
r = 1;
300+
} else if (index == 1) {
301+
r = 2;
302+
} else {
303+
r = 0;
304+
}
305+
} else if (m_case == case_n1n) {
306+
// sub domain order: x, z, y
307+
if (index == 0) {
308+
r = 0;
309+
} else if (index == 1) {
310+
r = 2;
311+
} else {
312+
r = 1;
313+
}
314+
}
315+
return r;
316+
};
317+
318+
IntVect iv;
319+
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
320+
iv[idim] = translate(order[idim]);
321+
}
322+
return iv;
323+
#endif
324+
}
325+
326+
GpuArray<int,3> SubHelper::xyz_order () const
327+
{
328+
#if (AMREX_SPACEDIM == 1)
329+
amrex::ignore_unused(this);
330+
return GpuArray<int,3>{0,1,2};
331+
#elif (AMREX_SPACEDIM == 2)
332+
if (m_case == case_1n) {
333+
return GpuArray<int,3>{1,0,2};
334+
} else {
335+
return GpuArray<int,3>{0,1,2};
336+
}
337+
#else
338+
if (m_case == case_11n) {
339+
return GpuArray<int,3>{1,2,0};
340+
} else if (m_case == case_1n1) {
341+
return GpuArray<int,3>{1,0,2};
342+
} else if (m_case == case_1nn) {
343+
return GpuArray<int,3>{2,0,1};
344+
} else if (m_case == case_n1n) {
345+
return GpuArray<int,3>{0,2,1};
346+
} else {
347+
return GpuArray<int,3>{0,1,2};
348+
}
349+
#endif
350+
}
351+
121352
}

Src/FFT/AMReX_FFT_Helper.H

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
#include <AMReX_DataAllocator.H>
88
#include <AMReX_DistributionMapping.H>
99
#include <AMReX_Enum.H>
10+
#include <AMReX_FabArray.H>
1011
#include <AMReX_Gpu.H>
1112
#include <AMReX_GpuComplex.H>
1213
#include <AMReX_Math.H>
14+
#include <AMReX_Periodicity.H>
1315

1416
#if defined(AMREX_USE_CUDA)
1517
# include <cufft.h>
@@ -1447,6 +1449,83 @@ struct RotateBwd
14471449
}
14481450
};
14491451

1452+
namespace detail
1453+
{
1454+
struct SubHelper
1455+
{
1456+
explicit SubHelper (Box const& domain);
1457+
1458+
[[nodiscard]] Box make_box (Box const& box) const;
1459+
1460+
[[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const;
1461+
1462+
[[nodiscard]] bool ghost_safe (IntVect const& ng) const;
1463+
1464+
// This rearranges the order.
1465+
[[nodiscard]] IntVect make_iv (IntVect const& iv) const;
1466+
1467+
// This keeps the order, but zero out the values in the hidden dimension.
1468+
[[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const;
1469+
1470+
[[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const;
1471+
1472+
[[nodiscard]] IntVect inverse_order (IntVect const& order) const;
1473+
1474+
template <typename T>
1475+
[[nodiscard]] T make_array (T const& a) const
1476+
{
1477+
#if (AMREX_SPACEDIM == 1)
1478+
amrex::ignore_unused(this);
1479+
return a;
1480+
#elif (AMREX_SPACEDIM == 2)
1481+
if (m_case == case_1n) {
1482+
return T{a[1],a[0]};
1483+
} else {
1484+
return a;
1485+
}
1486+
#else
1487+
if (m_case == case_11n) {
1488+
return T{a[2],a[0],a[1]};
1489+
} else if (m_case == case_1n1) {
1490+
return T{a[1],a[0],a[2]};
1491+
} else if (m_case == case_1nn) {
1492+
return T{a[1],a[2],a[0]};
1493+
} else if (m_case == case_n1n) {
1494+
return T{a[0],a[2],a[1]};
1495+
} else {
1496+
return a;
1497+
}
1498+
#endif
1499+
}
1500+
1501+
[[nodiscard]] GpuArray<int,3> xyz_order () const;
1502+
1503+
template <typename FA>
1504+
FA make_alias_mf (FA const& mf)
1505+
{
1506+
BoxList bl = mf.boxArray().boxList();
1507+
for (auto& b : bl) {
1508+
b = make_box(b);
1509+
}
1510+
auto const& ng = make_iv(mf.nGrowVect());
1511+
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false));
1512+
using FAB = typename FA::fab_type;
1513+
for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
1514+
submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr()));
1515+
}
1516+
return submf;
1517+
}
1518+
1519+
#if (AMREX_SPACEDIM == 2)
1520+
enum Case { case_1n, case_other };
1521+
int m_case = case_other;
1522+
#elif (AMREX_SPACEDIM == 3)
1523+
enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
1524+
int m_case = case_other;
1525+
#endif
1526+
};
1527+
}
1528+
14501529
}
14511530

14521531
#endif

0 commit comments

Comments
 (0)