Skip to content

Commit 80ced7d

Browse files
generatedunixname89002005232357meta-codesync[bot]
authored andcommitted
Revert D100485231 (#1247)
Summary: Pull Request resolved: #1247 This diff reverts D100485231 T264456316 Depends on D100485231 Reviewed By: tissue3 Differential Revision: D100527284 fbshipit-source-id: 44c77e0a1e2d1358b559e66e01ae25dd96d37bee
1 parent dcb9837 commit 80ced7d

1 file changed

Lines changed: 5 additions & 37 deletions

File tree

python/triton/runtime/jit.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,6 @@
2424
TRITON_MODULE = "triton.language"
2525
GLUON_MODULE = "triton.experimental.gluon.language"
2626

27-
# Global monotonic counter incremented on every _unsafe_update_src() call.
28-
# Used by the Layer 1 fast path to detect source modifications on JITCallable
29-
# objects that are passed as arguments (identity check alone is insufficient
30-
# because the same Python object can have its source mutated in-place).
31-
_src_update_version = 0
32-
3327
T = TypeVar("T")
3428

3529
# -----------------------------------------------------------------------------
@@ -552,8 +546,6 @@ def _unsafe_update_src(self, new_src):
552546
553547
Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None.
554548
"""
555-
global _src_update_version
556-
_src_update_version += 1
557549
self.hash = None
558550
self._src = new_src
559551

@@ -611,23 +603,6 @@ def convert_to_tuple_if_list(item):
611603
return tuple(item)
612604

613605

614-
class _DeviceCaches(defaultdict):
615-
"""A defaultdict that also invalidates the Layer 1 fast-path cache
616-
(``_last_call``) whenever the in-memory kernel cache is cleared.
617-
Without this, ``device_caches.clear()`` would wipe Layer 2 but
618-
leave a stale Layer 1 entry, causing the fast path to return a
619-
kernel that is no longer in the device cache."""
620-
621-
def __init__(self, jit_function, default_factory):
622-
super().__init__(default_factory)
623-
self._jit_function = jit_function
624-
625-
def clear(self):
626-
super().clear()
627-
self._jit_function._last_call = None
628-
self._jit_function._last_kwargs = {}
629-
630-
631606
class JITFunction(JITCallable, KernelInterface[T]):
632607

633608
def is_gluon(self):
@@ -806,8 +781,7 @@ def run(self, *args, grid, warmup, **kwargs):
806781
# This is just N pointer comparisons with zero attribute access.
807782
if not warmup and not self.pre_run_hooks:
808783
last = self._last_call
809-
if last is not None and last[0] is device and last[4] == _src_update_version and last[
810-
5] == knobs.compilation.instrumentation_mode:
784+
if last is not None and last[0] is device:
811785
last_args = last[1]
812786
if len(args) == len(last_args):
813787
identical = True
@@ -948,13 +922,8 @@ def run(self, *args, grid, warmup, **kwargs):
948922
# Populate fast-path caches for future calls.
949923
# Store both raw args (for identity check) and bound_args values
950924
# (for launching — includes default parameter values).
951-
# Only populate when the fast path guard would allow reuse —
952-
# if pre_run_hooks are active, the compiled kernel may depend on
953-
# hook-controlled state that the fast path doesn't check.
954-
if not self.pre_run_hooks:
955-
self._last_call = (device, args, kernel, tuple(bound_args.values()), _src_update_version,
956-
knobs.compilation.instrumentation_mode)
957-
self._last_kwargs = _user_kwargs
925+
self._last_call = (device, args, kernel, tuple(bound_args.values()))
926+
self._last_kwargs = _user_kwargs
958927
if fast_key is not None:
959928
self._run_cache[fast_key] = kernel
960929

@@ -983,11 +952,10 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
983952
self.params.append(KernelParam(i, param, dns, dns_oa))
984953

985954
# cache of just-in-time compiled kernels
986-
self.device_caches = _DeviceCaches(self, self.create_binder)
955+
self.device_caches = defaultdict(self.create_binder)
987956

988957
# Last-call cache for identity-based fast path (Layer 1).
989-
# Stores (device, args, kernel, bound_vals, src_update_version, instrumentation_mode)
990-
# from the previous successful launch.
958+
# Stores (device, args, kernel, bound_vals) from the previous successful launch.
991959
self._last_call = None
992960
self._last_kwargs = {}
993961
# Signature-based fast-path cache (Layer 2).

0 commit comments

Comments
 (0)