|
15 | 15 | """Library for JAX-Triton integrations.""" |
16 | 16 |
|
17 | 17 | __all__ = [ |
18 | | - "utils", |
19 | | - "triton_call", |
20 | | - "cdiv", |
21 | | - "next_power_of_2", |
22 | | - "strides_from_shape", |
23 | | - "__version__", |
24 | | - "__version_info__", |
| 18 | + "utils", |
| 19 | + "triton_call", |
| 20 | + "cdiv", |
| 21 | + "next_power_of_2", |
| 22 | + "strides_from_shape", |
| 23 | + "__version__", |
| 24 | + "__version_info__", |
25 | 25 | ] |
26 | 26 |
|
27 | 27 | from jax._src.lib import gpu_triton |
|
38 | 38 | get_serialized_metadata = gpu_triton.get_serialized_metadata |
39 | 39 | except AttributeError: |
40 | 40 | raise ImportError( |
41 | | - "jax-triton requires JAX to be installed with GPU support. The " |
42 | | - "installation page on the JAX documentation website includes " |
43 | | - "instructions for installing a supported version:\n" |
44 | | - "https://jax.readthedocs.io/en/latest/installation.html" |
| 41 | + "jax-triton requires JAX to be installed with GPU support. The " |
| 42 | + "installation page on the JAX documentation website includes " |
| 43 | + "instructions for installing a supported version:\n" |
| 44 | + "https://jax.readthedocs.io/en/latest/installation.html" |
45 | 45 | ) |
46 | 46 | else: |
47 | 47 | del gpu_triton # Not part of the API. |
| 48 | + |
| 49 | +###################### triton.backends.amd.compiler.HIPBackend patch begin ############## |
| 50 | +# this is patch fixing a bomb planted into Triton's AMD-specific compilation path by |
| 51 | +# https://github.com/triton-lang/triton/commit/37ff43c5efd6e1b84c00a599ba070a501181e832#diff-33c9a103282c05c9d9d213b94450ae7481b6db8c3c6d810f54f175b4735a3c72 |
| 52 | +# In short: there's an unconditional and totally unnecessary "import torch" directive crashing |
| 53 | +# the code when torch isn't installed. |
| 54 | +# |
| 55 | +# Remove the patch once triton wheel package version is pinned to >= triton version with the fix. |
| 56 | +try: |
| 57 | + from triton.backends.amd.compiler import HIPBackend |
| 58 | + # if we're thrown at above, that's fine, Triton not available and it's handled later |
| 59 | + |
| 60 | + # now testing if the original implementation fails: |
| 61 | + try: |
| 62 | + # passing an int should cause the exception |
| 63 | + HIPBackend.is_within_2gb(1) |
| 64 | + # if we're here, either the torch is installed and it's not a problem, or the code was fixed |
| 65 | + # and it's not a problem anymore too |
| 66 | + except ImportError: |
| 67 | + # redefining poisoned implementation |
| 68 | + def fixed_is_within_2gb(arg): |
| 69 | + # leaving the check as theoretically a user could update package discovery paths and the |
| 70 | + # module becomes importable. |
| 71 | + HAS_TORCH = False |
| 72 | + try: |
| 73 | + import torch # pytype: disable=import-error |
| 74 | + |
| 75 | + HAS_TORCH = True |
| 76 | + except ImportError: |
| 77 | + pass |
| 78 | + |
| 79 | + MAX_INT_32 = 2**31 - 1 |
| 80 | + if hasattr(arg, "ptr_range"): |
| 81 | + return arg.ptr_range() <= MAX_INT_32 |
| 82 | + if ( |
| 83 | + HAS_TORCH and isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage") |
| 84 | + ): |
| 85 | + return arg.untyped_storage().size() <= MAX_INT_32 |
| 86 | + return False |
| 87 | + |
| 88 | + HIPBackend.is_within_2gb = fixed_is_within_2gb |
| 89 | + |
| 90 | +except ImportError: |
| 91 | + pass |
| 92 | +###################### triton.backends.amd.compiler.HIPBackend patch end ############## |
0 commit comments