|
24 | 24 | TRITON_MODULE = "triton.language" |
25 | 25 | GLUON_MODULE = "triton.experimental.gluon.language" |
26 | 26 |
|
27 | | -# Global monotonic counter incremented on every _unsafe_update_src() call. |
28 | | -# Used by the Layer 1 fast path to detect source modifications on JITCallable |
29 | | -# objects that are passed as arguments (identity check alone is insufficient |
30 | | -# because the same Python object can have its source mutated in-place). |
31 | | -_src_update_version = 0 |
32 | | - |
33 | 27 | T = TypeVar("T") |
34 | 28 |
|
35 | 29 | # ----------------------------------------------------------------------------- |
@@ -552,8 +546,6 @@ def _unsafe_update_src(self, new_src): |
552 | 546 |
|
553 | 547 | Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None. |
554 | 548 | """ |
555 | | - global _src_update_version |
556 | | - _src_update_version += 1 |
557 | 549 | self.hash = None |
558 | 550 | self._src = new_src |
559 | 551 |
|
@@ -611,23 +603,6 @@ def convert_to_tuple_if_list(item): |
611 | 603 | return tuple(item) |
612 | 604 |
|
613 | 605 |
|
614 | | -class _DeviceCaches(defaultdict): |
615 | | - """A defaultdict that also invalidates the Layer 1 fast-path cache |
616 | | - (``_last_call``) whenever the in-memory kernel cache is cleared. |
617 | | - Without this, ``device_caches.clear()`` would wipe Layer 2 but |
618 | | - leave a stale Layer 1 entry, causing the fast path to return a |
619 | | - kernel that is no longer in the device cache.""" |
620 | | - |
621 | | - def __init__(self, jit_function, default_factory): |
622 | | - super().__init__(default_factory) |
623 | | - self._jit_function = jit_function |
624 | | - |
625 | | - def clear(self): |
626 | | - super().clear() |
627 | | - self._jit_function._last_call = None |
628 | | - self._jit_function._last_kwargs = {} |
629 | | - |
630 | | - |
631 | 606 | class JITFunction(JITCallable, KernelInterface[T]): |
632 | 607 |
|
633 | 608 | def is_gluon(self): |
@@ -806,8 +781,7 @@ def run(self, *args, grid, warmup, **kwargs): |
806 | 781 | # This is just N pointer comparisons with zero attribute access. |
807 | 782 | if not warmup and not self.pre_run_hooks: |
808 | 783 | last = self._last_call |
809 | | - if last is not None and last[0] is device and last[4] == _src_update_version and last[ |
810 | | - 5] == knobs.compilation.instrumentation_mode: |
| 784 | + if last is not None and last[0] is device: |
811 | 785 | last_args = last[1] |
812 | 786 | if len(args) == len(last_args): |
813 | 787 | identical = True |
@@ -948,13 +922,8 @@ def run(self, *args, grid, warmup, **kwargs): |
948 | 922 | # Populate fast-path caches for future calls. |
949 | 923 | # Store both raw args (for identity check) and bound_args values |
950 | 924 | # (for launching — includes default parameter values). |
951 | | - # Only populate when the fast path guard would allow reuse — |
952 | | - # if pre_run_hooks are active, the compiled kernel may depend on |
953 | | - # hook-controlled state that the fast path doesn't check. |
954 | | - if not self.pre_run_hooks: |
955 | | - self._last_call = (device, args, kernel, tuple(bound_args.values()), _src_update_version, |
956 | | - knobs.compilation.instrumentation_mode) |
957 | | - self._last_kwargs = _user_kwargs |
| 925 | + self._last_call = (device, args, kernel, tuple(bound_args.values())) |
| 926 | + self._last_kwargs = _user_kwargs |
958 | 927 | if fast_key is not None: |
959 | 928 | self._run_cache[fast_key] = kernel |
960 | 929 |
|
@@ -983,11 +952,10 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o |
983 | 952 | self.params.append(KernelParam(i, param, dns, dns_oa)) |
984 | 953 |
|
985 | 954 | # cache of just-in-time compiled kernels |
986 | | - self.device_caches = _DeviceCaches(self, self.create_binder) |
| 955 | + self.device_caches = defaultdict(self.create_binder) |
987 | 956 |
|
988 | 957 | # Last-call cache for identity-based fast path (Layer 1). |
989 | | - # Stores (device, args, kernel, bound_vals, src_update_version, instrumentation_mode) |
990 | | - # from the previous successful launch. |
| 958 | + # Stores (device, args, kernel, bound_vals) from the previous successful launch. |
991 | 959 | self._last_call = None |
992 | 960 | self._last_kwargs = {} |
993 | 961 | # Signature-based fast-path cache (Layer 2). |
|
0 commit comments