Skip to content

Commit 507c73f

Browse files
authored
FFT::PoissonHybrid: Fix cases with nx or ny being 1 (#4829)
They used to work, but were broken in #4671.
1 parent 26cd0aa commit 507c73f

File tree

5 files changed

+51
-18
lines changed

5 files changed

+51
-18
lines changed

Src/FFT/AMReX_FFT_Helper.H

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ struct Info
6868
//! batch size.
6969
bool twod_mode = false;
7070

71+
//! We might have a special twod_mode: nx or ny == 1 && nz > 1.
72+
bool oned_mode = false;
73+
7174
//! Batched FFT size. Only support in R2C, not R2X.
7275
int batch_size = 1;
7376

@@ -77,6 +80,7 @@ struct Info
7780
Info& setDomainStrategy (DomainStrategy s) { domain_strategy = s; return *this; }
7881
Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
7982
Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
83+
Info& setOneDMode (bool x) { oned_mode = x; return *this; }
8084
Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
8185
Info& setNumProcs (int n) { nprocs = n; return *this; }
8286
};

Src/FFT/AMReX_FFT_Poisson.H

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ public:
172172
bc[2].second != Boundary::periodic);
173173
Info info{};
174174
info.setTwoDMode(true);
175+
if (m_geom.Domain().length(0) == 1 || m_geom.Domain().length(1) == 1) {
176+
info.setOneDMode(true);
177+
}
175178
if (periodic_xy) {
176179
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain(),
177180
info);

Src/FFT/AMReX_FFT_R2C.H

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -502,10 +502,12 @@ R2C<T,D,C>::R2C (Box const& domain, Info const& info)
502502
}
503503
}
504504

