Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]:
check_entity: Callable[[ast.stmt], bool]

if style is PyKokkosStyles.workload:
return entities
check_entity = self.is_workload
elif style is PyKokkosStyles.functor:
check_entity = self.is_functor
elif style is PyKokkosStyles.workunit:
Expand Down Expand Up @@ -448,6 +448,35 @@ def is_classtype(node: ast.stmt, pk_import: str) -> bool:

return False

@staticmethod
def is_workload(node: ast.stmt, pk_import: str) -> bool:
"""
Checks if an ast node is a a PyKokkos workload

:param node: the node being checked
:param pk_import: the identifier used to access the PyKokkos package
:returns: true or false
"""

if not isinstance(node, ast.ClassDef):
return False

for decorator in node.decorator_list:
attribute = None

if isinstance(decorator, ast.Call):
attribute = decorator.func
elif isinstance(decorator, ast.Attribute):
attribute = decorator

if isinstance(attribute, ast.Attribute):
if attribute.value.id == pk_import and Decorator.is_workload(
attribute.attr
):
return True

return False

@staticmethod
def is_functor(node: ast.stmt, pk_import: str) -> bool:
"""
Expand Down
63 changes: 40 additions & 23 deletions pykokkos/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,31 +518,40 @@ def get_arguments(
args: Dict[str, Any] = {}

entity_members: Dict[str, type]
is_workload: bool = not isinstance(entity, (Callable, list))

if policy is None:
raise RuntimeError("Execution policy is None")
if is_workload:
args.update(self.get_result_arguments(members))
entity_members = entity.__dict__
args["pk_exec_space_instance"] = km.get_execution_space_instance(
space
).instance

args.update(self.get_policy_arguments(policy))
is_functor: bool = hasattr(entity, "__self__")
if is_functor:
functor: object = entity.__self__
entity_members = functor.__dict__
self._convert_functor_arrays(entity_members)
else:
is_fused: bool = isinstance(entity, list)
if is_fused:
parsers = [
self.compiler.get_parser(get_metadata(e).path) for e in entity
]
entity_trees = [
this_parser.get_entity(get_metadata(this_entity).name).AST
for this_entity, this_parser in zip(entity, parsers)
]

kwargs, _ = fuse_workunit_kwargs_and_params(
entity_trees, kwargs, f"parallel_{operation}"
)
entity_members = kwargs
if policy is None:
raise RuntimeError("Execution policy is None")

args.update(self.get_policy_arguments(policy))
is_functor: bool = hasattr(entity, "__self__")
if is_functor:
functor: object = entity.__self__
entity_members = functor.__dict__
self._convert_functor_arrays(entity_members)
else:
is_fused: bool = isinstance(entity, list)
if is_fused:
parsers = [
self.compiler.get_parser(get_metadata(e).path) for e in entity
]
entity_trees = [
this_parser.get_entity(get_metadata(this_entity).name).AST
for this_entity, this_parser in zip(entity, parsers)
]

kwargs, _ = fuse_workunit_kwargs_and_params(
entity_trees, kwargs, f"parallel_{operation}"
)
entity_members = kwargs

args.update(self.get_fields(entity_members))
args.update(self.get_views(entity_members))
Expand Down Expand Up @@ -857,9 +866,17 @@ def get_module_setup_id(
if isinstance(entity, list):
entity = tuple(entity) # Since entity needs to be hashed

is_workload: bool = not isinstance(entity, (Callable, tuple))
is_functor: bool = hasattr(entity, "__self__")

if is_functor:
if is_workload:
workload_type: Type = type(entity)
module_setup_id: Tuple[Callable, str, ExecutionSpace] = (
workload_type,
workload_type.__module__,
space,
)
elif is_functor:
functor_type: Type = type(entity.__self__)
module_setup_id: Tuple[Callable, str, str, ExecutionSpace] = (
type(functor_type),
Expand Down
33 changes: 19 additions & 14 deletions pykokkos/core/translators/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_view_memory_space(view_type: cppast.ClassType, location: str) -> str:
def get_kernel_params(
members: PyKokkosMembers,
is_hierarchical: bool,
is_workload: bool,
real: Optional[str],
) -> Dict[str, str]:
"""
Expand Down Expand Up @@ -87,18 +88,19 @@ def get_kernel_params(

params[Keywords.DefaultExecSpaceInstance.value] = Keywords.DefaultExecSpace.value

params[Keywords.KernelName.value] = "const std::string&"
if not is_workload:
params[Keywords.KernelName.value] = "const std::string&"

if is_hierarchical:
params[Keywords.LeagueSize.value] = "int"
params[Keywords.TeamSize.value] = "int"
params[Keywords.VectorLength.value] = "int"
params[Keywords.ScratchSizeLevel.value] = "int"
params[Keywords.ScratchSizeValue.value] = "int"
params[Keywords.ScratchSizeIsPerTeam.value] = "bool"
else:
params[Keywords.ThreadsBegin.value] = "int"
params[Keywords.ThreadsEnd.value] = "int"
if is_hierarchical:
params[Keywords.LeagueSize.value] = "int"
params[Keywords.TeamSize.value] = "int"
params[Keywords.VectorLength.value] = "int"
params[Keywords.ScratchSizeLevel.value] = "int"
params[Keywords.ScratchSizeValue.value] = "int"
params[Keywords.ScratchSizeIsPerTeam.value] = "bool"
else:
params[Keywords.ThreadsBegin.value] = "int"
params[Keywords.ThreadsEnd.value] = "int"

params[Keywords.RandPoolSeed.value] = "int"
params[Keywords.RandPoolNumStates.value] = "int"
Expand Down Expand Up @@ -392,7 +394,10 @@ def generate_wrapper(
:returns: the wrapper source
"""

params: Dict[str, str] = get_kernel_params(members, is_hierarchical(workunit), real)
is_workload: bool = True if operation == "workload" else False
params: Dict[str, str] = get_kernel_params(
members, is_hierarchical(workunit), is_workload, real
)
return_type: str = get_return_type(operation, workunit)

args: List[str] = []
Expand Down Expand Up @@ -439,7 +444,7 @@ def generate_kernel(
"""

hierarchical: bool = is_hierarchical(workunit)
params: Dict[str, str] = get_kernel_params(members, hierarchical, real)
params: Dict[str, str] = get_kernel_params(members, hierarchical, False, real)
return_type: str = get_return_type(operation, workunit)
signature: str = generate_kernel_signature(return_type, kernel, params)

Expand Down Expand Up @@ -638,7 +643,7 @@ def bind_main_single(
functor += f"<{Keywords.DefaultExecSpace.value}>"

main: List[str] = translate_mains(source, functor, members, pk_import)
params: Dict[str, str] = get_kernel_params(members, False, real)
params: Dict[str, str] = get_kernel_params(members, False, True, real)

# fall back to the old hard-coded default
# for now--this includes cases where an
Expand Down
1 change: 1 addition & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
function,
functor,
main,
workload,
workunit,
)
from .execution_policy import (
Expand Down
12 changes: 12 additions & 0 deletions pykokkos/interface/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


class Decorator(Enum):
Workload = "workload"
Functor = "functor"
WorkUnit = "workunit"
KokkosClasstype = "classtype"
Expand Down Expand Up @@ -43,6 +44,10 @@ def is_space(decorator: str) -> bool:
def is_functor(decorator: str) -> bool:
return decorator == Decorator.Functor.value

@staticmethod
def is_workload(decorator: str) -> bool:
return decorator == Decorator.Workload.value


def functor(func=None, **kwargs):
if func is None:
Expand All @@ -68,6 +73,13 @@ def workunit(func=None, *, scratch=None, **kwargs):
return func


def workload(func=None, **kwargs):
if func is None:
return partial(functor)

return func


def classtype(func):
return func

Expand Down
Loading