Skip to content

Commit 841b810

Browse files
committed
Refactor: separate the roles of Decomposition & FFTLayout
1 parent 7d3aad8 commit 841b810

File tree

4 files changed

+107
-45
lines changed

4 files changed

+107
-45
lines changed

include/openpfc/core/decomposition.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ using Int3 = pfc::types::Int3;
143143
struct Decomposition {
144144

145145
const pfc::World &m_global_world; ///< The World object.
146-
const std::array<int, 3> &m_grid; ///< The number of parts in each dimension.
146+
const std::array<int, 3> m_grid; ///< The number of parts in each dimension.
147147
const std::vector<pfc::World> m_subworlds; ///< The sub-worlds for each part.
148148

149149
Decomposition(const World &world, const Int3 grid)
@@ -187,7 +187,13 @@ inline const auto create(const World &world, const Int3 &grid) noexcept {
187187
};
188188

189189
inline const auto create(const World &world, const int &nparts) noexcept {
190-
return create(world, heffte::proc_setup_min_surface(to_indices(world), nparts));
190+
auto indices = to_indices(world);
191+
auto grid = heffte::proc_setup_min_surface(indices, nparts);
192+
return create(world, grid);
193+
}
194+
195+
inline const auto get_num_domains(const Decomposition &decomposition) noexcept {
196+
return get_subworlds(decomposition).size();
191197
}
192198

193199
} // namespace decomposition

