Skip to content
Open
2 changes: 1 addition & 1 deletion pykokkos/core/cpp_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def generate_cmake(
Copy CMakeLists.txt template and prepare CMake configuration variables

:param output_dir: the base directory
:param space: the execution space of the workload
:param space: the execution space of the workunit
:param enable_uvm: whether to enable CudaUVMSpace
:param compiler: what compiler to use
:returns: tuple of (cmake_args, module_name)
Expand Down
25 changes: 9 additions & 16 deletions pykokkos/core/module_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,22 @@ def get_metadata(entity: Union[Callable[..., None], object]) -> EntityMetadata:
"""
Gets the name and filepath of an entity

:param entity: the workload or workunit function object
:param entity: the workunit function object
:returns: an EntityMetadata object
"""

name: str
filepath: str

if isinstance(entity, Callable):
# Workunit/functor
is_functor: bool = hasattr(entity, "__self__")
if is_functor:
entity_type: type = get_functor(entity)
name = entity_type.__name__
filepath = inspect.getfile(entity_type)
else:
name = entity.__name__
filepath = inspect.getfile(entity)

else:
# Workload
entity_type: type = type(entity)
# Workunit/functor
is_functor: bool = hasattr(entity, "__self__")
if is_functor:
entity_type: type = get_functor(entity)
name = entity_type.__name__
filepath = inspect.getfile(entity_type)
else:
name = entity.__name__
filepath = inspect.getfile(entity)

return EntityMetadata(entity, name, filepath)

Expand All @@ -88,7 +81,7 @@ def __init__(
"""
ModuleSetup constructor

:param entity: the functor/workunit/workload or list of workunits for fusion
:param entity: the functor/workunit or list of workunits for fusion
:param types_signature: hash/string to identify workunit signature against types
:param restricted_views: a set of view names that do not alias any other views
"""
Expand Down
12 changes: 2 additions & 10 deletions pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class PyKokkosStyles(Enum):
"""

functor = auto()
workload = auto()
workunit = auto()
classtype = auto()
fused = auto()
Expand All @@ -37,7 +36,7 @@ class PyKokkosEntity:

class Parser:
"""
Parse a PyKokkos workload and its dependencies
Parse a PyKokkos workunit and its dependencies
"""

def __init__(self, path: Optional[str], pk_import: Optional[str] = None):
Expand All @@ -56,12 +55,10 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None):
self.tree = ast.parse("".join(self.lines))
self.path: Optional[str] = path
self.pk_import: str = self.get_import()
self.workloads: Dict[str, PyKokkosEntity] = {}
self.classtypes: Dict[str, PyKokkosEntity] = {}
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}

self.workloads = self.get_entities(PyKokkosStyles.workload)
self.classtypes = self.get_entities(PyKokkosStyles.classtype)
self.functors = self.get_entities(PyKokkosStyles.functor)
self.workunits = self.get_entities(PyKokkosStyles.workunit)
Expand All @@ -74,7 +71,6 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None):
self.tree = ast.Module(body=[])
self.path = None
self.pk_import = pk_import
self.workloads: Dict[str, PyKokkosEntity] = {}
self.classtypes: Dict[str, PyKokkosEntity] = {}
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}
Expand Down Expand Up @@ -112,8 +108,6 @@ def get_entity(self, name: str) -> PyKokkosEntity:
:returns: the PyKokkosEntity representation of the entity
"""

if name in self.workloads:
return self.workloads[name]
if name in self.functors:
return self.functors[name]
if name in self.workunits:
Expand All @@ -132,9 +126,7 @@ def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]:
entities: Dict[str, PyKokkosEntity] = {}
check_entity: Callable[[ast.stmt], bool]

if style is PyKokkosStyles.workload:
return entities
elif style is PyKokkosStyles.functor:
if style is PyKokkosStyles.functor:
check_entity = self.is_functor
elif style is PyKokkosStyles.workunit:
check_entity = self.is_workunit
Expand Down
40 changes: 0 additions & 40 deletions pykokkos/core/run_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,6 @@
import pykokkos.kokkos_manager as km


def run_workload_debug(workload: object) -> None:
"""
Run a workload in Python

:param workload: the workload object
"""

workload_source: str = inspect.getsource(type(workload))
tree: ast.Module = ast.parse(workload_source)
classdef: ast.ClassDef = tree.body[0]

def get_annotated_functions(decorator: Decorator) -> Dict[str, ast.FunctionDef]:
visitor = ast.NodeVisitor()
functions: Dict[str, ast.FunctionDef] = {}

def visit_FunctionDef(node):
if node.decorator_list:
try:
node_decorator: str = node.decorator_list[0].id
except AttributeError:
node_decorator: str = node.decorator_list[0].attr

if decorator.value == node_decorator:
functions[node.name] = node

visitor.visit_FunctionDef = visit_FunctionDef
for method in classdef.body:
visitor.visit(method)

return functions

for name in get_annotated_functions(Decorator.KokkosMain):
kokkos_main = getattr(workload, name)
kokkos_main()

for name in get_annotated_functions(Decorator.KokkosCallback):
kokkos_callback = getattr(workload, name)
kokkos_callback()


def call_workunit(
operation: str,
workunit: Callable[..., None],
Expand Down
110 changes: 15 additions & 95 deletions pykokkos/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from .compiler import Compiler
from .module_setup import EntityMetadata, get_metadata, ModuleSetup
from .run_debug import run_workload_debug, run_workunit_debug
from .run_debug import run_workunit_debug


def _calculate_aligned_scratch_size(
Expand Down Expand Up @@ -112,38 +112,18 @@ def apply_scratch_spec(workunit: Callable, policy: TeamPolicy, **kwargs) -> None

class Runtime:
"""
Executes (and optionally compiles) PyKokkos workloads
Executes (and optionally compiles) PyKokkos workunits
"""

def __init__(self):
self.compiler: Compiler = Compiler()
self.tracer: Tracer = Tracer()

# cache module_setup objects using a workload/workunit and space tuple
# cache module_setup objects using a workunit and space tuple
self.module_setups: Dict[Tuple, ModuleSetup] = {}

self.fusion_strategy: Optional[str] = os.getenv("PK_FUSION")

def run_workload(self, space: ExecutionSpace, workload: object) -> None:
"""
Run the workload

:param space: the execution space of the workload
:param workload: the workload object
"""

if self.is_debug(space):
run_workload_debug(workload)
return

module_setup: ModuleSetup = self.get_module_setup(workload, space)
members: PyKokkosMembers = self.compiler.compile_object(
module_setup, space, km.is_uvm_enabled(), None, None, None, set()
)

self.execute(workload, module_setup, members, space)
self.run_callbacks(workload, members)

def precompile_workunit(
self,
workunit: Callable[..., None],
Expand Down Expand Up @@ -410,15 +390,15 @@ def execute(
"""
Imports the module containing the bindings and executes the necessary function

:param entity: the workload or workunit object
:param entity: the workunit object
:param module_path: the path to the compiled module
:param members: a collection of PyKokkos related members
:param space: the execution space
:param policy: the execution policy for workunits
:param name: the name of the kernel
:param operation: the name of the operation "for", "reduce", or "scan"
:param kwargs: the keyword arguments passed to the workunit
:returns: the result of the operation (None for "for" and workloads)
:returns: the result of the operation (None for "for")
"""

module_path: str
Expand All @@ -440,10 +420,6 @@ def execute(

result = self.call_wrapper(entity, members, args, module)

is_workunit_or_functor: bool = isinstance(entity, (Callable, list))
if not is_workunit_or_functor:
self.retrieve_results(entity, members, args)

return result

def import_module(self, module_name: str, module_path: str):
Expand Down Expand Up @@ -480,7 +456,7 @@ def get_arguments(
"""
Get the arguments for a wrapper function, including fields, views, etc

:param entity: the workload or workunit object
:param entity: the workunit object
:param members: a collection of PyKokkos related members
:param space: the execution space
:param policy: the execution policy of the operation
Expand Down Expand Up @@ -533,7 +509,7 @@ def call_wrapper(
"""
Call the wrapper in the imported module

:param entity: the workload or workunit object
:param entity: the workunit object
:param members: a collection of PyKokkos related members
:param args: the arguments to be passed to the wrapper
:param module: the imported module
Expand Down Expand Up @@ -592,31 +568,9 @@ def get_precision(self, members: PyKokkosMembers, args: Dict[str, Any]) -> str:

return precision

def get_result_arguments(self, members: PyKokkosMembers) -> Dict[str, Any]:
"""
Get the views that are passed as arguments to hold the results for workloads

:param members: a collection of PyKokkos related members
:returns: a dictionary of argument name to value
"""

args: Dict[str, Any] = {}

for result in members.reduction_result_queue:
name: str = f"reduction_result_{result}"
result_view = View([1], DataType.double, MemorySpace.HostSpace)
args[name] = result_view.array

for result in members.timer_result_queue:
name: str = f"timer_result_{result}"
result_view = View([1], DataType.double, MemorySpace.HostSpace)
args[name] = result_view.array

return args

def get_policy_arguments(self, policy: ExecutionPolicy) -> Dict[str, Any]:
"""
Get the arguments that are used for to hold the results for workloads
Get the arguments that are used for to hold the results for workunits

:param policy: the execution policy of the operation
:returns: a dictionary of argument name to value
Expand Down Expand Up @@ -668,9 +622,9 @@ def get_policy_arguments(self, policy: ExecutionPolicy) -> Dict[str, Any]:

def get_fields(self, members: Dict[str, type]) -> Dict[str, Any]:
"""
Gets all the primitive type fields from the workload object
Gets all the primitive type fields from the workunit object

:param workload: the dictionary containing all members
:param members: the dictionary containing all members
:returns: a dict mapping from field name to value
"""

Expand Down Expand Up @@ -725,9 +679,9 @@ def _convert_functor_arrays(self, members: Dict[str, Any]) -> None:

def get_views(self, members: Dict[str, type]) -> Dict[str, Any]:
"""
Gets all the views from the workload object
Gets all the views from the workunit object

:param workload: the dictionary containing all members
:param workunit: the dictionary containing all members
:returns: a dict mapping from view name to object
"""

Expand All @@ -738,40 +692,6 @@ def get_views(self, members: Dict[str, type]) -> Dict[str, Any]:

return views

def retrieve_results(
self, workload: object, members: PyKokkosMembers, args: Dict[str, Any]
) -> None:
"""
Get the results for workloads

:param workload: the workload object
:param members: a collection of PyKokkos related members
:param args: the arguments passed to the wrapper, including views that hold results
"""

for result in members.reduction_result_queue:
name: str = f"reduction_result_{result}"
view: View = args[name]
setattr(workload, result, view[0])

for result in members.timer_result_queue:
name: str = f"timer_result_{result}"
view: View = args[name]
setattr(workload, result, view[0])

def run_callbacks(self, workload: object, members: PyKokkosMembers) -> None:
"""
Run all methods in the workload that are annotated with @pk.callback

:param workload: the workload object
:param members: a collection of PyKokkos related members in workload
"""

callbacks = members.pk_callbacks
for name in callbacks:
callback = getattr(workload, name.declname)
callback()

def get_module_setup(
self,
entity: Union[object, Callable[..., None]],
Expand All @@ -782,7 +702,7 @@ def get_module_setup(
"""
Get the compiled module setup information unique to an entity + space

:param entity: the workload or workunit object
:param entity: the workunit object
:param space: the execution space
:param types_signature: Hash/identifer string for workunit module against data types
:param restrict_signature: Hash/identifer string for views that do not alias any other views
Expand Down Expand Up @@ -815,10 +735,10 @@ def get_module_setup_id(
"""
Get a unique module setup id for an entity + space
combination. For workunits, the idenitifier is just the
workunit and execution space. For workloads and functors, we
workunit and execution space. For functors, we
need the type of the class as well as the file containing it.

:param entity: the workload or workunit object
:param entity: the workunit object
:param space: the execution space
:param types_signature: optional identifier/hash string for
types of parameters against workunit module
Expand Down
Loading