Warning
These bindings are not actively maintained, you should prefer using the ones from the flash-attn3-jax repo instead.
This repo contains bindings for FlashAttention3
in JAX. There are two versions for these bindings, a C++ version
jax_flash_attn and a Rust version jflash_attn.
The BSD-3 license that holds for the flash-attention repo also applies here.
Build a wheel using uv build system.
uv build --wheelFor parallel compilation of CUDA kernels, you can set the CMAKE_BUILD_PARALLEL_LEVEL env variable:
CMAKE_BUILD_PARALLEL_LEVEL=32 uv build --wheelBuild locally for development.
uv sync --group test
uv run test.py # run some tests and benchmarksIn order to build a python package as a wheel, run maturin build --release.
In order to build a python package and install it in the current virtual
enviroment, run maturin develop.
First compile the C++ and/or Rust package and install them locally. Use the following to run the tests.
python test.py --bindings cpp
python test.py --bindings rustAnd use the --bench flag to run the benchmarks instead of the tests.
python test.py --bindings cpp --bench True
python test.py --bindings rust --bench TrueThis measures the time spent in the attention layer for three different implementations.
flash-attn: uses the optimized flash-attention kernel.attn-einsum: uses a simple attention implementation based on einsum.attn-flax: usesflax.linen.dot_product_attention. Timings include the forward pass only for the first lines and both the forward and backward passes for the lines that start withbwd. The second column is the sequence length (the batch size is adapted so as to have a reasonable amount of computation).
flash-attn 512 0.96ms 71.6 TFLOPS (std 0.39ms, min 0.79ms, max 2.43ms)
attn-flax 512 1.90ms 36.1 TFLOPS (std 0.44ms, min 1.64ms, max 3.46ms)
flash-attn 1024 1.04ms 131.8 TFLOPS (std 0.25ms, min 0.88ms, max 1.74ms)
attn-flax 1024 1.13ms 122.0 TFLOPS (std 0.27ms, min 0.98ms, max 1.94ms)
flash-attn 2048 1.16ms 237.6 TFLOPS (std 0.13ms, min 1.08ms, max 1.58ms)
attn-flax 2048 1.44ms 191.2 TFLOPS (std 0.39ms, min 1.25ms, max 2.68ms)
flash-attn 4096 1.59ms 346.2 TFLOPS (std 0.30ms, min 1.45ms, max 2.82ms)
attn-flax 4096 1.91ms 287.8 TFLOPS (std 0.33ms, min 1.75ms, max 3.20ms)
flash-attn 8192 2.27ms 483.9 TFLOPS (std 0.18ms, min 2.16ms, max 3.05ms)
attn-flax 8192 2.97ms 370.4 TFLOPS (std 0.36ms, min 2.79ms, max 4.17ms)
flash-attn 16384 3.88ms 566.6 TFLOPS (std 0.29ms, min 3.71ms, max 4.67ms)
attn-flax 16384 22.14ms 99.3 TFLOPS (std 0.56ms, min 21.54ms, max 23.44ms)
bwd flash-attn 512 2.23ms 107.9 TFLOPS (std 0.30ms, min 2.04ms, max 2.93ms)
bwd attn-flax 512 3.30ms 72.9 TFLOPS (std 0.17ms, min 3.17ms, max 3.84ms)
bwd flash-attn 1024 2.54ms 189.4 TFLOPS (std 0.31ms, min 2.29ms, max 3.28ms)
bwd attn-flax 1024 4.79ms 100.4 TFLOPS (std 0.38ms, min 4.60ms, max 5.92ms)
bwd flash-attn 2048 3.29ms 292.1 TFLOPS (std 0.50ms, min 2.89ms, max 4.42ms)
bwd attn-flax 2048 7.66ms 125.5 TFLOPS (std 0.35ms, min 7.48ms, max 8.52ms)
bwd flash-attn 4096 4.25ms 452.7 TFLOPS (std 0.34ms, min 4.03ms, max 5.20ms)
bwd attn-flax 4096 13.70ms 140.4 TFLOPS (std 0.51ms, min 13.17ms, max 15.23ms)
bwd flash-attn 8192 7.86ms 489.7 TFLOPS (std 1.57ms, min 7.02ms, max 13.35ms)
bwd attn-flax 8192 25.31ms 152.0 TFLOPS (std 0.51ms, min 24.80ms, max 26.60ms)
bwd flash-attn 16384 13.62ms 565.3 TFLOPS (std 0.49ms, min 13.09ms, max 15.08ms)
bwd attn-flax 16384 47.84ms 160.9 TFLOPS (std 0.44ms, min 47.54ms, max 49.61ms)