include/openpfc/fft.hpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,22 @@ struct FFTLayout {
4040
* @param decomposition The Decomposition object defining the domain
4141
* decomposition.
4242
* @param r2c_direction The direction of real-to-complex symmetry.
43-
* @param num_domains The number of domains for the FFT layout.
4443
* @return An FFTLayout object containing the layout information.
4544
*/
46-
const FFTLayout create(const Decomposition &decomposition, int r2c_direction,
47-
int num_domains);
45+
const FFTLayout create(const Decomposition &decomposition, int r2c_direction);
4846

49-
inline auto get_real_box(const FFTLayout &layout, int i) {
47+
inline const auto &get_real_box(const FFTLayout &layout, int i) {
5048
return layout.m_real_boxes.at(i);
5149
}
5250

53-
inline auto get_complex_box(const FFTLayout &layout, int i) {
51+
inline const auto &get_complex_box(const FFTLayout &layout, int i) {
5452
return layout.m_complex_boxes.at(i);
5553
}
5654

55+
inline auto get_r2c_direction(const FFTLayout &layout) {
56+
return layout.m_r2c_direction;
57+
}
58+
5759
} // namespace layout
5860

5961
using pfc::types::Int3;
@@ -164,24 +166,36 @@ inline const auto get_outbox(const FFT &fft) noexcept {
164166
return get_fft_object(fft).outbox();
165167
}
166168

169+
using heffte::plan_options;
170+
using layout::FFTLayout;
171+
167172
/**
168-
* @brief Creates an FFT object based on the given decomposition and MPI
169-
* communicator.
173+
* @brief Creates an FFT object based on the given FFTLayout and rank ID.
174+
*
175+
* @param fft_layout The FFTLayout object defining the FFT configuration.
176+
* @param rank_id The rank ID of the current process in the MPI communicator.
177+
* @param options Plan options for configuring the FFT behavior.
178+
* @return An FFT object containing the FFT configuration and data.
179+
*/
180+
FFT create(const FFTLayout &fft_layout, int rank_id, plan_options options);
181+
182+
/**
183+
* @brief Creates an FFT object based on the given decomposition and rank ID.
170184
*
171185
* @param decomposition The Decomposition object defining the domain
172186
* decomposition.
173-
* @param comm The MPI communicator for parallel computations.
174-
* @param options Optional plan options for configuring the FFT behavior.
187+
* @param rank_id The rank ID of the current process in the MPI communicator.
175188
* @return An FFT object containing the FFT configuration and data.
176189
*/
177-
FFT create(const Decomposition &decomposition, MPI_Comm comm,
178-
heffte::plan_options options);
190+
FFT create(const Decomposition &decomposition, int rank_id);
191+
179192
/**
180193
* @brief Creates an FFT object based on the given decomposition.
181194
*
182195
* @param decomposition The Decomposition object defining the domain
183196
* decomposition.
184197
* @return An FFT object containing the FFT configuration and data.
198+
* @throws std::logic_error, if decomposition size and rank size do not match.
185199
*/
186200
FFT create(const Decomposition &decomposition);
187201

include/openpfc/ui.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,11 @@ template <class ConcreteModel> class App {
652652
std::cout << "World: " << world << std::endl;
653653

654654
int num_ranks = m_worker.get_num_ranks();
655+
int rank_id = m_worker.get_rank();
655656
auto decomp = decomposition::create(world, num_ranks);
656-
auto plan_options =
657-
ui::from_json<heffte::plan_options>(m_settings["plan_options"]);
658-
auto fft = fft::create(decomp, m_comm, plan_options);
657+
auto options = ui::from_json<heffte::plan_options>(m_settings["plan_options"]);
658+
auto fft_layout = fft::layout::create(decomp, 0);
659+
auto fft = fft::create(fft_layout, rank_id, options);
659660
Time time(ui::from_json<Time>(m_settings));
660661
ConcreteModel model(world);
661662
model.set_fft(fft);

src/openpfc/fft.cpp

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,53 @@ namespace pfc {
77
namespace fft {
88
namespace layout {
99

10-
const FFTLayout create(const Decomposition &decomposition, int r2c_direction,
11-
int num_domains) {
12-
if (num_domains <= 0) {
13-
throw std::logic_error("Cannot construct domain decomposition: !(nprocs > 0)");
10+
#include <array>
11+
#include <iostream>
12+
#include <sstream>
13+
14+
// Helper function to print std::array
15+
template <typename T, std::size_t N>
16+
std::ostream &operator<<(std::ostream &os, const std::array<T, N> &arr) {
17+
os << "{";
18+
for (std::size_t i = 0; i < N; ++i) {
19+
os << arr[i];
20+
if (i < N - 1) {
21+
os << ", ";
22+
}
1423
}
15-
auto world = get_world(decomposition);
16-
auto [N1r, N2r, N3r] = get_size(world); // real size
17-
auto [N1c, N2c, N3c] = get_size(world); // complex size
24+
os << "}";
25+
return os;
26+
}
27+
28+
using heffte::split_world;
29+
30+
const auto get_real_indices(const Decomposition &decomposition) {
31+
auto world = get_global_world(decomposition);
32+
auto [N1, N2, N3] = get_size(world);
33+
return heffte::box3d<int>({0, 0, 0}, {N1 - 1, N2 - 1, N3 - 1});
34+
}
35+
36+
const auto get_complex_indices(const Decomposition &decomposition,
37+
int r2c_direction) {
38+
auto [N1, N2, N3] = get_size(get_global_world(decomposition));
1839
if (r2c_direction == 0) {
19-
N1c = N1c / 2 + 1;
40+
return heffte::box3d<int>({0, 0, 0}, {N1 / 2, N2 - 1, N3 - 1});
2041
} else if (r2c_direction == 1) {
21-
N2c = N2c / 2 + 1;
42+
return heffte::box3d<int>({0, 0, 0}, {N1 - 1, N2 / 2, N3 - 1});
2243
} else if (r2c_direction == 2) {
23-
N3c = N3c / 2 + 1;
44+
return heffte::box3d<int>({0, 0, 0}, {N1 - 1, N2 - 1, N3 / 2});
2445
} else {
2546
throw std::logic_error("Invalid r2c_direction: " +
2647
std::to_string(r2c_direction));
2748
}
28-
box3di real_indices({0, 0, 0}, {N1r - 1, N2r - 1, N3r - 1});
29-
box3di complex_indices({0, 0, 0}, {N1c - 1, N2c - 1, N3c - 1});
30-
Int3 grid = heffte::proc_setup_min_surface(real_indices, num_domains);
31-
std::vector<box3di> real_boxes = heffte::split_world(real_indices, grid);
32-
std::vector<box3di> complex_boxes = heffte::split_world(complex_indices, grid);
49+
}
50+
51+
const FFTLayout create(const Decomposition &decomposition, int r2c_direction) {
52+
auto real_indices = get_real_indices(decomposition);
53+
auto complex_indices = get_complex_indices(decomposition, r2c_direction);
54+
auto grid = get_grid(decomposition);
55+
auto real_boxes = split_world(real_indices, grid);
56+
auto complex_boxes = split_world(complex_indices, grid);
3357
return FFTLayout{decomposition, r2c_direction, real_boxes, complex_boxes};
3458
}
3559

@@ -49,25 +73,42 @@ int get_mpi_size(MPI_Comm comm) {
4973
return size;
5074
}
5175

52-
FFT create(const Decomposition &decomposition, MPI_Comm comm,
53-
heffte::plan_options options) {
54-
int rank = get_mpi_rank(comm);
55-
int mpi_num_ranks = get_mpi_size(comm);
56-
if (mpi_num_ranks <= 0) {
57-
throw std::logic_error("Cannot construct domain decomposition: !(nprocs > 0)");
58-
}
59-
auto r2c_dir = 0;
60-
auto fft_layout = fft::layout::create(decomposition, r2c_dir, mpi_num_ranks);
61-
auto inbox = get_real_box(fft_layout, rank);
62-
auto outbox = get_complex_box(fft_layout, rank);
63-
using fft_r2c = heffte::fft3d_r2c<heffte::backend::fftw>;
76+
using heffte::plan_options;
77+
using layout::FFTLayout;
78+
using fft_r2c = heffte::fft3d_r2c<heffte::backend::fftw>;
79+
80+
FFT create(const FFTLayout &fft_layout, int rank_id, plan_options options) {
81+
auto inbox = get_real_box(fft_layout, rank_id);
82+
auto outbox = get_complex_box(fft_layout, rank_id);
83+
auto r2c_dir = get_r2c_direction(fft_layout);
84+
auto comm = get_comm();
6485
return FFT(fft_r2c(inbox, outbox, r2c_dir, comm, options));
6586
}
6687

88+
FFT create(const Decomposition &decomposition, int rank_id) {
89+
auto options = heffte::default_options<heffte::backend::fftw>();
90+
auto r2c_dir = 0;
91+
auto fft_layout = layout::create(decomposition, r2c_dir);
92+
return create(fft_layout, rank_id, options);
93+
}
94+
6795
FFT create(const Decomposition &decomposition) {
6896
auto comm = get_comm();
69-
auto options = heffte::default_options<heffte::backend::fftw>();
70-
return create(decomposition, comm, options);
97+
auto mpi_comm_size = get_mpi_size(comm);
98+
auto rank_id = get_mpi_rank(comm);
99+
auto decomposition_size = get_num_domains(decomposition);
100+
if (mpi_comm_size != decomposition_size) {
101+
throw std::logic_error(
102+
"Mismatch between MPI communicator size and domain decomposition size: " +
103+
std::to_string(mpi_comm_size) + " != " + std::to_string(decomposition_size) +
104+
". This indicates that the number of MPI ranks does not match the number of "
105+
"domains in the decomposition. To resolve this issue, you can manually "
106+
"specify the rank by calling fft::create(decomposition, rank_id) instead.");
107+
}
108+
// if mpi communicator size matches decomposition size, we can safely assume
109+
// that the intention is to decompose the whole communicator into the
110+
// decomposition
111+
return create(decomposition, rank_id);
71112
}
72113

73114
} // namespace fft

0 commit comments

Comments
 (0)