Skip to content

Commit 2eb52d9

Browse files
committed
Merge branch 'jgfouca/port_gw_prof' into master (PR E3SM-Project#7629)
Port GW gw_prof to Kokkos/C++ [BFB]
2 parents 70c6ab0 + 11a7602 commit 2eb52d9

File tree

5 files changed

+165
-29
lines changed

5 files changed

+165
-29
lines changed

components/eamxx/src/physics/gw/gw_functions.hpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,19 @@ struct Functions
143143
Real tndmax; // = huge(1._r8)
144144
};
145145

146+
KOKKOS_INLINE_FUNCTION
147+
static void midpoint_interp(
148+
const MemberType& team,
149+
const uview_1d<const Real>& in,
150+
const uview_1d<Real>& interp)
151+
{
152+
EKAT_KERNEL_REQUIRE(in.size() == interp.size() + 1);
153+
Kokkos::parallel_for(
154+
Kokkos::TeamVectorRange(team, 0, in.extent(0)-1), [&] (const int k) {
155+
interp(k) = (in(k)+in(k+1)) / 2;
156+
});
157+
}
158+
146159
//
147160
// --------- Init/Finalize Functions ---------
148161
//
@@ -201,20 +214,27 @@ struct Functions
201214
const uview_1d<Real>& utgw,
202215
const uview_1d<Real>& vtgw);
203216

217+
/*
218+
* Compute profiles of background state quantities for the multiple
219+
* gravity wave drag parameterization.
220+
*
221+
* The parameterization is assumed to operate only where water vapor
222+
* concentrations are negligible in determining the density.
223+
*/
204224
KOKKOS_FUNCTION
205225
static void gw_prof(
206226
// Inputs
227+
const MemberType& team,
207228
const Int& pver,
208-
const Int& ncol,
209-
const Spack& cpair,
210-
const uview_1d<const Spack>& t,
211-
const uview_1d<const Spack>& pmid,
212-
const uview_1d<const Spack>& pint,
229+
const Real& cpair,
230+
const uview_1d<const Real>& t,
231+
const uview_1d<const Real>& pmid,
232+
const uview_1d<const Real>& pint,
213233
// Outputs
214-
const uview_1d<Spack>& rhoi,
215-
const uview_1d<Spack>& ti,
216-
const uview_1d<Spack>& nm,
217-
const uview_1d<Spack>& ni);
234+
const uview_1d<Real>& rhoi,
235+
const uview_1d<Real>& ti,
236+
const uview_1d<Real>& nm,
237+
const uview_1d<Real>& ni);
218238

