Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ struct Info
//! batch size.
bool twod_mode = false;

//! We might have a special twod_mode: nx or ny == 1 && nz > 1.
bool oned_mode = false;

//! Batched FFT size. Only support in R2C, not R2X.
int batch_size = 1;

Expand All @@ -77,6 +80,7 @@ struct Info
Info& setDomainStrategy (DomainStrategy s) { domain_strategy = s; return *this; }
Info& setPencilThreshold (int t) { pencil_threshold = t; return *this; }
Info& setTwoDMode (bool x) { twod_mode = x; return *this; }
Info& setOneDMode (bool x) { oned_mode = x; return *this; }
Info& setBatchSize (int bsize) { batch_size = bsize; return *this; }
Info& setNumProcs (int n) { nprocs = n; return *this; }
};
Expand Down
3 changes: 3 additions & 0 deletions Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ public:
bc[2].second != Boundary::periodic);
Info info{};
info.setTwoDMode(true);
if (m_geom.Domain().length(0) == 1 || m_geom.Domain().length(1) == 1) {
info.setOneDMode(true);
}
if (periodic_xy) {
m_r2c = std::make_unique<R2C<typename MF::value_type>>(m_geom.Domain(),
info);
Expand Down
12 changes: 7 additions & 5 deletions Src/FFT/AMReX_FFT_R2C.H
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,12 @@ R2C<T,D,C>::R2C (Box const& domain, Info const& info)
}
}

if (m_info.twod_mode) {
m_slab_decomp = true;
} else if (m_info.domain_strategy == DomainStrategy::slab && (m_real_domain.length(1) > 1)) {
m_slab_decomp = true;
if (!m_info.oned_mode) {
if (m_info.twod_mode) {
m_slab_decomp = true;
} else if (m_info.domain_strategy == DomainStrategy::slab && (m_real_domain.length(1) > 1)) {
m_slab_decomp = true;
}
}

#endif
Expand Down Expand Up @@ -538,7 +540,7 @@ R2C<T,D,C>::R2C (Box const& domain, Info const& info)

#if (AMREX_SPACEDIM >= 2)
DistributionMapping cdmy;
if ((m_real_domain.length(1) > 1) && !m_slab_decomp)
if ((m_real_domain.length(1) > 1) && !m_slab_decomp && !m_info.oned_mode)
{
auto cbay = amrex::decompose(m_spectral_domain_y, nprocs,
{AMREX_D_DECL(false,true,true)}, true);
Expand Down
2 changes: 1 addition & 1 deletion Src/FFT/AMReX_FFT_R2X.H
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ R2X<T>::R2X (Box const& domain,
} // else: x-fft: r2r(m_rx)

#if (AMREX_SPACEDIM >= 2)
if (domain.length(1) > 1) {
if (domain.length(1) > 1 && !m_info.oned_mode) {
if (! m_cx.empty()) {
// copy(m_cx->m_cy)
m_dom_cy = Box(IntVect(0), IntVect(AMREX_D_DECL(m_dom_cx.bigEnd(1),
Expand Down
44 changes: 32 additions & 12 deletions Tests/FFT/Poisson/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,9 @@ std::pair<Real,Real> check_convergence
return {bnorm, rnorm};
}

int main (int argc, char* argv[])
void run_test (AMREX_D_DECL(int n_cell_x, int n_cell_y, int n_cell_z))
{
amrex::Initialize(argc, argv);
{
BL_PROFILE("main");

AMREX_D_TERM(int n_cell_x = 64;,
int n_cell_y = 48;,
int n_cell_z = 128);

AMREX_D_TERM(int max_grid_size_x = 32;,
int max_grid_size_y = 32;,
int max_grid_size_z = 32);
Expand Down Expand Up @@ -168,12 +161,16 @@ int main (int argc, char* argv[])
std::pair<FFT::Boundary,FFT::Boundary>{FFT::Boundary::even,
FFT::Boundary::odd}};

int ncasesx = ncases;
int ncasesy = (AMREX_SPACEDIM > 1) ? ncases : 1;
int ncasesz = (AMREX_SPACEDIM > 2) ? ncases : 1;
if (n_cell_x == 1) { ncasesx = 1; }
if (n_cell_y == 1) { ncasesy = 1; }
if (n_cell_z == 1) { ncasesz = 1; }
int icase = 0;
for (int zcase = 0; zcase < ncasesz; ++zcase) {
for (int ycase = 0; ycase < ncasesy; ++ycase) {
for (int xcase = 0; xcase < ncases ; ++xcase) {
for (int xcase = 0; xcase < ncasesx; ++xcase) {
++icase;
Array<std::pair<FFT::Boundary,FFT::Boundary>,AMREX_SPACEDIM>
fft_bc{AMREX_D_DECL(bcs[xcase],bcs[ycase],bcs[zcase])};
Expand All @@ -200,7 +197,7 @@ int main (int argc, char* argv[])
#ifdef AMREX_USE_FLOAT
auto eps = 2.e-3f;
#else
auto eps = 1.e-11;
auto eps = 2.e-10;
#endif
AMREX_ALWAYS_ASSERT(rnorm < eps*bnorm);
}}}
Expand All @@ -211,7 +208,7 @@ int main (int argc, char* argv[])
icase = 0;
for (int zcase = 1; zcase < ncasesz; ++zcase) { // skip periodic z-direction
for (int ycase = 0; ycase < ncasesy; ++ycase) {
for (int xcase = 0; xcase < ncases ; ++xcase) {
for (int xcase = 0; xcase < ncasesx; ++xcase) {
++icase;
Array<std::pair<FFT::Boundary,FFT::Boundary>,AMREX_SPACEDIM>
fft_bc{bcs[xcase], bcs[ycase], bcs[zcase]};
Expand Down Expand Up @@ -241,10 +238,33 @@ int main (int argc, char* argv[])
#ifdef AMREX_USE_FLOAT
auto eps = 2.e-3f;
#else
auto eps = 1.e-11;
auto eps = 2.e-10;
#endif
AMREX_ALWAYS_ASSERT(rnorm < eps*bnorm);
}}}
#endif
}
}

int main (int argc, char* argv[])
{
amrex::Initialize(argc, argv);
{
BL_PROFILE("main");

AMREX_D_TERM(int n_cell_x = 64;,
int n_cell_y = 48;,
int n_cell_z = 128);

ParmParse pp;
AMREX_D_TERM(pp.query("n_cell_x", n_cell_x);,
pp.query("n_cell_y", n_cell_y);,
pp.query("n_cell_z", n_cell_z));

run_test(AMREX_D_DECL(n_cell_x, n_cell_y, n_cell_z));
#if (AMREX_SPACEDIM == 3)
run_test(n_cell_x, 1, n_cell_z);
run_test( 1, n_cell_y, n_cell_z);
#endif
}

Expand Down
Loading