|
138 | 138 | )), |
139 | 139 |
|
140 | 140 | ("cxx_f2c_bind_decl" , ( |
141 | | - lambda phys, sub, gb: f"{phys}_functions_f90.hpp", |
| 141 | + lambda phys, sub, gb: f"tests/infra/{phys}_test_data.hpp", |
142 | 142 | lambda phys, sub, gb: expect_exists(phys, sub, gb, "cxx_f2c_bind_decl"), |
143 | | - lambda phys, sub, gb: get_cxx_close_block_regex(comment="end _f function decls"), # reqs special comment |
144 | | - lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_f"), # cxx_f decl |
| 143 | + lambda phys, sub, gb: get_plain_comment_regex(comment="end _host function decls"), # reqs special comment |
| 144 | + lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_host"), # cxx_host decl |
145 | 145 | lambda phys, sub, gb: re.compile(r".*;\s*$"), # ; |
146 | | - lambda *x : "The f90 to cxx function declaration(<name>_f)" |
| 146 | + lambda *x : "The f90 to cxx function declaration(<name>_host)" |
147 | 147 | )), |
148 | 148 |
|
149 | 149 | ("cxx_f2c_bind_impl" , ( |
150 | | - lambda phys, sub, gb: f"{phys}_functions_f90.cpp", |
| 150 | + lambda phys, sub, gb: f"tests/infra/{phys}_test_data.cpp", |
151 | 151 | lambda phys, sub, gb: expect_exists(phys, sub, gb, "cxx_f2c_bind_impl"), |
152 | 152 | lambda phys, sub, gb: get_namespace_close_regex(phys), # insert at end of namespace |
153 | | - lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_f"), # cxx_f |
| 153 | + lambda phys, sub, gb: get_cxx_function_begin_regex(sub + "_host"), # cxx_f |
154 | 154 | lambda phys, sub, gb: get_cxx_close_block_regex(at_line_start=True), # terminating } |
155 | | - lambda *x : "The f90 to cxx function implementation(<name>_f)" |
| 155 | + lambda *x : "The f90 to cxx function implementation(<name>_host)" |
156 | 156 | )), |
157 | 157 |
|
158 | 158 | ("cxx_func_decl", ( |
@@ -455,6 +455,12 @@ def get_cxx_struct_begin_regex(struct): |
455 | 455 | struct_regex_str = fr"^\s*struct\s+{struct}([\W]|$)" |
456 | 456 | return re.compile(struct_regex_str) |
457 | 457 |
|
| 458 | +############################################################################### |
| 459 | +def get_plain_comment_regex(comment): |
| 460 | +############################################################################### |
| 461 | + comment_regex_str = fr"^\s*//\s*{comment}" |
| 462 | + return re.compile(comment_regex_str) |
| 463 | + |
458 | 464 | ############################################################################### |
459 | 465 | def get_data_struct_name(sub): |
460 | 466 | ############################################################################### |
@@ -1169,6 +1175,21 @@ def split_by_type(arg_data): |
1169 | 1175 |
|
1170 | 1176 | return reals, ints, logicals |
1171 | 1177 |
|
| 1178 | +############################################################################### |
| 1179 | +def split_by_scalar_vs_view(arg_data): |
| 1180 | +############################################################################### |
| 1181 | + """ |
| 1182 | + Take arg data and split into two lists of names based on scalar/not-scalar: [scalars] [non-scalars] |
| 1183 | + """ |
| 1184 | + scalars, non_scalars = [], [] |
| 1185 | + for name, _, _, dims in arg_data: |
| 1186 | + if dims is not None: |
| 1187 | + non_scalars.append(name) |
| 1188 | + else: |
| 1189 | + scalars.append(name) |
| 1190 | + |
| 1191 | + return scalars, non_scalars |
| 1192 | + |
1172 | 1193 | ############################################################################### |
1173 | 1194 | def gen_cxx_data_args(physics, arg_data): |
1174 | 1195 | ############################################################################### |
@@ -1441,6 +1462,30 @@ def check_existing_piece(lines, begin_regex, end_regex): |
1441 | 1462 |
|
1442 | 1463 | return None if begin_idx is None else (begin_idx, end_idx+1) |
1443 | 1464 |
|
| 1465 | +############################################################################### |
| 1466 | +def get_data_by_name(arg_data, arg_name, data_idx): |
| 1467 | +############################################################################### |
| 1468 | + for name, a, b, c in arg_data: |
| 1469 | + if name == arg_name: |
| 1470 | + return [name, a, b, c][data_idx] |
| 1471 | + |
| 1472 | + expect(False, f"Name {arg_name} not found") |
| 1473 | + |
| 1474 | +############################################################################### |
| 1475 | +def get_rank_map(arg_data, arg_names): |
| 1476 | +############################################################################### |
| 1477 | + # Create map of rank -> [args] |
| 1478 | + rank_map = {} |
| 1479 | + for arg in arg_names: |
| 1480 | + dims = get_data_by_name(arg_data, arg, ARG_DIMS) |
| 1481 | + rank = len(dims) |
| 1482 | + if rank in rank_map: |
| 1483 | + rank_map[rank].append(arg) |
| 1484 | + else: |
| 1485 | + rank_map[rank] = [arg] |
| 1486 | + |
| 1487 | + return rank_map |
| 1488 | + |
1444 | 1489 | # |
1445 | 1490 | # Main classes |
1446 | 1491 | # |
@@ -1505,10 +1550,10 @@ def _get_db(self, phys): |
1505 | 1550 | db = parse_origin(origin_file.open(encoding="utf-8").read(), self._subs) |
1506 | 1551 | self._db[phys].update(db) |
1507 | 1552 | if self._verbose: |
1508 | | - print("For physics {}, found:") |
| 1553 | + print(f"For physics {phys}, found:") |
1509 | 1554 | for sub in self._subs: |
1510 | 1555 | if sub in db: |
1511 | | - print(" For subroutine {}, found args:") |
| 1556 | + print(f" For subroutine {sub}, found args:") |
1512 | 1557 | for name, argtype, intent, dims in db[sub]: |
1513 | 1558 | print(" name:{} type:{} intent:{} dims:({})".\ |
1514 | 1559 | format(name, argtype, intent, ",".join(dims) if dims else "scalar")) |
@@ -1729,7 +1774,7 @@ def gen_cxx_f2c_bind_decl(self, phys, sub, force_arg_data=None): |
1729 | 1774 | arg_data = force_arg_data if force_arg_data else self._get_arg_data(phys, sub) |
1730 | 1775 | arg_decls = gen_arg_cxx_decls(arg_data) |
1731 | 1776 |
|
1732 | | - return f"void {sub}_f({', '.join(arg_decls)});" |
| 1777 | + return f"void {sub}_host({', '.join(arg_decls)});" |
1733 | 1778 |
|
1734 | 1779 | ########################################################################### |
1735 | 1780 | def gen_cxx_f2c_bind_impl(self, phys, sub, force_arg_data=None): |
@@ -1809,8 +1854,166 @@ def gen_cxx_f2c_bind_impl(self, phys, sub, force_arg_data=None): |
1809 | 1854 |
|
1810 | 1855 | impl = "" |
1811 | 1856 | if has_arrays(arg_data): |
1812 | | - # TODO |
1813 | | - impl += " // TODO" |
| 1857 | + # |
| 1858 | + # Steps: |
| 1859 | + # 1) Set up typedefs |
| 1860 | + # 2) Sync to device |
| 1861 | + # 3) Unpack view array |
| 1862 | + # 4) Get nk_pack and policy |
| 1863 | + # 5) Get subviews |
| 1864 | + # 6) Call fn |
| 1865 | + # 7) Sync back to host |
| 1866 | + # |
| 1867 | + inputs, inouts, outputs = split_by_intent(arg_data) |
| 1868 | + reals, ints, bools = split_by_type(arg_data) |
| 1869 | + scalars, views = split_by_scalar_vs_view(arg_data) |
| 1870 | + all_inputs = inputs + inouts |
| 1871 | + all_outputs = inouts + outputs |
| 1872 | + |
| 1873 | + vreals = list(sorted(set(reals) & set(views))) |
| 1874 | + vints = list(sorted(set(ints) & set(views))) |
| 1875 | + vbools = list(sorted(set(bools) & set(views))) |
| 1876 | + |
| 1877 | + sreals = list(sorted(set(reals) & set(scalars))) |
| 1878 | + sints = list(sorted(set(ints) & set(scalars))) |
| 1879 | + sbools = list(sorted(set(bools) & set(scalars))) |
| 1880 | + |
| 1881 | + ivreals = list(sorted(set(vreals) & set(all_inputs))) |
| 1882 | + ivints = list(sorted(set(vints) & set(all_inputs))) |
| 1883 | + ivbools = list(sorted(set(vbools) & set(all_inputs))) |
| 1884 | + |
| 1885 | + ovreals = list(sorted(set(vreals) & set(all_outputs))) |
| 1886 | + ovints = list(sorted(set(vints) & set(all_outputs))) |
| 1887 | + ovbools = list(sorted(set(vbools) & set(all_outputs))) |
| 1888 | + |
| 1889 | + isreals = list(sorted(set(sreals) & set(all_inputs))) |
| 1890 | + isints = list(sorted(set(sints) & set(all_inputs))) |
| 1891 | + isbools = list(sorted(set(sbools) & set(all_inputs))) |
| 1892 | + |
| 1893 | + osreals = list(sorted(set(sreals) & set(all_outputs))) |
| 1894 | + osints = list(sorted(set(sints) & set(all_outputs))) |
| 1895 | + osbools = list(sorted(set(sbools) & set(all_outputs))) |
| 1896 | + |
| 1897 | + # |
| 1898 | + # 1) Set up typedefs |
| 1899 | + # |
| 1900 | + |
| 1901 | + # set up basics |
| 1902 | + impl += "#if 0\n" # There's no way to guarantee this code compiles |
| 1903 | + impl += " using SHF = Functions<Real, DefaultDevice>;\n" |
| 1904 | + impl += " using Scalar = typename SHF::Scalar;\n" |
| 1905 | + impl += " using Spack = typename SHF::Spack;\n" |
| 1906 | + impl += " using KT = typename SHF::KT;\n" |
| 1907 | + impl += " using ExeSpace = typename KT::ExeSpace;\n" |
| 1908 | + impl += " using MemberType = typename SHF::MemberType;\n\n" |
| 1909 | + |
| 1910 | + prefix_list = ["", "i", "b"] |
| 1911 | + type_list = ["Real", "Int", "bool"] |
| 1912 | + ktype_list = ["Spack", "Int", "bool"] |
| 1913 | + |
| 1914 | + # make necessary view types. Anything that's an array needs a view type |
| 1915 | + for view_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list): |
| 1916 | + if view_group: |
| 1917 | + rank_map = get_rank_map(arg_data, view_group) |
| 1918 | + for rank in rank_map: |
| 1919 | + if typename == "Real" and rank > 1: |
| 1920 | + # Probably this should be packed data |
| 1921 | + impl += f" using {prefix_char}view_{rank}d = typename SHF::view_{rank}d<Spack>;\n" |
| 1922 | + else: |
| 1923 | + impl += f" using {prefix_char}view_{rank}d = typename SHF::view_{rank}d<{typename}>;\n" |
| 1924 | + |
| 1925 | + impl += "\n" |
| 1926 | + |
| 1927 | + # |
| 1928 | + # 2) Sync to device. Do ALL views, not just inputs |
| 1929 | + # |
| 1930 | + |
| 1931 | + for input_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list): |
| 1932 | + if input_group: |
| 1933 | + rank_map = get_rank_map(arg_data, input_group) |
| 1934 | + |
| 1935 | + for rank, arg_list in rank_map.items(): |
| 1936 | + impl += f" static constexpr Int {prefix_char}num_arrays_{rank} = {len(arg_list)};\n" |
| 1937 | + impl += f" std::vector<{prefix_char}view_{rank}d> {prefix_char}temp_d_{rank}({prefix_char}num_arrays_{rank});\n" |
| 1938 | + for rank_itr in range(rank): |
| 1939 | + dims = [get_data_by_name(arg_data, arg_name, ARG_DIMS)[rank_itr] for arg_name in arg_list] |
| 1940 | + impl += f" std::vector<int> {prefix_char}dim_{rank}_{rank_itr}_sizes = {{{', '.join(dims)}}};\n" |
| 1941 | + dim_vectors = [f"{prefix_char}dim_{rank}_{rank_itr}_sizes" for rank_itr in range(rank)] |
| 1942 | + funcname = "ekat::host_to_device" if (typename == "Real" and rank > 1) else "ScreamDeepCopy::copy_to_device" |
| 1943 | + impl += f" {funcname}({{{', '.join(arg_list)}}}, {', '.join(dim_vectors)}, {prefix_char}temp_d_{rank});\n\n" |
| 1944 | + |
| 1945 | + # |
| 1946 | + # 3) Unpack view array |
| 1947 | + # |
| 1948 | + |
| 1949 | + for input_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list): |
| 1950 | + if input_group: |
| 1951 | + rank_map = get_rank_map(arg_data, input_group) |
| 1952 | + |
| 1953 | + for rank, arg_list in rank_map.items(): |
| 1954 | + impl += f" {prefix_char}view_{rank}d\n" |
| 1955 | + for idx, input_item in enumerate(arg_list): |
| 1956 | + impl += f" {input_item}_d({prefix_char}temp_d_{rank}[{idx}]){';' if idx == len(arg_list) - 1 else ','}\n" |
| 1957 | + impl += "\n" |
| 1958 | + |
| 1959 | + |
| 1960 | + # |
| 1961 | + # 4) Get nk_pack and policy, launch kernel |
| 1962 | + # |
| 1963 | + impl += " const Int nk_pack = ekat::npack<Spack>(nlev);\n" |
| 1964 | + impl += " const auto policy = ekat::ExeSpaceUtils<ExeSpace>::get_default_team_policy(shcol, nk_pack);\n" |
| 1965 | + impl += " Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const MemberType& team) {\n" |
| 1966 | + impl += " const Int i = team.league_rank();\n\n" |
| 1967 | + |
| 1968 | + # |
| 1969 | + # 5) Get subviews |
| 1970 | + # |
| 1971 | + for view_group, prefix_char, typename in zip([vreals, vints, vbools], prefix_list, type_list): |
| 1972 | + if view_group: |
| 1973 | + for view_arg in view_group: |
| 1974 | + dims = get_data_by_name(arg_data, view_arg, ARG_DIMS) |
| 1975 | + if "shcol" in dims: |
| 1976 | + if len(dims) == 1: |
| 1977 | + impl += f" const Scalar {view_arg}_s = {view_arg}_d(i);\n" |
| 1978 | + else: |
| 1979 | + impl += f" const auto {view_arg}_s = ekat::subview({view_arg}_d, i);\n" |
| 1980 | + |
| 1981 | + impl += "\n" |
| 1982 | + |
| 1983 | + # |
| 1984 | + # 6) Call fn |
| 1985 | + # |
| 1986 | + kernel_arg_names = [] |
| 1987 | + for arg_name in arg_names: |
| 1988 | + if arg_name in views: |
| 1989 | + if "shcol" in dims: |
| 1990 | + kernel_arg_names.append(f"{arg_name}_s") |
| 1991 | + else: |
| 1992 | + kernel_arg_names.append(f"{arg_name}_d") |
| 1993 | + else: |
| 1994 | + kernel_arg_names.append(arg_name) |
| 1995 | + |
| 1996 | + impl += f" SHF::{sub}({', '.join(kernel_arg_names)});\n" |
| 1997 | + impl += " });\n" |
| 1998 | + |
| 1999 | + # |
| 2000 | + # 7) Sync back to host |
| 2001 | + # |
| 2002 | + for output_group, prefix_char, typename in zip([ovreals, ovints, ovbools], prefix_list, type_list): |
| 2003 | + if output_group: |
| 2004 | + rank_map = get_rank_map(arg_data, output_group) |
| 2005 | + |
| 2006 | + for rank, arg_list in rank_map.items(): |
| 2007 | + impl += f" std::vector<{prefix_char}view_{rank}d> {prefix_char}tempout_d_{rank}({prefix_char}num_arrays_{rank});\n" |
| 2008 | + for rank_itr in range(rank): |
| 2009 | + dims = [get_data_by_name(arg_data, arg_name, ARG_DIMS)[rank_itr] for arg_name in arg_list] |
| 2010 | + impl += f" std::vector<int> {prefix_char}dim_{rank}_{rank_itr}_out_sizes = {{{', '.join(dims)}}};\n" |
| 2011 | + dim_vectors = [f"{prefix_char}dim_{rank}_{rank_itr}_out_sizes" for rank_itr in range(rank)] |
| 2012 | + funcname = "ekat::device_to_host" if (typename == "Real" and rank > 1) else "ScreamDeepCopy::copy_to_host" |
| 2013 | + impl += f" {funcname}({{{', '.join(arg_list)}}}, {', '.join(dim_vectors)}, {prefix_char}tempout_d_{rank});\n\n" |
| 2014 | + |
| 2015 | + impl += "#endif\n" |
| 2016 | + |
1814 | 2017 | else: |
1815 | 2018 | inputs, inouts, outputs = split_by_intent(arg_data) |
1816 | 2019 | reals, ints, logicals = split_by_type(arg_data) |
|
0 commit comments