245245
246246# physics map. maps the name of a physics packages containing the original fortran subroutines to:
247247# (path-to-origin, path-to-cxx-src, init-code)
248- ORIGIN_FILES , CXX_ROOT , INIT_CODE , FINALIZE_CODE , COLS_DIMNAME = range (5 )
248+ ORIGIN_FILES , CXX_ROOT , INIT_CODE , FINALIZE_CODE , COLS_DIMNAME , UNPACKED = range (6 )
249249PHYSICS = {
250250 "p3" : (
251251 ("components/eam/src/physics/cam/micro_p3.F90" ,),
252252 "components/eamxx/src/physics/p3" ,
253253 "p3_init();" ,
254254 "" ,
255- "its:ite"
255+ "its:ite" ,
256+ False
256257 ),
257258 "shoc" : (
258259 ("components/eam/src/physics/cam/shoc.F90" ,),
259260 "components/eamxx/src/physics/shoc" ,
260261 "shoc_init(d.nlev, true);" ,
261262 "" ,
262- "shcol"
263+ "shcol" ,
264+ False
263265 ),
264266 "dp" : (
265267 (
271273 "components/eamxx/src/physics/dp" ,
272274 "dp_init(d.plev, true);" ,
273275 ""
274- ""
276+ "" ,
277+ False
275278 ),
276279 "gw" : (
277280 (
285288 "components/eamxx/src/physics/gw" ,
286289 "gw_common_init(); // Might need more specific init" ,
287290 "gw_finalize_cxx();" ,
288- "ncol"
291+ "ncol" ,
292+ True
289293 ),
290294}
291295
@@ -715,60 +719,82 @@ def get_cxx_type(arg_datum):
715719 arg_cxx_type = get_cxx_scalar_type (arg_type )
716720 return f"{ arg_cxx_type } { '*' if is_ptr else '' } "
717721
718- KOKKOS_TYPE_MAP = {"real" : "Spack" , "integer" : "Int" , "logical" : "bool" }
719722###############################################################################
720- def get_kokkos_type (arg_datum ):
723+ def get_kokkos_type (arg_datum , col_dim , unpacked = False ):
721724###############################################################################
722725 """
723726 Based on arg datum, give c++ kokkos type
724727
725728 Note: We can only guess at the correct types, especially whether an argument
726729 should be packed data or not!
727730 """
731+ expect (col_dim is not None , "Kokkos must know col_dim name" )
732+
728733 is_const = arg_datum [ARG_INTENT ] == "in"
729- is_view = arg_datum [ARG_DIMS ] is not None
734+ dims = arg_datum [ARG_DIMS ]
735+ # Remove the cols dim, we parallelize over cols
736+ if dims is not None and col_dim in dims :
737+ dims = list (dims )
738+ dims .remove (col_dim )
739+
740+ is_view = bool (dims )
730741 arg_type = arg_datum [ARG_TYPE ]
731742 if is_custom_type (arg_type ):
732743 kokkos_type = arg_type .split ("::" )[- 1 ]
733744 else :
734- kokkos_type = KOKKOS_TYPE_MAP [arg_type ]
745+ kokkos_type = CXX_TYPE_MAP [arg_type ]
746+ if kokkos_type == "Real" and not unpacked :
747+ kokkos_type = "Spack"
735748
736749 base_type = f"{ 'const ' if is_const else '' } { kokkos_type } "
737750
738- # We assume 1d even if the f90 array is 2d since we assume c++ will spawn a kernel
739- # over one of the dimensions
740- return f"const uview_1d<{ base_type } >&" if is_view else f"{ base_type } &"
751+ return f"const uview_{ len (dims )} d<{ base_type } >&" if is_view else f"{ base_type } &"
741752
742753###############################################################################
743- def gen_arg_cxx_decls (arg_data , kokkos = False ):
754+ def gen_arg_cxx_decls (arg_data , kokkos = False , unpacked = False , col_dim = None ):
744755###############################################################################
745756 """
746757 Get all arg decls for a set of arg data. kokkos flag tells us to use C++/Kokkos
747758 types instead of C types.
748759 """
749760 arg_names = [item [ARG_NAME ] for item in arg_data ]
750- get_type = get_kokkos_type if kokkos else get_cxx_type
751- arg_types = [get_type (item ) for item in arg_data ]
752- arg_sig_list = [f"{ arg_type } { arg_name } " for arg_name , arg_type in zip (arg_names , arg_types )]
761+ if kokkos :
762+ arg_types = [get_kokkos_type (item , col_dim , unpacked = unpacked ) for item in arg_data ]
763+ else :
764+ arg_types = [get_cxx_type (item ) for item in arg_data ]
765+
766+ arg_sig_list = [(f"{ arg_type } { arg_name } " , arg_datum [ARG_INTENT ])
767+ for arg_name , arg_type , arg_datum in zip (arg_names , arg_types , arg_data )]
768+
769+ # For kokkos functions, we will almost always want the team and we don't want
770+ # the col_dim
771+ if kokkos :
772+ arg_sig_list .insert (0 , ("const MemberType& team" , "in" ))
773+ for arg_sig , arg_intent in arg_sig_list :
774+ if arg_sig .split ()[- 1 ] == col_dim :
775+ expect (arg_intent == "in" , f"col_dim { col_dim } wasn't an input, { arg_intent } ?" )
776+ arg_sig_list .remove ((arg_sig , arg_intent ))
777+ break
778+
779+ result = []
753780
754781 # For permanent sigs, we want them to look nice. We may want to order these
755782 # by intent and scalar vs array, but for now, just mimic the fortran order.
756783 if kokkos :
757- list_with_comments = []
758784 intent_map = {"in" : "Inputs" , "inout" : "Inputs/Outputs" , "out" : "Outputs" }
759785 curr = None
760- for arg_sig , arg_datum in zip (arg_sig_list , arg_data ):
761- intent = arg_datum [ARG_INTENT ]
762- if intent != curr :
763- fullname = intent_map [intent ]
764- list_with_comments .append (f"// { fullname } " )
765- curr = intent
786+ for arg_sig , arg_intent in arg_sig_list :
787+ if arg_intent != curr :
788+ fullname = intent_map [arg_intent ]
789+ result .append (f"// { fullname } " )
790+ curr = arg_intent
766791
767- list_with_comments .append (arg_sig )
792+ result .append (arg_sig )
768793
769- arg_sig_list = list_with_comments
794+ else :
795+ result = [arg_sig for arg_sig , _ in arg_sig_list ]
770796
771- return arg_sig_list
797+ return result
772798
773799###############################################################################
774800def split_by_intent (arg_data ):
@@ -983,8 +1009,12 @@ def convert_to_cxx_dim(dim, add_underscore=False, from_d=False):
9831009 uns = "_" if add_underscore else ""
9841010 obj = "d." if from_d else ""
9851011
1012+ # null case, could not determine anything
1013+ if not tokens :
1014+ return ""
1015+
9861016 # case 1, single token
987- if len (tokens ) == 1 :
1017+ elif len (tokens ) == 1 :
9881018 expect (not tokens [0 ].startswith ("-" ), f"Received weird negative fortran dim: '{ dim } '" )
9891019 return obj + tokens [0 ] + uns
9901020
@@ -1032,7 +1062,7 @@ def convert_to_cxx_dim(dim, add_underscore=False, from_d=False):
10321062 return f"{ obj } { second_token } { uns } - { obj } { first_token } { uns } "
10331063
10341064 else :
1035- expect (False , f"Received weird fortran range with more than 2 tokens: '{ dim } '" )
1065+ expect (False , f"Received weird fortran range with more than 2 tokens: '{ tokens } '" )
10361066
10371067###############################################################################
10381068def gen_struct_api (struct_name , arg_data ):
@@ -1204,7 +1234,7 @@ def get_htd_dth_call(arg_data, rank, arg_list, typename, is_output=False, f2c=Fa
12041234 return result
12051235
12061236###############################################################################
1207- def gen_glue_impl (phys , sub , arg_data , arg_names , col_dim , f2c = False , unpack = False ):
1237+ def gen_glue_impl (phys , sub , arg_data , arg_names , col_dim , f2c = False , unpacked = False ):
12081238###############################################################################
12091239 """
12101240 Generate code that takes a TestData struct and unpacks it to call the CXX
@@ -1303,7 +1333,7 @@ def gen_glue_impl(phys, sub, arg_data, arg_names, col_dim, f2c=False, unpack=Fal
13031333 #
13041334 # 4) Get nk_pack and policy, unpack scalars, and launch kernel
13051335 #
1306- if unpack :
1336+ if unpacked :
13071337 impl += f" const auto policy = ekat::TeamPolicyFactory<ExeSpace>::get_default_team_policy({ obj } { col_dim } , { obj } nlev);\n \n "
13081338 else :
13091339 impl += f" const Int nk_pack = ekat::npack<Spack>({ obj } nlev);\n "
@@ -1419,7 +1449,7 @@ def __init__(self,
14191449 source_repo = get_git_toplevel_dir (),
14201450 target_repo = get_git_toplevel_dir (),
14211451 f2c = False ,
1422- unpack = False ,
1452+ unpacked = False ,
14231453 col_dim = None ,
14241454 dry_run = False ,
14251455 verbose = False ):
@@ -1443,7 +1473,7 @@ def __init__(self,
14431473 self ._kernel = kernel
14441474 self ._source_repo = Path (source_repo ).resolve ()
14451475 self ._target_repo = Path (target_repo ).resolve ()
1446- self ._unpack = unpack
1476+ self ._unpacked = unpacked if unpacked else get_physics_data ( physics , UNPACKED )
14471477 self ._col_dim = get_physics_data (physics , COLS_DIMNAME ) if col_dim is None else col_dim
14481478 self ._dry_run = dry_run
14491479 self ._verbose = verbose
@@ -1510,7 +1540,7 @@ def gen_cxx_c2f_bind_decl(self, phys, sub, force_arg_data=None):
15101540 In C, generate the C to F90 brige declaration. The definition will be in fortran
15111541 """
15121542 arg_data = force_arg_data if force_arg_data else self ._get_arg_data (phys , sub )
1513- arg_decls = gen_arg_cxx_decls (arg_data )
1543+ arg_decls = gen_arg_cxx_decls (arg_data , unpacked = self . _unpacked )
15141544 result = f"void { sub } _bridge_f({ ', ' .join (arg_decls )} );\n "
15151545 return result
15161546
@@ -1612,7 +1642,7 @@ def gen_cxx_t2cxx_glue_impl(self, phys, sub, force_arg_data=None):
16121642 arg_names = [item [ARG_NAME ] for item in arg_data ]
16131643 decl = self .gen_cxx_t2cxx_glue_decl (phys , sub , force_arg_data = force_arg_data ).rstrip (";" )
16141644
1615- impl = gen_glue_impl (phys , sub , arg_data , arg_names , self ._col_dim , unpack = self ._unpack )
1645+ impl = gen_glue_impl (phys , sub , arg_data , arg_names , self ._col_dim , unpacked = self ._unpacked )
16161646
16171647 result = \
16181648f"""{ decl }
@@ -1655,7 +1685,7 @@ def gen_cxx_f2c_bind_impl(self, phys, sub, force_arg_data=None):
16551685 """
16561686 arg_data = force_arg_data if force_arg_data else self ._get_arg_data (phys , sub )
16571687 arg_names = [item [ARG_NAME ] for item in arg_data ]
1658- arg_decls = gen_arg_cxx_decls (arg_data )
1688+ arg_decls = gen_arg_cxx_decls (arg_data , unpacked = self . _unpacked )
16591689 decl = f"void { sub } _bridge_c({ ', ' .join (arg_decls )} )"
16601690
16611691 impl = gen_glue_impl (phys , sub , arg_data , arg_names , self ._col_dim , f2c = True )
@@ -1675,7 +1705,7 @@ def gen_cxx_func_decl(self, phys, sub, force_arg_data=None):
16751705 In CXX, generate the function declaration in the main physics header
16761706 """
16771707 arg_data = force_arg_data if force_arg_data else self ._get_arg_data (phys , sub )
1678- arg_decls = gen_arg_cxx_decls (arg_data , kokkos = True )
1708+ arg_decls = gen_arg_cxx_decls (arg_data , kokkos = True , unpacked = self . _unpacked , col_dim = self . _col_dim )
16791709
16801710 arg_decls_str = ("\n " .join ([item if item .startswith ("//" ) else f"{ item } ," for item in arg_decls ])).rstrip ("," )
16811711
0 commit comments