219239
KOKKOS_FUNCTION
220240
static void momentum_energy_conservation(

components/eamxx/src/physics/gw/impl/gw_gw_prof_impl.hpp

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include "gw_functions.hpp" // for ETI only but harmless for GPU
55

6+
#include <ekat_subview_utils.hpp>
7+
68
namespace scream {
79
namespace gw {
810

@@ -14,21 +16,59 @@ namespace gw {
1416
template<typename S, typename D>
1517
KOKKOS_FUNCTION
1618
void Functions<S,D>::gw_prof(
17-
// Inputs
18-
const Int& pver,
19-
const Int& ncol,
20-
const Spack& cpair,
21-
const uview_1d<const Spack>& t,
22-
const uview_1d<const Spack>& pmid,
23-
const uview_1d<const Spack>& pint,
24-
// Outputs
25-
const uview_1d<Spack>& rhoi,
26-
const uview_1d<Spack>& ti,
27-
const uview_1d<Spack>& nm,
28-
const uview_1d<Spack>& ni)
19+
// Inputs
20+
const MemberType& team,
21+
const Int& pver,
22+
const Real& cpair,
23+
const uview_1d<const Real>& t,
24+
const uview_1d<const Real>& pmid,
25+
const uview_1d<const Real>& pint,
26+
// Outputs
27+
const uview_1d<Real>& rhoi,
28+
const uview_1d<Real>& ti,
29+
const uview_1d<Real>& nm,
30+
const uview_1d<Real>& ni)
2931
{
30-
// TODO
31-
// Note, argument types may need tweaking. Generator is not always able to tell what needs to be packed
32+
// Minimum value of Brunt-Vaisalla frequency squared.
33+
static constexpr Real n2min = 1.e-8;
34+
35+
//-----------------------------------------------------------------------
36+
// Determine the interface densities and Brunt-Vaisala frequencies.
37+
//-----------------------------------------------------------------------
38+
39+
// The top interface values are calculated assuming an isothermal
40+
// atmosphere above the top level.
41+
Kokkos::single(Kokkos::PerTeam(team), [&] {
42+
ti(0) = t(0);
43+
rhoi(0) = pint(0) / (C::Rair*ti(0));
44+
ni(0) = sqrt(C::gravit*C::gravit / (cpair*ti(0)));
45+
});
46+
47+
// Interior points use centered differences.
48+
midpoint_interp(team, t, ekat::subview(ti, Kokkos::pair<int, int>{1, pver}));
49+
team.team_barrier();
50+
Kokkos::parallel_for(
51+
Kokkos::TeamVectorRange(team, 1, pver), [&] (const int k) {
52+
rhoi(k) = pint(k) / (C::Rair*ti(k));
53+
const Real dtdp = (t(k)-t(k-1)) / (pmid(k)-pmid(k-1));
54+
const Real n2 = C::gravit*C::gravit/ti(k) * (1/cpair - rhoi(k)*dtdp);
55+
ni(k) = std::sqrt(ekat::impl::max(n2min, n2));
56+
});
57+
58+
// Bottom interface uses bottom level temperature, density; next interface
59+
// B-V frequency.
60+
team.team_barrier();
61+
Kokkos::single(Kokkos::PerTeam(team), [&] {
62+
ti(pver) = t(pver-1);
63+
rhoi(pver) = pint(pver) / (C::Rair*ti(pver));
64+
ni(pver) = ni(pver-1);
65+
});
66+
67+
//------------------------------------------------------------------------
68+
// Determine the midpoint Brunt-Vaisala frequencies.
69+
//------------------------------------------------------------------------
70+
team.team_barrier();
71+
midpoint_interp(team, ni, nm);
3272
}
3373

3474
} // namespace gw

components/eamxx/src/physics/gw/tests/gw_gw_prof_tests.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,29 @@ struct UnitWrap::UnitTest<D>::TestGwProf : public UnitWrap::UnitTest<D>::Base {
5757

5858
// Get data from test
5959
for (auto& d : test_data) {
60-
gw_prof(d);
60+
if (this->m_baseline_action == GENERATE) {
61+
gw_prof_f(d);
62+
}
63+
else {
64+
gw_prof(d);
65+
}
6166
}
6267

6368
// Verify BFB results, all data should be in C layout
6469
if (SCREAM_BFB_TESTING && this->m_baseline_action == COMPARE) {
6570
for (Int i = 0; i < num_runs; ++i) {
6671
GwProfData& d_baseline = baseline_data[i];
6772
GwProfData& d_test = test_data[i];
73+
REQUIRE(d_baseline.total(d_baseline.rhoi) == d_test.total(d_test.rhoi));
74+
REQUIRE(d_baseline.total(d_baseline.rhoi) == d_test.total(d_test.ti));
75+
REQUIRE(d_baseline.total(d_baseline.rhoi) == d_test.total(d_test.ni));
6876
for (Int k = 0; k < d_baseline.total(d_baseline.rhoi); ++k) {
69-
REQUIRE(d_baseline.total(d_baseline.rhoi) == d_test.total(d_test.rhoi));
7077
REQUIRE(d_baseline.rhoi[k] == d_test.rhoi[k]);
71-
REQUIRE(d_baseline.total(d_baseline.rhoi) == d_test.total(d_test.ti));
7278
REQUIRE(d_baseline.ti[k] == d_test.ti[k]);
73-
REQUIRE(d_baseline.total(d_baseline.rhoi) == d_test.total(d_test.ni));
7479
REQUIRE(d_baseline.ni[k] == d_test.ni[k]);
7580
}
81+
REQUIRE(d_baseline.total(d_baseline.nm) == d_test.total(d_test.nm));
7682
for (Int k = 0; k < d_baseline.total(d_baseline.nm); ++k) {
77-
REQUIRE(d_baseline.total(d_baseline.nm) == d_test.total(d_test.nm));
7883
REQUIRE(d_baseline.nm[k] == d_test.nm[k]);
7984
}
8085

components/eamxx/src/physics/gw/tests/infra/gw_test_data.cpp

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,84 @@ void gwd_compute_tendencies_from_stress_divergence(GwdComputeTendenciesFromStres
237237
gw_finalize_cxx(d.init);
238238
}
239239

240-
void gw_prof(GwProfData& d)
240+
void gw_prof_f(GwProfData& d)
241241
{
242242
gw_init(d.init);
243243
d.transition<ekat::TransposeDirection::c2f>();
244244
gw_prof_c(d.ncol, d.cpair, d.t, d.pmid, d.pint, d.rhoi, d.ti, d.nm, d.ni);
245245
d.transition<ekat::TransposeDirection::f2c>();
246246
}
247247

248+
void gw_prof(GwProfData& d)
249+
{
250+
gw_init_cxx(d.init);
251+
252+
// create device views and copy
253+
std::vector<view2dr_d> two_d_reals_in(7);
254+
255+
ekat::host_to_device({d.t, d.pmid, d.nm, d.pint, d.rhoi, d.ti, d.ni},
256+
std::vector<int>(7, d.ncol),
257+
std::vector<int>{ // dim2 sizes
258+
d.init.pver,
259+
d.init.pver,
260+
d.init.pver,
261+
d.init.pver + 1,
262+
d.init.pver + 1,
263+
d.init.pver + 1,
264+
d.init.pver + 1},
265+
two_d_reals_in);
266+
267+
const auto t = two_d_reals_in[0];
268+
const auto pmid = two_d_reals_in[1];
269+
const auto nm = two_d_reals_in[2];
270+
const auto pint = two_d_reals_in[3];
271+
const auto rhoi = two_d_reals_in[4];
272+
const auto ti = two_d_reals_in[5];
273+
const auto ni = two_d_reals_in[6];
274+
275+
auto policy = ekat::TeamPolicyFactory<ExeSpace>::get_default_team_policy(d.ncol, d.init.pver);
276+
277+
// unpack init because we do not want the lambda to capture it
278+
const int pver = d.init.pver;
279+
const Real cpair = d.cpair;
280+
281+
Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const MemberType& team) {
282+
const int col = team.league_rank();
283+
284+
// Get single-column subviews of all inputs, shouldn't need any i-indexing
285+
// after this.
286+
const auto t_c = ekat::subview(t, col);
287+
const auto pmid_c = ekat::subview(pmid, col);
288+
const auto nm_c = ekat::subview(nm, col);
289+
const auto pint_c = ekat::subview(pint, col);
290+
const auto rhoi_c = ekat::subview(rhoi, col);
291+
const auto ti_c = ekat::subview(ti, col);
292+
const auto ni_c = ekat::subview(ni, col);
293+
294+
GWF::gw_prof(
295+
team,
296+
pver,
297+
cpair,
298+
t_c,
299+
pmid_c,
300+
pint_c,
301+
rhoi_c,
302+
ti_c,
303+
nm_c,
304+
ni_c
305+
);
306+
});
307+
308+
// Get outputs back
309+
std::vector<view2dr_d> two_d_reals_out = {rhoi, ti, nm, ni};
310+
ekat::device_to_host({d.rhoi, d.ti, d.nm, d.ni},
311+
std::vector<int>(4, d.ncol),
312+
std::vector<int>{d.init.pver+1, d.init.pver+1, d.init.pver, d.init.pver+1},
313+
two_d_reals_out);
314+
315+
gw_finalize_cxx(d.init);
316+
}
317+
248318
void momentum_energy_conservation_f(MomentumEnergyConservationData& d)
249319
{
250320
gw_init(d.init);

components/eamxx/src/physics/gw/tests/infra/gw_test_data.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ struct GwOroSrcData : public PhysicsTestData {
854854
// Glue functions to call fortran from from C++ with the Data struct
855855
void gwd_compute_tendencies_from_stress_divergence(GwdComputeTendenciesFromStressDivergenceData& d);
856856
void gwd_compute_tendencies_from_stress_divergence_f(GwdComputeTendenciesFromStressDivergenceData& d);
857+
void gw_prof_f(GwProfData& d);
857858
void gw_prof(GwProfData& d);
858859
void momentum_energy_conservation_f(MomentumEnergyConservationData& d);
859860
void momentum_energy_conservation(MomentumEnergyConservationData& d);

0 commit comments

Comments
 (0)