Skip to content

Commit 8f12c3f

Browse files
committed
Merge branch 'jgfouca/bfb_unit_no_f90_shoc' into master (PR #6779)
Following the pattern of a recent P3 PR, this PR changes the bfb unit tests to be based on baseline files instead of fortran calls. Unlike in that PR, we are actually able to completely remove all dependence on fortran code here because SHOC doesn't need any f90 init routines. Change list: 1) Main change: For all the BFB SHOC unit tests, use baseline files instead of calling fortran! 2) Remove all of the CXX->f90 bridge code 3) Cleanup of shoc_init situation. SHOC has no global data, so you only need to call it if you need to compute npbl 4) Reorg by moving some less-interesting testing stuff into shoc/tests/infra. Now, the files presented in the public interface (just shoc directory) is a lot cleaner. 5) All shoc tests (bfb and run_and_cmp) accept the same set of arguments (-c for compare, -g for generate, -b for baseline dir, and -n for no baselines). 6) Some of the tests that just confirm that DIFFs are caught do need to do thread count spreads. 7) Remove the ability of shoc.F90 to call CXX impls.. 8) gen-boiler can now generate _host functions for function that have array data. 9) No longer need to worry about transposing array data, we never cross language boundary anymore. There were a lot of f90 functions that had a property test but were never actually used in the CXX. I removed all of these: * shoc_impli_sfc_fluxes_tests.cpp * shoc_impli_srf_stress_tests.cpp * shoc_impli_srf_tke_tests.cpp * shoc_energy_total_fixer_tests.cpp * shoc_energy_dse_fixer_tests.cpp * shoc_energy_threshold_fixer_tests.cpp * shoc_fterm_input_third_moms_tests.cpp * shoc_fterm_diag_third_moms_tests.cpp * shoc_omega_diag_third_moms_tests.cpp * shoc_xy_diag_third_moms_tests.cpp * shoc_aa_diag_third_moms_tests.cpp * shoc_w3_diag_third_moms_tests.cpp [BFB] Moved over from: E3SM-Project/scream#3121 P3 DIFFs are expected because I caught a few places that were transposing views that didn't need to do that anymore. SHOC fails are expected because the baselines haven't been generated for the containers yet. Follow on work: Port P3 init code from fortran; this is the last of the fortran being used by either p3 or shoc Remove all references to "fortran" or "Fortran". These names are all obsolete.
2 parents 406c2f3 + c4829c4 commit 8f12c3f

File tree

90 files changed

+2898
-7318
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+2898
-7318
lines changed

components/eam/src/physics/cam/shoc.F90

Lines changed: 0 additions & 529 deletions
Large diffs are not rendered by default.

components/eamxx/scripts/gen_boiler.py

Lines changed: 215 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,21 @@
138138
)),
139139

140140
("cxx_f2c_bind_decl" , (
141-
lambda phys, sub, gb: f"{phys}_functions_f90.hpp",
141+
lambda phys, sub, gb: f"tests/infra/{phys}_test_data.hpp",
142142
lambda phys, sub, gb: expect_exists(phys, sub, gb, "cxx_f2c_bind_decl"),
143-
lambda phys, sub, gb: get_cxx_close_block_regex(comment="end _f function decls"), # reqs special comment
144-
lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_f"), # cxx_f decl
143+
lambda phys, sub, gb: get_plain_comment_regex(comment="end _host function decls"), # reqs special comment
144+
lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_host"), # cxx_host decl
145145
lambda phys, sub, gb: re.compile(r".*;\s*$"), # ;
146-
lambda *x : "The f90 to cxx function declaration(<name>_f)"
146+
lambda *x : "The f90 to cxx function declaration(<name>_host)"
147147
)),
148148

149149
("cxx_f2c_bind_impl" , (
150-
lambda phys, sub, gb: f"{phys}_functions_f90.cpp",
150+
lambda phys, sub, gb: f"tests/infra/{phys}_test_data.cpp",
151151
lambda phys, sub, gb: expect_exists(phys, sub, gb, "cxx_f2c_bind_impl"),
152152
lambda phys, sub, gb: get_namespace_close_regex(phys), # insert at end of namespace
153-
lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_f"), # cxx_f
153+
lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_host"), # cxx_f
154154
lambda phys, sub, gb: get_cxx_close_block_regex(at_line_start=True), # terminating }
155-
lambda *x : "The f90 to cxx function implementation(<name>_f)"
155+
lambda *x : "The f90 to cxx function implementation(<name>_host)"
156156
)),
157157

