Skip to content

Commit 207bf39

Browse files
committed
Refactor: FFT use fft::create everywhere
1 parent fc5ebc2 commit 207bf39

File tree

15 files changed

+134
-124
lines changed

15 files changed

+134
-124
lines changed

apps/aluminumNew/Aluminum.hpp

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,26 +162,18 @@ class Aluminum : public Model {
162162
}
163163

164164
void prepare_operators(double dt) {
165-
const Decomposition &decomp = get_decomposition();
166-
const FFT &fft = get_fft();
167-
World w = decomposition::get_world(decomp);
168-
auto spacing = get_spacing(w);
169-
auto size = get_size(w);
170-
auto dx = spacing[0];
171-
auto dy = spacing[1];
172-
auto dz = spacing[2];
173-
auto Lx = size[0];
174-
auto Ly = size[1];
175-
auto Lz = size[2];
176-
177-
std::array<int, 3> low = get_outbox(fft).low;
178-
std::array<int, 3> high = get_outbox(fft).high;
165+
auto &fft = get_fft();
166+
auto world = get_world();
167+
auto [dx, dy, dz] = get_spacing(world);
168+
auto [Lx, Ly, Lz] = get_size(world);
169+
auto low = get_outbox(fft).low;
170+
auto high = get_outbox(fft).high;
179171

180172
int idx = 0;
181-
const double pi = std::atan(1.0) * 4.0;
182-
const double fx = 2.0 * pi / (dx * Lx);
183-
const double fy = 2.0 * pi / (dy * Ly);
184-
const double fz = 2.0 * pi / (dz * Lz);
173+
double pi = std::atan(1.0) * 4.0;
174+
double fx = 2.0 * pi / (dx * Lx);
175+
double fy = 2.0 * pi / (dy * Ly);
176+
double fz = 2.0 * pi / (dz * Lz);
185177

186178
for (int k = low[2]; k <= high[2]; k++) {
187179
for (int j = low[1]; j <= high[1]; j++) {

apps/tungsten.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,13 @@ class Tungsten : public Model {
102102
}
103103

104104
void prepare_operators(double dt) {
105-
const FFT &fft = get_fft();
106-
const Decomposition &decomp = get_decomposition();
107-
World w = decomposition::get_world(decomp);
108-
auto spacing = get_spacing(w);
109-
auto size = get_size(w);
110-
auto dx = spacing[0];
111-
auto dy = spacing[1];
112-
auto dz = spacing[2];
113-
auto Lx = size[0];
114-
auto Ly = size[1];
115-
auto Lz = size[2];
116-
117-
std::array<int, 3> low = get_outbox(fft).low;
118-
std::array<int, 3> high = get_outbox(fft).high;
105+
auto &fft = get_fft();
106+
auto &world = get_world();
107+
auto [dx, dy, dz] = get_spacing(world);
108+
auto [Lx, Ly, Lz] = get_size(world);
109+
110+
auto low = get_outbox(fft).low;
111+
auto high = get_outbox(fft).high;
119112

120113
int idx = 0;
121114
const double pi = std::atan(1.0) * 4.0;

examples/02_domain_decomposition.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ int main(int argc, char *argv[]) {
3333
int comm_rank = 0, comm_size = 2;
3434
auto world1 = world::create({32, 4, 4});
3535
auto decomp1 = make_decomposition(world1, comm_rank, comm_size);
36-
// cout << decomp1 << endl;
36+
cout << decomp1 << endl;
3737

3838
// In practice, we let MPI communicator to decide the number of subdomains.
3939
MPI_Init(&argc, &argv);

examples/04_diffusion_model.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,8 @@ class Diffusion : public Model {
8383
if (rank0) cout << "Allocate space" << endl;
8484

8585
// Get references to world, fft and domain decomposition
86-
const Decomposition &decomp = get_decomposition();
87-
const World &w = decomposition::get_world(decomp);
88-
FFT &fft = get_fft();
86+
auto &world = get_world();
87+
auto &fft = get_fft();
8988

9089
// Allocate space for the main variable and it's fourier transform
9190
psi.resize(fft.size_inbox());
@@ -104,30 +103,33 @@ class Diffusion : public Model {
104103
World is defining the global dimensions of the problem as well as origin and
105104
chosen discretization parameters.
106105
*/
107-
if (rank0) cout << "World: " << w << endl;
106+
if (rank0) cout << "World: " << world << endl;
108107

109108
/*
110109
Upper and lower limits for this particular MPI rank, in both inbox and
111-
outbox, are given by domain decomposition object
110+
outbox, are given by fft object
112111
*/
113-
Int3 i_low = get_inbox(fft).low;
114-
Int3 i_high = get_inbox(fft).high;
115-
Int3 o_low = get_outbox(fft).low;
116-
Int3 o_high = get_outbox(fft).high;
112+
auto i_low = get_inbox(fft).low;
113+
auto i_high = get_inbox(fft).high;
114+
auto o_low = get_outbox(fft).low;
115+
auto o_high = get_outbox(fft).high;
117116

118117
/*
119118
Typically initial conditions are constructed elsewhere. However, to keep
120119
things as simple as possible, the initial condition can be also constructed
121120
here.
122121
*/
123122
if (rank0) cout << "Create initial condition" << endl;
123+
124+
auto size = get_size(world);
125+
auto origin = get_origin(world);
126+
auto spacing = get_spacing(world);
127+
124128
int idx = 0;
125129
double D = 1.0;
126130
for (int k = i_low[2]; k <= i_high[2]; k++) {
127131
for (int j = i_low[1]; j <= i_high[1]; j++) {
128132
for (int i = i_low[0]; i <= i_high[0]; i++) {
129-
auto origin = get_origin(w);
130-
auto spacing = get_spacing(w);
131133
double x = origin[0] + i * spacing[0];
132134
double y = origin[1] + j * spacing[1];
133135
double z = origin[2] + k * spacing[2];
@@ -144,8 +146,6 @@ class Diffusion : public Model {
144146
if (rank0) cout << "Prepare operators" << endl;
145147
idx = 0;
146148
double pi = std::atan(1.0) * 4.0;
147-
auto spacing = get_spacing(w);
148-
auto size = get_size(w);
149149
double fx = 2.0 * pi / (spacing[0] * size[0]);
150150
double fy = 2.0 * pi / (spacing[1] * size[1]);
151151
double fz = 2.0 * pi / (spacing[2] * size[2]);

examples/05_simulator.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,8 @@ class Diffusion : public Model {
9898
}
9999

100100
void prepare_operators(double dt) {
101-
const Decomposition &decomp = get_decomposition();
102-
const World &w = decomposition::get_world(decomp);
103-
FFT &fft = get_fft();
101+
auto &w = get_world();
102+
auto &fft = get_fft();
104103
std::array<int, 3> low = get_outbox(fft).low;
105104
std::array<int, 3> high = get_outbox(fft).high;
106105

examples/09_parallel_fft_high_level.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ int main(int argc, char *argv[]) {
3535
// Create output array to store FFT results. If requested array is of type T =
3636
// complex<double>, then array will be constructed using complex indices so
3737
// that it matches the Fourier-space, i.e. first dimension is floor(Lx/2) + 1.
38-
Array<complex<double>, 3> output(get_outbox_size(fft));
38+
Array<complex<double>, 3> output(get_outbox(fft).size);
3939

4040
std::cout << "input: " << input << std::endl; // this is {4, 3, 2}
4141
std::cout << "output: " << output << std::endl; // this is {3, 3, 2}

examples/12_cahn_hilliard.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ int main(int argc, char **argv) {
141141
// file_count
142142
writer.set_uri(sprintf("cahn_hilliard_%04i.vti", file_count));
143143
writer.set_field_name("concentration");
144-
writer.set_domain(get_size(world), get_inbox_size(fft), get_inbox_offset(fft));
144+
writer.set_domain(get_size(world), get_inbox(fft).size, get_inbox(fft).low);
145145
writer.set_origin(get_origin(world));
146146
writer.set_spacing(get_spacing(world));
147147
writer.initialize();

examples/diffusion_model_with_custom_initial_condition.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,23 @@ class Gaussian : public FieldModifier {
4949
if (m.rank0) {
5050
cout << "Applying custom initial condition at time " << t << endl;
5151
}
52-
const World &w = m.get_world();
53-
const Decomposition &d = m.get_decomposition();
54-
Field &f = m.get_field();
55-
auto low = get_inbox(d).low;
56-
auto high = get_inbox(d).high;
52+
auto &world = m.get_world();
53+
auto &field = m.get_field();
54+
auto &fft = m.get_fft();
55+
auto origin = get_origin(world);
56+
auto spacing = get_spacing(world);
57+
58+
auto low = get_inbox(fft).low;
59+
auto high = get_inbox(fft).high;
5760
long int idx = 0;
5861

5962
for (int k = low[2]; k <= high[2]; k++) {
6063
for (int j = low[1]; j <= high[1]; j++) {
6164
for (int i = low[0]; i <= high[0]; i++) {
62-
auto origin = get_origin(w);
63-
auto spacing = get_spacing(w);
6465
double x = origin[0] + i * spacing[0];
6566
double y = origin[1] + j * spacing[1];
6667
double z = origin[2] + k * spacing[2];
67-
f[idx] = exp(-(x * x + y * y + z * z) / (4.0 * m_D));
68+
field[idx] = exp(-(x * x + y * y + z * z) / (4.0 * m_D));
6869
idx += 1;
6970
}
7071
}

include/openpfc/array.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ template <typename T, size_t D> class Array {
3434
* @param fft
3535
*/
3636
Array(const FFT &fft, std::true_type)
37-
: index(get_outbox_size(fft), get_outbox_offset(fft)) {}
37+
: index(get_outbox(fft).size, get_outbox(fft).low) {}
3838

3939
/**
4040
* @brief Construct a new Array object from FFT, using inbox, if
@@ -43,7 +43,7 @@ template <typename T, size_t D> class Array {
4343
* @param decomp
4444
*/
4545
Array(const FFT &fft, std::false_type)
46-
: index(get_inbox_size(fft), get_inbox_offset(fft)) {}
46+
: index(get_inbox(fft).size, get_inbox(fft).low) {}
4747

4848
// Custom type trait to check if a type is complex
4949
template <typename U> struct is_complex : std::false_type {};

include/openpfc/fft.hpp

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,34 @@ struct FFTLayout {
4646
const FFTLayout create(const Decomposition &decomposition, int r2c_direction,
4747
int num_domains);
4848

49+
inline auto get_real_box(const FFTLayout &layout, int i) {
50+
return layout.m_real_boxes.at(i);
51+
}
52+
53+
inline auto get_complex_box(const FFTLayout &layout, int i) {
54+
return layout.m_complex_boxes.at(i);
55+
}
56+
4957
} // namespace layout
5058

51-
using heffte::box3d;
52-
using pfc::types::Bool3;
5359
using pfc::types::Int3;
5460
using pfc::types::Real3;
55-
using Box3D = heffte::box3d<int>; ///< Type alias for 3D integer box.
5661

57-
inline heffte::fft3d_r2c<heffte::backend::fftw>
58-
make_fft(const Box3D &inbox, const Box3D &outbox, const int &r2c_direction,
59-
MPI_Comm comm, heffte::plan_options plan_options) {
60-
return heffte::fft3d_r2c<heffte::backend::fftw>(inbox, outbox, r2c_direction, comm,
61-
plan_options);
62-
}
62+
using Decomposition = pfc::decomposition::Decomposition<pfc::csys::CartesianTag>;
63+
using ComplexVector = std::vector<std::complex<double>>;
64+
using fft_r2c = heffte::fft3d_r2c<heffte::backend::fftw>;
65+
using box3di = heffte::box3d<int>; ///< Type alias for 3D integer box.
6366

6467
/**
6568
* @brief FFT class for performing forward and backward Fast Fourier
6669
* Transformations.
6770
*/
6871
struct FFT {
6972

70-
using ComplexVector = std::vector<std::complex<double>>;
73+
// const Decomposition m_decomposition; /**< The Decomposition object. */
74+
// const box3di m_inbox, m_outbox; /**< Local inbox and outbox boxes. */
7175

72-
const Decomposition m_decomposition; /**< The Decomposition object. */
73-
const heffte::box3d<int> m_inbox, m_outbox; /**< Local inbox and outbox boxes. */
74-
const int m_r2c_direction; /**< Real-to-complex symmetry direction. */
75-
const heffte::fft3d_r2c<heffte::backend::fftw> m_fft; /**< HeFFTe FFT object. */
76+
const fft_r2c m_fft; /**< HeFFTe FFT object. */
7677
ComplexVector m_wrk; /**< Workspace vector for FFT computations. */
7778
double m_fft_time = 0.0; /**< Recorded FFT computation time. */
7879

@@ -86,21 +87,15 @@ struct FFT {
8687
* @param plan_options Optional plan options for configuring the FFT behavior.
8788
* @param world The World object providing the domain size information.
8889
*/
89-
FFT(const Decomposition &decomposition, MPI_Comm comm,
90-
heffte::plan_options plan_options)
91-
: m_decomposition(decomposition), m_inbox(decomposition.m_inbox),
92-
m_outbox(decomposition.m_outbox), m_r2c_direction(0),
93-
m_fft(make_fft(m_inbox, m_outbox, m_r2c_direction, comm, plan_options)),
94-
m_wrk(std::vector<std::complex<double>>(m_fft.size_workspace())){};
90+
FFT(fft_r2c fft) : m_fft(std::move(fft)), m_wrk(m_fft.size_workspace()) {}
9591

9692
/**
9793
* @brief Performs the forward FFT transformation.
9894
*
9995
* @param in Input vector of real values.
10096
* @param out Output vector of complex values.
10197
*/
102-
void forward(const std::vector<double> &in,
103-
std::vector<std::complex<double>> &out) {
98+
void forward(const std::vector<double> &in, ComplexVector &out) {
10499
m_fft_time -= MPI_Wtime();
105100
m_fft.forward(in.data(), out.data(), m_wrk.data());
106101
m_fft_time += MPI_Wtime();
@@ -112,8 +107,7 @@ struct FFT {
112107
* @param in Input vector of complex values.
113108
* @param out Output vector of real values.
114109
*/
115-
void backward(const std::vector<std::complex<double>> &in,
116-
std::vector<double> &out) {
110+
void backward(const ComplexVector &in, std::vector<double> &out) {
117111
m_fft_time -= MPI_Wtime();
118112
m_fft.backward(in.data(), out.data(), m_wrk.data(), heffte::scale::full);
119113
m_fft_time += MPI_Wtime();
@@ -136,7 +130,7 @@ struct FFT {
136130
*
137131
* @return Reference to the Decomposition object.
138132
*/
139-
const Decomposition &get_decomposition() { return m_decomposition; }
133+
// const Decomposition &get_decomposition() { return m_decomposition; }
140134

141135
/**
142136
* @brief Returns the size of the inbox used for FFT computations.
@@ -160,24 +154,36 @@ struct FFT {
160154
size_t size_workspace() const { return m_fft.size_workspace(); }
161155
};
162156

163-
inline const Box3D &get_inbox(const FFT &fft) noexcept { return fft.m_inbox; }
164-
inline const Box3D &get_outbox(const FFT &fft) noexcept { return fft.m_outbox; }
165-
inline const Int3 &get_inbox_size(const FFT &fft) noexcept {
166-
return get_inbox(fft).size;
167-
}
168-
inline const Int3 &get_inbox_offset(const FFT &fft) noexcept {
169-
return get_inbox(fft).low;
170-
}
171-
inline const Int3 &get_outbox_size(const FFT &fft) noexcept {
172-
return get_outbox(fft).size;
157+
inline const auto &get_fft_object(const FFT &fft) noexcept { return fft.m_fft; }
158+
159+
inline const auto get_inbox(const FFT &fft) noexcept {
160+
return get_fft_object(fft).inbox();
173161
}
174-
inline const Int3 &get_outbox_offset(const FFT &fft) noexcept {
175-
return get_outbox(fft).low;
162+
163+
inline const auto get_outbox(const FFT &fft) noexcept {
164+
return get_fft_object(fft).outbox();
176165
}
177166

178-
FFT create(const Decomposition &decomposition);
167+
/**
168+
* @brief Creates an FFT object based on the given decomposition and MPI
169+
* communicator.
170+
*
171+
* @param decomposition The Decomposition object defining the domain
172+
* decomposition.
173+
* @param comm The MPI communicator for parallel computations.
174+
* @param options Optional plan options for configuring the FFT behavior.
175+
* @return An FFT object containing the FFT configuration and data.
176+
*/
179177
FFT create(const Decomposition &decomposition, MPI_Comm comm,
180178
heffte::plan_options options);
179+
/**
180+
* @brief Creates an FFT object based on the given decomposition.
181+
*
182+
* @param decomposition The Decomposition object defining the domain
183+
* decomposition.
184+
* @return An FFT object containing the FFT configuration and data.
185+
*/
186+
FFT create(const Decomposition &decomposition);
181187

182188
} // namespace fft
183189

0 commit comments

Comments
 (0)