JAX-AITER integrates AMD's AITER operator library into JAX, bringing AITER's high-performance attention kernels to JAX on ROCm via a stable FFI bridge and custom_vjp integration. It enables optimized MHA/FMHA (including variable-length attention) in JAX for both inference and training on AMD GPUs.
Status: experimental.
AITER is AMD’s centralized library of AI operators optimized for ROCm GPUs. It unifies multiple backends (C++, Python, CK, assembly, etc.) and exposes a consistent operator interface.
JAX-AITER builds on that foundation by providing:
- A JAX-native API: Operators are exposed as JAX functions with seamless
custom_vjp
(forward/backward) wiring. - Automatic operator dispatch: Dynamically chooses optimized implementations based on tensor shapes, data type, and options (e.g. causal, windowed).
- Zero-copy buffer exchange: Uses JAX FFI to avoid unnecessary transfers between JAX and AITER.
- Training-ready: Gradients flow natively through AITER kernels, enabling end-to-end differentiable pipelines.
- ROCm-native performance: Fully integrated with AMD GPU runtime and compiler stack.
If a wheel is available (e.g. from project Releases):
pip install path/to/jax_aiter-<version>-*.whl
Custom build requires cmake and ninja (both installable via pip):
pip install cmake ninja pyyaml
Environment setup (run from the top of the jax-aiter project tree):
export JA_ROOT_DIR="$PWD" # Set to the top of jax-aiter project tree
export AITER_SYMBOL_VISIBLE=1
export GPU_ARCHS=gfx950 # Example for MI350; use your GPU arch (e.g., gfx942 for MI300)
export AITER_ASM_DIR=/aiter-hsa-path/gfx950/ # Example for MI350
You can build natively or inside a ROCm container. You can pull docker images from the latest release of ROCm jax.
https://hub.docker.com/r/rocm/jax/tags
We suggest to use latest jax docker images:
docker pull rocm/jax:rocm7.0.2-jax0.6.0-py3.10-ubu22
Inside the container (or on your host with ROCm installed), proceed:
git clone --recursive [email protected]:ROCm/jax-aiter.git
Submodules:
- third_party/aiter
- third_party/pytorch
Statically build minimal PyTorch libraries (c10, torch_cpu, torch_hip, caffe2_nvrtc) and headers for linking.
Apply the caffe2_nvrtc static/PIC patch:
cd third_party/pytorch
git apply ../../scripts/torch_caffe.patch
cd -
cd third_party/aiter
git apply ../../scripts/aiter_torch_remove.patch
cd -
Run the static build script:
bash ./scripts/build_static_pytorch.sh
Script details:
- Source: third_party/pytorch
- Build dir: third_party/pytorch/build_static
- Install prefix: third_party/pytorch/build_static/install
- Tunables: ROCM_ARCH (gpu arch), ROCM_PATH (/opt/rocm), JOBS (nproc), PYTHON (python3)
Link the required static PyTorch objects and ROCm libs into a single .so:
make
Key paths (from Makefile):
- Output: build/aiter_build/libjax_aiter.so
- Static libs: third_party/pytorch/build_static/lib
- Include dirs: JAX FFI, PyTorch, and csrc/common are used
Build specific modules (example: varlen fwd+bwd):
python3 jax_aiter/jit/build_jit.py --module module_fmha_v3_varlen_fwd,module_fmha_v3_varlen_bwd
Build all configured modules:
python3 jax_aiter/jit/build_jit.py
Outputs (.so) are placed under build/aiter_build/.
Notes:
- build_jit.py doesn't use any "jit" atm, but in future we may change that to do so.
- build_jit.py patches AITER's core to redirect user JIT dir to build/aiter_build and inject PyTorch/JAX-FFI include paths.
- Ensure static PyTorch build completed first; headers and libs expected under third_party/pytorch/build_static and build_static/install/include.
Smoke test:
python -c "from jax_aiter.mha import flash_attn_func, flash_attn_varlen; print('jax-aiter import OK')"
Run tests (requires JAX ROCm and GPU):
pip install pytest-xdist
pytest -q tests/test_mha_varlen_ja.py
pytest -q tests/test_mha_ck_ja.py
-
Arch/driver mismatch: Set both GPU_ARCHS (e.g., gfx950 for MI350) and ROCM_ARCH for the static PyTorch build, then rebuild:
export GPU_ARCHS=<gfx*> env ROCM_ARCH=<gfx*> ./scripts/build_static_pytorch.sh make
- caffe2_nvrtc not found or not PIC: Ensure the patch (scripts/torch_caffe.patch) was applied, then rerun build_static_pytorch.sh
-
JIT cannot find PyTorch headers: Confirm third_party/pytorch/build_static/install/include exists. Re-run:
python3 jax_aiter/jit/build_jit.py --verbose --module module_fmha_v3_varlen_fwd
-
Symbol not found errors while loading .sos for MHA kernels Confitm that libmha_fwd and libmha_bwd are built and loaded before loading the respective modules.
- Static PyTorch build targets:
- c10, torch_cpu, torch_hip, caffe2_nvrtc (static, PIC)
- JIT module config: see jax_aiter/jit/optCompilerConfig.json for available modules:
- module_fmha_v3_varlen_fwd, module_fmha_v3_varlen_bwd
- module_mha_varlen_fwd, module_mha_varlen_bwd
- module_mha_fwd, module_mha_bwd
- libmha_fwd, libmha_bwd
- module_custom