Skip to content

Commit 9713827

Browse files
authored
Support JITFunction in preload (#8794)
1 parent 6b67b3c commit 9713827

2 files changed

Lines changed: 82 additions & 8 deletions

File tree

python/test/unit/runtime/test_cache.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,3 +825,62 @@ def kernel(out_ptr, FUNC: tl.constexpr) -> None:
825825
Compiling with fn_a
826826
Compiling with fn_a after modification
827827
""")
828+
829+
830+
def test_preload_higher_order_kernels(device, fresh_triton_cache) -> None:
831+
832+
@triton.jit
833+
def fn_a():
834+
return 17
835+
836+
@triton.jit
837+
def fn_b():
838+
return 31
839+
840+
@triton.jit
841+
def kernel(out_ptr, FUNC: tl.constexpr) -> None:
842+
val = FUNC()
843+
tl.store(out_ptr, val)
844+
845+
device = getattr(torch, device).current_device()
846+
847+
# get the serialized specialization data
848+
specialization_data = None
849+
850+
def cache_hook(*args, **kwargs):
851+
nonlocal specialization_data
852+
specialization_data = kwargs["compile"]["specialization_data"]
853+
854+
triton.knobs.runtime.jit_cache_hook = cache_hook
855+
output = torch.empty((), device=device, dtype=torch.int32)
856+
compiled_kernel = kernel[(1, )](output, fn_a)
857+
assert output.item() == 17
858+
hash = compiled_kernel.hash
859+
assert specialization_data is not None
860+
861+
# clear the cache
862+
shutil.rmtree(fresh_triton_cache)
863+
kernel.device_caches[device][0].clear()
864+
865+
# preload the kernel
866+
kernel_preload = kernel.preload(specialization_data)
867+
assert kernel_preload.hash == hash
868+
assert len(kernel.device_caches[device][0]) == 1
869+
870+
# we should hit the cache and not compile anything
871+
counter = 0
872+
873+
def inc_counter(*args, **kwargs):
874+
nonlocal counter
875+
counter += 1
876+
877+
triton.knobs.runtime.jit_cache_hook = inc_counter
878+
final_kernel = kernel[(1, )](output, fn_a)
879+
assert counter == 0
880+
assert len(kernel.device_caches[device][0]) == 1
881+
assert final_kernel.hash == hash
882+
883+
# different function should compile and not hit the cache
884+
kernel[(1, )](output, fn_b)
885+
assert counter == 1
886+
assert output.item() == 31

python/triton/runtime/jit.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,9 @@ def __getitem__(self, grid) -> T:
373373

374374
def serialize_specialization_data(name, signature, constants, attrs, options, key):
375375
constants = {
376-
key: str(value) if value.__class__.__name__ == "dtype" else
377-
{"constexpr": value.value} if value.__class__.__name__ == "constexpr" else value
376+
key: str(value) if value.__class__.__name__ == "dtype" else {"constexpr": value.value}
377+
if value.__class__.__name__ == "constexpr" else {"jit_function": f"{value.module}:{value.fn.__qualname__}"}
378+
if value.__class__.__name__ == "JITFunction" else value
378379
for key, value in constants.items()
379380
}
380381

@@ -560,6 +561,9 @@ def _get_src(self):
560561
src = property(fget=_get_src, fset=_set_src)
561562

562563

564+
_triton_jit_function_registry = {}
565+
566+
563567
@dataclass
564568
class JitFunctionInfo:
565569
module: ModuleType
@@ -771,6 +775,8 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
771775
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
772776
self._repr = repr
773777
self.launch_metadata = launch_metadata
778+
# Register for simple deserialization of JITFunction constants
779+
_triton_jit_function_registry[f"{self.module}:{self.fn.__qualname__}"] = self
774780

775781
self.params = []
776782
for i, param in enumerate(self.signature.parameters.values()):
@@ -805,12 +811,21 @@ def preload(self, specialization_data):
805811
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
806812
constant_keys = map(tuple, deserialized_obj['constant_keys'])
807813
constant_vals = deserialized_obj['constant_vals']
808-
constexprs = {
809-
key:
810-
tl.dtype(value) if tl.dtype.is_dtype(value) else
811-
tl.constexpr(value['constexpr']) if isinstance(value, dict) and 'constexpr' in value else value
812-
for key, value in zip(constant_keys, constant_vals)
813-
}
814+
815+
def _decode_constant(value):
816+
if tl.dtype.is_dtype(value):
817+
return tl.dtype(value)
818+
if isinstance(value, dict):
819+
if 'constexpr' in value:
820+
return tl.constexpr(value['constexpr'])
821+
if 'jit_function' in value:
822+
jf_key = value['jit_function']
823+
if jf_key in _triton_jit_function_registry:
824+
return _triton_jit_function_registry[jf_key]
825+
raise RuntimeError(f"Unable to resolve JITFunction {jf_key} for preload")
826+
return value
827+
828+
constexprs = {key: _decode_constant(value) for key, value in zip(constant_keys, constant_vals)}
814829
attrs_keys = map(tuple, deserialized_obj['attrs_keys'])
815830
attrs_vals = deserialized_obj['attrs_vals']
816831
attrs = dict(zip(attrs_keys, attrs_vals))

0 commit comments

Comments
 (0)