Skip to content

Commit a494e12

Browse files
committed
Houston, we have a batched 3D FFT
1 parent fe8ea29 commit a494e12

1 file changed

Lines changed: 16 additions & 9 deletions

File tree

src/pgen/turbulence.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -607,11 +607,13 @@ void UserMeshWorkBeforeOutput(Mesh *pmesh, ParameterInput *pin,
607607

608608
// vectors with the correct sizes to store the input and output data
609609
// std::vector<Real> input(fft.size_inbox());
610-
parthenon::HostArray1D<Real> input("fft input", fft.size_inbox());
611-
parthenon::HostArray1D<Real> inverse("fft inverse", fft.size_inbox());
612-
parthenon::HostArray1D<std::complex<Real>> output("fft output", fft.size_outbox());
610+
int n_comp = 3;
611+
parthenon::HostArray1D<Real> input("fft input", n_comp * fft.size_inbox());
612+
parthenon::HostArray1D<Real> inverse("fft inverse", n_comp * fft.size_inbox());
613+
parthenon::HostArray1D<std::complex<Real>> output("fft output",
614+
n_comp * fft.size_outbox());
613615
parthenon::HostArray1D<std::complex<Real>> workspace("fft workspace",
614-
fft.size_workspace());
616+
n_comp * fft.size_workspace());
615617
PARTHENON_REQUIRE_THROWS(pmesh->DefaultNumPartitions() == 1,
616618
"Only pack_size=-1 currently supported for heffte.")
617619
auto &md = pmesh->mesh_data.GetOrAdd("base", 0);
@@ -628,11 +630,13 @@ void UserMeshWorkBeforeOutput(Mesh *pmesh, ParameterInput *pin,
628630
const int jj = j - jb.s;
629631
const int ii = i - ib.s;
630632
const int idx = (kk * nx2b + jj) * nx1b + ii;
631-
input(idx) = p(IDN, k, j, i);
633+
input(idx) = p(IV1, k, j, i);
634+
input(idx + fft.size_inbox()) = p(IV2, k, j, i);
635+
input(idx + 2 * fft.size_inbox()) = p(IV3, k, j, i);
632636
realsum += SQR(p(IDN, k, j, i) - 1.0);
633637
}
634638

635-
fft.forward(input.data(), output.data(), workspace.data());
639+
fft.forward(n_comp, input.data(), output.data(), workspace.data());
636640

637641
auto k_max = std::sqrt(SQR(gnx1 / 2) + SQR(gnx2 / 2) + SQR(gnx3 / 2));
638642
Real mysum = 0;
@@ -648,10 +652,13 @@ void UserMeshWorkBeforeOutput(Mesh *pmesh, ParameterInput *pin,
648652
auto k_mag =
649653
static_cast<int>(std::floor(std::sqrt(SQR(k_x) + SQR(k_y) + SQR(k_z))));
650654

651-
auto val = SQR(
652-
std::abs(output[((k - outbox.low[2]) * outbox.size[1] + (j - outbox.low[1])) *
655+
const auto outidx = ((k - outbox.low[2]) * outbox.size[1] + (j - outbox.low[1])) *
653656
outbox.size[0] +
654-
i - outbox.low[0]]));
657+
i - outbox.low[0];
658+
659+
auto val = SQR(std::abs(output[outidx])) +
660+
SQR(std::abs(output[outidx + fft.size_outbox()])) +
661+
SQR(std::abs(output[outidx + 2 * fft.size_outbox()]));
655662
// account for Hermitian symmetry of r2c transform
656663
if ((k_x > 0) && (2 * k_x != gnx1)) {
657664
val *= 2.0;

0 commit comments

Comments
 (0)