Add CPU backend#322
Draft
stephen-huan wants to merge 3 commits intojax-ml:mainfrom
Draft
Conversation
This was referenced Dec 21, 2024
Closed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
(This is more of a issue/feature request than a PR, but since I have a working prototype I figured I'd share it.)
Triton now has a CPU backend from triton-cpu, which compiles LLIR to assembly using LLVM. This PR adds support for this by using
jax.pure_callbackto wrap calling Triton kernels from Python (generating a XLA custom call). A proper implementation of this would add a openmp cpu launcher to jaxlib's gpu_triton.py akin to triton_kernels.cc (unfortunately, the cpu backend doesn't seem to fit neatly into jaxlib's existing Triton abstractions, for example, it seems cuda/rocm are mutually exclusive since they overwrite the same names ingpu_triton.pywhile cpu can co-exist with gpu). I don't have enough familiarity with C++/jaxlib/xla to make this change myself, hence the feature request.The motivation for adding a cpu backend is that it's faster than
TRITON_INTERPRET=1and allows forjax.jit'ing Triton kernels like on gpu. In addition, it would possibly allow Pallas kernels to be ran on cpu withoutinterpret=True, which is generally very slow. Pure JAX code can be ran on either cpu or gpu with no code modifications, and it'd be nice if this was also true for Triton/Pallas kernels (for debugging/prototyping, but also to run fast on cpu itself).Known limitations of this PR:
jax.pure_callbackis used instead of a C++ XLA custom call, kernel launch overhead is relatively hightriton_kernel_callor MLIR loweringtriton_kernel_call_lowering) so the behavior is slightly differentzeroed_outputsdoesn't receive meta parameters from Triton configurationsPasses (and definitely completely overfit to) all tests except for those that count the number of compilations (as it doesn't use the MLIR lowering path) and
test_autotune_with_heuristicssince Triton evaluates the configuration multiple times.The first two commits are unrelated fixes to the tests which can be merged, and I've opened #321 with them verbatim.