Skip to content

Commit 042516d

Browse files
committed
updated zonal_avg test to support multiple MPI ranks
1 parent 74ac297 commit 042516d

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

components/eamxx/src/diagnostics/tests/zonal_avg_test.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,14 @@
77

88
namespace scream {
99

10-
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols, const int nlevs) {
11-
const int num_global_cols = ncols * comm.size();
12-
10+
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ngcols, const int nlevs) {
1311
using vos_t = std::vector<std::string>;
1412
ekat::ParameterList gm_params;
1513
gm_params.set("grids_names", vos_t{"Point Grid"});
1614
auto &pl = gm_params.sublist("Point Grid");
1715
pl.set<std::string>("type", "point_grid");
1816
pl.set("aliases", vos_t{"Physics"});
19-
pl.set<int>("number_of_global_columns", num_global_cols);
17+
pl.set<int>("number_of_global_columns", ngcols);
2018
pl.set<int>("number_of_vertical_levels", nlevs);
2119

2220
auto gm = create_mesh_free_grids_manager(comm, gm_params);
@@ -41,11 +39,12 @@ TEST_CASE("zonal_avg") {
4139
// Create a grids manager - single column for these tests
4240
constexpr int nlevs = 3;
4341
constexpr int dim3 = 4;
44-
const int ngcols = 6 * comm.size();
45-
const int nlats = 4; // needs to be <= ngcols
42+
const int ncols = 6;
43+
const int nlats = 4; // needs to be <= ncols
4644

47-
auto gm = create_gm(comm, ngcols, nlevs);
48-
auto grid = gm->get_grid("Physics");
45+
const int ngcols = ncols * comm.size();
46+
auto gm = create_gm(comm, ngcols, nlevs);
47+
auto grid = gm->get_grid("Physics");
4948

5049
Field area = grid->get_geometry_data("area");
5150
auto area_view_h = area.get_view<const Real *, Host>();
@@ -56,18 +55,19 @@ TEST_CASE("zonal_avg") {
5655
auto lat_view_h = lat.get_view<Real *, Host>();
5756
const Real lat_delta = sp(180.0) / nlats;
5857
std::vector<Real> zonal_areas(nlats, 0.0);
59-
for (int i = 0; i < ngcols; i++) {
58+
for (int i = 0; i < ncols; i++) {
6059
lat_view_h(i) = sp(-90.0) + (i % nlats + sp(0.5)) * lat_delta;
6160
zonal_areas[i % nlats] += area_view_h[i];
6261
}
62+
comm.all_reduce(zonal_areas.data(), zonal_areas.size(), MPI_SUM);
6363
lat_view_h(0) = sp(-90.0); // move column to be directly at southern pole
6464
lat_view_h(nlats-1) = sp(90.0); // move column to be directly at northern pole
6565
lat.sync_to_dev();
6666

6767
// Input (randomized) qc
68-
FieldLayout scalar1d_layout{{COL}, {ngcols}};
69-
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
70-
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}};
68+
FieldLayout scalar1d_layout{{COL}, {ncols}};
69+
FieldLayout scalar2d_layout{{COL, LEV}, {ncols, nlevs}};
70+
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ncols, dim3, nlevs}};
7171

7272
FieldIdentifier qc1_id("qc", scalar1d_layout, kg / kg, grid->name());
7373
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid->name());
@@ -137,10 +137,12 @@ TEST_CASE("zonal_avg") {
137137
// calculate the zonal average
138138
auto qc1_view_h = qc1.get_view<const Real *, Host>();
139139
auto diag0_view_h = diag0_field.get_view<Real *, Host>();
140-
for (int i = 0; i < ngcols; i++) {
140+
for (int i = 0; i < ncols; i++) {
141141
const int nlat = i % nlats;
142142
diag0_view_h(nlat) += area_view_h(i) / zonal_areas[nlat] * qc1_view_h(i);
143143
}
144+
comm.all_reduce(diag0_field.template get_internal_view_data<Real, Host>(),
145+
diag0_layout.size(), MPI_SUM);
144146
diag0_field.sync_to_dev();
145147

146148
// Compare
@@ -180,14 +182,16 @@ TEST_CASE("zonal_avg") {
180182
diag3m_field.allocate_view();
181183
auto qc3_view_h = qc3.get_view<Real ***, Host>();
182184
auto diag3m_view_h = diag3m_field.get_view<Real ***, Host>();
183-
for (int i = 0; i < ngcols; i++) {
185+
for (int i = 0; i < ncols; i++) {
184186
const int nlat = i % nlats;
185187
for (int j = 0; j < dim3; j++) {
186188
for (int k = 0; k < nlevs; k++) {
187189
diag3m_view_h(nlat, j, k) += area_view_h(i) / zonal_areas[nlat] * qc3_view_h(i, j, k);
188190
}
189191
}
190192
}
193+
comm.all_reduce(diag3m_field.template get_internal_view_data<Real, Host>(),
194+
diag3m_layout.size(), MPI_SUM);
191195
diag3m_field.sync_to_dev();
192196
diag3->set_required_field(qc3);
193197
diag3->initialize(t0, RunType::Initial);

0 commit comments

Comments
 (0)