Skip to content

Commit bd71a86

Browse files
committed
Defuse the bomb in Triton
1 parent a170ac0 commit bd71a86

1 file changed

Lines changed: 56 additions & 11 deletions

File tree

jax_triton/__init__.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
"""Library for JAX-Triton integrations."""
1616

1717
__all__ = [
18-
"utils",
19-
"triton_call",
20-
"cdiv",
21-
"next_power_of_2",
22-
"strides_from_shape",
23-
"__version__",
24-
"__version_info__",
18+
"utils",
19+
"triton_call",
20+
"cdiv",
21+
"next_power_of_2",
22+
"strides_from_shape",
23+
"__version__",
24+
"__version_info__",
2525
]
2626

2727
from jax._src.lib import gpu_triton
@@ -38,10 +38,55 @@
3838
get_serialized_metadata = gpu_triton.get_serialized_metadata
3939
except AttributeError:
4040
raise ImportError(
41-
"jax-triton requires JAX to be installed with GPU support. The "
42-
"installation page on the JAX documentation website includes "
43-
"instructions for installing a supported version:\n"
44-
"https://jax.readthedocs.io/en/latest/installation.html"
41+
"jax-triton requires JAX to be installed with GPU support. The "
42+
"installation page on the JAX documentation website includes "
43+
"instructions for installing a supported version:\n"
44+
"https://jax.readthedocs.io/en/latest/installation.html"
4545
)
4646
else:
4747
del gpu_triton # Not part of the API.
48+
49+
###################### triton.backends.amd.compiler.HIPBackend patch begin ##############
50+
# this is patch fixing a bomb planted into Triton's AMD-specific compilation path by
51+
# https://github.com/triton-lang/triton/commit/37ff43c5efd6e1b84c00a599ba070a501181e832#diff-33c9a103282c05c9d9d213b94450ae7481b6db8c3c6d810f54f175b4735a3c72
52+
# In short: there's an unconditional and totally unnecessary "import torch" directive crashing
53+
# the code when torch isn't installed.
54+
#
55+
# Remove the patch once triton wheel package version is pinned to >= triton version with the fix.
56+
try:
57+
from triton.backends.amd.compiler import HIPBackend
58+
# if we're thrown at above, that's fine, Triton not available and it's handled later
59+
60+
# now testing if the original implementation fails:
61+
try:
62+
# passing an int should cause the exception
63+
HIPBackend.is_within_2gb(1)
64+
# if we're here, either the torch is installed and it's not a problem, or the code was fixed
65+
# and it's not a problem anymore too
66+
except ImportError:
67+
# redefining poisoned implementation
68+
def fixed_is_within_2gb(arg):
69+
# leaving the check as theoretically a user could update package discovery paths and the
70+
# module becomes importable.
71+
HAS_TORCH = False
72+
try:
73+
import torch # pytype: disable=import-error
74+
75+
HAS_TORCH = True
76+
except ImportError:
77+
pass
78+
79+
MAX_INT_32 = 2**31 - 1
80+
if hasattr(arg, "ptr_range"):
81+
return arg.ptr_range() <= MAX_INT_32
82+
if (
83+
HAS_TORCH and isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage")
84+
):
85+
return arg.untyped_storage().size() <= MAX_INT_32
86+
return False
87+
88+
HIPBackend.is_within_2gb = fixed_is_within_2gb
89+
90+
except ImportError:
91+
pass
92+
###################### triton.backends.amd.compiler.HIPBackend patch end ##############

0 commit comments

Comments
 (0)