158158
("cxx_func_decl", (
@@ -455,6 +455,12 @@ def get_cxx_struct_begin_regex(struct):
455455
struct_regex_str = fr"^\s*struct\s+{struct}([\W]|$)"
456456
return re.compile(struct_regex_str)
457457

458+
###############################################################################
459+
def get_plain_comment_regex(comment):
460+
###############################################################################
461+
comment_regex_str = fr"^\s*//\s*{comment}"
462+
return re.compile(comment_regex_str)
463+
458464
###############################################################################
459465
def get_data_struct_name(sub):
460466
###############################################################################
@@ -1169,6 +1175,21 @@ def split_by_type(arg_data):
11691175

11701176
return reals, ints, logicals
11711177

1178+
###############################################################################
1179+
def split_by_scalar_vs_view(arg_data):
1180+
###############################################################################
1181+
"""
1182+
Take arg data and split into two lists of names based on scalar/not-scalar: [scalars] [non-scalars]
1183+
"""
1184+
scalars, non_scalars = [], []
1185+
for name, _, _, dims in arg_data:
1186+
if dims is not None:
1187+
non_scalars.append(name)
1188+
else:
1189+
scalars.append(name)
1190+
1191+
return scalars, non_scalars
1192+
11721193
###############################################################################
11731194
def gen_cxx_data_args(physics, arg_data):
11741195
###############################################################################
@@ -1441,6 +1462,30 @@ def check_existing_piece(lines, begin_regex, end_regex):
14411462

14421463
return None if begin_idx is None else (begin_idx, end_idx+1)
14431464

1465+
###############################################################################
1466+
def get_data_by_name(arg_data, arg_name, data_idx):
1467+
###############################################################################
1468+
for name, a, b, c in arg_data:
1469+
if name == arg_name:
1470+
return [name, a, b, c][data_idx]
1471+
1472+
expect(False, f"Name {arg_name} not found")
1473+
1474+
###############################################################################
1475+
def get_rank_map(arg_data, arg_names):
1476+
###############################################################################
1477+
# Create map of rank -> [args]
1478+
rank_map = {}
1479+
for arg in arg_names:
1480+
dims = get_data_by_name(arg_data, arg, ARG_DIMS)
1481+
rank = len(dims)
1482+
if rank in rank_map:
1483+
rank_map[rank].append(arg)
1484+
else:
1485+
rank_map[rank] = [arg]
1486+
1487+
return rank_map
1488+
14441489
#
14451490
# Main classes
14461491
#
@@ -1505,10 +1550,10 @@ def _get_db(self, phys):
15051550
db = parse_origin(origin_file.open(encoding="utf-8").read(), self._subs)
15061551
self._db[phys].update(db)
15071552
if self._verbose:
1508-
print("For physics {}, found:")
1553+
print(f"For physics {phys}, found:")
15091554
for sub in self._subs:
15101555
if sub in db:
1511-
print(" For subroutine {}, found args:")
1556+
print(f" For subroutine {sub}, found args:")
15121557
for name, argtype, intent, dims in db[sub]:
15131558
print(" name:{} type:{} intent:{} dims:({})".\
15141559
format(name, argtype, intent, ",".join(dims) if dims else "scalar"))
@@ -1729,7 +1774,7 @@ def gen_cxx_f2c_bind_decl(self, phys, sub, force_arg_data=None):
17291774
arg_data = force_arg_data if force_arg_data else self._get_arg_data(phys, sub)
17301775
arg_decls = gen_arg_cxx_decls(arg_data)
17311776

1732-
return f"void {sub}_f({', '.join(arg_decls)});"
1777+
return f"void {sub}_host({', '.join(arg_decls)});"
17331778

17341779
###########################################################################
17351780
def gen_cxx_f2c_bind_impl(self, phys, sub, force_arg_data=None):
@@ -1809,8 +1854,166 @@ def gen_cxx_f2c_bind_impl(self, phys, sub, force_arg_data=None):
18091854

18101855
impl = ""
18111856
if has_arrays(arg_data):
1812-
# TODO
1813-
impl += " // TODO"
1857+
#
1858+
# Steps:
1859+
# 1) Set up typedefs
1860+
# 2) Sync to device
1861+
# 3) Unpack view array
1862+
# 4) Get nk_pack and policy
1863+
# 5) Get subviews
1864+
# 6) Call fn
1865+
# 7) Sync back to host
1866+
#
1867+
inputs, inouts, outputs = split_by_intent(arg_data)
1868+
reals, ints, bools = split_by_type(arg_data)
1869+
scalars, views = split_by_scalar_vs_view(arg_data)
1870+
all_inputs = inputs + inouts
1871+
all_outputs = inouts + outputs
1872+
1873+
vreals = list(sorted(set(reals) & set(views)))
1874+
vints = list(sorted(set(ints) & set(views)))
1875+
vbools = list(sorted(set(bools) & set(views)))
1876+
1877+
sreals = list(sorted(set(reals) & set(scalars)))
1878+
sints = list(sorted(set(ints) & set(scalars)))
1879+
sbools = list(sorted(set(bools) & set(scalars)))
1880+
1881+
ivreals = list(sorted(set(vreals) & set(all_inputs)))
1882+
ivints = list(sorted(set(vints) & set(all_inputs)))
1883+
ivbools = list(sorted(set(vbools) & set(all_inputs)))
1884+
1885+
ovreals = list(sorted(set(vreals) & set(all_outputs)))
1886+
ovints = list(sorted(set(vints) & set(all_outputs)))
1887+
ovbools = list(sorted(set(vbools) & set(all_outputs)))
1888+
1889+
isreals = list(sorted(set(sreals) & set(all_inputs)))
1890+
isints = list(sorted(set(sints) & set(all_inputs)))
1891+
isbools = list(sorted(set(sbools) & set(all_inputs)))
1892+
1893+
osreals = list(sorted(set(sreals) & set(all_outputs)))
1894+
osints = list(sorted(set(sints) & set(all_outputs)))
1895+
osbools = list(sorted(set(sbools) & set(all_outputs)))
1896+
1897+
#
1898+
# 1) Set up typedefs
1899+
#
1900+
1901+
# set up basics
1902+
impl += "#if 0\n" # There's no way to guarantee this code compiles
1903+
impl += " using SHF = Functions<Real, DefaultDevice>;\n"
1904+
impl += " using Scalar = typename SHF::Scalar;\n"
1905+
impl += " using Spack = typename SHF::Spack;\n"
1906+
impl += " using KT = typename SHF::KT;\n"
1907+
impl += " using ExeSpace = typename KT::ExeSpace;\n"
1908+
impl += " using MemberType = typename SHF::MemberType;\n\n"
1909+
1910+
prefix_list = ["", "i", "b"]
1911+
type_list = ["Real", "Int", "bool"]
1912+
ktype_list = ["Spack", "Int", "bool"]
1913+
1914+
# make necessary view types. Anything that's an array needs a view type
1915+
for view_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list):
1916+
if view_group:
1917+
rank_map = get_rank_map(arg_data, view_group)
1918+
for rank in rank_map:
1919+
if typename == "Real" and rank > 1:
1920+
# Probably this should be packed data
1921+
impl += f" using {prefix_char}view_{rank}d = typename SHF::view_{rank}d<Spack>;\n"
1922+
else:
1923+
impl += f" using {prefix_char}view_{rank}d = typename SHF::view_{rank}d<{typename}>;\n"
1924+
1925+
impl += "\n"
1926+
1927+
#
1928+
# 2) Sync to device. Do ALL views, not just inputs
1929+
#
1930+
1931+
for input_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list):
1932+
if input_group:
1933+
rank_map = get_rank_map(arg_data, input_group)
1934+
1935+
for rank, arg_list in rank_map.items():
1936+
impl += f" static constexpr Int {prefix_char}num_arrays_{rank} = {len(arg_list)};\n"
1937+
impl += f" std::vector<{prefix_char}view_{rank}d> {prefix_char}temp_d_{rank}({prefix_char}num_arrays_{rank});\n"
1938+
for rank_itr in range(rank):
1939+
dims = [get_data_by_name(arg_data, arg_name, ARG_DIMS)[rank_itr] for arg_name in arg_list]
1940+
impl += f" std::vector<int> {prefix_char}dim_{rank}_{rank_itr}_sizes = {{{', '.join(dims)}}};\n"
1941+
dim_vectors = [f"{prefix_char}dim_{rank}_{rank_itr}_sizes" for rank_itr in range(rank)]
1942+
funcname = "ekat::host_to_device" if (typename == "Real" and rank > 1) else "ScreamDeepCopy::copy_to_device"
1943+
impl += f" {funcname}({{{', '.join(arg_list)}}}, {', '.join(dim_vectors)}, {prefix_char}temp_d_{rank});\n\n"
1944+
1945+
#
1946+
# 3) Unpack view array
1947+
#
1948+
1949+
for input_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list):
1950+
if input_group:
1951+
rank_map = get_rank_map(arg_data, input_group)
1952+
1953+
for rank, arg_list in rank_map.items():
1954+
impl += f" {prefix_char}view_{rank}d\n"
1955+
for idx, input_item in enumerate(arg_list):
1956+
impl += f" {input_item}_d({prefix_char}temp_d_{rank}[{idx}]){';' if idx == len(arg_list) - 1 else ','}\n"
1957+
impl += "\n"
1958+
1959+
1960+
#
1961+
# 4) Get nk_pack and policy, launch kernel
1962+
#
1963+
impl += " const Int nk_pack = ekat::npack<Spack>(nlev);\n"
1964+
impl += " const auto policy = ekat::ExeSpaceUtils<ExeSpace>::get_default_team_policy(shcol, nk_pack);\n"
1965+
impl += " Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const MemberType& team) {\n"
1966+
impl += " const Int i = team.league_rank();\n\n"
1967+
1968+
#
1969+
# 5) Get subviews
1970+
#
1971+
for view_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list):
1972+
if view_group:
1973+
for view_arg in view_group:
1974+
dims = get_data_by_name(arg_data, view_arg, ARG_DIMS)
1975+
if "shcol" in dims:
1976+
if len(dims) == 1:
1977+
impl += f" const Scalar {view_arg}_s = {view_arg}_d(i);\n"
1978+
else:
1979+
impl += f" const auto {view_arg}_s = ekat::subview({view_arg}_d, i);\n"
1980+
1981+
impl += "\n"
1982+
1983+
#
1984+
# 6) Call fn
1985+
#
1986+
kernel_arg_names = []
1987+
for arg_name in arg_names:
1988+
if arg_name in views:
1989+
if "shcol" in dims:
1990+
kernel_arg_names.append(f"{arg_name}_s")
1991+
else:
1992+
kernel_arg_names.append(f"{arg_name}_d")
1993+
else:
1994+
kernel_arg_names.append(arg_name)
1995+
1996+
impl += f" SHF::{sub}({', '.join(kernel_arg_names)});\n"
1997+
impl += " });\n"
1998+
1999+
#
2000+
# 7) Sync back to host
2001+
#
2002+
for output_group, prefix_char, typename in zip([ovreals, ovints, ovbools], prefix_list, type_list):
2003+
if output_group:
2004+
rank_map = get_rank_map(arg_data, output_group)
2005+
2006+
for rank, arg_list in rank_map.items():
2007+
impl += f" std::vector<{prefix_char}view_{rank}d> {prefix_char}tempout_d_{rank}({prefix_char}num_arrays_{rank});\n"
2008+
for rank_itr in range(rank):
2009+
dims = [get_data_by_name(arg_data, arg_name, ARG_DIMS)[rank_itr] for arg_name in arg_list]
2010+
impl += f" std::vector<int> {prefix_char}dim_{rank}_{rank_itr}_out_sizes = {{{', '.join(dims)}}};\n"
2011+
dim_vectors = [f"{prefix_char}dim_{rank}_{rank_itr}_out_sizes" for rank_itr in range(rank)]
2012+
funcname = "ekat::device_to_host" if (typename == "Real" and rank > 1) else "ScreamDeepCopy::copy_to_host"
2013+
impl += f" {funcname}({{{', '.join(arg_list)}}}, {', '.join(dim_vectors)}, {prefix_char}tempout_d_{rank});\n\n"
2014+
2015+
impl += "#endif\n"
2016+
18142017
else:
18152018
inputs, inouts, outputs = split_by_intent(arg_data)
18162019
reals, ints, logicals = split_by_type(arg_data)

