@@ -607,11 +607,13 @@ void UserMeshWorkBeforeOutput(Mesh *pmesh, ParameterInput *pin,
607607
608608 // vectors with the correct sizes to store the input and output data
609609 // std::vector<Real> input(fft.size_inbox());
610- parthenon::HostArray1D<Real> input (" fft input" , fft.size_inbox ());
611- parthenon::HostArray1D<Real> inverse (" fft inverse" , fft.size_inbox ());
612- parthenon::HostArray1D<std::complex <Real>> output (" fft output" , fft.size_outbox ());
610+ int n_comp = 3 ;
611+ parthenon::HostArray1D<Real> input (" fft input" , n_comp * fft.size_inbox ());
612+ parthenon::HostArray1D<Real> inverse (" fft inverse" , n_comp * fft.size_inbox ());
613+ parthenon::HostArray1D<std::complex <Real>> output (" fft output" ,
614+ n_comp * fft.size_outbox ());
613615 parthenon::HostArray1D<std::complex <Real>> workspace (" fft workspace" ,
614- fft.size_workspace ());
616+ n_comp * fft.size_workspace ());
615617 PARTHENON_REQUIRE_THROWS (pmesh->DefaultNumPartitions () == 1 ,
616618 " Only pack_size=-1 currently supported for heffte." )
617619 auto &md = pmesh->mesh_data .GetOrAdd (" base" , 0 );
@@ -628,11 +630,13 @@ void UserMeshWorkBeforeOutput(Mesh *pmesh, ParameterInput *pin,
628630 const int jj = j - jb.s ;
629631 const int ii = i - ib.s ;
630632 const int idx = (kk * nx2b + jj) * nx1b + ii;
631- input (idx) = p (IDN, k, j, i);
633+ input (idx) = p (IV1, k, j, i);
634+ input (idx + fft.size_inbox ()) = p (IV2, k, j, i);
635+ input (idx + 2 * fft.size_inbox ()) = p (IV3, k, j, i);
632636 realsum += SQR (p (IDN, k, j, i) - 1.0 );
633637 }
634638
635- fft.forward (input.data (), output.data (), workspace.data ());
639+ fft.forward (n_comp, input.data (), output.data (), workspace.data ());
636640
637641 auto k_max = std::sqrt (SQR (gnx1 / 2 ) + SQR (gnx2 / 2 ) + SQR (gnx3 / 2 ));
638642 Real mysum = 0 ;
@@ -648,10 +652,13 @@ void UserMeshWorkBeforeOutput(Mesh *pmesh, ParameterInput *pin,
648652 auto k_mag =
649653 static_cast <int >(std::floor (std::sqrt (SQR (k_x) + SQR (k_y) + SQR (k_z))));
650654
651- auto val = SQR (
652- std::abs (output[((k - outbox.low [2 ]) * outbox.size [1 ] + (j - outbox.low [1 ])) *
655+ const auto outidx = ((k - outbox.low [2 ]) * outbox.size [1 ] + (j - outbox.low [1 ])) *
653656 outbox.size [0 ] +
654- i - outbox.low [0 ]]));
657+ i - outbox.low [0 ];
658+
659+ auto val = SQR (std::abs (output[outidx])) +
660+ SQR (std::abs (output[outidx + fft.size_outbox ()])) +
661+ SQR (std::abs (output[outidx + 2 * fft.size_outbox ()]));
655662 // account for Hermitian symmetry of r2c transform
656663 if ((k_x > 0 ) && (2 * k_x != gnx1)) {
657664 val *= 2.0 ;
0 commit comments