@@ -118,6 +118,8 @@ def get_triton_type(obj: Any) -> str:
118118 return f"*{ _JAX_TO_TRITON_TYPE_MAP [obj .dtype ]} "
119119 if isinstance (obj , tl .constexpr ):
120120 obj = obj .value
121+ if isinstance (obj , bool ): # True == isinstance(True, int) !!!
122+ return "B"
121123 if isinstance (obj , int ):
122124 if - (2 ** 31 ) <= obj < 2 ** 31 :
123125 return "i32"
@@ -133,8 +135,6 @@ def get_triton_type(obj: Any) -> str:
133135 return "fp64"
134136 if isinstance (obj , np .float32 ):
135137 return "fp32"
136- if isinstance (obj , bool ):
137- return "B"
138138 if isinstance (obj , str ):
139139 return "str"
140140 raise NotImplementedError (
@@ -167,7 +167,43 @@ def get_cuda_backend(device, compute_capability):
167167 return backend
168168
169169
170+ _IS_HIPBackend_PATCHED = False
171+ def _patch_hip_backend ():
172+ """
173+ This defuses a bomb planted into Triton's AMD-specific compilation path by
174+ https://github.com/triton-lang/triton/commit/37ff43c5efd6e1b84c00a599ba070a501181e832#diff-33c9a103282c05c9d9d213b94450ae7481b6db8c3c6d810f54f175b4735a3c72
175+ In short: there's an unconditional and totally unnecessary "import torch" directive crashing
176+ the code when torch isn't installed.
177+
178+ Remove the patch once triton wheel package version is pinned to >= triton version with the fix.
179+ """
180+ global _IS_HIPBackend_PATCHED
181+ if _IS_HIPBackend_PATCHED :
182+ return
183+ _IS_HIPBackend_PATCHED = True
184+
185+ if not hasattr (hb .HIPBackend , "is_within_2gb" ):
186+ return
187+ try :
188+ hb .HIPBackend .is_within_2gb (1 )
189+ # if we're here, either the torch is installed, or the code was fixed
190+ except ImportError :
191+ # redefining poisoned implementation. At this point, it's super unlikely a user
192+ # would update python package discovery paths before the real call to is_within_2gb() to make
193+ # `import torch` succeed, so we could assume there's just no torch in the redefinition.
194+ def fixed_is_within_2gb (arg ):
195+ MAX_INT_32 = 2 ** 31 - 1
196+ if hasattr (arg , "ptr_range" ):
197+ return arg .ptr_range () <= MAX_INT_32
198+ return False
199+
200+ hb .HIPBackend .is_within_2gb = fixed_is_within_2gb
201+
202+
170203def get_hip_backend (device , compute_capability ):
204+ # TODO(Arech): remove _patch_hip_backend() once Triton releases a fix
205+ _patch_hip_backend ()
206+
171207 arch = triton_kernel_call_lib .get_arch_details (device )
172208 arch = arch .split (":" )[0 ]
173209 target = hb .GPUTarget ("hip" , arch , 64 )
@@ -626,6 +662,14 @@ def prune_configs(configs, named_args, **kwargs):
626662 )
627663 )
628664 elif i not in equal_to_1 :
665+ # Convert TypedInt/TypedFloat subclasses to plain Python types,
666+ # as nanobind's strict-mode integer caster rejects subclasses.
667+ if isinstance (arg , bool ):
668+ arg = bool (arg )
669+ elif isinstance (arg , int ):
670+ arg = int (arg )
671+ elif isinstance (arg , float ):
672+ arg = float (arg )
629673 kernel_params .append (
630674 triton_kernel_call_lib .create_scalar_parameter (arg , dtype )
631675 )
0 commit comments