Skip to content

Update to the latest JAX main; bugfixes; improvements for AMD/ROCm#371

Merged
copybara-service[bot] merged 1 commit intojax-ml:mainfrom
Arech8:arech_update_to_jax_main
Feb 24, 2026
Merged

Update to the latest JAX main; bugfixes; improvements for AMD/ROCm#371
copybara-service[bot] merged 1 commit intojax-ml:mainfrom
Arech8:arech_update_to_jax_main

Conversation

@Arech8
Copy link
Copy Markdown
Contributor

@Arech8 Arech8 commented Feb 11, 2026

The PR bumps jax-triton implementation to support the current latest JAX main branch (v9.1.0-). It also works with the released v0.8.2, v0.9.0 and hopefully would work for some time with the next JAX versions.

Specifically, it:

  • fixes invalid pyproject.toml section name and removes an unused build dependency on "setuptools-scm".
  • fixes dependencies versions to more recent ones
  • fixes invalid order of type checks in triton_lib.py leading to a bool check being shadowed by int check
  • implements a workaround for nanobind strict type conversions, that leads to inability to, for example, cast a TypedInt subtype of int to an integer
  • makes triton_test.py platform agnostic by using a proper vendor-agnostic implementation
  • provides a hot runtime patch for an import torch bomb planted into AMD specific implementation of a Triton component. Related PR into Triton upstream: Fix unguarded import torch in HIPBackend triton-lang/triton#9441

Functionality was tested with a bundled test suite on AMD MI355X on the latest build of JAX from current main

tests$ pytest
====== test session starts =======
platform linux -- Python 3.11.14, pytest-8.4.2, pluggy-1.6.0
rootdir: /my/rocm-jax/jax-triton
configfile: pyproject.toml
plugins: csv-3.0.0, json-report-1.5.0, rerunfailures-16.1, reportlog-1.0.0, hypothesis-6.142.1, xdist-3.8.0, html-4.2.0, metadata-3.1.1
collected 162 items

cluster_test.py .s                                [  1%]
triton_call_test.py .............................. [ 68%]
.............................                   [ 98%]
triton_test.py ..                                  [100%]

====== 161 passed, 1 skipped in 25.00s ======

Copy link
Copy Markdown
Collaborator

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment thread jax_triton/triton_lib.py Outdated
if isinstance(obj, int):
if isinstance(obj, bool): # bool is a subclass of int and the test MUST go before int
return "B"
if isinstance(obj, int): # True == isinstance(True, int) !!!
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the comment?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which one? The second or both? They both kind of in sync. Someone made a mistake because forgot the second part...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified to leave a single comment

Comment thread jax_triton/__init__.py Outdated
# now testing if the original implementation fails:
try:
# passing an int should cause the exception
HIPBackend.is_within_2gb(1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about instead we check the Triton version and unconditionally monkeypatch HIPBackend if we know the version is bad?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't make a PR into Triton yet (going to do later today or tomorrow) and I can't forecast under what version index they will release fixed code... Would it be a patch, or a minor, or a major bump? ... Behavior based patching just doesn't care about that. The runtime cost is minuscule... Do you think it's a problem?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you see any downsides in always patching? In theory the implementation in the HIPBackend might diverge, but given that this method is fairly straightforward, I think it's unlikely it'll diverge.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the implementation divergence is my main concern here... I'm way out of the context why this method exist for AMD backend and what actual problem it's supposed to solve (documenting things isn't fancy, heh?..), so wouldn't like to assume more than barely necessary (this patch is also not ideal in that respect, tbh, just didn't want to spend much time on it...).

The thing is, a solution where the patch is applied only after it verifies it throws in a default setting is safe to stay in place indefinitely long, irrespective of anything (just need to add a check if such a method present before calling it), - the patch just won't be applied as soon as the implementation ceases to throw, and no harm done. If the patch always overrides the implementation - someone must keep an eye on the upstream and remove the patch as soon as a fixed version is released to prevent hiding potential subsequent implementation changes... How realistic these changes? What could diverge is so small code? For example, they might add support for more objects types there: other tensors types, or a support for some internal data wrappers, analogous to our TypedInt subtypes (which are already causing issues with strict nanobind conversion rules, btw). So, imho, surprisingly quite a few things might go wrong. That's why I think the principle of least surprise is so important... Debugging of such a patched code is a next level of enjoyment, btw 😁

Is that something that concerns you about the pre-patch check (if the implementation always throws)? I agree that this isn't beneficial to put the patch in the generic code path, - putting it somewhere where it executes only on AMD platform is way better. I'll try to find such a place, so non-AMD users won't be affected at all. Would such a change be ok to you (leave a pre-patch check, but move everything under AMD-specific branch)?

Comment thread jax_triton/__init__.py Outdated
"strides_from_shape",
"__version__",
"__version_info__",
"utils",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change?

Google uses 2 space indentation, but 4 space hanging indent.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, Right, something is off with my settings. Thanks, I'll fix it (in a couple of hours, more precisely, need to drop out a bit)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eeerm... So this was 4 space formatted and ruff has changed this to 2 spaces according to pyproject.toml. So the change is applied by the only formatter configured in the project and seem totally legit b/c of that... (but I'll revert the file back anyway, since no changes to it is needed, I've moved the patch elsewhere)

Which formatter do you use then if not ruff? I constantly have ruff reformatting Google's python, it's super annoying to contribute b/c of that...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted __init_py and moved the patch to AMD specific method in triton_lib.py. I hope, this addresses all your concerns not related to AMD.

Comment thread jax_triton/__init__.py Outdated
Comment thread jax_triton/__init__.py Outdated
@Arech8 Arech8 force-pushed the arech_update_to_jax_main branch from bd71a86 to 5e9258a Compare February 12, 2026 09:24
@Arech8 Arech8 requested a review from superbobry February 12, 2026 16:38
@Arech8 Arech8 force-pushed the arech_update_to_jax_main branch from 54f2666 to 4d686a0 Compare February 24, 2026 11:23
@copybara-service copybara-service Bot merged commit 511eb22 into jax-ml:main Feb 24, 2026
6 checks passed
@Arech8 Arech8 deleted the arech_update_to_jax_main branch March 4, 2026 12:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants