From 6abc14ba064bc4698b72f371c60f17ee23350e6d Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:01:24 -0600 Subject: [PATCH 01/10] Remove workload pykokkos style --- pykokkos/core/parsers/parser.py | 7 +------ pykokkos/core/translators/members.py | 16 ++-------------- pykokkos/core/translators/static.py | 11 +++-------- 3 files changed, 6 insertions(+), 28 deletions(-) diff --git a/pykokkos/core/parsers/parser.py b/pykokkos/core/parsers/parser.py index f2913c61..b910dc4f 100644 --- a/pykokkos/core/parsers/parser.py +++ b/pykokkos/core/parsers/parser.py @@ -14,7 +14,6 @@ class PyKokkosStyles(Enum): """ functor = auto() - workload = auto() workunit = auto() classtype = auto() fused = auto() @@ -61,7 +60,6 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None): self.functors: Dict[str, PyKokkosEntity] = {} self.workunits: Dict[str, PyKokkosEntity] = {} - self.workloads = self.get_entities(PyKokkosStyles.workload) self.classtypes = self.get_entities(PyKokkosStyles.classtype) self.functors = self.get_entities(PyKokkosStyles.functor) self.workunits = self.get_entities(PyKokkosStyles.workunit) @@ -74,7 +72,6 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None): self.tree = ast.Module(body=[]) self.path = None self.pk_import = pk_import - self.workloads: Dict[str, PyKokkosEntity] = {} self.classtypes: Dict[str, PyKokkosEntity] = {} self.functors: Dict[str, PyKokkosEntity] = {} self.workunits: Dict[str, PyKokkosEntity] = {} @@ -132,9 +129,7 @@ def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]: entities: Dict[str, PyKokkosEntity] = {} check_entity: Callable[[ast.stmt], bool] - if style is PyKokkosStyles.workload: - return entities - elif style is PyKokkosStyles.functor: + if style is PyKokkosStyles.functor: check_entity = self.is_functor elif style is PyKokkosStyles.workunit: check_entity = self.is_workunit diff --git a/pykokkos/core/translators/members.py b/pykokkos/core/translators/members.py index 3d8e8e64..c7a5c908 100644 --- a/pykokkos/core/translators/members.py +++ b/pykokkos/core/translators/members.py @@ -56,13 +56,7 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N source: Tuple[List[str], int] = entity.source pk_import: str = entity.pk_import - if entity.style is PyKokkosStyles.workload: - self.pk_mains = self.get_decorated_functions(AST, Decorator.KokkosMain) - self.fields = self.get_fields(AST, source, pk_import) - self.views = self.get_views(AST, source, pk_import) - self.random_pool = self.get_random_pool(AST, source, pk_import) - - elif entity.style is PyKokkosStyles.functor: + if entity.style is PyKokkosStyles.functor: self.fields = self.get_fields(AST, source, pk_import) self.views = self.get_views(AST, source, pk_import) self.random_pool = self.get_random_pool(AST, source, pk_import) @@ -104,7 +98,7 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N if n in self.view_template_params: t.template_params.extend(self.view_template_params[n]) - if entity.style in (PyKokkosStyles.workload, PyKokkosStyles.functor): + if entity.style == PyKokkosStyles.functor: self.pk_workunits = self.get_decorated_functions(AST, Decorator.WorkUnit) self.pk_functions = self.get_decorated_functions( AST, Decorator.KokkosFunction @@ -120,12 +114,6 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N self.classtype_methods = self.get_classtype_methods(classtypes) - if entity.style is PyKokkosStyles.workload: - name: str = f"pk_functor_{entity.name}" - self.reduction_result_queue, self.timer_result_queue = self.get_queues( - source, name, pk_import - ) - if len(self.pk_mains) > 1: print("ERROR: Only one pk.main function can be translated") sys.exit(1) diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index 3feb9294..7d524008 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -511,14 +511,9 @@ def generate_bindings( """ bindings: List[str] - if entity.style is PyKokkosStyles.workload: - bindings = bind_main( - functor_name, self.pk_members, source, self.pk_import, self.module_file - ) - else: - bindings = bind_workunits( - functor_name, self.pk_members, workunits, self.module_file - ) + bindings = bind_workunits( + functor_name, self.pk_members, workunits, self.module_file + ) return bindings From 9d194e909dc61d5f5e73f5dc2b5e0bddeb6abed2 Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:09:58 -0600 Subject: [PATCH 02/10] Remove pk.execute parallel dispatch Only used for workload dispatches --- pykokkos/interface/__init__.py | 1 - pykokkos/interface/parallel_dispatch.py | 7 ------- 2 files changed, 8 deletions(-) diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index a14c2ffa..8098b86f 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -67,7 +67,6 @@ from .mathematical_special_functions import cyl_bessel_j0, cyl_bessel_j1 from .memory_space import MemorySpace, get_default_memory_space from .parallel_dispatch import ( - execute, flush, parallel_for, parallel_reduce, diff --git a/pykokkos/interface/parallel_dispatch.py b/pykokkos/interface/parallel_dispatch.py index 3b79d755..c746d858 100644 --- a/pykokkos/interface/parallel_dispatch.py +++ b/pykokkos/interface/parallel_dispatch.py @@ -370,12 +370,5 @@ def parallel_scan(*args, **kwargs) -> Union[float, int]: return reduce_body("scan", *args, **kwargs) -def execute(space: ExecutionSpace, workload: object) -> None: - if space is ExecutionSpace.Default: - runtime_singleton.runtime.run_workload(km.get_default_space(), workload) - else: - runtime_singleton.runtime.run_workload(space, workload) - - def flush(): runtime_singleton.runtime.flush_trace() From f413090df765d9e88216ff762eabf95cc8b9bdd0 Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:18:53 -0600 Subject: [PATCH 03/10] Remove support for kokkos main tranlation Only used by workloads --- pykokkos/core/translators/bindings.py | 45 +- pykokkos/core/translators/members.py | 40 -- pykokkos/core/visitors/__init__.py | 1 - pykokkos/core/visitors/kokkosmain_visitor.py | 535 ------------------- 4 files changed, 1 insertion(+), 620 deletions(-) delete mode 100644 pykokkos/core/visitors/kokkosmain_visitor.py diff --git a/pykokkos/core/translators/bindings.py b/pykokkos/core/translators/bindings.py index cbb4d837..591134f1 100644 --- a/pykokkos/core/translators/bindings.py +++ b/pykokkos/core/translators/bindings.py @@ -4,7 +4,7 @@ from pykokkos.core import cppast from pykokkos.core.keywords import Keywords -from pykokkos.core.visitors import cpp_view_type, KokkosMainVisitor, visitors_util +from pykokkos.core.visitors import cpp_view_type, visitors_util from pykokkos.interface.data_types import DataType from .members import PyKokkosMembers @@ -564,49 +564,6 @@ def bind_workunits( return bindings -def translate_mains( - source: Tuple[List[str], int], - functor: str, - members: PyKokkosMembers, - pk_import: str, -) -> List[str]: - """ - Translate all PyKokkos main functions - - :param source: the python source code of the workload - :param functor: the name of the functor - :param members: an object containing the fields and views - :returns: a list of strings of translated source code - """ - - node_visitor = KokkosMainVisitor( - {}, - source, - members.views, - members.pk_workunits, - members.fields, - members.pk_functions, - members.classtype_methods, - functor, - pk_import, - debug=True, - ) - - translation: List[str] = [] - - for main in members.pk_mains.values(): - try: - translation.append(node_visitor.visit(main)) - except NotImplementedError: - print(f"Translation of {main.name} failed") - sys.exit(1) - - members.reduction_result_queue = node_visitor.reduction_result_queue - members.timer_result_queue = node_visitor.timer_result_queue - - return translation - - def bind_main_single( functor: str, members: PyKokkosMembers, diff --git a/pykokkos/core/translators/members.py b/pykokkos/core/translators/members.py index c7a5c908..a5f5cf2b 100644 --- a/pykokkos/core/translators/members.py +++ b/pykokkos/core/translators/members.py @@ -8,7 +8,6 @@ from pykokkos.core.parsers import PyKokkosEntity, PyKokkosStyles from pykokkos.core.visitors import ( ConstructorVisitor, - KokkosMainVisitor, ParameterVisitor, visitors_util, ) @@ -154,45 +153,6 @@ def get_views( return views - def get_queues( - self, source: Tuple[List[str], int], name: str, pk_import: str - ) -> Tuple[List[str], List[str]]: - """ - Get all fields assigned to a reduction result or timer result - - :param source: the python source code of the workload - :param name: the name of the workload - :param pk_import: the identifier used to access the PyKokkos package - :returns: two lists, one for the reduction results and one for the timer results - """ - - views = copy.deepcopy( - self.views - ) # Needed since KokkosMainVisitor modifies views - - # Copied from translate_mains() in bindings.py - node_visitor = KokkosMainVisitor( - {}, - source, - views, - self.pk_workunits, - self.fields, - self.pk_functions, - self.classtype_methods, - name, - pk_import, - debug=True, - ) - - for main in self.pk_mains.values(): - try: - node_visitor.visit(main) - except NotImplementedError: - print(f"Translation of {main.name} failed") - sys.exit(1) - - return (node_visitor.reduction_result_queue, node_visitor.timer_result_queue) - def get_real_views(self): """ Get all the views that contain a pk.real datatype diff --git a/pykokkos/core/visitors/__init__.py b/pykokkos/core/visitors/__init__.py index ea01e9a8..d2c5e252 100644 --- a/pykokkos/core/visitors/__init__.py +++ b/pykokkos/core/visitors/__init__.py @@ -2,7 +2,6 @@ from .constructor_visitor import ConstructorVisitor from .debug_transformer import DebugTransformer from .kokkosfunction_visitor import KokkosFunctionVisitor -from .kokkosmain_visitor import KokkosMainVisitor from .parameter_visitor import ParameterVisitor from .pykokkos_visitor import PyKokkosVisitor from .visitors_util import cpp_view_type, parse_view_template_params diff --git a/pykokkos/core/visitors/kokkosmain_visitor.py b/pykokkos/core/visitors/kokkosmain_visitor.py deleted file mode 100644 index 315a6a43..00000000 --- a/pykokkos/core/visitors/kokkosmain_visitor.py +++ /dev/null @@ -1,535 +0,0 @@ -import ast -from typing import List, Dict, Optional, Set, Union -from ast import FunctionDef - -from pykokkos.core import cppast -from pykokkos.core.keywords import Keywords -from pykokkos.interface import BinOp, BinSort, View, Iterate, TeamPolicy - -from . import visitors_util -from .pykokkos_visitor import PyKokkosVisitor - - -class KokkosMainVisitor(PyKokkosVisitor): - def __init__( - self, - env, - src, - views: Dict[str, View], - work_units: Dict[str, FunctionDef], - fields: Dict[cppast.DeclRefExpr, cppast.PrimitiveType], - kokkos_functions: Dict[str, FunctionDef], - dependency_methods: Dict[str, List[str]], - functor: str, - pk_import: str, - restrict_views: Set[str] = set(), - debug=False, - ): - super().__init__( - env, - src, - views, - work_units, - fields, - kokkos_functions, - dependency_methods, - pk_import, - restrict_views, - debug, - ) - - self.functor: str = functor - self.reduction_result_queue: List[str] = [] - self.timer_result_queue: List[str] = [] - - def visit_FunctionDef(self, node: ast.FunctionDef) -> str: - run_body: str = "" - serializer = cppast.Serializer() - for statement in node.body: - run_body += serializer.serialize(self.visit(statement)) - - return run_body - - def visit_Assign(self, node: ast.Assign) -> cppast.Stmt: - target = node.targets[0] - - if isinstance(node.value, ast.Call): - name: str = visitors_util.get_node_name(node.value.func) - - # Create Timer object - if name == "Timer": - decltype = cppast.ClassType("Kokkos::Timer") - declname = cppast.DeclRefExpr("timer") - return cppast.DeclStmt(cppast.VarDecl(decltype, declname, None)) - - # Call Timer.seconds() - if name == "seconds": - target_name: str = visitors_util.get_node_name(target) - if target_name not in self.timer_result_queue: - self.timer_result_queue.append(target_name) - - call = cppast.CallStmt(self.visit(node.value)) - target_ref = cppast.DeclRefExpr(target_name) - target_view_ref = cppast.DeclRefExpr(f"timer_result_{target_name}") - subscript = cppast.ArraySubscriptExpr( - target_view_ref, [cppast.IntegerLiteral(0)] - ) - assign_op = cppast.BinaryOperatorKind.Assign - - # Holds the result of the reduction temporarily - temp_ref = cppast.DeclRefExpr("pk_acc") - target_assign = cppast.AssignOperator([target_ref], temp_ref, assign_op) - view_assign = cppast.AssignOperator([subscript], target_ref, assign_op) - - return cppast.CompoundStmt([call, target_assign, view_assign]) - - if name in ("BinSort", "BinOp1D", "BinOp3D"): - args: List = node.value.args - # if not isinstance(args[0], ast.Attribute): - # self.error(node.value, "First argument has to be a view") - - view = cppast.DeclRefExpr(visitors_util.get_node_name(args[0])) - if view not in self.views: - self.error(args[0], "Undefined view") - - view_type: cppast.ClassType = self.views[view] - is_subview: bool = view_type is None - if is_subview: - parent_view_name: str = self.subviews[view.declname] - - # Need to remove "pk_d_" from the start of the - # view name to get the type of the parent - if parent_view_name.startswith("pk_d_"): - parent_view_name = parent_view_name.replace("pk_d_", "", 1) - parent_view = cppast.DeclRefExpr(parent_view_name) - view_type = self.views[parent_view] - - view_type_str: str = visitors_util.cpp_view_type(view_type) - - if name != "BinSort": - dimension: int = 1 if name == "BinOp1D" else 3 - cpp_type = cppast.DeclRefExpr( - BinOp.get_type(dimension, view_type_str) - ) - - # Do not translate the first argument (view) - constructor = cppast.CallExpr( - cpp_type, [self.visit(a) for a in args[1:]] - ) - - else: - bin_op_type: str = ( - f"decltype({visitors_util.get_node_name(args[1])})" - ) - - binsort_args: List[cppast.DeclRefExpr] = [ - self.visit(a) for a in args - ] - cpp_type = cppast.DeclRefExpr( - BinSort.get_type( - f"decltype({binsort_args[0].declname})", - bin_op_type, - Keywords.DefaultExecSpace.value, - ) - ) - constructor = cppast.CallExpr(cpp_type, binsort_args) - - cpp_target: cppast.DeclRefExpr = self.visit(target) - auto_type = cppast.ClassType("auto") - - return cppast.DeclStmt( - cppast.VarDecl(auto_type, cpp_target, constructor) - ) - - if name in ("get_bin_count", "get_bin_offsets", "get_permute_vector"): - if not isinstance(target, ast.Attribute) or target.value.id != "self": - self.error( - node, "Views defined in pk.main must be an instance variable" - ) - - cpp_target: str = visitors_util.get_node_name(target) - cpp_device_target = f"pk_d_{cpp_target}" - cpp_target_ref = cppast.DeclRefExpr(cpp_device_target) - sorter: cppast.DeclRefExpr = self.visit(node.value.func.value) - - initial_target_ref = cppast.DeclRefExpr( - f"_pk_{cpp_target_ref.declname}" - ) - - function = cppast.MemberCallExpr(sorter, cppast.DeclRefExpr(name), []) - - # Add to the dict of declarations made in pk.main - if name == "get_permute_vector": - # This occurs when a workload is executed multiple times - # Initially the view has not been defined in the workload, - # so it needs to be classified as a pkmain_view. - if cpp_target in self.views: - self.views[cpp_target_ref].add_template_param( - cppast.PrimitiveType(cppast.BuiltinType.INT) - ) - - return cppast.AssignOperator( - [cpp_target_ref], function, cppast.BinaryOperatorKind.Assign - ) - # return f"{cpp_target} = {sorter}.{name}();" - - self.pkmain_views[cpp_target_ref] = cppast.ClassType("View1D") - else: - self.pkmain_views[cpp_target_ref] = None - - auto_type = cppast.ClassType("auto") - decl = cppast.DeclStmt( - cppast.VarDecl(auto_type, initial_target_ref, function) - ) - - # resize the workload's vector to match the generated vector - resize_call = cppast.CallStmt( - cppast.CallExpr( - cppast.DeclRefExpr("Kokkos::resize"), - [ - cpp_target_ref, - cppast.MemberCallExpr( - initial_target_ref, - cppast.DeclRefExpr("extent"), - [cppast.IntegerLiteral(0)], - ), - ], - ) - ) - - copy_call = cppast.CallStmt( - cppast.CallExpr( - cppast.DeclRefExpr("Kokkos::deep_copy"), - [cpp_target_ref, initial_target_ref], - ) - ) - - # Assign to the functor after resizing - functor = cppast.DeclRefExpr("pk_f") - functor_access = cppast.MemberExpr(functor, cpp_target) - functor_assign = cppast.AssignOperator( - [functor_access], cpp_target_ref, cppast.BinaryOperatorKind.Assign - ) - - return cppast.CompoundStmt( - [decl, resize_call, copy_call, functor_assign] - ) - - # Assign result of parallel_reduce - if type(target) not in {ast.Name, ast.Subscript} and target.value.id == "self": - target_name: str = visitors_util.get_node_name(target) - if target_name not in self.reduction_result_queue: - self.reduction_result_queue.append(target_name) - - call = cppast.CallStmt(self.visit(node.value)) - target_ref = cppast.DeclRefExpr(target_name) - target_view_ref = cppast.DeclRefExpr(f"reduction_result_{target_name}") - subscript = cppast.ArraySubscriptExpr( - target_view_ref, [cppast.IntegerLiteral(0)] - ) - assign_op = cppast.BinaryOperatorKind.Assign - - # Holds the result of the reduction temporarily - temp_ref = cppast.DeclRefExpr("pk_acc") - target_assign = cppast.AssignOperator([target_ref], temp_ref, assign_op) - view_assign = cppast.AssignOperator([subscript], target_ref, assign_op) - - return cppast.CompoundStmt([call, target_assign, view_assign]) - - return super().visit_Assign(node) - - def visit_Attribute(self, node: ast.Attribute) -> cppast.DeclRefExpr: - name: str = visitors_util.get_node_name(node) - if name in self.work_units: - return cppast.DeclRefExpr(name) - - if node.value.id == "self": - if name in self.views: - return name - - return cppast.DeclRefExpr(name) - - return super().visit_Attribute(node) - - def visit_Lambda(self, node: ast.Lambda) -> cppast.Expr: - # NOTE: should handle args, kwonlyargs, varargs, kwargs, defaults - return self.visit(node.body) - - def visit_Subscript( - self, node: ast.Subscript - ) -> Union[cppast.ArraySubscriptExpr, cppast.CallExpr]: - call: Union[cppast.ArraySubscriptExpr, cppast.CallExpr] = ( - super().visit_Subscript(node) - ) - if isinstance(call, cppast.CallExpr): - view_name: str = call.function.declname - call._function._declname = f"pk_d_{view_name}" - - return call - - def visit_Call(self, node: ast.Call) -> Union[cppast.Expr, cppast.Stmt]: - name: str = visitors_util.get_node_name(node.func) - args: List[cppast.Expr] = [self.visit(a) for a in node.args] - - # Add pk_d_ before each view name to match mirror view names - s = cppast.Serializer() - for i in range(len(args)): - if args[i] in self.views: - if self.views[args[i]] is not None: - view: str = s.serialize(args[i]) - args[i] = cppast.DeclRefExpr(f"pk_d_{view}") - - # Nested parallelism - if name == "TeamPolicy": - function = cppast.DeclRefExpr(f"Kokkos::{name}") - if len(args) == 2: - args.append(cppast.IntegerLiteral(1)) - - policy = cppast.ConstructExpr(function, args) - - return policy - - if name in dir(TeamPolicy): - team_policy = self.visit(node.func.value) - return cppast.MemberCallExpr(team_policy, cppast.DeclRefExpr(name), args) - - elif name in ["RangePolicy", "MDRangePolicy"]: - rank = len(node.args[0].elts) - if rank == 0: - self.error(node.value, "RangePolicy dimension must be greater than 0") - if rank != len(node.args[1].elts): - self.error(node.value, "RangePolicy dimension mismatch") - - iter_outer = Iterate.Default - iter_inner = Iterate.Default - for keyword in node.keywords: - if keyword.arg == "rank": - explicit_rank = keyword.value.args[0].value - if explicit_rank != rank: - self.error(node.value, "RangePolicy dimension mismatch") - - iter_outer = getattr(Iterate, keyword.value.args[1].attr) - iter_inner = getattr(Iterate, keyword.value.args[2].attr) - - function = cppast.DeclRefExpr( - f"Kokkos::{name}>" - ) - policy = cppast.ConstructExpr(cppast.DeclRefExpr(f"Kokkos::{name}"), args) - if name == "MDRangePolicy": - policy.add_template_param( - cppast.DeclRefExpr( - f"Kokkos::Rank<{rank},{iter_outer.value},{iter_inner.value}>" - ) - ) - - return policy - - if name == "seconds": - fence = cppast.CallStmt( - cppast.CallExpr(cppast.DeclRefExpr("Kokkos::fence"), []) - ) - temp_decl = cppast.DeclRefExpr("pk_acc") - seconds = cppast.MemberCallExpr( - cppast.DeclRefExpr("timer"), cppast.DeclRefExpr("seconds"), [] - ) - result = cppast.AssignOperator( - [temp_decl], seconds, cppast.BinaryOperatorKind.Assign - ) - - return cppast.CompoundStmt([fence, result]) - - function = cppast.DeclRefExpr(f"Kokkos::{name}") - if name == "parallel_for": - arg_start: int = 0 # Accounts for the optional kernel name - kernel_name: Optional[cppast.StringLiteral] = None - if isinstance(args[0], cppast.StringLiteral): - kernel_name = args[0] - arg_start = 1 - - policy: cppast.ConstructExpr = args[arg_start] - policy = self.add_space_to_policy(policy) - - if isinstance(node.args[arg_start + 1], ast.Lambda): - decl: str = "KOKKOS_LAMBDA (" - tid = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[0].arg) - - # if target exists - if len(args) == arg_start + 3: - target = cppast.ArraySubscriptExpr(args[arg_start + 2], [tid]) - args[arg_start + 1] = cppast.AssignOperator( - [target], args[arg_start + 1], cppast.BinaryOperatorKind.Assign - ) - - serializer = cppast.Serializer() - decl += f"int {tid.declname}) {{" - decl += serializer.serialize(args[arg_start + 1]) + ";}\n" - - call_args: List[cppast.Expr] = [policy, decl] - if kernel_name is not None: - call_args.insert(0, kernel_name) - - return cppast.CallExpr(function, call_args) - - else: - work_unit: str = args[arg_start + 1].declname - policy = self.add_workunit_to_policy(policy, work_unit) - - call_args: List[cppast.Expr] = [policy, cppast.DeclRefExpr("pk_f")] - if kernel_name is not None: - call_args.insert(0, kernel_name) - - return cppast.CallExpr(function, call_args) - - if name in ("parallel_reduce", "parallel_scan"): - arg_start: int = 0 # Accounts for the optional kernel name - kernel_name: Optional[cppast.StringLiteral] = None - if isinstance(args[0], cppast.StringLiteral): - kernel_name = args[0] - arg_start = 1 - - initial_value: cppast.Expr - if len(args) == arg_start + 3: - initial_value = args[arg_start + 2] - else: - initial_value = cppast.IntegerLiteral(0) - - acc_decl = cppast.DeclRefExpr("pk_acc") - init_var = cppast.BinaryOperator( - acc_decl, initial_value, cppast.BinaryOperatorKind.Assign - ) - - policy: cppast.ConstructExpr = args[arg_start] - policy = self.add_space_to_policy(policy) - - if isinstance(node.args[arg_start + 1], ast.Lambda): - decl: str = "KOKKOS_LAMBDA (" - tid = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[0].arg) - acc = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[1].arg) - - # assign to accumulator - args[arg_start + 1] = cppast.AssignOperator( - [acc], args[arg_start + 1], cppast.BinaryOperatorKind.Assign - ) - - serializer = cppast.Serializer() - decl += f"int {tid.declname}, double& {acc.declname}) {{" - decl += serializer.serialize(args[arg_start + 1]) + ";}\n" - - call_args: List[cppast.Expr] = [policy, decl, acc_decl] - if kernel_name is not None: - call_args.insert(0, kernel_name) - - call = cppast.CallExpr(function, call_args) - - else: - work_unit: str = args[arg_start + 1].declname - policy = self.add_workunit_to_policy(policy, work_unit) - - call_args: List[cppast.Expr] = [ - policy, - cppast.DeclRefExpr("pk_f"), - acc_decl, - ] - if kernel_name is not None: - call_args.insert(0, kernel_name) - - return cppast.CallExpr(function, call_args) - - return cppast.BinaryOperator( - init_var, call, cppast.BinaryOperatorKind.Comma - ) - - if name in dir(BinSort): - sorter: str = visitors_util.get_node_name(node.func.value) - sorter_ref = cppast.DeclRefExpr(sorter) - function = cppast.DeclRefExpr(name) - - return cppast.MemberCallExpr(sorter_ref, function, args) - - if name == "shmem_size": - if len(args) != 1: - self.error(node.func, "shmem_size() accepts only a single argument") - - func: ast.Attribute = node.func - cpp_view_type: str = self.get_scratch_view_type(func.value) - - if cpp_view_type is None: - self.error(func, "Wrong call to shmem_size()") - - view_type = cppast.ClassType(cpp_view_type) - call_expr = cppast.MemberCallExpr(view_type, name, args) - call_expr.is_static = True - - return call_expr - - return super().visit_Call(node) - - def visit_Constant(self, node: ast.Constant) -> cppast.DeclRefExpr: - if isinstance(node.value, str) and node.value == "auto": - return cppast.DeclRefExpr("Kokkos::AUTO") - - return super().visit_Constant(node) - - def add_space_to_policy( - self, policy: Union[cppast.ConstructExpr, cppast.MemberCallExpr] - ) -> Union[cppast.ConstructExpr, cppast.MemberCallExpr]: - """ - Add the execution space to the execution policy - - :param policy: the execution policy (could also be an integer) - :returns: the execution policy - """ - - # Replace the number of threads with a RangePolicy - if type(policy) not in (cppast.ConstructExpr, cppast.MemberCallExpr): - begin = cppast.IntegerLiteral(0) - policy = cppast.ConstructExpr( - cppast.DeclRefExpr("Kokkos::RangePolicy"), [begin, policy] - ) - - space = cppast.DeclRefExpr(Keywords.DefaultExecSpace.value) - policy_constructor = self.get_policy_constructor(policy) - policy_constructor.add_template_param(space) - - return policy - - def add_workunit_to_policy( - self, policy: Union[cppast.ConstructExpr, cppast.MemberCallExpr], work_unit: str - ) -> Union[cppast.ConstructExpr, cppast.MemberCallExpr]: - """ - Add the workunit tag to the execution policy - - :param policy: the execution policy (could also be an integer) - :param work_unit: the tag of the workunit - :returns: the execution policy - """ - - policy_constructor = self.get_policy_constructor(policy) - policy_constructor.add_template_param( - cppast.DeclRefExpr(f"{self.functor}::{work_unit}_tag") - ) - - return policy - - def get_policy_constructor( - self, policy: Union[cppast.ConstructExpr, cppast.MemberCallExpr] - ) -> cppast.ConstructExpr: - """ - Get the call to the policy constructor from the policy object - - :param: the policy object - :returns: the call to the constructor - """ - - if isinstance(policy, cppast.MemberCallExpr): - return policy.base - else: - return policy - - def generate_subview(self, node: ast.Assign, view_name: str) -> cppast.DeclStmt: - """ - Generate a subview in pk.main. This involves adding the - "pk_d_" prefix to the parent view. - """ - - return super().generate_subview(node, f"pk_d_{view_name}") From 2c69e12c6593cebcaf15f6cc44fbdb0b39d0632a Mon Sep 17 00:00:00 2001 From: gkk345 <73120774+kennykos@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:24:11 -0600 Subject: [PATCH 04/10] Remove run workload debug function --- pykokkos/core/run_debug.py | 40 -------------------------------------- pykokkos/core/runtime.py | 22 +-------------------- 2 files changed, 1 insertion(+), 61 deletions(-) diff --git a/pykokkos/core/run_debug.py b/pykokkos/core/run_debug.py index 5cf87ed0..95a50616 100644 --- a/pykokkos/core/run_debug.py +++ b/pykokkos/core/run_debug.py @@ -17,46 +17,6 @@ import pykokkos.kokkos_manager as km -def run_workload_debug(workload: object) -> None: - """ - Run a workload in Python - - :param workload: the workload object - """ - - workload_source: str = inspect.getsource(type(workload)) - tree: ast.Module = ast.parse(workload_source) - classdef: ast.ClassDef = tree.body[0] - - def get_annotated_functions(decorator: Decorator) -> Dict[str, ast.FunctionDef]: - visitor = ast.NodeVisitor() - functions: Dict[str, ast.FunctionDef] = {} - - def visit_FunctionDef(node): - if node.decorator_list: - try: - node_decorator: str = node.decorator_list[0].id - except AttributeError: - node_decorator: str = node.decorator_list[0].attr - - if decorator.value == node_decorator: - functions[node.name] = node - - visitor.visit_FunctionDef = visit_FunctionDef - for method in classdef.body: - visitor.visit(method) - - return functions - - for name in get_annotated_functions(Decorator.KokkosMain): - kokkos_main = getattr(workload, name) - kokkos_main() - - for name in get_annotated_functions(Decorator.KokkosCallback): - kokkos_callback = getattr(workload, name) - kokkos_callback() - - def call_workunit( operation: str, workunit: Callable[..., None], diff --git a/pykokkos/core/runtime.py b/pykokkos/core/runtime.py index c5f76906..ed7bc262 100644 --- a/pykokkos/core/runtime.py +++ b/pykokkos/core/runtime.py @@ -39,7 +39,7 @@ from .compiler import Compiler from .module_setup import EntityMetadata, get_metadata, ModuleSetup -from .run_debug import run_workload_debug, run_workunit_debug +from .run_debug import run_workunit_debug def _calculate_aligned_scratch_size( @@ -124,26 +124,6 @@ def __init__(self): self.fusion_strategy: Optional[str] = os.getenv("PK_FUSION") - def run_workload(self, space: ExecutionSpace, workload: object) -> None: - """ - Run the workload - - :param space: the execution space of the workload - :param workload: the workload object - """ - - if self.is_debug(space): - run_workload_debug(workload) - return - - module_setup: ModuleSetup = self.get_module_setup(workload, space) - members: PyKokkosMembers = self.compiler.compile_object( - module_setup, space, km.is_uvm_enabled(), None, None, None, set() - ) - - self.execute(workload, module_setup, members, space) - self.run_callbacks(workload, members) - def precompile_workunit( self, workunit: Callable[..., None], From 50cef318b2d4a6817989f42af221e36b1686a0bf Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:41:45 -0600 Subject: [PATCH 05/10] Remove workload functions form runtime --- pykokkos/core/cpp_setup.py | 2 +- pykokkos/core/module_setup.py | 4 +- pykokkos/core/runtime.py | 173 ++-------------------------------- 3 files changed, 12 insertions(+), 167 deletions(-) diff --git a/pykokkos/core/cpp_setup.py b/pykokkos/core/cpp_setup.py index 350f607b..a85d8291 100644 --- a/pykokkos/core/cpp_setup.py +++ b/pykokkos/core/cpp_setup.py @@ -187,7 +187,7 @@ def generate_cmake( Copy CMakeLists.txt template and prepare CMake configuration variables :param output_dir: the base directory - :param space: the execution space of the workload + :param space: the execution space of the workunit :param enable_uvm: whether to enable CudaUVMSpace :param compiler: what compiler to use :returns: tuple of (cmake_args, module_name) diff --git a/pykokkos/core/module_setup.py b/pykokkos/core/module_setup.py index b05aabb5..23c28925 100644 --- a/pykokkos/core/module_setup.py +++ b/pykokkos/core/module_setup.py @@ -45,7 +45,7 @@ def get_metadata(entity: Union[Callable[..., None], object]) -> EntityMetadata: """ Gets the name and filepath of an entity - :param entity: the workload or workunit function object + :param entity: the workunit function object :returns: an EntityMetadata object """ @@ -88,7 +88,7 @@ def __init__( """ ModuleSetup constructor - :param entity: the functor/workunit/workload or list of workunits for fusion + :param entity: the functor/workunit or list of workunits for fusion :param types_signature: hash/string to identify workunit signature against types :param restricted_views: a set of view names that do not alias any other views """ diff --git a/pykokkos/core/runtime.py b/pykokkos/core/runtime.py index ed7bc262..ccc5ae3c 100644 --- a/pykokkos/core/runtime.py +++ b/pykokkos/core/runtime.py @@ -112,14 +112,14 @@ def apply_scratch_spec(workunit: Callable, policy: TeamPolicy, **kwargs) -> None class Runtime: """ - Executes (and optionally compiles) PyKokkos workloads + Executes (and optionally compiles) PyKokkos workunits/functors """ def __init__(self): self.compiler: Compiler = Compiler() self.tracer: Tracer = Tracer() - # cache module_setup objects using a workload/workunit and space tuple + # cache module_setup objects using a workunit and space tuple self.module_setups: Dict[Tuple, ModuleSetup] = {} self.fusion_strategy: Optional[str] = os.getenv("PK_FUSION") @@ -390,7 +390,7 @@ def execute( """ Imports the module containing the bindings and executes the necessary function - :param entity: the workload or workunit object + :param entity: the workunit object :param module_path: the path to the compiled module :param members: a collection of PyKokkos related members :param space: the execution space @@ -398,7 +398,7 @@ def execute( :param name: the name of the kernel :param operation: the name of the operation "for", "reduce", or "scan" :param kwargs: the keyword arguments passed to the workunit - :returns: the result of the operation (None for "for" and workloads) + :returns: the result of the operation (None for "for") """ module_path: str @@ -460,7 +460,7 @@ def get_arguments( """ Get the arguments for a wrapper function, including fields, views, etc - :param entity: the workload or workunit object + :param entity: the workunit object :param members: a collection of PyKokkos related members :param space: the execution space :param policy: the execution policy of the operation @@ -513,7 +513,7 @@ def call_wrapper( """ Call the wrapper in the imported module - :param entity: the workload or workunit object + :param entity: the workunit object :param members: a collection of PyKokkos related members :param args: the arguments to be passed to the wrapper :param module: the imported module @@ -572,112 +572,6 @@ def get_precision(self, members: PyKokkosMembers, args: Dict[str, Any]) -> str: return precision - def get_result_arguments(self, members: PyKokkosMembers) -> Dict[str, Any]: - """ - Get the views that are passed as arguments to hold the results for workloads - - :param members: a collection of PyKokkos related members - :returns: a dictionary of argument name to value - """ - - args: Dict[str, Any] = {} - - for result in members.reduction_result_queue: - name: str = f"reduction_result_{result}" - result_view = View([1], DataType.double, MemorySpace.HostSpace) - args[name] = result_view.array - - for result in members.timer_result_queue: - name: str = f"timer_result_{result}" - result_view = View([1], DataType.double, MemorySpace.HostSpace) - args[name] = result_view.array - - return args - - def get_policy_arguments(self, policy: ExecutionPolicy) -> Dict[str, Any]: - """ - Get the arguments that are used for to hold the results for workloads - - :param policy: the execution policy of the operation - :returns: a dictionary of argument name to value - """ - - args: Dict[str, Any] = {} - - args["pk_exec_space_instance"] = policy.space.instance - - if isinstance(policy, RangePolicy): - args["pk_threads_begin"] = policy.begin - args["pk_threads_end"] = policy.end - elif isinstance(policy, TeamPolicy): - args["pk_league_size"] = policy.league_size - args["pk_team_size"] = policy.team_size - args["pk_vector_length"] = policy.vector_length - - # Add scratch size information if it was set, otherwise use -1 to indicate not set - if policy.scratch_size_level is not None: - args["pk_scratch_size_level"] = policy.scratch_size_level - # Extract the actual size value from PerTeam/PerThread wrapper if present - from pykokkos.interface.hierarchical import PerTeam, PerThread - - if isinstance(policy.scratch_size_value, PerTeam): - # PerTeam wrapper - extract the value and set flag - args["pk_scratch_size_is_per_team"] = True - args["pk_scratch_size_value"] = policy.scratch_size_value.value - elif isinstance(policy.scratch_size_value, PerThread): - # PerThread wrapper - extract the value and set flag - args["pk_scratch_size_is_per_team"] = False - args["pk_scratch_size_value"] = policy.scratch_size_value.value - elif isinstance(policy.scratch_size_value, (int, np.integer)): - # Direct size value (workunit case without wrapper) - args["pk_scratch_size_is_per_team"] = ( - True # Default to PerTeam for simple int - ) - args["pk_scratch_size_value"] = int(policy.scratch_size_value) - else: - # Unknown type, treat as PerTeam with value from variable - args["pk_scratch_size_is_per_team"] = True - args["pk_scratch_size_value"] = policy.scratch_size_value - else: - # No scratch size set, use -1 as indicator - args["pk_scratch_size_level"] = -1 - args["pk_scratch_size_value"] = 0 - args["pk_scratch_size_is_per_team"] = True - - return args - - def get_fields(self, members: Dict[str, type]) -> Dict[str, Any]: - """ - Gets all the primitive type fields from the workload object - - :param workload: the dictionary containing all members - :returns: a dict mapping from field name to value - """ - - fields: Dict[str, Any] = {} - for key, value in members.items(): - if type(value) in ( - int, - float, - bool, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - np.float32, - np.double, - np.float64, - ): - fields[key] = value - if isinstance(value, Future): - fields[key] = value.value - - return fields - def _convert_functor_arrays(self, members: Dict[str, Any]) -> None: """ Convert numpy/cupy arrays in functor members to Views (similar to convert_arrays for kwargs) @@ -703,55 +597,6 @@ def _convert_functor_arrays(self, members: Dict[str, Any]) -> None: elif cp_available and isinstance(v, cp.ndarray): members[k] = array(v) - def get_views(self, members: Dict[str, type]) -> Dict[str, Any]: - """ - Gets all the views from the workload object - - :param workload: the dictionary containing all members - :returns: a dict mapping from view name to object - """ - - views: Dict[str, Any] = {} - for key, value in members.items(): - if isinstance(value, ViewType): - views[key] = value.array - - return views - - def retrieve_results( - self, workload: object, members: PyKokkosMembers, args: Dict[str, Any] - ) -> None: - """ - Get the results for workloads - - :param workload: the workload object - :param members: a collection of PyKokkos related members - :param args: the arguments passed to the wrapper, including views that hold results - """ - - for result in members.reduction_result_queue: - name: str = f"reduction_result_{result}" - view: View = args[name] - setattr(workload, result, view[0]) - - for result in members.timer_result_queue: - name: str = f"timer_result_{result}" - view: View = args[name] - setattr(workload, result, view[0]) - - def run_callbacks(self, workload: object, members: PyKokkosMembers) -> None: - """ - Run all methods in the workload that are annotated with @pk.callback - - :param workload: the workload object - :param members: a collection of PyKokkos related members in workload - """ - - callbacks = members.pk_callbacks - for name in callbacks: - callback = getattr(workload, name.declname) - callback() - def get_module_setup( self, entity: Union[object, Callable[..., None]], @@ -762,7 +607,7 @@ def get_module_setup( """ Get the compiled module setup information unique to an entity + space - :param entity: the workload or workunit object + :param entity: the workunit object :param space: the execution space :param types_signature: Hash/identifer string for workunit module against data types :param restrict_signature: Hash/identifer string for views that do not alias any other views @@ -795,10 +640,10 @@ def get_module_setup_id( """ Get a unique module setup id for an entity + space combination. For workunits, the idenitifier is just the - workunit and execution space. For workloads and functors, we + workunit and execution space. For functors, we need the type of the class as well as the file containing it. - :param entity: the workload or workunit object + :param entity: the workunit object :param space: the execution space :param types_signature: optional identifier/hash string for types of parameters against workunit module From 497babceb4cb80a97e16c46a00289ab4befcc035 Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Wed, 4 Mar 2026 17:47:03 -0600 Subject: [PATCH 06/10] Remove main and callback decorators --- pykokkos/core/translators/members.py | 9 --------- pykokkos/core/translators/static.py | 2 -- pykokkos/interface/__init__.py | 2 -- pykokkos/interface/decorators.py | 18 ------------------ 4 files changed, 31 deletions(-) diff --git a/pykokkos/core/translators/members.py b/pykokkos/core/translators/members.py index a5f5cf2b..8e81659b 100644 --- a/pykokkos/core/translators/members.py +++ b/pykokkos/core/translators/members.py @@ -30,8 +30,6 @@ def __init__(self): self.pk_workunits: Dict[cppast.DeclRefExpr, ast.FunctionDef] = {} self.pk_functions: Dict[cppast.DeclRefExpr, ast.FunctionDef] = {} - self.pk_mains: Dict[cppast.DeclRefExpr, ast.FunctionDef] = {} - self.pk_callbacks: Dict[cppast.DeclRefExpr, ast.FunctionDef] = {} self.classtype_methods: Dict[cppast.DeclRefExpr, List[cppast.DeclRefExpr]] = {} @@ -102,9 +100,6 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N self.pk_functions = self.get_decorated_functions( AST, Decorator.KokkosFunction ) - self.pk_callbacks = self.get_decorated_functions( - AST, Decorator.KokkosCallback - ) else: self.pk_workunits[cppast.DeclRefExpr(AST.name)] = AST self.pk_functions = self.get_decorated_functions( @@ -113,10 +108,6 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N self.classtype_methods = self.get_classtype_methods(classtypes) - if len(self.pk_mains) > 1: - print("ERROR: Only one pk.main function can be translated") - sys.exit(1) - def get_fields( self, classdef: ast.ClassDef, source: Tuple[List[str], int], pk_import: str ) -> Dict[cppast.DeclRefExpr, cppast.PrimitiveType]: diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index 7d524008..7569c769 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -172,8 +172,6 @@ def check_symbols(self, classtypes: List[PyKokkosEntity], path: str) -> None: symbols_pass = SymbolsPass(self.pk_members, self.pk_import, path) error_messages: List[str] = [] - for AST in self.pk_members.pk_mains.values(): - error_messages.extend(symbols_pass.check_symbols(AST)) for AST in self.pk_members.pk_workunits.values(): error_messages.extend(symbols_pass.check_symbols(AST)) for AST in self.pk_members.pk_functions.values(): diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index 8098b86f..6e99e309 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -38,12 +38,10 @@ complex128, ) from .decorators import ( - callback, classtype, Decorator, function, functor, - main, workunit, ) from .execution_policy import ( diff --git a/pykokkos/interface/decorators.py b/pykokkos/interface/decorators.py index 2a753f92..330c2c74 100644 --- a/pykokkos/interface/decorators.py +++ b/pykokkos/interface/decorators.py @@ -7,8 +7,6 @@ class Decorator(Enum): WorkUnit = "workunit" KokkosClasstype = "classtype" KokkosFunction = "function" - KokkosMain = "main" - KokkosCallback = "callback" Space = "space" @staticmethod @@ -27,14 +25,6 @@ def is_kokkos_classtype(decorator: str) -> bool: def is_kokkos_function(decorator: str) -> bool: return decorator == Decorator.KokkosFunction.value - @staticmethod - def is_kokkos_main(decorator: str) -> bool: - return decorator == Decorator.KokkosMain.value - - @staticmethod - def is_kokkos_callback(decorator: str) -> bool: - return decorator == Decorator.KokkosCallback.value - @staticmethod def is_space(decorator: str) -> bool: return decorator == Decorator.Space.value @@ -74,11 +64,3 @@ def classtype(func): def function(func): return func - - -def main(func): - return func - - -def callback(func): - return func From 4d737d28cf8c4dc1c045603a22920b45fdc7e0c1 Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:17:24 -0600 Subject: [PATCH 07/10] Remove workload funcs from bindings and runtime --- pykokkos/core/runtime.py | 105 +++++++++++++++++- pykokkos/core/translators/bindings.py | 111 ++------------------ pykokkos/core/visitors/debug_transformer.py | 32 ------ 3 files changed, 106 insertions(+), 142 deletions(-) delete mode 100644 pykokkos/core/visitors/debug_transformer.py diff --git a/pykokkos/core/runtime.py b/pykokkos/core/runtime.py index ccc5ae3c..1232dacc 100644 --- a/pykokkos/core/runtime.py +++ b/pykokkos/core/runtime.py @@ -112,7 +112,7 @@ def apply_scratch_spec(workunit: Callable, policy: TeamPolicy, **kwargs) -> None class Runtime: """ - Executes (and optionally compiles) PyKokkos workunits/functors + Executes (and optionally compiles) PyKokkos workunits """ def __init__(self): @@ -420,10 +420,6 @@ def execute( result = self.call_wrapper(entity, members, args, module) - is_workunit_or_functor: bool = isinstance(entity, (Callable, list)) - if not is_workunit_or_functor: - self.retrieve_results(entity, members, args) - return result def import_module(self, module_name: str, module_path: str): @@ -572,6 +568,90 @@ def get_precision(self, members: PyKokkosMembers, args: Dict[str, Any]) -> str: return precision + def get_policy_arguments(self, policy: ExecutionPolicy) -> Dict[str, Any]: + """ + Get the arguments that are used for to hold the results for workunits + + :param policy: the execution policy of the operation + :returns: a dictionary of argument name to value + """ + + args: Dict[str, Any] = {} + + args["pk_exec_space_instance"] = policy.space.instance + + if isinstance(policy, RangePolicy): + args["pk_threads_begin"] = policy.begin + args["pk_threads_end"] = policy.end + elif isinstance(policy, TeamPolicy): + args["pk_league_size"] = policy.league_size + args["pk_team_size"] = policy.team_size + args["pk_vector_length"] = policy.vector_length + + # Add scratch size information if it was set, otherwise use -1 to indicate not set + if policy.scratch_size_level is not None: + args["pk_scratch_size_level"] = policy.scratch_size_level + # Extract the actual size value from PerTeam/PerThread wrapper if present + from pykokkos.interface.hierarchical import PerTeam, PerThread + + if isinstance(policy.scratch_size_value, PerTeam): + # PerTeam wrapper - extract the value and set flag + args["pk_scratch_size_is_per_team"] = True + args["pk_scratch_size_value"] = policy.scratch_size_value.value + elif isinstance(policy.scratch_size_value, PerThread): + # PerThread wrapper - extract the value and set flag + args["pk_scratch_size_is_per_team"] = False + args["pk_scratch_size_value"] = policy.scratch_size_value.value + elif isinstance(policy.scratch_size_value, (int, np.integer)): + # Direct size value (workunit case without wrapper) + args["pk_scratch_size_is_per_team"] = ( + True # Default to PerTeam for simple int + ) + args["pk_scratch_size_value"] = int(policy.scratch_size_value) + else: + # Unknown type, treat as PerTeam with value from variable + args["pk_scratch_size_is_per_team"] = True + args["pk_scratch_size_value"] = policy.scratch_size_value + else: + # No scratch size set, use -1 as indicator + args["pk_scratch_size_level"] = -1 + args["pk_scratch_size_value"] = 0 + args["pk_scratch_size_is_per_team"] = True + + return args + + def get_fields(self, members: Dict[str, type]) -> Dict[str, Any]: + """ + Gets all the primitive type fields from the workunit object + + :param members: the dictionary containing all members + :returns: a dict mapping from field name to value + """ + + fields: Dict[str, Any] = {} + for key, value in members.items(): + if type(value) in ( + int, + float, + bool, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float32, + np.double, + np.float64, + ): + fields[key] = value + if isinstance(value, Future): + fields[key] = value.value + + return fields + def _convert_functor_arrays(self, members: Dict[str, Any]) -> None: """ Convert numpy/cupy arrays in functor members to Views (similar to convert_arrays for kwargs) @@ -597,6 +677,21 @@ def _convert_functor_arrays(self, members: Dict[str, Any]) -> None: elif cp_available and isinstance(v, cp.ndarray): members[k] = array(v) + def get_views(self, members: Dict[str, type]) -> Dict[str, Any]: + """ + Gets all the views from the workunit object + + :param workunit: the dictionary containing all members + :returns: a dict mapping from view name to object + """ + + views: Dict[str, Any] = {} + for key, value in members.items(): + if isinstance(value, ViewType): + views[key] = value.array + + return views + def get_module_setup( self, entity: Union[object, Callable[..., None]], diff --git a/pykokkos/core/translators/bindings.py b/pykokkos/core/translators/bindings.py index 591134f1..07d16985 100644 --- a/pykokkos/core/translators/bindings.py +++ b/pykokkos/core/translators/bindings.py @@ -4,22 +4,22 @@ from pykokkos.core import cppast from pykokkos.core.keywords import Keywords -from pykokkos.core.visitors import cpp_view_type, visitors_util +from pykokkos.core.visitors import cpp_view_type, KokkosMainVisitor, visitors_util from pykokkos.interface.data_types import DataType from .members import PyKokkosMembers -def is_hierarchical(workunit: Optional[cppast.MethodDecl]) -> bool: +def is_hierarchical(workunit: cppast.MethodDecl) -> bool: """ Checks if a workunit uses hierarchical parallelism by checking if it has a TeamMember instead of a thread ID - :param workunit: the workunit definition or None for a workload + :param workunit: the workunit definition :returns: true if hierarchical false otherwise """ if workunit is None: - return False + raise ValueError("workunit definition must be passed") # Iterate over each parameter (skipping the tag) for p in workunit.params[1:]: @@ -256,7 +256,7 @@ def get_return_type(operation: str, workunit: cppast.MethodDecl) -> str: """ Get the return type of a binding - :param operation: the type of the operation (for, reduce, scan, or workload) + :param operation: the type of the operation (for, reduce, or scan) :param workunit: the workunit for which the binding is being generated :returns: the return type as a string """ @@ -384,7 +384,7 @@ def generate_wrapper( Generate the wrapper that calls the kernel and its binding :param members: an object containing the fields and views - :param operation: the type of the operation (for, reduce, scan, or workload) + :param operation: the type of the operation (for, reduce, or scan) :param workunit: the workunit for which the binding is being generated :param wrapper: the name of the wrapper :param kernel: the name of the kernel @@ -562,102 +562,3 @@ def bind_workunits( bindings.append(bind_wrappers(module, wrapper_names)) return bindings - - -def bind_main_single( - functor: str, - members: PyKokkosMembers, - source: Tuple[List[str], int], - pk_import: str, - precision: Optional[DataType], -) -> Tuple[str, str]: - """ - Generates the kernel and its python binding - - :param functor: the functor class name - :param members: an object containing the fields and views - :param source: the python source code of the workload - :param pk_import: the pykokkos import alias - :param precision: the precision for which to generate a binding - :returns: a tuple of strings containing the wrapper name, and the kernel and wrapper - """ - - wrapper_name: str = "wrapper" - kernel_name: str = "run" - - real: Optional[str] = None - if precision is not None: - real = visitors_util.view_dtypes[precision.name].value - wrapper_name += f"_{real}" - kernel_name += f"_{real}" - functor += f"<{Keywords.DefaultExecSpace.value},{real}>" - else: - functor += f"<{Keywords.DefaultExecSpace.value}>" - - main: List[str] = translate_mains(source, functor, members, pk_import) - params: Dict[str, str] = get_kernel_params(members, False, real) - - # fall back to the old hard-coded default - # for now--this includes cases where an - # accumulator is not even defined - acc_type = "double" - - for element in source[0]: - # TODO: support more types - if "pk.Acc" in element: - if "pk.int64" in element: - acc_type = "int64_t" - elif "pk.double" in element: - acc_type = "double" - - signature: str = generate_kernel_signature("void", kernel_name, params) - instantiation: str = generate_functor_instance(functor, members) - acc: str = f"{acc_type} {Keywords.Accumulator.value} = 0;" - body: str = "".join(main) - copy_back: str = generate_copy_back(members) - # fence: str = generate_fence_call() - - kernel: str = f"{signature} {{ {instantiation} {acc} {body} {copy_back} }}" - wrapper: str = generate_wrapper( - members, "workload", None, wrapper_name, kernel_name, real - ) - binding: str = f"{kernel} {wrapper}" - - return wrapper_name, binding - - -def bind_main( - functor: str, - members: PyKokkosMembers, - source: Tuple[List[str], int], - pk_import: str, - module: str, -) -> List[str]: - """ - Generates the kernel and its python binding - - :param functor: the functor class name - :param members: an object containing the fields and views - :param source: the python source code of the workload - :param pk_import: the pykokkos import alias - :param module: the name of the generated module - :returns: a list of strings containing the kernel, wrapper, and binding - """ - - bindings: List[str] = [] - wrapper_names: List[str] = [] - if members.has_real: - for d in DataType: - if d is DataType.real: - continue - w, b = bind_main_single(functor, members, source, pk_import, d) - bindings.append(b) - wrapper_names.append(w) - else: - w, b = bind_main_single(functor, members, source, pk_import, None) - bindings.append(b) - wrapper_names.append(w) - - bindings.append(bind_wrappers(module, wrapper_names)) - - return bindings diff --git a/pykokkos/core/visitors/debug_transformer.py b/pykokkos/core/visitors/debug_transformer.py deleted file mode 100644 index e011ed49..00000000 --- a/pykokkos/core/visitors/debug_transformer.py +++ /dev/null @@ -1,32 +0,0 @@ -import ast - -from pykokkos.interface import Decorator - - -# in the Debug ExecSpace we need to wrap all instances of the accumulator -# variable in reduction workloads to work around python's lack of reference -# types for primative numbers -class DebugTransformer(ast.NodeTransformer): - def __init__(self): - self.inside_reduction = False - - def visit_FunctionDef(self, node): - if ( - node.decorator_list - and Decorator.is_work_unit(node.decorator_list[0].id) - and len(node.args.args) == 3 - ): - self.inside_reduction = True - self.acc = node.args.args[-1].arg - node.body = list(map(self.visit, node.body)) - self.inside_reduction = False - return node - - def visit_Name(self, node): - if self.inside_reduction and node.id == self.acc: - return ast.Subscript( - ast.Name(self.acc, ast.Load()), - ast.Index(ast.Constant(0, None)), - node.ctx, - ) - return node From 10c217c1e337f5512ca1a9bfa2f929add9c59182 Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:19:23 -0600 Subject: [PATCH 08/10] Remove old imports --- pykokkos/core/translators/bindings.py | 2 +- pykokkos/core/translators/static.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pykokkos/core/translators/bindings.py b/pykokkos/core/translators/bindings.py index 07d16985..b3bc30fd 100644 --- a/pykokkos/core/translators/bindings.py +++ b/pykokkos/core/translators/bindings.py @@ -4,7 +4,7 @@ from pykokkos.core import cppast from pykokkos.core.keywords import Keywords -from pykokkos.core.visitors import cpp_view_type, KokkosMainVisitor, visitors_util +from pykokkos.core.visitors import cpp_view_type, visitors_util from pykokkos.interface.data_types import DataType from .members import PyKokkosMembers diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index 7569c769..b6197918 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -15,7 +15,7 @@ WorkunitVisitor, ) -from .bindings import bind_main, bind_workunits +from .bindings import bind_workunits from .functor import generate_functor from .functor_cast import generate_cast from .members import PyKokkosMembers From 5d66de669574635c1a985fdd5fa80b85f7244ade Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:25:49 -0600 Subject: [PATCH 09/10] Clarify docs, mark functor funcs to be depreciated --- pykokkos/core/translators/members.py | 17 ++++++++++++----- pykokkos/core/visitors/constructor_visitor.py | 2 ++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pykokkos/core/translators/members.py b/pykokkos/core/translators/members.py index 8e81659b..b12ac7b5 100644 --- a/pykokkos/core/translators/members.py +++ b/pykokkos/core/translators/members.py @@ -115,9 +115,12 @@ def get_fields( Get all fields (or instance variables) in classdef by parsing the constructor :param classdef: the classdef being parsed - :param source: the python source code of the workload + :param source: the python source code of the functor :param pk_import: the identifier used to access the PyKokkos package :returns: a dictionary mapping from field name to type + + + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "fields", pk_import, True) @@ -133,9 +136,11 @@ def get_views( Get all views defined in classdef by parsing the constructor :param classdef: the classdef to be parsed - :param source: the python source code of the workload + :param source: the python source code of the functor :param pk_import: the identifier used to access the PyKokkos package :returns: a dictionary mapping from view name to type (only dimensionality and type) + + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "views", pk_import, True) @@ -179,7 +184,7 @@ def get_params( Gets all fields and views passed as parameters to the workunit :param functiondef: the functiondef to be parsed - :param source: the python source code of the workload + :param source: the python source code of the workunit :param param_begin: where workunit argument begins (excluding tid/acc) """ @@ -198,7 +203,7 @@ def get_view_template_params( Get the template parameters for all views defined in the constructor :param node: the classdef or functiondef to be parsed - :param source: the python source code of the workload + :param source: the python source code of the workunit :returns: a dictionary mapping from view name to a list of template parameters """ @@ -268,9 +273,11 @@ def get_random_pool( Gets the type of the random pool if it exists :param classdef: the classdef to be parsed - :param source: the python source code of the workload + :param source: the python source code of the functor :param pk_import: the identifier used to access the PyKokkos package :returns: the type of the random pool if it exists + + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "randpool", pk_import, True) diff --git a/pykokkos/core/visitors/constructor_visitor.py b/pykokkos/core/visitors/constructor_visitor.py index 69c94f01..baa4318e 100644 --- a/pykokkos/core/visitors/constructor_visitor.py +++ b/pykokkos/core/visitors/constructor_visitor.py @@ -22,6 +22,8 @@ def __init__( :param member_type: specifies which members to retrieve, "fields", "views", or "typeinfo" :param pk_import: the identifier used to access the PyKokkos package :param debug: if true, prints the python AST when an error is encountered + + NOTE: **used by workloads & functors; depreciate with functions** """ if member_type not in ("fields", "views", "typeinfo", "randpool"): From cd5804d2b5f766c7eef6b6a03f5d7a7db34b5213 Mon Sep 17 00:00:00 2001 From: Gabriel Kosmacher <73120774+kennykos@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:29:15 -0600 Subject: [PATCH 10/10] Remove workload functionality --- pykokkos/core/module_setup.py | 21 +++++++------------ pykokkos/core/parsers/parser.py | 5 +---- pykokkos/core/translators/members.py | 6 +++--- pykokkos/core/translators/static.py | 12 +++++------ pykokkos/core/visitors/__init__.py | 1 - pykokkos/core/visitors/constructor_visitor.py | 2 +- pykokkos/core/visitors/parameter_visitor.py | 2 +- 7 files changed, 19 insertions(+), 30 deletions(-) diff --git a/pykokkos/core/module_setup.py b/pykokkos/core/module_setup.py index 23c28925..0d5cdae0 100644 --- a/pykokkos/core/module_setup.py +++ b/pykokkos/core/module_setup.py @@ -52,22 +52,15 @@ def get_metadata(entity: Union[Callable[..., None], object]) -> EntityMetadata: name: str filepath: str - if isinstance(entity, Callable): - # Workunit/functor - is_functor: bool = hasattr(entity, "__self__") - if is_functor: - entity_type: type = get_functor(entity) - name = entity_type.__name__ - filepath = inspect.getfile(entity_type) - else: - name = entity.__name__ - filepath = inspect.getfile(entity) - - else: - # Workload - entity_type: type = type(entity) + # Workunit/functor + is_functor: bool = hasattr(entity, "__self__") + if is_functor: + entity_type: type = get_functor(entity) name = entity_type.__name__ filepath = inspect.getfile(entity_type) + else: + name = entity.__name__ + filepath = inspect.getfile(entity) return EntityMetadata(entity, name, filepath) diff --git a/pykokkos/core/parsers/parser.py b/pykokkos/core/parsers/parser.py index b910dc4f..adfa2a9b 100644 --- a/pykokkos/core/parsers/parser.py +++ b/pykokkos/core/parsers/parser.py @@ -36,7 +36,7 @@ class PyKokkosEntity: class Parser: """ - Parse a PyKokkos workload and its dependencies + Parse a PyKokkos workunit and its dependencies """ def __init__(self, path: Optional[str], pk_import: Optional[str] = None): @@ -55,7 +55,6 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None): self.tree = ast.parse("".join(self.lines)) self.path: Optional[str] = path self.pk_import: str = self.get_import() - self.workloads: Dict[str, PyKokkosEntity] = {} self.classtypes: Dict[str, PyKokkosEntity] = {} self.functors: Dict[str, PyKokkosEntity] = {} self.workunits: Dict[str, PyKokkosEntity] = {} @@ -109,8 +108,6 @@ def get_entity(self, name: str) -> PyKokkosEntity: :returns: the PyKokkosEntity representation of the entity """ - if name in self.workloads: - return self.workloads[name] if name in self.functors: return self.functors[name] if name in self.workunits: diff --git a/pykokkos/core/translators/members.py b/pykokkos/core/translators/members.py index b12ac7b5..45e48882 100644 --- a/pykokkos/core/translators/members.py +++ b/pykokkos/core/translators/members.py @@ -120,7 +120,7 @@ def get_fields( :returns: a dictionary mapping from field name to type - NOTE: **used by workloads & functors; depreciate with functions** + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "fields", pk_import, True) @@ -140,7 +140,7 @@ def get_views( :param pk_import: the identifier used to access the PyKokkos package :returns: a dictionary mapping from view name to type (only dimensionality and type) - NOTE: **used by workloads & functors; depreciate with functions** + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "views", pk_import, True) @@ -277,7 +277,7 @@ def get_random_pool( :param pk_import: the identifier used to access the PyKokkos package :returns: the type of the random pool if it exists - NOTE: **used by workloads & functors; depreciate with functions** + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "randpool", pk_import, True) diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index b6197918..b7615d59 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -34,7 +34,7 @@ def generate_include_guard_end() -> str: class StaticTranslator: """ - Translates a PyKokkos workload to C++ using static analysis only + Translates a PyKokkos workunit to C++ using static analysis only """ def __init__( @@ -189,9 +189,9 @@ def translate_classtypes( self, classtypes: List[PyKokkosEntity], restrict_views: Set[str] ) -> List[cppast.RecordDecl]: """ - Translate all classtypes, i.e. classes that the workload uses internally + Translate all classtypes, i.e. classes that the workunit uses internally - :param classtypes: the list of classtypes needed by the workload + :param classtypes: the list of classtypes needed by the workunit :param restrict_views: the views with the restrict keyword :returns: a list of strings of translated source code """ @@ -363,7 +363,7 @@ def translate_functions( """ Translate all PyKokkos functions - :param source: the python source code of the workload + :param source: the python source code of the workunit :param restrict_views: the views with the restrict keyword :returns: a list of method declarations """ @@ -400,9 +400,9 @@ def translate_workunits( """ Translate the workunits - :param source: the python source code of the workload + :param source: the python source code of the workunit :param restrict_views: the views with the restrict keyword - :returns: a tuple of a dictionary mapping from workload name + :returns: a tuple of a dictionary mapping from workunit name to a tuple of operation name and source, and a boolean indicating whether the workunit has a call to pk.rand() """ diff --git a/pykokkos/core/visitors/__init__.py b/pykokkos/core/visitors/__init__.py index d2c5e252..a5ab93ee 100644 --- a/pykokkos/core/visitors/__init__.py +++ b/pykokkos/core/visitors/__init__.py @@ -1,6 +1,5 @@ from .classtype_visitor import ClasstypeVisitor from .constructor_visitor import ConstructorVisitor -from .debug_transformer import DebugTransformer from .kokkosfunction_visitor import KokkosFunctionVisitor from .parameter_visitor import ParameterVisitor from .pykokkos_visitor import PyKokkosVisitor diff --git a/pykokkos/core/visitors/constructor_visitor.py b/pykokkos/core/visitors/constructor_visitor.py index baa4318e..f2a31b46 100644 --- a/pykokkos/core/visitors/constructor_visitor.py +++ b/pykokkos/core/visitors/constructor_visitor.py @@ -23,7 +23,7 @@ def __init__( :param pk_import: the identifier used to access the PyKokkos package :param debug: if true, prints the python AST when an error is encountered - NOTE: **used by workloads & functors; depreciate with functions** + NOTE: **used by workloads & functors; depreciate with functions** """ if member_type not in ("fields", "views", "typeinfo", "randpool"): diff --git a/pykokkos/core/visitors/parameter_visitor.py b/pykokkos/core/visitors/parameter_visitor.py index 28a6c005..9f0a028b 100644 --- a/pykokkos/core/visitors/parameter_visitor.py +++ b/pykokkos/core/visitors/parameter_visitor.py @@ -17,7 +17,7 @@ def __init__( """ ParameterVisitor constructor - :param src: the python source code of the workload + :param src: the python source code of the workunit :param param_begin: where workunit argument begins (excluding tid/acc) :param pk_import: the identifier used to access the PyKokkos package :param debug: if true, prints the python AST when an error is encountered