components/eamxx/src/physics/p3/tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ if (NOT SCREAM_P3_SMALL_KERNELS AND NOT SCREAM_ONLY_GENERATE_BASELINES)
8989
CreateUnitTest(p3_sk_tests "p3_main_unit_tests.cpp"
9090
LIBS p3_sk p3_test_infra
9191
EXE_ARGS "--args ${BASELINE_FILE_ARG}"
92-
THREADS 1 ${SCREAM_TEST_MAX_THREADS} ${SCREAM_TEST_THREAD_INC}
92+
THREADS ${P3_THREADS}
9393
LABELS "p3_sk;physics;baseline_cmp")
9494
endif()
9595

components/eamxx/src/physics/p3/tests/infra/p3_data.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ struct P3Data {
1616
using KT = KokkosTypes<HostDevice>;
1717
using Scalar = Real;
1818

19-
using Array1 = typename KT::template lview<Scalar*>;
20-
using Array2 = typename KT::template lview<Scalar**>;
21-
using Array3 = typename KT::template lview<Scalar***>;
19+
using Array1 = typename KT::template view_1d<Scalar>;
20+
using Array2 = typename KT::template view_2d<Scalar>;
21+
using Array3 = typename KT::template view_3d<Scalar>;
2222

2323
bool do_predict_nc;
2424
bool do_prescribed_CCN;

components/eamxx/src/physics/p3/tests/infra/p3_test_data.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,7 +1299,7 @@ Int p3_main_host(
12991299
}
13001300
}
13011301

