Skip to content

Commit 511eb22

Browse files
Merge pull request #371 from Arech8:arech_update_to_jax_main
PiperOrigin-RevId: 874562305
2 parents 048c0a5 + 4d686a0 commit 511eb22

3 files changed

Lines changed: 50 additions & 6 deletions

File tree

jax_triton/triton_lib.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def get_triton_type(obj: Any) -> str:
118118
return f"*{_JAX_TO_TRITON_TYPE_MAP[obj.dtype]}"
119119
if isinstance(obj, tl.constexpr):
120120
obj = obj.value
121+
if isinstance(obj, bool): # True == isinstance(True, int) !!!
122+
return "B"
121123
if isinstance(obj, int):
122124
if -(2**31) <= obj < 2**31:
123125
return "i32"
@@ -133,8 +135,6 @@ def get_triton_type(obj: Any) -> str:
133135
return "fp64"
134136
if isinstance(obj, np.float32):
135137
return "fp32"
136-
if isinstance(obj, bool):
137-
return "B"
138138
if isinstance(obj, str):
139139
return "str"
140140
raise NotImplementedError(
@@ -167,7 +167,43 @@ def get_cuda_backend(device, compute_capability):
167167
return backend
168168

169169

170+
_IS_HIPBackend_PATCHED = False
171+
def _patch_hip_backend():
172+
"""
173+
This defuses a bomb planted into Triton's AMD-specific compilation path by
174+
https://github.com/triton-lang/triton/commit/37ff43c5efd6e1b84c00a599ba070a501181e832#diff-33c9a103282c05c9d9d213b94450ae7481b6db8c3c6d810f54f175b4735a3c72
175+
In short: there's an unconditional and totally unnecessary "import torch" directive crashing
176+
the code when torch isn't installed.
177+
178+
Remove the patch once triton wheel package version is pinned to >= triton version with the fix.
179+
"""
180+
global _IS_HIPBackend_PATCHED
181+
if _IS_HIPBackend_PATCHED:
182+
return
183+
_IS_HIPBackend_PATCHED = True
184+
185+
if not hasattr(hb.HIPBackend, "is_within_2gb"):
186+
return
187+
try:
188+
hb.HIPBackend.is_within_2gb(1)
189+
# if we're here, either the torch is installed, or the code was fixed
190+
except ImportError:
191+
# redefining poisoned implementation. At this point, it's super unlikely a user
192+
# would update python package discovery paths before the real call to is_within_2gb() to make
193+
# `import torch` succeed, so we could assume there's just no torch in the redefinition.
194+
def fixed_is_within_2gb(arg):
195+
MAX_INT_32 = 2**31 - 1
196+
if hasattr(arg, "ptr_range"):
197+
return arg.ptr_range() <= MAX_INT_32
198+
return False
199+
200+
hb.HIPBackend.is_within_2gb = fixed_is_within_2gb
201+
202+
170203
def get_hip_backend(device, compute_capability):
204+
# TODO(Arech): remove _patch_hip_backend() once Triton releases a fix
205+
_patch_hip_backend()
206+
171207
arch = triton_kernel_call_lib.get_arch_details(device)
172208
arch = arch.split(":")[0]
173209
target = hb.GPUTarget("hip", arch, 64)
@@ -626,6 +662,14 @@ def prune_configs(configs, named_args, **kwargs):
626662
)
627663
)
628664
elif i not in equal_to_1:
665+
# Convert TypedInt/TypedFloat subclasses to plain Python types,
666+
# as nanobind's strict-mode integer caster rejects subclasses.
667+
if isinstance(arg, bool):
668+
arg = bool(arg)
669+
elif isinstance(arg, int):
670+
arg = int(arg)
671+
elif isinstance(arg, float):
672+
arg = float(arg)
629673
kernel_params.append(
630674
triton_kernel_call_lib.create_scalar_parameter(arg, dtype)
631675
)

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ readme = "README.md"
66
requires-python = ">=3.11"
77
dependencies = [
88
"absl-py>=1.4.0",
9-
"jax>=0.6.0",
9+
"jax>=0.8.2",
1010
"triton>=3.6",
1111
]
1212

@@ -17,10 +17,10 @@ tests = [
1717

1818

1919
[build-system]
20-
requires = ["setuptools", "setuptools-scm"]
20+
requires = ["setuptools"]
2121
build-backend = "setuptools.build_meta"
2222

23-
[tools.setuptools]
23+
[tool.setuptools]
2424
packages = ["jax_triton"]
2525

2626
[tool.setuptools.dynamic]

tests/triton_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import numpy as np
2121
import triton
2222
import triton.language as tl
23-
from triton.language.extra.cuda import libdevice
23+
from triton.language.extra import libdevice
2424

2525

2626
@triton.jit

0 commit comments

Comments
 (0)