Skip to content
5 changes: 3 additions & 2 deletions pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,9 @@ def compile_entity(
main,
module_setup.metadata,
space,
module_setup.types_signature,
module_setup.restrict_signature,
module_setup.ast_signature,
types_signature=module_setup.types_signature,
restrict_signature=module_setup.restrict_signature,
)
c_start: float = time.perf_counter()
cpp_setup.compile(
Expand Down
28 changes: 27 additions & 1 deletion pykokkos/core/cpp_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def compile(

def initialize_directory(self, name: Path) -> None:
"""
Creates an output directory, overwriting an existing directory with the same name
Creates an output directory, overwriting an existing directory with the same name.
Checks if an older AST directory exists, and if so copies over the relevant information.
Comment thread
JBludau marked this conversation as resolved.
Outdated

:param name: the name of the directory
"""
Expand All @@ -127,6 +128,31 @@ def initialize_directory(self, name: Path) -> None:
except OSError:
pass

# make the parent directory if necessary
try:
os.makedirs(name.parent, exist_ok=True)
except FileExistsError:
pass

Comment thread
JBludau marked this conversation as resolved.
# check if an older AST library exists
Comment thread
kennykos marked this conversation as resolved.
Outdated
old_dir = [
f for f in name.parent.parent.iterdir() if f.is_dir() and f != name.parent
]
if len(old_dir) == 1:
Comment thread
JBludau marked this conversation as resolved.
Outdated
# if an older AST directory exists, change the name to the new AST directory
old_dir = old_dir[0]
old_dir.rename(name.parent)
# CMakeCache.txt has hardcoded absolute paths to the old AST directory.
# Delete only the cache files so CMake reconfigures with correct paths
# while keeping object files in build/ for incremental recompilation.
stale_cache = name / "build" / "CMakeCache.txt"
if stale_cache.is_file():
stale_cache.unlink()
stale_cache_check = name / "build" / "CMakeFiles" / "cmake.check_cache"
if stale_cache_check.is_file():
stale_cache_check.unlink()

# make the name directory if necessary
try:
os.makedirs(name, exist_ok=True)
except FileExistsError:
Expand Down
21 changes: 18 additions & 3 deletions pykokkos/core/module_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,16 @@ def __init__(
self,
entity: Union[Callable[..., None], type, List[Callable[..., None]]],
space: ExecutionSpace,
ast_signature: str,
*,
types_signature: Optional[str] = None,
restricted_views: Optional[Set[str]] = None,
):
"""
ModuleSetup constructor

:param entity: the functor/workunit/workload or list of workunits for fusion
:param ast_signature: hash/string to identify workunit signature against AST
:param types_signature: hash/string to identify workunit signature against types
:param restricted_views: a set of view names that do not alias any other views
"""
Expand All @@ -103,6 +106,7 @@ def __init__(
self.metadata = [get_metadata(entity)]

self.space: ExecutionSpace = space
self.ast_signature = ast_signature
self.types_signature = types_signature
self.restrict_signature: Optional[str] = None
if restricted_views is not None:
Expand All @@ -126,7 +130,12 @@ def __init__(
self.main: Path = self.get_main_path()

self.output_dir: Optional[Path] = self.get_output_dir(
self.main, self.metadata, space, types_signature, self.restrict_signature
self.main,
self.metadata,
space,
ast_signature,
types_signature=types_signature,
restrict_signature=self.restrict_signature,
)
self.gpu_module_files: List[str] = []
if km.is_multi_gpu_enabled():
Expand Down Expand Up @@ -154,6 +163,8 @@ def get_output_dir(
main: Path,
metadata: List[EntityMetadata],
space: ExecutionSpace,
ast_signature,
*,
types_signature: Optional[str] = None,
restrict_signature: Optional[str] = None,
) -> Optional[Path]:
Expand All @@ -163,6 +174,7 @@ def get_output_dir(
:param main: the path to the main file in the current PyKokkos application
:param metadata: the metadata of the entity or fused entities being compiled
:param space: the execution space to compile for
:param ast_signature: hash/string to identify workunit signature against AST
:param types_signature: optional identifier/hash string for types of parameters
:param restrict_signature: optional identifier/hash string from the views that do not alias any other views
:returns: the path to the output directory for a specific execution space
Expand All @@ -180,6 +192,8 @@ def get_output_dir(
out_dir = out_dir / f"types_{types_signature}"
if restrict_signature is not None:
out_dir = out_dir / f"restrict_{restrict_signature}"
if ast_signature is not None:
out_dir = out_dir / f"AST_{ast_signature}"
Comment on lines +208 to +209
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm so we have a hierarchy which has the ast signature as the highest priority ... but is there cases where the ast signature would not trigger but the others would?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, consider the scenario:

  • workunit v0 is compiled with pk.float arrays;
  • an edit is made and we now have ATS has for workunit v1;
  • workunit v1 is compiled with pk.double arrays;
  • workunit v1 is called again with pk.float arrays. In this case, the type hash would trigger, but the ats hash would not as it is out of date.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but doesn't the AST change if we change types from float to double? Even unmanaged Views change their type, I am surprised that this would not trigger an AST hash fail

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the AST will change, but we want to keep both float and double types cached in this scenario to avoid excessive recompilation (e.g., an adaptive time-stepper may use a float and double execution to get error estimates)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could you add a comment explaining this intent? Future selves will thank us for it

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in a3a23c8


out_dir = out_dir / space.value

Expand Down Expand Up @@ -258,7 +272,8 @@ def is_compiled(self) -> bool:
self.main,
self.metadata,
self.space,
self.types_signature,
self.restrict_signature,
self.ast_signature,
types_signature=self.types_signature,
restrict_signature=self.restrict_signature,
)
)
5 changes: 5 additions & 0 deletions pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import hashlib
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -79,6 +80,10 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None):
self.functors: Dict[str, PyKokkosEntity] = {}
self.workunits: Dict[str, PyKokkosEntity] = {}

# get parser signature
signature = ast.dump(self.tree)
self.signature = hashlib.md5(signature.encode()).hexdigest()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how expensive is this? I assume our ast's for kernels are small that this should not be an issue?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do a performance test. This block is only done when the workunit is first seen by pykokkos, so it is a startup cost that will not add overhead during repeated execution.

In most cases the ast's should be relatively small; e.g., in my Ewald code, the slowest time for this block of code is 0.013 seconds.


def get_import(self) -> str:
"""
Get the pykokkos import identifier
Expand Down
45 changes: 41 additions & 4 deletions pykokkos/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
import sysconfig
import hashlib

import numpy as np

Expand Down Expand Up @@ -174,6 +175,7 @@ def run_workload(self, space: ExecutionSpace, workload: object) -> None:
def precompile_workunit(
self,
workunit: Callable[..., None],
ast_signature: str,
space: ExecutionSpace,
updated_decorator: Optional[UpdatedDecorator],
updated_types: Optional[UpdatedTypes],
Expand All @@ -186,15 +188,21 @@ def precompile_workunit(
precompile the workunit

:param workunit: the workunit function object
:param ast_signature: Hash/identifer string for workunit module against AST
:param space: the ExecutionSpace for which the bindings are generated
:param updated_decorator: Object for decorator specifier
:param updated_types: Object with type inference information
:param types_signature: Hash/identifer string for workunit module against data types
:param restrict_views: a set of view names that do not alias any other views
:returns: the members the functor is containing
"""

module_setup: ModuleSetup = self.get_module_setup(
workunit, space, types_signature, restrict_signature
workunit,
space,
ast_signature,
types_signature=types_signature,
restrict_signature=restrict_signature,
)
members: PyKokkosMembers = self.compiler.compile_object(
module_setup,
Expand Down Expand Up @@ -342,9 +350,17 @@ def execute_workunit(
}
restrict_views, restrict_signature = get_restrict_views(view_dict)

# Set ast signature
if isinstance(parser, list):
ast_signature = "".join([p.signature for p in parser])
ast_signature = hashlib.md5(ast_signature.encode()).hexdigest()
else:
ast_signature = parser.signature

execution_space: ExecutionSpace = policy.space.space
members: PyKokkosMembers = self.precompile_workunit(
workunit,
ast_signature,
execution_space,
updated_decorator,
updated_types,
Expand All @@ -355,7 +371,11 @@ def execute_workunit(
)

module_setup: ModuleSetup = self.get_module_setup(
workunit, execution_space, types_signature, restrict_signature
workunit,
execution_space,
ast_signature,
types_signature=types_signature,
restrict_signature=restrict_signature,
)
return self.execute(
workunit,
Expand Down Expand Up @@ -812,6 +832,8 @@ def get_module_setup(
self,
entity: Union[object, Callable[..., None]],
space: ExecutionSpace,
ast_signature: str,
*,
types_signature: Optional[str] = None,
restrict_signature: Optional[str] = None,
) -> ModuleSetup:
Expand All @@ -820,6 +842,7 @@ def get_module_setup(

:param entity: the workload or workunit object
:param space: the execution space
:param ast_signature: Hash/identifer string for workunit module against AST
:param types_signature: Hash/identifer string for workunit module against data types
:param restrict_signature: Hash/identifer string for views that do not alias any other views
:returns: the ModuleSetup object
Expand All @@ -830,13 +853,23 @@ def get_module_setup(
)

module_setup_id = self.get_module_setup_id(
entity, space, types_signature, restrict_signature
entity,
space,
ast_signature,
types_signature=types_signature,
restrict_signature=restrict_signature,
)

if module_setup_id in self.module_setups:
return self.module_setups[module_setup_id]

module_setup = ModuleSetup(entity, space, types_signature, restrict_signature)
module_setup = ModuleSetup(
entity,
space,
ast_signature,
types_signature=types_signature,
restricted_views=restrict_signature,
)
self.module_setups[module_setup_id] = module_setup

return module_setup
Expand All @@ -845,6 +878,8 @@ def get_module_setup_id(
self,
entity: Union[object, Callable[..., None]],
space: ExecutionSpace,
ast_signature: str,
*,
types_signature: Optional[str] = None,
restrict_signature: Optional[str] = None,
) -> Tuple:
Expand All @@ -856,6 +891,7 @@ def get_module_setup_id(

:param entity: the workload or workunit object
:param space: the execution space
:param ast_signature: Hash/identifer string for workunit module against AST
:param types_signature: optional identifier/hash string for
types of parameters against workunit module
:param restrict_signature: Hash/identifer string for views
Expand Down Expand Up @@ -890,6 +926,7 @@ def get_module_setup_id(
module_setup_id_list.append(types_signature)
if restrict_signature is not None:
module_setup_id_list.append(restrict_signature)
module_setup_id_list.append(ast_signature)

module_setup_id = tuple(module_setup_id_list)

Expand Down
98 changes: 98 additions & 0 deletions tests/test_recompile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
pytest: PyKokkos JIT kernel recompilation test.

Tests that modifying a @pk.workunit source file triggers recompilation
and that both versions of the kernel produce correct results.

Workflow:
1. Write the "buggy" kernel (arr[i] += 2) to a temp module file.
2. Import it and run it — verify each element increased by 2.
3. Overwrite the module file with the "fixed" kernel (arr[i] += 1).
4. Invalidate Python's module cache so the new source is picked up.
5. Re-import and run — verify each element increased by 1.
"""

import sys
import textwrap
import importlib
from pathlib import Path

import numpy as np
import pykokkos as pk

BUGGY_SOURCE = textwrap.dedent(
"""\
import pykokkos as pk

@pk.workunit
def add1(i: int, arr: pk.View1D[int]):
arr[i] += 2
"""
)

CORRECT_SOURCE = textwrap.dedent(
"""\
import pykokkos as pk

@pk.workunit
def add1(i: int, arr: pk.View1D[int]):
arr[i] += 1
"""
)

MODULE_NAME = "_test_jit_kernel_add1"


def _load_fresh(path: Path):
"""Force a fresh import of the module at *path*, bypassing sys.modules."""
# Remove any previously cached version.
sys.modules.pop(MODULE_NAME, None)

spec = importlib.util.spec_from_file_location(MODULE_NAME, path)
module = importlib.util.module_from_spec(spec)
sys.modules[MODULE_NAME] = module
spec.loader.exec_module(module)
return module


# ---------- Test JIT Recompilation ----------


def test_recompilation(tmp_path):
kernel_file = tmp_path / "_test_jit_recompile.py"

# ---- Stage 1: buggy kernel
# write buggy source
kernel_file.write_text(BUGGY_SOURCE, encoding="utf-8")

# load buggy source
mod_buggy = _load_fresh(kernel_file)

# run the buggy kernel
n = 5
arr_buggy = np.zeros(n, dtype=np.int32)
pk.parallel_for(n, mod_buggy.add1, arr=arr_buggy)

# assert buggy array is correct
Comment thread
kennykos marked this conversation as resolved.
Outdated
try:
np.testing.assert_equal(arr_buggy, np.zeros(n, dtype=np.int32) + 2)
except AssertionError as e:
raise AssertionError("buggy kernel is incorrect") from e

# ---- Stage 2: correct kernel

# reload pykokkos to clear cache
importlib.reload(sys.modules["pykokkos"])

kernel_file.write_text(CORRECT_SOURCE, encoding="utf-8")
mod_correct = _load_fresh(kernel_file)

arr_correct = np.zeros(n, dtype=np.int32)
pk.parallel_for(n, mod_correct.add1, arr=arr_correct)
expected = np.zeros(n, dtype=np.int32) + 1
try:
np.testing.assert_equal(arr_correct, expected)
except AssertionError as e:
raise AssertionError(
f"kernel is incorrect\nactual: {arr_correct}\ndesired: {expected}"
) from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test does not check that values of signature on disk changed or something was renamed

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it abstract in case we want to change in implementation details down the road (i.e., if we want to move away from renaming or change the directory structure).

Loading