505-
if (m_info.twod_mode) {
506-
m_slab_decomp = true;
507-
} else if (m_info.domain_strategy == DomainStrategy::slab && (m_real_domain.length(1) > 1)) {
508-
m_slab_decomp = true;
505+
if (!m_info.oned_mode) {
506+
if (m_info.twod_mode) {
507+
m_slab_decomp = true;
508+
} else if (m_info.domain_strategy == DomainStrategy::slab && (m_real_domain.length(1) > 1)) {
509+
m_slab_decomp = true;
510+
}
509511
}
510512

511513
#endif
@@ -538,7 +540,7 @@ R2C<T,D,C>::R2C (Box const& domain, Info const& info)
538540

539541
#if (AMREX_SPACEDIM >= 2)
540542
DistributionMapping cdmy;
541-
if ((m_real_domain.length(1) > 1) && !m_slab_decomp)
543+
if ((m_real_domain.length(1) > 1) && !m_slab_decomp && !m_info.oned_mode)
542544
{
543545
auto cbay = amrex::decompose(m_spectral_domain_y, nprocs,
544546
{AMREX_D_DECL(false,true,true)}, true);

Src/FFT/AMReX_FFT_R2X.H

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ R2X<T>::R2X (Box const& domain,
187187
} // else: x-fft: r2r(m_rx)
188188

189189
#if (AMREX_SPACEDIM >= 2)
190-
if (domain.length(1) > 1) {
190+
if (domain.length(1) > 1 && !m_info.oned_mode) {
191191
if (! m_cx.empty()) {
192192
// copy(m_cx->m_cy)
193193
m_dom_cy = Box(IntVect(0), IntVect(AMREX_D_DECL(m_dom_cx.bigEnd(1),

Tests/FFT/Poisson/main.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,9 @@ std::pair<Real,Real> check_convergence
108108
return {bnorm, rnorm};
109109
}
110110

111-
int main (int argc, char* argv[])
111+
void run_test (AMREX_D_DECL(int n_cell_x, int n_cell_y, int n_cell_z))
112112
{
113-
amrex::Initialize(argc, argv);
114113
{
115-
BL_PROFILE("main");
116-
117-
AMREX_D_TERM(int n_cell_x = 64;,
118-
int n_cell_y = 48;,
119-
int n_cell_z = 128);
120-
121114
AMREX_D_TERM(int max_grid_size_x = 32;,
122115
int max_grid_size_y = 32;,
123116
int max_grid_size_z = 32);
@@ -168,12 +161,20 @@ int main (int argc, char* argv[])
168161
std::pair<FFT::Boundary,FFT::Boundary>{FFT::Boundary::even,
169162
FFT::Boundary::odd}};
170163

164+
int ncasesx = ncases;
171165
int ncasesy = (AMREX_SPACEDIM > 1) ? ncases : 1;
172166
int ncasesz = (AMREX_SPACEDIM > 2) ? ncases : 1;
167+
if (n_cell_x == 1) { ncasesx = 1; }
168+
#if (AMREX_SPACEDIM > 1)
169+
if (n_cell_y == 1) { ncasesy = 1; }
170+
#endif
171+
#if (AMREX_SPACEDIM == 3)
172+
if (n_cell_z == 1) { ncasesz = 1; }
173+
#endif
173174
int icase = 0;
174175
for (int zcase = 0; zcase < ncasesz; ++zcase) {
175176
for (int ycase = 0; ycase < ncasesy; ++ycase) {
176-
for (int xcase = 0; xcase < ncases ; ++xcase) {
177+
for (int xcase = 0; xcase < ncasesx; ++xcase) {
177178
++icase;
178179
Array<std::pair<FFT::Boundary,FFT::Boundary>,AMREX_SPACEDIM>
179180
fft_bc{AMREX_D_DECL(bcs[xcase],bcs[ycase],bcs[zcase])};
@@ -200,7 +201,7 @@ int main (int argc, char* argv[])
200201
#ifdef AMREX_USE_FLOAT
201202
auto eps = 2.e-3f;
202203
#else
203-
auto eps = 1.e-11;
204+
auto eps = 2.e-10;
204205
#endif
205206
AMREX_ALWAYS_ASSERT(rnorm < eps*bnorm);
206207
}}}
@@ -211,7 +212,7 @@ int main (int argc, char* argv[])
211212
icase = 0;
212213
for (int zcase = 1; zcase < ncasesz; ++zcase) { // skip periodic z-direction
213214
for (int ycase = 0; ycase < ncasesy; ++ycase) {
214-
for (int xcase = 0; xcase < ncases ; ++xcase) {
215+
for (int xcase = 0; xcase < ncasesx; ++xcase) {
215216
++icase;
216217
Array<std::pair<FFT::Boundary,FFT::Boundary>,AMREX_SPACEDIM>
217218
fft_bc{bcs[xcase], bcs[ycase], bcs[zcase]};
@@ -241,10 +242,33 @@ int main (int argc, char* argv[])
241242
#ifdef AMREX_USE_FLOAT
242243
auto eps = 2.e-3f;
243244
#else
244-
auto eps = 1.e-11;
245+
auto eps = 2.e-10;
245246
#endif
246247
AMREX_ALWAYS_ASSERT(rnorm < eps*bnorm);
247248
}}}
249+
#endif
250+
}
251+
}
252+
253+
int main (int argc, char* argv[])
254+
{
255+
amrex::Initialize(argc, argv);
256+
{
257+
BL_PROFILE("main");
258+
259+
AMREX_D_TERM(int n_cell_x = 64;,
260+
int n_cell_y = 48;,
261+
int n_cell_z = 128);
262+
263+
ParmParse pp;
264+
AMREX_D_TERM(pp.query("n_cell_x", n_cell_x);,
265+
pp.query("n_cell_y", n_cell_y);,
266+
pp.query("n_cell_z", n_cell_z));
267+
268+
run_test(AMREX_D_DECL(n_cell_x, n_cell_y, n_cell_z));
269+
#if (AMREX_SPACEDIM == 3)
270+
run_test(n_cell_x, 1, n_cell_z);
271+
run_test( 1, n_cell_y, n_cell_z);
248272
#endif
249273
}
250274

0 commit comments

Comments
 (0)