1302-
ekat::host_to_device(ptr_array, dim1_sizes, dim2_sizes, temp_d, true);
1302+
ekat::host_to_device(ptr_array, dim1_sizes, dim2_sizes, temp_d);
13031303

13041304
int counter = 0;
13051305
view_2d
@@ -1452,7 +1452,7 @@ Int p3_main_host(
14521452
rho_qi, qv2qi_depos_tend,
14531453
liq_ice_exchange, vap_liq_exchange, vap_ice_exchange, precip_liq_flux, precip_ice_flux, precip_liq_surf, precip_ice_surf
14541454
},
1455-
dim1_sizes_out, dim2_sizes_out, inout_views, true);
1455+
dim1_sizes_out, dim2_sizes_out, inout_views);
14561456

14571457
return elapsed_microsec;
14581458
}

components/eamxx/src/physics/p3/tests/p3_main_unit_tests.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,15 +445,13 @@ void run_bfb_p3_main()
445445

446446
// Get data from cxx
447447
for (auto& d : isds_cxx) {
448-
d.template transpose<ekat::TransposeDirection::c2f>();
449448
p3_main_host(
450449
d.qc, d.nc, d.qr, d.nr, d.th_atm, d.qv, d.dt, d.qi, d.qm, d.ni,
451450
d.bm, d.pres, d.dz, d.nc_nuceat_tend, d.nccn_prescribed, d.ni_activated, d.inv_qc_relvar, d.it, d.precip_liq_surf,
452451
d.precip_ice_surf, d.its, d.ite, d.kts, d.kte, d.diag_eff_radius_qc, d.diag_eff_radius_qi, d.diag_eff_radius_qr,
453452
d.rho_qi, d.do_predict_nc, d.do_prescribed_CCN, d.dpres, d.inv_exner, d.qv2qi_depos_tend,
454453
d.precip_liq_flux, d.precip_ice_flux, d.cld_frac_r, d.cld_frac_l, d.cld_frac_i,
455454
d.liq_ice_exchange, d.vap_liq_exchange, d.vap_ice_exchange, d.qv_prev, d.t_prev);
456-
d.template transpose<ekat::TransposeDirection::f2c>();
457455
}
458456

