Update to the latest JAX main; bugfixes; improvements for AMD/ROCm#371
Conversation
| if isinstance(obj, int): | ||
| if isinstance(obj, bool): # bool is a subclass of int and the test MUST go before int | ||
| return "B" | ||
| if isinstance(obj, int): # True == isinstance(True, int) !!! |
There was a problem hiding this comment.
which one? The second or both? They both kind of in sync. Someone made a mistake because forgot the second part...
There was a problem hiding this comment.
modified to leave a single comment
| # now testing if the original implementation fails: | ||
| try: | ||
| # passing an int should cause the exception | ||
| HIPBackend.is_within_2gb(1) |
There was a problem hiding this comment.
How about instead we check the Triton version and unconditionally monkeypatch HIPBackend if we know the version is bad?
There was a problem hiding this comment.
I didn't make a PR into Triton yet (going to do later today or tomorrow) and I can't forecast under what version index they will release fixed code... Would it be a patch, or a minor, or a major bump? ... Behavior based patching just doesn't care about that. The runtime cost is minuscule... Do you think it's a problem?
There was a problem hiding this comment.
Do you see any downsides in always patching? In theory the implementation in the HIPBackend might diverge, but given that this method is fairly straightforward, I think it's unlikely it'll diverge.
There was a problem hiding this comment.
Yeah, the implementation divergence is my main concern here... I'm way out of the context why this method exist for AMD backend and what actual problem it's supposed to solve (documenting things isn't fancy, heh?..), so wouldn't like to assume more than barely necessary (this patch is also not ideal in that respect, tbh, just didn't want to spend much time on it...).
The thing is, a solution where the patch is applied only after it verifies it throws in a default setting is safe to stay in place indefinitely long, irrespective of anything (just need to add a check if such a method present before calling it), - the patch just won't be applied as soon as the implementation ceases to throw, and no harm done. If the patch always overrides the implementation - someone must keep an eye on the upstream and remove the patch as soon as a fixed version is released to prevent hiding potential subsequent implementation changes... How realistic these changes? What could diverge is so small code? For example, they might add support for more objects types there: other tensors types, or a support for some internal data wrappers, analogous to our TypedInt subtypes (which are already causing issues with strict nanobind conversion rules, btw). So, imho, surprisingly quite a few things might go wrong. That's why I think the principle of least surprise is so important... Debugging of such a patched code is a next level of enjoyment, btw 😁
Is that something that concerns you about the pre-patch check (if the implementation always throws)? I agree that this isn't beneficial to put the patch in the generic code path, - putting it somewhere where it executes only on AMD platform is way better. I'll try to find such a place, so non-AMD users won't be affected at all. Would such a change be ok to you (leave a pre-patch check, but move everything under AMD-specific branch)?
| "strides_from_shape", | ||
| "__version__", | ||
| "__version_info__", | ||
| "utils", |
There was a problem hiding this comment.
Unrelated change?
Google uses 2 space indentation, but 4 space hanging indent.
There was a problem hiding this comment.
Hmmm, Right, something is off with my settings. Thanks, I'll fix it (in a couple of hours, more precisely, need to drop out a bit)
There was a problem hiding this comment.
Eeerm... So this was 4 space formatted and ruff has changed this to 2 spaces according to pyproject.toml. So the change is applied by the only formatter configured in the project and seem totally legit b/c of that... (but I'll revert the file back anyway, since no changes to it is needed, I've moved the patch elsewhere)
Which formatter do you use then if not ruff? I constantly have ruff reformatting Google's python, it's super annoying to contribute b/c of that...
There was a problem hiding this comment.
Reverted __init_py and moved the patch to AMD specific method in triton_lib.py. I hope, this addresses all your concerns not related to AMD.
bd71a86 to
5e9258a
Compare
54f2666 to
4d686a0
Compare
The PR bumps
jax-tritonimplementation to support the current latest JAXmainbranch (v9.1.0-). It also works with the released v0.8.2, v0.9.0 and hopefully would work for some time with the next JAX versions.Specifically, it:
pyproject.tomlsection name and removes an unused build dependency on"setuptools-scm".triton_lib.pyleading to aboolcheck being shadowed byintcheckTypedIntsubtype ofintto an integertriton_test.pyplatform agnostic by using a proper vendor-agnostic implementationimport torchbomb planted into AMD specific implementation of a Triton component. Related PR into Triton upstream: Fix unguardedimport torchin HIPBackend triton-lang/triton#9441Functionality was tested with a bundled test suite on AMD MI355X on the latest build of JAX from current
main