11#include " diagnostics/zonal_avg.hpp"
2-
3- #include " share/field/field_utils.hpp"
4-
52#include < ekat_math_utils.hpp>
3+ #include < ekat_team_policy_utils.hpp>
64
75namespace scream {
86
9- void ZonalAvgDiag::compute_zonal_sum (const Field &result, const Field &field, const Field &weight,
10- const Field &lat, const ekat::Comm *comm) {
11- auto result_layout = result.get_header ().get_identifier ().get_layout ();
12- const int num_zonal_bins = result_layout.dim (0 );
13- const int ncols = field.get_header ().get_identifier ().get_layout ().dim (0 );
14- const Real lat_delta = sp (180.0 ) / num_zonal_bins;
7+ // Utility to compute the contraction of a field along its column dimension.
8+ // This is equivalent to f_out = einsum('i,i...k->...k', weight, f_in).
9+ // The implementation is such that:
10+ // - all Field objects must be allocated
11+ // - the first dimension for field and weight and second dimension for
12+ // bin_to_cols is for the columns (COL)
13+ // - the first dimension for result and bin_to_cols is for the zonal bins (CMP,"bin")
14+ // - field and result must be the same dimension, up to 3
15+ void compute_zonal_sum (const Field &result, const Field &field, const Field &weight,
16+ const Field &bin_to_cols, const ekat::Comm *comm) {
17+ auto result_layout = result.get_header ().get_identifier ().get_layout ();
18+ auto bin_to_cols_layout = bin_to_cols.get_header ().get_identifier ().get_layout ();
19+ const int num_zonal_bins = bin_to_cols_layout.dim (0 );
20+ const int max_ncols_per_bin = bin_to_cols_layout.dim (1 );
1521
1622 auto weight_view = weight.get_view <const Real *>();
17- auto lat_view = lat .get_view <const Real *>();
23+ auto bin_to_cols_view = bin_to_cols .get_view <const Int * *>();
1824 using KT = ekat::KokkosTypes<DefaultDevice>;
1925 using TeamPolicy = Kokkos::TeamPolicy<Field::device_t ::execution_space>;
2026 using TeamMember = typename TeamPolicy::member_type;
@@ -23,67 +29,60 @@ void ZonalAvgDiag::compute_zonal_sum(const Field &result, const Field &field, co
2329 case 1 : {
2430 auto field_view = field.get_view <const Real *>();
2531 auto result_view = result.get_view <Real *>();
26- TeamPolicy team_policy = TPF::get_default_team_policy (num_zonal_bins, ncols );
32+ TeamPolicy team_policy = TPF::get_default_team_policy (num_zonal_bins, max_ncols_per_bin );
2733 Kokkos::parallel_for (
28- " compute_zonal_sum_" + field.name (), team_policy, KOKKOS_LAMBDA (const TeamMember &tm) {
29- const int lat_i = tm.league_rank ();
30- const Real lat_lower = sp (-90.0 ) + lat_i * lat_delta;
31- const Real lat_upper = lat_lower + lat_delta;
34+ " compute_zonal_sum_" + field.name (), team_policy,
35+ KOKKOS_LAMBDA (const TeamMember &tm) {
36+ const int bin_i = tm.league_rank ();
3237 Kokkos::parallel_reduce (
33- Kokkos::TeamVectorRange (tm, ncols),
34- [&](int i, Real &val) {
35- // TODO: check if tenary is ok here (if not, multiply by flag)
36- int flag = (lat_lower <= lat_view (i)) && (lat_view (i) < lat_upper);
37- val += flag ? weight_view (i) * field_view (i) : sp (0.0 );
38+ Kokkos::TeamVectorRange (tm, 1 , 1 +bin_to_cols_view (bin_i,0 )),
39+ [&](int lcol_j, Real &val) {
40+ const int col_i = bin_to_cols_view (bin_i, lcol_j);
41+ val += weight_view (col_i) * field_view (col_i);
3842 },
39- result_view (lat_i ));
43+ result_view (bin_i ));
4044 });
4145 } break ;
4246 case 2 : {
4347 const int d1 = result_layout.dim (1 );
4448 auto field_view = field.get_view <const Real **>();
4549 auto result_view = result.get_view <Real **>();
46- TeamPolicy team_policy = TPF::get_default_team_policy (num_zonal_bins * d1, ncols );
50+ TeamPolicy team_policy = TPF::get_default_team_policy (num_zonal_bins * d1, max_ncols_per_bin );
4751 Kokkos::parallel_for (
48- " compute_zonal_sum_" + field.name (), team_policy, KOKKOS_LAMBDA (const TeamMember &tm) {
52+ " compute_zonal_sum_" + field.name (), team_policy,
53+ KOKKOS_LAMBDA (const TeamMember &tm) {
4954 const int idx = tm.league_rank ();
5055 const int d1_i = idx / num_zonal_bins;
51- const int lat_i = idx % num_zonal_bins;
52- const Real lat_lower = sp (-90.0 ) + lat_i * lat_delta;
53- const Real lat_upper = lat_lower + lat_delta;
56+ const int bin_i = idx % num_zonal_bins;
5457 Kokkos::parallel_reduce (
55- Kokkos::TeamVectorRange (tm, ncols),
56- [&](int i, Real &val) {
57- int flag = (lat_lower <= lat_view (i)) && (lat_view (i) < lat_upper);
58- // TODO: check if tenary is ok here (if not, multiply by flag)
59- val += flag ? weight_view (i) * field_view (i, d1_i) : sp (0.0 );
58+ Kokkos::TeamVectorRange (tm, 1 , 1 +bin_to_cols_view (bin_i,0 )),
59+ [&](int lcol_j, Real &val) {
60+ const int col_i = bin_to_cols_view (bin_i, lcol_j);
61+ val += weight_view (col_i) * field_view (col_i, d1_i);
6062 },
61- result_view (lat_i , d1_i));
63+ result_view (bin_i , d1_i));
6264 });
6365 } break ;
6466 case 3 : {
6567 const int d1 = result_layout.dim (1 );
6668 const int d2 = result_layout.dim (2 );
6769 auto field_view = field.get_view <const Real ***>();
6870 auto result_view = result.get_view <Real ***>();
69- TeamPolicy team_policy = TPF::get_default_team_policy (num_zonal_bins * d1 * d2, ncols );
71+ TeamPolicy team_policy = TPF::get_default_team_policy (num_zonal_bins * d1 * d2, max_ncols_per_bin );
7072 Kokkos::parallel_for (
7173 " compute_zonal_sum_" + field.name (), team_policy, KOKKOS_LAMBDA (const TeamMember &tm) {
7274 const int idx = tm.league_rank ();
7375 const int d1_i = idx / (num_zonal_bins * d2);
7476 const int idx2 = idx % (num_zonal_bins * d2);
7577 const int d2_i = idx2 / num_zonal_bins;
76- const int lat_i = idx2 % num_zonal_bins;
77- const Real lat_lower = sp (-90.0 ) + lat_i * lat_delta;
78- const Real lat_upper = lat_lower + lat_delta;
78+ const int bin_i = idx2 % num_zonal_bins;
7979 Kokkos::parallel_reduce (
80- Kokkos::TeamVectorRange (tm, ncols),
81- [&](int i, Real &val) {
82- int flag = (lat_lower <= lat_view (i)) && (lat_view (i) < lat_upper);
83- // TODO: check if tenary is ok here (if not, multiply by flag)
84- val += flag ? weight_view (i) * field_view (i, d1_i, d2_i) : sp (0.0 );
80+ Kokkos::TeamVectorRange (tm, 1 , 1 +bin_to_cols_view (bin_i,0 )),
81+ [&](int lcol_j, Real &val) {
82+ const int col_i = bin_to_cols_view (bin_i, lcol_j);
83+ val += weight_view (col_i) * field_view (col_i, d1_i, d2_i);
8584 },
86- result_view (lat_i , d1_i, d2_i));
85+ result_view (bin_i , d1_i, d2_i));
8786 });
8887 } break ;
8988 default :
@@ -152,6 +151,75 @@ void ZonalAvgDiag::initialize_impl(const RunType /*run_type*/) {
152151 m_diagnostic_output = Field (diagnostic_id);
153152 m_diagnostic_output.allocate_view ();
154153
154+ // allocate column counter
155+ FieldLayout ncols_per_bin_layout ({CMP}, {m_num_zonal_bins}, {" bin" });
156+ FieldIdentifier ncols_per_bin_id (" number of columns per bin" ,
157+ ncols_per_bin_layout, FieldIdentifier::Units::nondimensional (),
158+ field_id.get_grid_name (), DataType::IntType);
159+ Field ncols_per_bin (ncols_per_bin_id);
160+ ncols_per_bin.allocate_view ();
161+ ncols_per_bin.deep_copy (0 );
162+
163+ // count how many columns are in each zonal bin
164+ using KT = ekat::KokkosTypes<DefaultDevice>;
165+ using TeamPolicy = Kokkos::TeamPolicy<Field::device_t ::execution_space>;
166+ using TeamMember = typename TeamPolicy::member_type;
167+ using TPF = ekat::TeamPolicyFactory<typename KT::ExeSpace>;
168+ const int ncols = field_layout.dim (COL);
169+ const Real lat_delta = sp (180.0 ) / m_num_zonal_bins;
170+ auto lat_view = m_lat.get_view <const Real *>();
171+ auto ncols_per_bin_view = ncols_per_bin.get_view <Int *>();
172+ const int num_zonal_bins = m_num_zonal_bins; // for use inside lambdas
173+ TeamPolicy team_policy = TPF::get_default_team_policy (m_num_zonal_bins, ncols);
174+ Kokkos::parallel_for (" count_columns_per_zonal_bin_" + field.name (),
175+ team_policy, KOKKOS_LAMBDA (const TeamMember &tm) {
176+ const int bin_i = tm.league_rank ();
177+ const Real lat_lower = sp (-90.0 ) + bin_i * lat_delta;
178+ const Real lat_upper = (bin_i < num_zonal_bins-1 )
179+ ? lat_lower + lat_delta : sp (90.0 + 0.5 *lat_delta);
180+ Kokkos::parallel_reduce (Kokkos::TeamVectorRange (tm, ncols),
181+ [&](int col_i, Int &val) {
182+ if ((lat_lower <= lat_view (col_i)) && (lat_view (col_i) < lat_upper))
183+ val++;
184+ },
185+ ncols_per_bin_view (bin_i));
186+ });
187+
188+ // determine maximum number of columns per bin & allocate bin to column map
189+ using RangePolicy = Kokkos::RangePolicy<Field::device_t ::execution_space>;
190+ Int max_ncols_per_bin = 0 ;
191+ Kokkos::parallel_reduce (RangePolicy (0 , m_num_zonal_bins),
192+ KOKKOS_LAMBDA (int bin_i, Int &val) {
193+ val = ncols_per_bin_view (bin_i) > val ? ncols_per_bin_view (bin_i) : val;
194+ },
195+ Kokkos::Max<Int>(max_ncols_per_bin));
196+ FieldLayout bin_to_cols_layout = ncols_per_bin_layout.append_dim ({COL}, {1 +max_ncols_per_bin});
197+ FieldIdentifier bin_to_cols_id (" columns in each zonal bin" ,
198+ bin_to_cols_layout, FieldIdentifier::Units::nondimensional (),
199+ field_id.get_grid_name (), DataType::IntType);
200+ m_bin_to_cols = Field (bin_to_cols_id);
201+ m_bin_to_cols.allocate_view ();
202+
203+ // compute bin to column map, where the (i,j)-th entry is such that
204+ // - for j=0, the entry is the number of columns in the i-th zonal bin
205+ // - for j>0, the entry is a column index in the i-th zonal bin
206+ auto bin_to_cols_view = m_bin_to_cols.get_view <Int **>();
207+ Kokkos::parallel_for (" assign_columns_to_zonal_bins_" + field.name (),
208+ RangePolicy (0 , m_num_zonal_bins), KOKKOS_LAMBDA (int bin_i) {
209+ const Real lat_lower = sp (-90.0 ) + bin_i * lat_delta;
210+ const Real lat_upper = (bin_i < num_zonal_bins-1 )
211+ ? lat_lower + lat_delta : sp (90.0 + 0.5 *lat_delta);
212+ bin_to_cols_view (bin_i, 0 ) = 0 ;
213+ for (int col_i=0 ; col_i < ncols; col_i++)
214+ {
215+ if ((lat_lower <= lat_view (col_i)) && (lat_view (col_i) < lat_upper))
216+ {
217+ bin_to_cols_view (bin_i, 0 ) += 1 ;
218+ bin_to_cols_view (bin_i, bin_to_cols_view (bin_i,0 )) = col_i;
219+ }
220+ }
221+ });
222+
155223 // allocate zonal area
156224 const FieldIdentifier &area_id = m_scaled_area.get_header ().get_identifier ();
157225 FieldLayout zonal_area_layout ({CMP}, {m_num_zonal_bins}, {" bin" });
@@ -166,27 +234,26 @@ void ZonalAvgDiag::initialize_impl(const RunType /*run_type*/) {
166234 Field ones (ones_id);
167235 ones.allocate_view ();
168236 ones.deep_copy (1.0 );
169- compute_zonal_sum (zonal_area, m_scaled_area, ones, m_lat , &m_comm);
237+ compute_zonal_sum (zonal_area, m_scaled_area, ones, m_bin_to_cols , &m_comm);
170238
171239 // scale area by 1 / zonal area
172- using RangePolicy = Kokkos::RangePolicy<Field::device_t ::execution_space>;
173- const Real lat_delta = sp (180.0 ) / m_num_zonal_bins;
174- const int ncols = field_layout.dim (0 );
175- const int nbins = m_num_zonal_bins;
176- auto lat_view = m_lat.get_view <const Real *>();
177240 auto zonal_area_view = zonal_area.get_view <const Real *>();
178241 auto scaled_area_view = m_scaled_area.get_view <Real *>();
179- Kokkos::parallel_for (
180- " scale_area_by_zonal_area_" + field.name (), RangePolicy (0 , ncols),
181- KOKKOS_LAMBDA (const int &i) {
182- const int lat_i = ekat::impl::min (static_cast <int >((lat_view (i) + sp (90.0 )) / lat_delta),nbins-1 );
183- scaled_area_view (i) /= zonal_area_view (lat_i);
242+ Kokkos::parallel_for (" scale_area_by_zonal_area_" + field.name (),
243+ team_policy, KOKKOS_LAMBDA (const TeamMember &tm) {
244+ const int bin_i = tm.league_rank ();
245+ Kokkos::parallel_for (
246+ Kokkos::TeamVectorRange (tm, 1 , 1 +bin_to_cols_view (bin_i,0 )),
247+ [&](int lcol_j) {
248+ const int col_i = bin_to_cols_view (bin_i, lcol_j);
249+ scaled_area_view (col_i) /= zonal_area_view (bin_i);
250+ });
184251 });
185252}
186253
187254void ZonalAvgDiag::compute_diagnostic_impl () {
188255 const auto &field = get_fields_in ().front ();
189- compute_zonal_sum (m_diagnostic_output, field, m_scaled_area, m_lat , &m_comm);
256+ compute_zonal_sum (m_diagnostic_output, field, m_scaled_area, m_bin_to_cols , &m_comm);
190257}
191258
192259} // namespace scream
0 commit comments