459457
if (SCREAM_BFB_TESTING && this->m_baseline_action == COMPARE) {

components/eamxx/src/physics/p3/tests/p3_run_and_cmp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,15 @@ int main (int argc, char** argv) {
263263
for (int i = 1; i < argc-1; ++i) {
264264
if (ekat::argv_matches(argv[i], "-g", "--generate")) { generate = true; no_baseline = false; }
265265
if (ekat::argv_matches(argv[i], "-c", "--compare")) { no_baseline = false; }
266-
if (ekat::argv_matches(argv[i], "-t", "--tol")) {
266+
if (ekat::argv_matches(argv[i], "-b", "--baseline-file")) {
267267
expect_another_arg(i, argc);
268268
++i;
269-
tol = std::atof(argv[i]);
269+
baseline_fn = argv[i];
270270
}
271-
if (ekat::argv_matches(argv[i], "-b", "--baseline-file")) {
271+
if (ekat::argv_matches(argv[i], "-t", "--tol")) {
272272
expect_another_arg(i, argc);
273273
++i;
274-
baseline_fn = argv[i];
274+
tol = std::atof(argv[i]);
275275
}
276276
if (ekat::argv_matches(argv[i], "-s", "--steps")) {
277277
expect_another_arg(i, argc);

components/eamxx/src/physics/shoc/CMakeLists.txt

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,7 @@
11
set(SHOC_SRCS
2-
shoc_f90.cpp
3-
shoc_ic_cases.cpp
4-
shoc_iso_c.f90
5-
shoc_iso_f.f90
6-
${SCREAM_BASE_DIR}/../eam/src/physics/cam/shoc.F90
72
eamxx_shoc_process_interface.cpp
83
)
94

10-
if (NOT SCREAM_LIB_ONLY)
11-
list(APPEND SHOC_SRCS
12-
shoc_functions_f90.cpp
13-
shoc_main_wrap.cpp
14-
) # Add f90 bridges needed for testing
15-
endif()
16-
175
set(SHOC_HEADERS
186
shoc.hpp
197
eamxx_shoc_process_interface.hpp

0 commit comments

Comments
 (0)