77
88namespace scream {
99
10- std::shared_ptr<GridsManager> create_gm (const ekat::Comm &comm, const int ncols, const int nlevs) {
11- const int num_global_cols = ncols * comm.size ();
12-
10+ std::shared_ptr<GridsManager> create_gm (const ekat::Comm &comm, const int ngcols, const int nlevs) {
1311 using vos_t = std::vector<std::string>;
1412 ekat::ParameterList gm_params;
1513 gm_params.set (" grids_names" , vos_t {" Point Grid" });
1614 auto &pl = gm_params.sublist (" Point Grid" );
1715 pl.set <std::string>(" type" , " point_grid" );
1816 pl.set (" aliases" , vos_t {" Physics" });
19- pl.set <int >(" number_of_global_columns" , num_global_cols );
17+ pl.set <int >(" number_of_global_columns" , ngcols );
2018 pl.set <int >(" number_of_vertical_levels" , nlevs);
2119
2220 auto gm = create_mesh_free_grids_manager (comm, gm_params);
@@ -41,11 +39,12 @@ TEST_CASE("zonal_avg") {
4139 // Create a grids manager - single column for these tests
4240 constexpr int nlevs = 3 ;
4341 constexpr int dim3 = 4 ;
44- const int ngcols = 6 * comm. size () ;
45- const int nlats = 4 ; // needs to be <= ngcols
42+ const int ncols = 6 ;
43+ const int nlats = 4 ; // needs to be <= ncols
4644
47- auto gm = create_gm (comm, ngcols, nlevs);
48- auto grid = gm->get_grid (" Physics" );
45+ const int ngcols = ncols * comm.size ();
46+ auto gm = create_gm (comm, ngcols, nlevs);
47+ auto grid = gm->get_grid (" Physics" );
4948
5049 Field area = grid->get_geometry_data (" area" );
5150 auto area_view_h = area.get_view <const Real *, Host>();
@@ -56,18 +55,19 @@ TEST_CASE("zonal_avg") {
5655 auto lat_view_h = lat.get_view <Real *, Host>();
5756 const Real lat_delta = sp (180.0 ) / nlats;
5857 std::vector<Real> zonal_areas (nlats, 0.0 );
59- for (int i = 0 ; i < ngcols ; i++) {
58+ for (int i = 0 ; i < ncols ; i++) {
6059 lat_view_h (i) = sp (-90.0 ) + (i % nlats + sp (0.5 )) * lat_delta;
6160 zonal_areas[i % nlats] += area_view_h[i];
6261 }
62+ comm.all_reduce (zonal_areas.data (), zonal_areas.size (), MPI_SUM);
6363 lat_view_h (0 ) = sp (-90.0 ); // move column to be directly at southern pole
6464 lat_view_h (nlats-1 ) = sp (90.0 ); // move column to be directly at northern pole
6565 lat.sync_to_dev ();
6666
6767 // Input (randomized) qc
68- FieldLayout scalar1d_layout{{COL}, {ngcols }};
69- FieldLayout scalar2d_layout{{COL, LEV}, {ngcols , nlevs}};
70- FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols , dim3, nlevs}};
68+ FieldLayout scalar1d_layout{{COL}, {ncols }};
69+ FieldLayout scalar2d_layout{{COL, LEV}, {ncols , nlevs}};
70+ FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ncols , dim3, nlevs}};
7171
7272 FieldIdentifier qc1_id (" qc" , scalar1d_layout, kg / kg, grid->name ());
7373 FieldIdentifier qc2_fid (" qc" , scalar2d_layout, kg / kg, grid->name ());
@@ -137,10 +137,12 @@ TEST_CASE("zonal_avg") {
137137 // calculate the zonal average
138138 auto qc1_view_h = qc1.get_view <const Real *, Host>();
139139 auto diag0_view_h = diag0_field.get_view <Real *, Host>();
140- for (int i = 0 ; i < ngcols ; i++) {
140+ for (int i = 0 ; i < ncols ; i++) {
141141 const int nlat = i % nlats;
142142 diag0_view_h (nlat) += area_view_h (i) / zonal_areas[nlat] * qc1_view_h (i);
143143 }
144+ comm.all_reduce (diag0_field.template get_internal_view_data <Real, Host>(),
145+ diag0_layout.size (), MPI_SUM);
144146 diag0_field.sync_to_dev ();
145147
146148 // Compare
@@ -180,14 +182,16 @@ TEST_CASE("zonal_avg") {
180182 diag3m_field.allocate_view ();
181183 auto qc3_view_h = qc3.get_view <Real ***, Host>();
182184 auto diag3m_view_h = diag3m_field.get_view <Real ***, Host>();
183- for (int i = 0 ; i < ngcols ; i++) {
185+ for (int i = 0 ; i < ncols ; i++) {
184186 const int nlat = i % nlats;
185187 for (int j = 0 ; j < dim3; j++) {
186188 for (int k = 0 ; k < nlevs; k++) {
187189 diag3m_view_h (nlat, j, k) += area_view_h (i) / zonal_areas[nlat] * qc3_view_h (i, j, k);
188190 }
189191 }
190192 }
193+ comm.all_reduce (diag3m_field.template get_internal_view_data <Real, Host>(),
194+ diag3m_layout.size (), MPI_SUM);
191195 diag3m_field.sync_to_dev ();
192196 diag3->set_required_field (qc3);
193197 diag3->initialize (t0, RunType::Initial);
0 commit comments