-
Notifications
You must be signed in to change notification settings - Fork 23
Recompile stale BASE_DIR
#415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
e223c92
c979fb9
5c2e4e6
9f79823
d87b2c6
4acad26
7548ed5
5c5e734
ee7e739
19b56c0
a810722
a3a23c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,13 +82,16 @@ def __init__( | |
| self, | ||
| entity: Union[Callable[..., None], type, List[Callable[..., None]]], | ||
| space: ExecutionSpace, | ||
| ast_signature: str, | ||
| *, | ||
| types_signature: Optional[str] = None, | ||
| restricted_views: Optional[Set[str]] = None, | ||
| ): | ||
| """ | ||
| ModuleSetup constructor | ||
|
|
||
| :param entity: the functor/workunit/workload or list of workunits for fusion | ||
| :param ast_signature: hash/string to identify workunit signature against AST | ||
| :param types_signature: hash/string to identify workunit signature against types | ||
| :param restricted_views: a set of view names that do not alias any other views | ||
| """ | ||
|
|
@@ -103,6 +106,7 @@ def __init__( | |
| self.metadata = [get_metadata(entity)] | ||
|
|
||
| self.space: ExecutionSpace = space | ||
| self.ast_signature = ast_signature | ||
| self.types_signature = types_signature | ||
| self.restrict_signature: Optional[str] = None | ||
| if restricted_views is not None: | ||
|
|
@@ -126,7 +130,12 @@ def __init__( | |
| self.main: Path = self.get_main_path() | ||
|
|
||
| self.output_dir: Optional[Path] = self.get_output_dir( | ||
| self.main, self.metadata, space, types_signature, self.restrict_signature | ||
| self.main, | ||
| self.metadata, | ||
| space, | ||
| ast_signature, | ||
| types_signature=types_signature, | ||
| restrict_signature=self.restrict_signature, | ||
| ) | ||
| self.gpu_module_files: List[str] = [] | ||
| if km.is_multi_gpu_enabled(): | ||
|
|
@@ -154,6 +163,8 @@ def get_output_dir( | |
| main: Path, | ||
| metadata: List[EntityMetadata], | ||
| space: ExecutionSpace, | ||
| ast_signature, | ||
| *, | ||
| types_signature: Optional[str] = None, | ||
| restrict_signature: Optional[str] = None, | ||
| ) -> Optional[Path]: | ||
|
|
@@ -163,6 +174,7 @@ def get_output_dir( | |
| :param main: the path to the main file in the current PyKokkos application | ||
| :param metadata: the metadata of the entity or fused entities being compiled | ||
| :param space: the execution space to compile for | ||
| :param ast_signature: hash/string to identify workunit signature against AST | ||
| :param types_signature: optional identifier/hash string for types of parameters | ||
| :param restrict_signature: optional identifier/hash string from the views that do not alias any other views | ||
| :returns: the path to the output directory for a specific execution space | ||
|
|
@@ -180,6 +192,8 @@ def get_output_dir( | |
| out_dir = out_dir / f"types_{types_signature}" | ||
| if restrict_signature is not None: | ||
| out_dir = out_dir / f"restrict_{restrict_signature}" | ||
| if ast_signature is not None: | ||
| out_dir = out_dir / f"AST_{ast_signature}" | ||
|
Comment on lines
+208
to
+209
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm so we have a hierarchy which has the ast signature as the highest priority ... but is there cases where the ast signature would not trigger but the others would?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, consider the scenario:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but doesn't the AST change if we change types from float to double? Even unmanaged Views change their type, I am surprised that this would not trigger an AST hash fail
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the AST will change, but we want to keep both float and double types cached in this scenario to avoid excessive recompilation (e.g., an adaptive time-stepper may use a float and double execution to get error estimates)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Could you add a comment explaining this intent? Future selves will thank us for it
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in a3a23c8 |
||
|
|
||
| out_dir = out_dir / space.value | ||
|
|
||
|
|
@@ -258,7 +272,8 @@ def is_compiled(self) -> bool: | |
| self.main, | ||
| self.metadata, | ||
| self.space, | ||
| self.types_signature, | ||
| self.restrict_signature, | ||
| self.ast_signature, | ||
| types_signature=self.types_signature, | ||
| restrict_signature=self.restrict_signature, | ||
| ) | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import ast | ||
| import hashlib | ||
| from dataclasses import dataclass | ||
| from enum import Enum, auto | ||
| from typing import Callable, Dict, List, Optional, Tuple, Union | ||
|
|
@@ -79,6 +80,10 @@ def __init__(self, path: Optional[str], pk_import: Optional[str] = None): | |
| self.functors: Dict[str, PyKokkosEntity] = {} | ||
| self.workunits: Dict[str, PyKokkosEntity] = {} | ||
|
|
||
| # get parser signature | ||
| signature = ast.dump(self.tree) | ||
| self.signature = hashlib.md5(signature.encode()).hexdigest() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how expensive is this? I assume our ast's for kernels are small that this should not be an issue?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't do a performance test. This block is only done when the workunit is first seen by pykokkos, so it is a startup cost that will not add overhead during repeated execution. In most cases the ast's should be relatively small; e.g., in my Ewald code, the slowest time for this block of code is |
||
|
|
||
| def get_import(self) -> str: | ||
| """ | ||
| Get the pykokkos import identifier | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| """ | ||
| pytest: PyKokkos JIT kernel recompilation test. | ||
|
|
||
| Tests that modifying a @pk.workunit source file triggers recompilation | ||
| and that both versions of the kernel produce correct results. | ||
|
|
||
| Workflow: | ||
| 1. Write the "buggy" kernel (arr[i] += 2) to a temp module file. | ||
| 2. Import it and run it — verify each element increased by 2. | ||
| 3. Overwrite the module file with the "fixed" kernel (arr[i] += 1). | ||
| 4. Invalidate Python's module cache so the new source is picked up. | ||
| 5. Re-import and run — verify each element increased by 1. | ||
| """ | ||
|
|
||
| import sys | ||
| import textwrap | ||
| import importlib | ||
| from pathlib import Path | ||
|
|
||
| import numpy as np | ||
| import pykokkos as pk | ||
|
|
||
| BUGGY_SOURCE = textwrap.dedent( | ||
| """\ | ||
| import pykokkos as pk | ||
|
|
||
| @pk.workunit | ||
| def add1(i: int, arr: pk.View1D[int]): | ||
| arr[i] += 2 | ||
| """ | ||
| ) | ||
|
|
||
| CORRECT_SOURCE = textwrap.dedent( | ||
| """\ | ||
| import pykokkos as pk | ||
|
|
||
| @pk.workunit | ||
| def add1(i: int, arr: pk.View1D[int]): | ||
| arr[i] += 1 | ||
| """ | ||
| ) | ||
|
|
||
| MODULE_NAME = "_test_jit_kernel_add1" | ||
|
|
||
|
|
||
| def _load_fresh(path: Path): | ||
| """Force a fresh import of the module at *path*, bypassing sys.modules.""" | ||
| # Remove any previously cached version. | ||
| sys.modules.pop(MODULE_NAME, None) | ||
|
|
||
| spec = importlib.util.spec_from_file_location(MODULE_NAME, path) | ||
| module = importlib.util.module_from_spec(spec) | ||
| sys.modules[MODULE_NAME] = module | ||
| spec.loader.exec_module(module) | ||
| return module | ||
|
|
||
|
|
||
| # ---------- Test JIT Recompilation ---------- | ||
|
|
||
|
|
||
| def test_recompilation(tmp_path): | ||
| kernel_file = tmp_path / "_test_jit_recompile.py" | ||
|
|
||
| # ---- Stage 1: buggy kernel | ||
| # write buggy source | ||
| kernel_file.write_text(BUGGY_SOURCE, encoding="utf-8") | ||
|
|
||
| # load buggy source | ||
| mod_buggy = _load_fresh(kernel_file) | ||
|
|
||
| # run the buggy kernel | ||
| n = 5 | ||
| arr_buggy = np.zeros(n, dtype=np.int32) | ||
| pk.parallel_for(n, mod_buggy.add1, arr=arr_buggy) | ||
|
|
||
| # assert buggy array is correct | ||
|
kennykos marked this conversation as resolved.
Outdated
|
||
| try: | ||
| np.testing.assert_equal(arr_buggy, np.zeros(n, dtype=np.int32) + 2) | ||
| except AssertionError as e: | ||
| raise AssertionError("buggy kernel is incorrect") from e | ||
|
|
||
| # ---- Stage 2: correct kernel | ||
|
|
||
| # reload pykokkos to clear cache | ||
| importlib.reload(sys.modules["pykokkos"]) | ||
|
|
||
| kernel_file.write_text(CORRECT_SOURCE, encoding="utf-8") | ||
| mod_correct = _load_fresh(kernel_file) | ||
|
|
||
| arr_correct = np.zeros(n, dtype=np.int32) | ||
| pk.parallel_for(n, mod_correct.add1, arr=arr_correct) | ||
| expected = np.zeros(n, dtype=np.int32) + 1 | ||
| try: | ||
| np.testing.assert_equal(arr_correct, expected) | ||
| except AssertionError as e: | ||
| raise AssertionError( | ||
| f"kernel is incorrect\nactual: {arr_correct}\ndesired: {expected}" | ||
| ) from e | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test does not check that values of signature on disk changed or something was renamed
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept it abstract in case we want to change in implementation details down the road (i.e., if we want to move away from renaming or change the directory structure). |
||
Uh oh!
There was an error while loading. Please reload this page.