diff --git a/pykokkos/core/cpp_setup.py b/pykokkos/core/cpp_setup.py index 350f607b..a85d8291 100644 --- a/pykokkos/core/cpp_setup.py +++ b/pykokkos/core/cpp_setup.py @@ -187,7 +187,7 @@ def generate_cmake( Copy CMakeLists.txt template and prepare CMake configuration variables :param output_dir: the base directory - :param space: the execution space of the workload + :param space: the execution space of the workunit :param enable_uvm: whether to enable CudaUVMSpace :param compiler: what compiler to use :returns: tuple of (cmake_args, module_name) diff --git a/pykokkos/core/module_setup.py b/pykokkos/core/module_setup.py index b05aabb5..0d5cdae0 100644 --- a/pykokkos/core/module_setup.py +++ b/pykokkos/core/module_setup.py @@ -45,29 +45,22 @@ def get_metadata(entity: Union[Callable[..., None], object]) -> EntityMetadata: """ Gets the name and filepath of an entity - :param entity: the workload or workunit function object + :param entity: the workunit function object :returns: an EntityMetadata object """ name: str filepath: str - if isinstance(entity, Callable): - # Workunit/functor - is_functor: bool = hasattr(entity, "__self__") - if is_functor: - entity_type: type = get_functor(entity) - name = entity_type.__name__ - filepath = inspect.getfile(entity_type) - else: - name = entity.__name__ - filepath = inspect.getfile(entity) - - else: - # Workload - entity_type: type = type(entity) + # Workunit/functor + is_functor: bool = hasattr(entity, "__self__") + if is_functor: + entity_type: type = get_functor(entity) name = entity_type.__name__ filepath = inspect.getfile(entity_type) + else: + name = entity.__name__ + filepath = inspect.getfile(entity) return EntityMetadata(entity, name, filepath) @@ -88,7 +81,7 @@ def __init__( """ ModuleSetup constructor - :param entity: the functor/workunit/workload or list of workunits for fusion + :param entity: the functor/workunit or list of workunits for fusion :param types_signature: hash/string to identify workunit signature against types :param restricted_views: a set of view names that do not alias any other views """ diff --git a/pykokkos/core/parsers/parser.py b/pykokkos/core/parsers/parser.py index f2913c61..adfa2a9b 100644 --- a/pykokkos/core/parsers/parser.py +++ b/pykokkos/core/parsers/parser.py @@ -14,7 +14,6 @@ class PyKokkosStyles(Enum): """ functor = auto() - workload = auto() workunit = auto() classtype = auto() fused = auto() @@ -37,7 +36,7 @@ class PyKokkosEntity: class Parser: """ - Parse a PyKokkos workload and its dependencies + Parse a PyKokkos workunit and its dependencies """ def __init__(self, path: Optional[str], pk_import: Optional[str] = None): @@ -56,12 +55,10 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None): self.tree = ast.parse("".join(self.lines)) self.path: Optional[str] = path self.pk_import: str = self.get_import() - self.workloads: Dict[str, PyKokkosEntity] = {} self.classtypes: Dict[str, PyKokkosEntity] = {} self.functors: Dict[str, PyKokkosEntity] = {} self.workunits: Dict[str, PyKokkosEntity] = {} - self.workloads = self.get_entities(PyKokkosStyles.workload) self.classtypes = self.get_entities(PyKokkosStyles.classtype) self.functors = self.get_entities(PyKokkosStyles.functor) self.workunits = self.get_entities(PyKokkosStyles.workunit) @@ -74,7 +71,6 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None): self.tree = ast.Module(body=[]) self.path = None self.pk_import = pk_import - self.workloads: Dict[str, PyKokkosEntity] = {} self.classtypes: Dict[str, PyKokkosEntity] = {} self.functors: Dict[str, PyKokkosEntity] = {} self.workunits: Dict[str, PyKokkosEntity] = {} @@ -112,8 +108,6 @@ def get_entity(self, name: str) -> PyKokkosEntity: :returns: the PyKokkosEntity representation of the entity """ - if name in self.workloads: - return self.workloads[name] if name in self.functors: return self.functors[name] if name in self.workunits: @@ -132,9 +126,7 @@ def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]: entities: Dict[str, PyKokkosEntity] = {} check_entity: Callable[[ast.stmt], bool] - if style is PyKokkosStyles.workload: - return entities - elif style is PyKokkosStyles.functor: + if style is PyKokkosStyles.functor: check_entity = self.is_functor elif style is PyKokkosStyles.workunit: check_entity = self.is_workunit diff --git a/pykokkos/core/run_debug.py b/pykokkos/core/run_debug.py index 5cf87ed0..95a50616 100644 --- a/pykokkos/core/run_debug.py +++ b/pykokkos/core/run_debug.py @@ -17,46 +17,6 @@ import pykokkos.kokkos_manager as km -def run_workload_debug(workload: object) -> None: - """ - Run a workload in Python - - :param workload: the workload object - """ - - workload_source: str = inspect.getsource(type(workload)) - tree: ast.Module = ast.parse(workload_source) - classdef: ast.ClassDef = tree.body[0] - - def get_annotated_functions(decorator: Decorator) -> Dict[str, ast.FunctionDef]: - visitor = ast.NodeVisitor() - functions: Dict[str, ast.FunctionDef] = {} - - def visit_FunctionDef(node): - if node.decorator_list: - try: - node_decorator: str = node.decorator_list[0].id - except AttributeError: - node_decorator: str = node.decorator_list[0].attr - - if decorator.value == node_decorator: - functions[node.name] = node - - visitor.visit_FunctionDef = visit_FunctionDef - for method in classdef.body: - visitor.visit(method) - - return functions - - for name in get_annotated_functions(Decorator.KokkosMain): - kokkos_main = getattr(workload, name) - kokkos_main() - - for name in get_annotated_functions(Decorator.KokkosCallback): - kokkos_callback = getattr(workload, name) - kokkos_callback() - - def call_workunit( operation: str, workunit: Callable[..., None], diff --git a/pykokkos/core/runtime.py b/pykokkos/core/runtime.py index c5f76906..1232dacc 100644 --- a/pykokkos/core/runtime.py +++ b/pykokkos/core/runtime.py @@ -39,7 +39,7 @@ from .compiler import Compiler from .module_setup import EntityMetadata, get_metadata, ModuleSetup -from .run_debug import run_workload_debug, run_workunit_debug +from .run_debug import run_workunit_debug def _calculate_aligned_scratch_size( @@ -112,38 +112,18 @@ def apply_scratch_spec(workunit: Callable, policy: TeamPolicy, **kwargs) -> None class Runtime: """ - Executes (and optionally compiles) PyKokkos workloads + Executes (and optionally compiles) PyKokkos workunits """ def __init__(self): self.compiler: Compiler = Compiler() self.tracer: Tracer = Tracer() - # cache module_setup objects using a workload/workunit and space tuple + # cache module_setup objects using a workunit and space tuple self.module_setups: Dict[Tuple, ModuleSetup] = {} self.fusion_strategy: Optional[str] = os.getenv("PK_FUSION") - def run_workload(self, space: ExecutionSpace, workload: object) -> None: - """ - Run the workload - - :param space: the execution space of the workload - :param workload: the workload object - """ - - if self.is_debug(space): - run_workload_debug(workload) - return - - module_setup: ModuleSetup = self.get_module_setup(workload, space) - members: PyKokkosMembers = self.compiler.compile_object( - module_setup, space, km.is_uvm_enabled(), None, None, None, set() - ) - - self.execute(workload, module_setup, members, space) - self.run_callbacks(workload, members) - def precompile_workunit( self, workunit: Callable[..., None], @@ -410,7 +390,7 @@ def execute( """ Imports the module containing the bindings and executes the necessary function - :param entity: the workload or workunit object + :param entity: the workunit object :param module_path: the path to the compiled module :param members: a collection of PyKokkos related members :param space: the execution space @@ -418,7 +398,7 @@ def execute( :param name: the name of the kernel :param operation: the name of the operation "for", "reduce", or "scan" :param kwargs: the keyword arguments passed to the workunit - :returns: the result of the operation (None for "for" and workloads) + :returns: the result of the operation (None for "for") """ module_path: str @@ -440,10 +420,6 @@ def execute( result = self.call_wrapper(entity, members, args, module) - is_workunit_or_functor: bool = isinstance(entity, (Callable, list)) - if not is_workunit_or_functor: - self.retrieve_results(entity, members, args) - return result def import_module(self, module_name: str, module_path: str): @@ -480,7 +456,7 @@ def get_arguments( """ Get the arguments for a wrapper function, including fields, views, etc - :param entity: the workload or workunit object + :param entity: the workunit object :param members: a collection of PyKokkos related members :param space: the execution space :param policy: the execution policy of the operation @@ -533,7 +509,7 @@ def call_wrapper( """ Call the wrapper in the imported module - :param entity: the workload or workunit object + :param entity: the workunit object :param members: a collection of PyKokkos related members :param args: the arguments to be passed to the wrapper :param module: the imported module @@ -592,31 +568,9 @@ def get_precision(self, members: PyKokkosMembers, args: Dict[str, Any]) -> str: return precision - def get_result_arguments(self, members: PyKokkosMembers) -> Dict[str, Any]: - """ - Get the views that are passed as arguments to hold the results for workloads - - :param members: a collection of PyKokkos related members - :returns: a dictionary of argument name to value - """ - - args: Dict[str, Any] = {} - - for result in members.reduction_result_queue: - name: str = f"reduction_result_{result}" - result_view = View([1], DataType.double, MemorySpace.HostSpace) - args[name] = result_view.array - - for result in members.timer_result_queue: - name: str = f"timer_result_{result}" - result_view = View([1], DataType.double, MemorySpace.HostSpace) - args[name] = result_view.array - - return args - def get_policy_arguments(self, policy: ExecutionPolicy) -> Dict[str, Any]: """ - Get the arguments that are used for to hold the results for workloads + Get the arguments that are used for to hold the results for workunits :param policy: the execution policy of the operation :returns: a dictionary of argument name to value @@ -668,9 +622,9 @@ def get_policy_arguments(self, policy: ExecutionPolicy) -> Dict[str, Any]: def get_fields(self, members: Dict[str, type]) -> Dict[str, Any]: """ - Gets all the primitive type fields from the workload object + Gets all the primitive type fields from the workunit object - :param workload: the dictionary containing all members + :param members: the dictionary containing all members :returns: a dict mapping from field name to value """ @@ -725,9 +679,9 @@ def _convert_functor_arrays(self, members: Dict[str, Any]) -> None: def get_views(self, members: Dict[str, type]) -> Dict[str, Any]: """ - Gets all the views from the workload object + Gets all the views from the workunit object - :param workload: the dictionary containing all members + :param workunit: the dictionary containing all members :returns: a dict mapping from view name to object """ @@ -738,40 +692,6 @@ def get_views(self, members: Dict[str, type]) -> Dict[str, Any]: return views - def retrieve_results( - self, workload: object, members: PyKokkosMembers, args: Dict[str, Any] - ) -> None: - """ - Get the results for workloads - - :param workload: the workload object - :param members: a collection of PyKokkos related members - :param args: the arguments passed to the wrapper, including views that hold results - """ - - for result in members.reduction_result_queue: - name: str = f"reduction_result_{result}" - view: View = args[name] - setattr(workload, result, view[0]) - - for result in members.timer_result_queue: - name: str = f"timer_result_{result}" - view: View = args[name] - setattr(workload, result, view[0]) - - def run_callbacks(self, workload: object, members: PyKokkosMembers) -> None: - """ - Run all methods in the workload that are annotated with @pk.callback - - :param workload: the workload object - :param members: a collection of PyKokkos related members in workload - """ - - callbacks = members.pk_callbacks - for name in callbacks: - callback = getattr(workload, name.declname) - callback() - def get_module_setup( self, entity: Union[object, Callable[..., None]], @@ -782,7 +702,7 @@ def get_module_setup( """ Get the compiled module setup information unique to an entity + space - :param entity: the workload or workunit object + :param entity: the workunit object :param space: the execution space :param types_signature: Hash/identifer string for workunit module against data types :param restrict_signature: Hash/identifer string for views that do not alias any other views @@ -815,10 +735,10 @@ def get_module_setup_id( """ Get a unique module setup id for an entity + space combination. For workunits, the idenitifier is just the - workunit and execution space. For workloads and functors, we + workunit and execution space. For functors, we need the type of the class as well as the file containing it. - :param entity: the workload or workunit object + :param entity: the workunit object :param space: the execution space :param types_signature: optional identifier/hash string for types of parameters against workunit module diff --git a/pykokkos/core/translators/bindings.py b/pykokkos/core/translators/bindings.py index cbb4d837..b3bc30fd 100644 --- a/pykokkos/core/translators/bindings.py +++ b/pykokkos/core/translators/bindings.py @@ -4,22 +4,22 @@ from pykokkos.core import cppast from pykokkos.core.keywords import Keywords -from pykokkos.core.visitors import cpp_view_type, KokkosMainVisitor, visitors_util +from pykokkos.core.visitors import cpp_view_type, visitors_util from pykokkos.interface.data_types import DataType from .members import PyKokkosMembers -def is_hierarchical(workunit: Optional[cppast.MethodDecl]) -> bool: +def is_hierarchical(workunit: cppast.MethodDecl) -> bool: """ Checks if a workunit uses hierarchical parallelism by checking if it has a TeamMember instead of a thread ID - :param workunit: the workunit definition or None for a workload + :param workunit: the workunit definition :returns: true if hierarchical false otherwise """ if workunit is None: - return False + raise ValueError("workunit definition must be passed") # Iterate over each parameter (skipping the tag) for p in workunit.params[1:]: @@ -256,7 +256,7 @@ def get_return_type(operation: str, workunit: cppast.MethodDecl) -> str: """ Get the return type of a binding - :param operation: the type of the operation (for, reduce, scan, or workload) + :param operation: the type of the operation (for, reduce, or scan) :param workunit: the workunit for which the binding is being generated :returns: the return type as a string """ @@ -384,7 +384,7 @@ def generate_wrapper( Generate the wrapper that calls the kernel and its binding :param members: an object containing the fields and views - :param operation: the type of the operation (for, reduce, scan, or workload) + :param operation: the type of the operation (for, reduce, or scan) :param workunit: the workunit for which the binding is being generated :param wrapper: the name of the wrapper :param kernel: the name of the kernel @@ -562,145 +562,3 @@ def bind_workunits( bindings.append(bind_wrappers(module, wrapper_names)) return bindings - - -def translate_mains( - source: Tuple[List[str], int], - functor: str, - members: PyKokkosMembers, - pk_import: str, -) -> List[str]: - """ - Translate all PyKokkos main functions - - :param source: the python source code of the workload - :param functor: the name of the functor - :param members: an object containing the fields and views - :returns: a list of strings of translated source code - """ - - node_visitor = KokkosMainVisitor( - {}, - source, - members.views, - members.pk_workunits, - members.fields, - members.pk_functions, - members.classtype_methods, - functor, - pk_import, - debug=True, - ) - - translation: List[str] = [] - - for main in members.pk_mains.values(): - try: - translation.append(node_visitor.visit(main)) - except NotImplementedError: - print(f"Translation of {main.name} failed") - sys.exit(1) - - members.reduction_result_queue = node_visitor.reduction_result_queue - members.timer_result_queue = node_visitor.timer_result_queue - - return translation - - -def bind_main_single( - functor: str, - members: PyKokkosMembers, - source: Tuple[List[str], int], - pk_import: str, - precision: Optional[DataType], -) -> Tuple[str, str]: - """ - Generates the kernel and its python binding - - :param functor: the functor class name - :param members: an object containing the fields and views - :param source: the python source code of the workload - :param pk_import: the pykokkos import alias - :param precision: the precision for which to generate a binding - :returns: a tuple of strings containing the wrapper name, and the kernel and wrapper - """ - - wrapper_name: str = "wrapper" - kernel_name: str = "run" - - real: Optional[str] = None - if precision is not None: - real = visitors_util.view_dtypes[precision.name].value - wrapper_name += f"_{real}" - kernel_name += f"_{real}" - functor += f"<{Keywords.DefaultExecSpace.value},{real}>" - else: - functor += f"<{Keywords.DefaultExecSpace.value}>" - - main: List[str] = translate_mains(source, functor, members, pk_import) - params: Dict[str, str] = get_kernel_params(members, False, real) - - # fall back to the old hard-coded default - # for now--this includes cases where an - # accumulator is not even defined - acc_type = "double" - - for element in source[0]: - # TODO: support more types - if "pk.Acc" in element: - if "pk.int64" in element: - acc_type = "int64_t" - elif "pk.double" in element: - acc_type = "double" - - signature: str = generate_kernel_signature("void", kernel_name, params) - instantiation: str = generate_functor_instance(functor, members) - acc: str = f"{acc_type} {Keywords.Accumulator.value} = 0;" - body: str = "".join(main) - copy_back: str = generate_copy_back(members) - # fence: str = generate_fence_call() - - kernel: str = f"{signature} {{ {instantiation} {acc} {body} {copy_back} }}" - wrapper: str = generate_wrapper( - members, "workload", None, wrapper_name, kernel_name, real - ) - binding: str = f"{kernel} {wrapper}" - - return wrapper_name, binding - - -def bind_main( - functor: str, - members: PyKokkosMembers, - source: Tuple[List[str], int], - pk_import: str, - module: str, -) -> List[str]: - """ - Generates the kernel and its python binding - - :param functor: the functor class name - :param members: an object containing the fields and views - :param source: the python source code of the workload - :param pk_import: the pykokkos import alias - :param module: the name of the generated module - :returns: a list of strings containing the kernel, wrapper, and binding - """ - - bindings: List[str] = [] - wrapper_names: List[str] = [] - if members.has_real: - for d in DataType: - if d is DataType.real: - continue - w, b = bind_main_single(functor, members, source, pk_import, d) - bindings.append(b) - wrapper_names.append(w) - else: - w, b = bind_main_single(functor, members, source, pk_import, None) - bindings.append(b) - wrapper_names.append(w) - - bindings.append(bind_wrappers(module, wrapper_names)) - - return bindings diff --git a/pykokkos/core/translators/members.py b/pykokkos/core/translators/members.py index 3d8e8e64..45e48882 100644 --- a/pykokkos/core/translators/members.py +++ b/pykokkos/core/translators/members.py @@ -8,7 +8,6 @@ from pykokkos.core.parsers import PyKokkosEntity, PyKokkosStyles from pykokkos.core.visitors import ( ConstructorVisitor, - KokkosMainVisitor, ParameterVisitor, visitors_util, ) @@ -31,8 +30,6 @@ def __init__(self): self.pk_workunits: Dict[cppast.DeclRefExpr, ast.FunctionDef] = {} self.pk_functions: Dict[cppast.DeclRefExpr, ast.FunctionDef] = {} - self.pk_mains: Dict[cppast.DeclRefExpr, ast.FunctionDef] = {} - self.pk_callbacks: Dict[cppast.DeclRefExpr, ast.FunctionDef] = {} self.classtype_methods: Dict[cppast.DeclRefExpr, List[cppast.DeclRefExpr]] = {} @@ -56,13 +53,7 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N source: Tuple[List[str], int] = entity.source pk_import: str = entity.pk_import - if entity.style is PyKokkosStyles.workload: - self.pk_mains = self.get_decorated_functions(AST, Decorator.KokkosMain) - self.fields = self.get_fields(AST, source, pk_import) - self.views = self.get_views(AST, source, pk_import) - self.random_pool = self.get_random_pool(AST, source, pk_import) - - elif entity.style is PyKokkosStyles.functor: + if entity.style is PyKokkosStyles.functor: self.fields = self.get_fields(AST, source, pk_import) self.views = self.get_views(AST, source, pk_import) self.random_pool = self.get_random_pool(AST, source, pk_import) @@ -104,14 +95,11 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N if n in self.view_template_params: t.template_params.extend(self.view_template_params[n]) - if entity.style in (PyKokkosStyles.workload, PyKokkosStyles.functor): + if entity.style == PyKokkosStyles.functor: self.pk_workunits = self.get_decorated_functions(AST, Decorator.WorkUnit) self.pk_functions = self.get_decorated_functions( AST, Decorator.KokkosFunction ) - self.pk_callbacks = self.get_decorated_functions( - AST, Decorator.KokkosCallback - ) else: self.pk_workunits[cppast.DeclRefExpr(AST.name)] = AST self.pk_functions = self.get_decorated_functions( @@ -120,16 +108,6 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N self.classtype_methods = self.get_classtype_methods(classtypes) - if entity.style is PyKokkosStyles.workload: - name: str = f"pk_functor_{entity.name}" - self.reduction_result_queue, self.timer_result_queue = self.get_queues( - source, name, pk_import - ) - - if len(self.pk_mains) > 1: - print("ERROR: Only one pk.main function can be translated") - sys.exit(1) - def get_fields( self, classdef: ast.ClassDef, source: Tuple[List[str], int], pk_import: str ) -> Dict[cppast.DeclRefExpr, cppast.PrimitiveType]: @@ -137,9 +115,12 @@ def get_fields( Get all fields (or instance variables) in classdef by parsing the constructor :param classdef: the classdef being parsed - :param source: the python source code of the workload + :param source: the python source code of the functor :param pk_import: the identifier used to access the PyKokkos package :returns: a dictionary mapping from field name to type + + + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "fields", pk_import, True) @@ -155,9 +136,11 @@ def get_views( Get all views defined in classdef by parsing the constructor :param classdef: the classdef to be parsed - :param source: the python source code of the workload + :param source: the python source code of the functor :param pk_import: the identifier used to access the PyKokkos package :returns: a dictionary mapping from view name to type (only dimensionality and type) + + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "views", pk_import, True) @@ -166,45 +149,6 @@ def get_views( return views - def get_queues( - self, source: Tuple[List[str], int], name: str, pk_import: str - ) -> Tuple[List[str], List[str]]: - """ - Get all fields assigned to a reduction result or timer result - - :param source: the python source code of the workload - :param name: the name of the workload - :param pk_import: the identifier used to access the PyKokkos package - :returns: two lists, one for the reduction results and one for the timer results - """ - - views = copy.deepcopy( - self.views - ) # Needed since KokkosMainVisitor modifies views - - # Copied from translate_mains() in bindings.py - node_visitor = KokkosMainVisitor( - {}, - source, - views, - self.pk_workunits, - self.fields, - self.pk_functions, - self.classtype_methods, - name, - pk_import, - debug=True, - ) - - for main in self.pk_mains.values(): - try: - node_visitor.visit(main) - except NotImplementedError: - print(f"Translation of {main.name} failed") - sys.exit(1) - - return (node_visitor.reduction_result_queue, node_visitor.timer_result_queue) - def get_real_views(self): """ Get all the views that contain a pk.real datatype @@ -240,7 +184,7 @@ def get_params( Gets all fields and views passed as parameters to the workunit :param functiondef: the functiondef to be parsed - :param source: the python source code of the workload + :param source: the python source code of the workunit :param param_begin: where workunit argument begins (excluding tid/acc) """ @@ -259,7 +203,7 @@ def get_view_template_params( Get the template parameters for all views defined in the constructor :param node: the classdef or functiondef to be parsed - :param source: the python source code of the workload + :param source: the python source code of the workunit :returns: a dictionary mapping from view name to a list of template parameters """ @@ -329,9 +273,11 @@ def get_random_pool( Gets the type of the random pool if it exists :param classdef: the classdef to be parsed - :param source: the python source code of the workload + :param source: the python source code of the functor :param pk_import: the identifier used to access the PyKokkos package :returns: the type of the random pool if it exists + + NOTE: **used by workloads & functors; depreciate with functions** """ visitor = ConstructorVisitor(source, "randpool", pk_import, True) diff --git a/pykokkos/core/translators/static.py b/pykokkos/core/translators/static.py index 3feb9294..b7615d59 100644 --- a/pykokkos/core/translators/static.py +++ b/pykokkos/core/translators/static.py @@ -15,7 +15,7 @@ WorkunitVisitor, ) -from .bindings import bind_main, bind_workunits +from .bindings import bind_workunits from .functor import generate_functor from .functor_cast import generate_cast from .members import PyKokkosMembers @@ -34,7 +34,7 @@ def generate_include_guard_end() -> str: class StaticTranslator: """ - Translates a PyKokkos workload to C++ using static analysis only + Translates a PyKokkos workunit to C++ using static analysis only """ def __init__( @@ -172,8 +172,6 @@ def check_symbols(self, classtypes: List[PyKokkosEntity], path: str) -> None: symbols_pass = SymbolsPass(self.pk_members, self.pk_import, path) error_messages: List[str] = [] - for AST in self.pk_members.pk_mains.values(): - error_messages.extend(symbols_pass.check_symbols(AST)) for AST in self.pk_members.pk_workunits.values(): error_messages.extend(symbols_pass.check_symbols(AST)) for AST in self.pk_members.pk_functions.values(): @@ -191,9 +189,9 @@ def translate_classtypes( self, classtypes: List[PyKokkosEntity], restrict_views: Set[str] ) -> List[cppast.RecordDecl]: """ - Translate all classtypes, i.e. classes that the workload uses internally + Translate all classtypes, i.e. classes that the workunit uses internally - :param classtypes: the list of classtypes needed by the workload + :param classtypes: the list of classtypes needed by the workunit :param restrict_views: the views with the restrict keyword :returns: a list of strings of translated source code """ @@ -365,7 +363,7 @@ def translate_functions( """ Translate all PyKokkos functions - :param source: the python source code of the workload + :param source: the python source code of the workunit :param restrict_views: the views with the restrict keyword :returns: a list of method declarations """ @@ -402,9 +400,9 @@ def translate_workunits( """ Translate the workunits - :param source: the python source code of the workload + :param source: the python source code of the workunit :param restrict_views: the views with the restrict keyword - :returns: a tuple of a dictionary mapping from workload name + :returns: a tuple of a dictionary mapping from workunit name to a tuple of operation name and source, and a boolean indicating whether the workunit has a call to pk.rand() """ @@ -511,14 +509,9 @@ def generate_bindings( """ bindings: List[str] - if entity.style is PyKokkosStyles.workload: - bindings = bind_main( - functor_name, self.pk_members, source, self.pk_import, self.module_file - ) - else: - bindings = bind_workunits( - functor_name, self.pk_members, workunits, self.module_file - ) + bindings = bind_workunits( + functor_name, self.pk_members, workunits, self.module_file + ) return bindings diff --git a/pykokkos/core/visitors/__init__.py b/pykokkos/core/visitors/__init__.py index ea01e9a8..a5ab93ee 100644 --- a/pykokkos/core/visitors/__init__.py +++ b/pykokkos/core/visitors/__init__.py @@ -1,8 +1,6 @@ from .classtype_visitor import ClasstypeVisitor from .constructor_visitor import ConstructorVisitor -from .debug_transformer import DebugTransformer from .kokkosfunction_visitor import KokkosFunctionVisitor -from .kokkosmain_visitor import KokkosMainVisitor from .parameter_visitor import ParameterVisitor from .pykokkos_visitor import PyKokkosVisitor from .visitors_util import cpp_view_type, parse_view_template_params diff --git a/pykokkos/core/visitors/constructor_visitor.py b/pykokkos/core/visitors/constructor_visitor.py index 69c94f01..f2a31b46 100644 --- a/pykokkos/core/visitors/constructor_visitor.py +++ b/pykokkos/core/visitors/constructor_visitor.py @@ -22,6 +22,8 @@ def __init__( :param member_type: specifies which members to retrieve, "fields", "views", or "typeinfo" :param pk_import: the identifier used to access the PyKokkos package :param debug: if true, prints the python AST when an error is encountered + + NOTE: **used by workloads & functors; depreciate with functions** """ if member_type not in ("fields", "views", "typeinfo", "randpool"): diff --git a/pykokkos/core/visitors/debug_transformer.py b/pykokkos/core/visitors/debug_transformer.py deleted file mode 100644 index e011ed49..00000000 --- a/pykokkos/core/visitors/debug_transformer.py +++ /dev/null @@ -1,32 +0,0 @@ -import ast - -from pykokkos.interface import Decorator - - -# in the Debug ExecSpace we need to wrap all instances of the accumulator -# variable in reduction workloads to work around python's lack of reference -# types for primative numbers -class DebugTransformer(ast.NodeTransformer): - def __init__(self): - self.inside_reduction = False - - def visit_FunctionDef(self, node): - if ( - node.decorator_list - and Decorator.is_work_unit(node.decorator_list[0].id) - and len(node.args.args) == 3 - ): - self.inside_reduction = True - self.acc = node.args.args[-1].arg - node.body = list(map(self.visit, node.body)) - self.inside_reduction = False - return node - - def visit_Name(self, node): - if self.inside_reduction and node.id == self.acc: - return ast.Subscript( - ast.Name(self.acc, ast.Load()), - ast.Index(ast.Constant(0, None)), - node.ctx, - ) - return node diff --git a/pykokkos/core/visitors/kokkosmain_visitor.py b/pykokkos/core/visitors/kokkosmain_visitor.py deleted file mode 100644 index 315a6a43..00000000 --- a/pykokkos/core/visitors/kokkosmain_visitor.py +++ /dev/null @@ -1,535 +0,0 @@ -import ast -from typing import List, Dict, Optional, Set, Union -from ast import FunctionDef - -from pykokkos.core import cppast -from pykokkos.core.keywords import Keywords -from pykokkos.interface import BinOp, BinSort, View, Iterate, TeamPolicy - -from . import visitors_util -from .pykokkos_visitor import PyKokkosVisitor - - -class KokkosMainVisitor(PyKokkosVisitor): - def __init__( - self, - env, - src, - views: Dict[str, View], - work_units: Dict[str, FunctionDef], - fields: Dict[cppast.DeclRefExpr, cppast.PrimitiveType], - kokkos_functions: Dict[str, FunctionDef], - dependency_methods: Dict[str, List[str]], - functor: str, - pk_import: str, - restrict_views: Set[str] = set(), - debug=False, - ): - super().__init__( - env, - src, - views, - work_units, - fields, - kokkos_functions, - dependency_methods, - pk_import, - restrict_views, - debug, - ) - - self.functor: str = functor - self.reduction_result_queue: List[str] = [] - self.timer_result_queue: List[str] = [] - - def visit_FunctionDef(self, node: ast.FunctionDef) -> str: - run_body: str = "" - serializer = cppast.Serializer() - for statement in node.body: - run_body += serializer.serialize(self.visit(statement)) - - return run_body - - def visit_Assign(self, node: ast.Assign) -> cppast.Stmt: - target = node.targets[0] - - if isinstance(node.value, ast.Call): - name: str = visitors_util.get_node_name(node.value.func) - - # Create Timer object - if name == "Timer": - decltype = cppast.ClassType("Kokkos::Timer") - declname = cppast.DeclRefExpr("timer") - return cppast.DeclStmt(cppast.VarDecl(decltype, declname, None)) - - # Call Timer.seconds() - if name == "seconds": - target_name: str = visitors_util.get_node_name(target) - if target_name not in self.timer_result_queue: - self.timer_result_queue.append(target_name) - - call = cppast.CallStmt(self.visit(node.value)) - target_ref = cppast.DeclRefExpr(target_name) - target_view_ref = cppast.DeclRefExpr(f"timer_result_{target_name}") - subscript = cppast.ArraySubscriptExpr( - target_view_ref, [cppast.IntegerLiteral(0)] - ) - assign_op = cppast.BinaryOperatorKind.Assign - - # Holds the result of the reduction temporarily - temp_ref = cppast.DeclRefExpr("pk_acc") - target_assign = cppast.AssignOperator([target_ref], temp_ref, assign_op) - view_assign = cppast.AssignOperator([subscript], target_ref, assign_op) - - return cppast.CompoundStmt([call, target_assign, view_assign]) - - if name in ("BinSort", "BinOp1D", "BinOp3D"): - args: List = node.value.args - # if not isinstance(args[0], ast.Attribute): - # self.error(node.value, "First argument has to be a view") - - view = cppast.DeclRefExpr(visitors_util.get_node_name(args[0])) - if view not in self.views: - self.error(args[0], "Undefined view") - - view_type: cppast.ClassType = self.views[view] - is_subview: bool = view_type is None - if is_subview: - parent_view_name: str = self.subviews[view.declname] - - # Need to remove "pk_d_" from the start of the - # view name to get the type of the parent - if parent_view_name.startswith("pk_d_"): - parent_view_name = parent_view_name.replace("pk_d_", "", 1) - parent_view = cppast.DeclRefExpr(parent_view_name) - view_type = self.views[parent_view] - - view_type_str: str = visitors_util.cpp_view_type(view_type) - - if name != "BinSort": - dimension: int = 1 if name == "BinOp1D" else 3 - cpp_type = cppast.DeclRefExpr( - BinOp.get_type(dimension, view_type_str) - ) - - # Do not translate the first argument (view) - constructor = cppast.CallExpr( - cpp_type, [self.visit(a) for a in args[1:]] - ) - - else: - bin_op_type: str = ( - f"decltype({visitors_util.get_node_name(args[1])})" - ) - - binsort_args: List[cppast.DeclRefExpr] = [ - self.visit(a) for a in args - ] - cpp_type = cppast.DeclRefExpr( - BinSort.get_type( - f"decltype({binsort_args[0].declname})", - bin_op_type, - Keywords.DefaultExecSpace.value, - ) - ) - constructor = cppast.CallExpr(cpp_type, binsort_args) - - cpp_target: cppast.DeclRefExpr = self.visit(target) - auto_type = cppast.ClassType("auto") - - return cppast.DeclStmt( - cppast.VarDecl(auto_type, cpp_target, constructor) - ) - - if name in ("get_bin_count", "get_bin_offsets", "get_permute_vector"): - if not isinstance(target, ast.Attribute) or target.value.id != "self": - self.error( - node, "Views defined in pk.main must be an instance variable" - ) - - cpp_target: str = visitors_util.get_node_name(target) - cpp_device_target = f"pk_d_{cpp_target}" - cpp_target_ref = cppast.DeclRefExpr(cpp_device_target) - sorter: cppast.DeclRefExpr = self.visit(node.value.func.value) - - initial_target_ref = cppast.DeclRefExpr( - f"_pk_{cpp_target_ref.declname}" - ) - - function = cppast.MemberCallExpr(sorter, cppast.DeclRefExpr(name), []) - - # Add to the dict of declarations made in pk.main - if name == "get_permute_vector": - # This occurs when a workload is executed multiple times - # Initially the view has not been defined in the workload, - # so it needs to be classified as a pkmain_view. - if cpp_target in self.views: - self.views[cpp_target_ref].add_template_param( - cppast.PrimitiveType(cppast.BuiltinType.INT) - ) - - return cppast.AssignOperator( - [cpp_target_ref], function, cppast.BinaryOperatorKind.Assign - ) - # return f"{cpp_target} = {sorter}.{name}();" - - self.pkmain_views[cpp_target_ref] = cppast.ClassType("View1D") - else: - self.pkmain_views[cpp_target_ref] = None - - auto_type = cppast.ClassType("auto") - decl = cppast.DeclStmt( - cppast.VarDecl(auto_type, initial_target_ref, function) - ) - - # resize the workload's vector to match the generated vector - resize_call = cppast.CallStmt( - cppast.CallExpr( - cppast.DeclRefExpr("Kokkos::resize"), - [ - cpp_target_ref, - cppast.MemberCallExpr( - initial_target_ref, - cppast.DeclRefExpr("extent"), - [cppast.IntegerLiteral(0)], - ), - ], - ) - ) - - copy_call = cppast.CallStmt( - cppast.CallExpr( - cppast.DeclRefExpr("Kokkos::deep_copy"), - [cpp_target_ref, initial_target_ref], - ) - ) - - # Assign to the functor after resizing - functor = cppast.DeclRefExpr("pk_f") - functor_access = cppast.MemberExpr(functor, cpp_target) - functor_assign = cppast.AssignOperator( - [functor_access], cpp_target_ref, cppast.BinaryOperatorKind.Assign - ) - - return cppast.CompoundStmt( - [decl, resize_call, copy_call, functor_assign] - ) - - # Assign result of parallel_reduce - if type(target) not in {ast.Name, ast.Subscript} and target.value.id == "self": - target_name: str = visitors_util.get_node_name(target) - if target_name not in self.reduction_result_queue: - self.reduction_result_queue.append(target_name) - - call = cppast.CallStmt(self.visit(node.value)) - target_ref = cppast.DeclRefExpr(target_name) - target_view_ref = cppast.DeclRefExpr(f"reduction_result_{target_name}") - subscript = cppast.ArraySubscriptExpr( - target_view_ref, [cppast.IntegerLiteral(0)] - ) - assign_op = cppast.BinaryOperatorKind.Assign - - # Holds the result of the reduction temporarily - temp_ref = cppast.DeclRefExpr("pk_acc") - target_assign = cppast.AssignOperator([target_ref], temp_ref, assign_op) - view_assign = cppast.AssignOperator([subscript], target_ref, assign_op) - - return cppast.CompoundStmt([call, target_assign, view_assign]) - - return super().visit_Assign(node) - - def visit_Attribute(self, node: ast.Attribute) -> cppast.DeclRefExpr: - name: str = visitors_util.get_node_name(node) - if name in self.work_units: - return cppast.DeclRefExpr(name) - - if node.value.id == "self": - if name in self.views: - return name - - return cppast.DeclRefExpr(name) - - return super().visit_Attribute(node) - - def visit_Lambda(self, node: ast.Lambda) -> cppast.Expr: - # NOTE: should handle args, kwonlyargs, varargs, kwargs, defaults - return self.visit(node.body) - - def visit_Subscript( - self, node: ast.Subscript - ) -> Union[cppast.ArraySubscriptExpr, cppast.CallExpr]: - call: Union[cppast.ArraySubscriptExpr, cppast.CallExpr] = ( - super().visit_Subscript(node) - ) - if isinstance(call, cppast.CallExpr): - view_name: str = call.function.declname - call._function._declname = f"pk_d_{view_name}" - - return call - - def visit_Call(self, node: ast.Call) -> Union[cppast.Expr, cppast.Stmt]: - name: str = visitors_util.get_node_name(node.func) - args: List[cppast.Expr] = [self.visit(a) for a in node.args] - - # Add pk_d_ before each view name to match mirror view names - s = cppast.Serializer() - for i in range(len(args)): - if args[i] in self.views: - if self.views[args[i]] is not None: - view: str = s.serialize(args[i]) - args[i] = cppast.DeclRefExpr(f"pk_d_{view}") - - # Nested parallelism - if name == "TeamPolicy": - function = cppast.DeclRefExpr(f"Kokkos::{name}") - if len(args) == 2: - args.append(cppast.IntegerLiteral(1)) - - policy = cppast.ConstructExpr(function, args) - - return policy - - if name in dir(TeamPolicy): - team_policy = self.visit(node.func.value) - return cppast.MemberCallExpr(team_policy, cppast.DeclRefExpr(name), args) - - elif name in ["RangePolicy", "MDRangePolicy"]: - rank = len(node.args[0].elts) - if rank == 0: - self.error(node.value, "RangePolicy dimension must be greater than 0") - if rank != len(node.args[1].elts): - self.error(node.value, "RangePolicy dimension mismatch") - - iter_outer = Iterate.Default - iter_inner = Iterate.Default - for keyword in node.keywords: - if keyword.arg == "rank": - explicit_rank = keyword.value.args[0].value - if explicit_rank != rank: - self.error(node.value, "RangePolicy dimension mismatch") - - iter_outer = getattr(Iterate, keyword.value.args[1].attr) - iter_inner = getattr(Iterate, keyword.value.args[2].attr) - - function = cppast.DeclRefExpr( - f"Kokkos::{name}>" - ) - policy = cppast.ConstructExpr(cppast.DeclRefExpr(f"Kokkos::{name}"), args) - if name == "MDRangePolicy": - policy.add_template_param( - cppast.DeclRefExpr( - f"Kokkos::Rank<{rank},{iter_outer.value},{iter_inner.value}>" - ) - ) - - return policy - - if name == "seconds": - fence = cppast.CallStmt( - cppast.CallExpr(cppast.DeclRefExpr("Kokkos::fence"), []) - ) - temp_decl = cppast.DeclRefExpr("pk_acc") - seconds = cppast.MemberCallExpr( - cppast.DeclRefExpr("timer"), cppast.DeclRefExpr("seconds"), [] - ) - result = cppast.AssignOperator( - [temp_decl], seconds, cppast.BinaryOperatorKind.Assign - ) - - return cppast.CompoundStmt([fence, result]) - - function = cppast.DeclRefExpr(f"Kokkos::{name}") - if name == "parallel_for": - arg_start: int = 0 # Accounts for the optional kernel name - kernel_name: Optional[cppast.StringLiteral] = None - if isinstance(args[0], cppast.StringLiteral): - kernel_name = args[0] - arg_start = 1 - - policy: cppast.ConstructExpr = args[arg_start] - policy = self.add_space_to_policy(policy) - - if isinstance(node.args[arg_start + 1], ast.Lambda): - decl: str = "KOKKOS_LAMBDA (" - tid = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[0].arg) - - # if target exists - if len(args) == arg_start + 3: - target = cppast.ArraySubscriptExpr(args[arg_start + 2], [tid]) - args[arg_start + 1] = cppast.AssignOperator( - [target], args[arg_start + 1], cppast.BinaryOperatorKind.Assign - ) - - serializer = cppast.Serializer() - decl += f"int {tid.declname}) {{" - decl += serializer.serialize(args[arg_start + 1]) + ";}\n" - - call_args: List[cppast.Expr] = [policy, decl] - if kernel_name is not None: - call_args.insert(0, kernel_name) - - return cppast.CallExpr(function, call_args) - - else: - work_unit: str = args[arg_start + 1].declname - policy = self.add_workunit_to_policy(policy, work_unit) - - call_args: List[cppast.Expr] = [policy, cppast.DeclRefExpr("pk_f")] - if kernel_name is not None: - call_args.insert(0, kernel_name) - - return cppast.CallExpr(function, call_args) - - if name in ("parallel_reduce", "parallel_scan"): - arg_start: int = 0 # Accounts for the optional kernel name - kernel_name: Optional[cppast.StringLiteral] = None - if isinstance(args[0], cppast.StringLiteral): - kernel_name = args[0] - arg_start = 1 - - initial_value: cppast.Expr - if len(args) == arg_start + 3: - initial_value = args[arg_start + 2] - else: - initial_value = cppast.IntegerLiteral(0) - - acc_decl = cppast.DeclRefExpr("pk_acc") - init_var = cppast.BinaryOperator( - acc_decl, initial_value, cppast.BinaryOperatorKind.Assign - ) - - policy: cppast.ConstructExpr = args[arg_start] - policy = self.add_space_to_policy(policy) - - if isinstance(node.args[arg_start + 1], ast.Lambda): - decl: str = "KOKKOS_LAMBDA (" - tid = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[0].arg) - acc = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[1].arg) - - # assign to accumulator - args[arg_start + 1] = cppast.AssignOperator( - [acc], args[arg_start + 1], cppast.BinaryOperatorKind.Assign - ) - - serializer = cppast.Serializer() - decl += f"int {tid.declname}, double& {acc.declname}) {{" - decl += serializer.serialize(args[arg_start + 1]) + ";}\n" - - call_args: List[cppast.Expr] = [policy, decl, acc_decl] - if kernel_name is not None: - call_args.insert(0, kernel_name) - - call = cppast.CallExpr(function, call_args) - - else: - work_unit: str = args[arg_start + 1].declname - policy = self.add_workunit_to_policy(policy, work_unit) - - call_args: List[cppast.Expr] = [ - policy, - cppast.DeclRefExpr("pk_f"), - acc_decl, - ] - if kernel_name is not None: - call_args.insert(0, kernel_name) - - return cppast.CallExpr(function, call_args) - - return cppast.BinaryOperator( - init_var, call, cppast.BinaryOperatorKind.Comma - ) - - if name in dir(BinSort): - sorter: str = visitors_util.get_node_name(node.func.value) - sorter_ref = cppast.DeclRefExpr(sorter) - function = cppast.DeclRefExpr(name) - - return cppast.MemberCallExpr(sorter_ref, function, args) - - if name == "shmem_size": - if len(args) != 1: - self.error(node.func, "shmem_size() accepts only a single argument") - - func: ast.Attribute = node.func - cpp_view_type: str = self.get_scratch_view_type(func.value) - - if cpp_view_type is None: - self.error(func, "Wrong call to shmem_size()") - - view_type = cppast.ClassType(cpp_view_type) - call_expr = cppast.MemberCallExpr(view_type, name, args) - call_expr.is_static = True - - return call_expr - - return super().visit_Call(node) - - def visit_Constant(self, node: ast.Constant) -> cppast.DeclRefExpr: - if isinstance(node.value, str) and node.value == "auto": - return cppast.DeclRefExpr("Kokkos::AUTO") - - return super().visit_Constant(node) - - def add_space_to_policy( - self, policy: Union[cppast.ConstructExpr, cppast.MemberCallExpr] - ) -> Union[cppast.ConstructExpr, cppast.MemberCallExpr]: - """ - Add the execution space to the execution policy - - :param policy: the execution policy (could also be an integer) - :returns: the execution policy - """ - - # Replace the number of threads with a RangePolicy - if type(policy) not in (cppast.ConstructExpr, cppast.MemberCallExpr): - begin = cppast.IntegerLiteral(0) - policy = cppast.ConstructExpr( - cppast.DeclRefExpr("Kokkos::RangePolicy"), [begin, policy] - ) - - space = cppast.DeclRefExpr(Keywords.DefaultExecSpace.value) - policy_constructor = self.get_policy_constructor(policy) - policy_constructor.add_template_param(space) - - return policy - - def add_workunit_to_policy( - self, policy: Union[cppast.ConstructExpr, cppast.MemberCallExpr], work_unit: str - ) -> Union[cppast.ConstructExpr, cppast.MemberCallExpr]: - """ - Add the workunit tag to the execution policy - - :param policy: the execution policy (could also be an integer) - :param work_unit: the tag of the workunit - :returns: the execution policy - """ - - policy_constructor = self.get_policy_constructor(policy) - policy_constructor.add_template_param( - cppast.DeclRefExpr(f"{self.functor}::{work_unit}_tag") - ) - - return policy - - def get_policy_constructor( - self, policy: Union[cppast.ConstructExpr, cppast.MemberCallExpr] - ) -> cppast.ConstructExpr: - """ - Get the call to the policy constructor from the policy object - - :param: the policy object - :returns: the call to the constructor - """ - - if isinstance(policy, cppast.MemberCallExpr): - return policy.base - else: - return policy - - def generate_subview(self, node: ast.Assign, view_name: str) -> cppast.DeclStmt: - """ - Generate a subview in pk.main. This involves adding the - "pk_d_" prefix to the parent view. - """ - - return super().generate_subview(node, f"pk_d_{view_name}") diff --git a/pykokkos/core/visitors/parameter_visitor.py b/pykokkos/core/visitors/parameter_visitor.py index 28a6c005..9f0a028b 100644 --- a/pykokkos/core/visitors/parameter_visitor.py +++ b/pykokkos/core/visitors/parameter_visitor.py @@ -17,7 +17,7 @@ def __init__( """ ParameterVisitor constructor - :param src: the python source code of the workload + :param src: the python source code of the workunit :param param_begin: where workunit argument begins (excluding tid/acc) :param pk_import: the identifier used to access the PyKokkos package :param debug: if true, prints the python AST when an error is encountered diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index a14c2ffa..6e99e309 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -38,12 +38,10 @@ complex128, ) from .decorators import ( - callback, classtype, Decorator, function, functor, - main, workunit, ) from .execution_policy import ( @@ -67,7 +65,6 @@ from .mathematical_special_functions import cyl_bessel_j0, cyl_bessel_j1 from .memory_space import MemorySpace, get_default_memory_space from .parallel_dispatch import ( - execute, flush, parallel_for, parallel_reduce, diff --git a/pykokkos/interface/decorators.py b/pykokkos/interface/decorators.py index 2a753f92..330c2c74 100644 --- a/pykokkos/interface/decorators.py +++ b/pykokkos/interface/decorators.py @@ -7,8 +7,6 @@ class Decorator(Enum): WorkUnit = "workunit" KokkosClasstype = "classtype" KokkosFunction = "function" - KokkosMain = "main" - KokkosCallback = "callback" Space = "space" @staticmethod @@ -27,14 +25,6 @@ def is_kokkos_classtype(decorator: str) -> bool: def is_kokkos_function(decorator: str) -> bool: return decorator == Decorator.KokkosFunction.value - @staticmethod - def is_kokkos_main(decorator: str) -> bool: - return decorator == Decorator.KokkosMain.value - - @staticmethod - def is_kokkos_callback(decorator: str) -> bool: - return decorator == Decorator.KokkosCallback.value - @staticmethod def is_space(decorator: str) -> bool: return decorator == Decorator.Space.value @@ -74,11 +64,3 @@ def classtype(func): def function(func): return func - - -def main(func): - return func - - -def callback(func): - return func diff --git a/pykokkos/interface/parallel_dispatch.py b/pykokkos/interface/parallel_dispatch.py index 3b79d755..c746d858 100644 --- a/pykokkos/interface/parallel_dispatch.py +++ b/pykokkos/interface/parallel_dispatch.py @@ -370,12 +370,5 @@ def parallel_scan(*args, **kwargs) -> Union[float, int]: return reduce_body("scan", *args, **kwargs) -def execute(space: ExecutionSpace, workload: object) -> None: - if space is ExecutionSpace.Default: - runtime_singleton.runtime.run_workload(km.get_default_space(), workload) - else: - runtime_singleton.runtime.run_workload(space, workload) - - def flush(): runtime_singleton.runtime.flush_trace()