Releases: erfanzar/ejkernel
Releases · erfanzar/ejkernel
CUDA and CuTE Have Arrived
Highlights
- Implemented missing non-inference backward paths:
- Triton
rwkv6backward (fixed-length +cu_seqlensvarlen) - Triton
rwkv7backward (fixed-length +cu_seqlensvarlen), withrwkv7_mulbackprop flowing through RWKV7 - CUDA
blocksparse_attentionbackward via CUDA-side dense analytical fallback (removedNotImplementedError) - XLA recurrent varlen (
cu_seqlens) backward with packed-varlen/state normalization - Added
forward_autotune_onlycontext manager inejkernel.ops/ejkernel.ops.configto disable backward autotune validation while keeping forward autotune active.
Quantized Matmul / Backend Work
- Added TPU Pallas hybrid QMM path:
- Packed / predecode / XLA fallback dispatch
- Shared forward + dX kernel family and custom VJP dX path
- TPU legality gates and memory-aware heuristics
- Expanded QMM backend normalization across Triton/CUDA/CuTe/XLA/Pallas.
- CUDA QMM improvements:
- cuBLASLt/CUTLASS GEMM backend options (
EJKERNEL_QMM_CUDA_GEMM) - NF4 exact table lookup path
- Expanded affine group-size coverage.
Runtime & Compatibility Fixes
- Fixed Triton QMM tracer leak by removing global decode-table state.
- Added JAX compatibility shim for
pl.ANYmovement. - Added single-device XLA fallback path for ring attention (
axis_name=None). - Fixed recurrent varlen forward unpacking bug.
- Flash attention
FlashBiastyping improvements.
Developer Experience / Quality
- Refactored many kernel interfaces into split
*_impl_fwd.py/*_impl_bwd.py. - Added/expanded docstrings across kernels/callib/ops.
- Added structured unsupported-feature errors via
EjkernelRuntimeError. - Minor style cleanup (consolidated multiline errors, sorted exports).
Test Coverage Added/Expanded
- JIT grad/VJP + numerical sanity checks for:
- RWKV6/RWKV7/RWKV7_MUL backward paths
- CUDA blocksparse attention backward
- XLA recurrent varlen backward
- Expanded QMM routing/cache/memory-capacity/grad parity coverage across backends.
What's Changed
Full Changelog: v0.0.50...v0.0.55
ejkernel v0.0.50
Breaking Changes
- Renamed
attention_sinktosoftmax_auxacrossragged_page_attention_v3andunified_attention(modules + kernels + tests). - Renamed
mla_attentiontoflash_mlainejkernel.modules/ejkernel.modules.operations.
Added
- New paged-attention ops with Triton + XLA backends:
chunked_prefill_paged_decode(updates block-tabled KV cache + runs unified paged attention).decode_attention(paged decode attention returning output + LSE).
- XLA backend support for
flash_mla.
Changed
make_dummy_rpa_inputsnow supports padded page tables viatotal_num_pagesfor smaller physical caches.- Attention-module call signatures now place
cfgat the end (chunked_prefill_paged_decode,decode_attention,unified_attention).
Testing
- Added/expanded Triton↔XLA equivalence tests for
chunked_prefill_paged_decodeanddecode_attention, including shard_map coverage.
Full Changelog: v0.0.47...v0.0.50
Release 0.0.47
Release v0.0.47
- Fixes false-positive
Signature mismatchwarnings forragged_decode_attentionduring import-time kernel registry validation. - Improves type-annotation normalization so equivalent JAX array spellings (
Array/jax.Array/jaxlib._jax.Array) don’t trip signature checks when some backends use postponed annotations (from __future__ import annotations). - No API changes; this release cleans up diagnostics and makes cross-backend signature validation consistent.
What's Changed
New Contributors
Full Changelog: 0.0.45...v0.0.47
Release 0.0.45
Added
- RWKV recurrence ops: RWKV-4, RWKV-6 (multi-head + variable-length packing + reverse), RWKV-7 DPLR, and RWKV-7 Mul (multiplicative reparameterization)
- Triton + XLA kernel implementations for rwkv4, rwkv6, rwkv7, rwkv7_mul
Unit tests for RWKV kernels (output parity vs XLA + gradient shape checks)
Changed
- XLA block-sparse attention: delegates to dense attention when softmax_aux=None, improves KV-head replication for GQA/MQA, and tightens logits masking/soft-cap handling
v0.0.40
v0.0.40
Changes since v0.0.31.
Highlights
- Triton is now optional: base installs no longer require Triton. Install GPU/Triton extras to enable GPU kernels.
- TPU Flash Attention: sliding window + logits soft-cap support (forward/backward), plus cache layout fixes and stability improvements.
- Ring Attention: Triton backend now uses a block-sparse inner kernel for correct global position/segment masking;
q_position_ids/kv_position_idsadded across TPU/Triton/XLA backends. - New/expanded kernels:
prefill_page_attentionfor TPU Pallas + XLA;native_sparse_attentionextended with per-query-block index/count layouts.
Fixes & Improvements
- Numerical stability improvements across attention kernels (float32 accumulators, stable soft-cap/sink handling).
- Ragged decode correctness fixes (in-bounds masking, fully-masked blocks, KV length < block size handling).
- Triton compatibility updates (kernel API changes, wrapper unwrapping order support, min dot block size enforcement).
- Correctness fixes in
grouped_matmul(v2),mean_pooling, andpage_attention. - Full XLA cache disabled by default to prevent GPU kernel launch failures.
Upgrade Notes
- Persistent-cache version bumped for invalidation; previous caches will be rebuilt.
softmax_auxhandling is unified toFloat[Array, "num_sinks"]across attention kernels.- Tests were reorganized; module operation tests now live under
test/modules/operations.
Install
pip install -U ejkernel
# GPU/Triton kernels
pip install -U "ejkernel[gpu]"
# or
pip install -U "ejkernel[triton]"Full Changelog: v0.0.31...v0.0.40
v0.0.31
v0.0.31 - Initial Release
ejKernel: High-Performance JAX Kernels for Deep Learning
We're excited to announce the first public release of ejKernel — a production-grade kernel library for JAX with automatic multi-backend support, sophisticated autotuning, and seamless execution across GPUs, TPUs, and CPUs.
Highlights
- Multi-Backend Support: Automatic platform detection with Triton (GPU), Pallas (TPU), and XLA (universal) backends
- 7-Tier Configuration System: Intelligent kernel selection with persistent caching and autotuning
- Memory Efficient: Custom VJP implementations reducing gradient memory from O(N²) to O(N)
- Type Safe: Full jaxtyping annotations with runtime validation via beartype
- Distributed Ready: Full
shard_mapintegration for model and data parallelism
Attention Mechanisms (14)
| Operation | Description |
|---|---|
| Flash Attention v2 | Memory-efficient exact attention with sliding window, soft capping |
| Ring Attention | Distributed sequence parallelism for 100K+ token sequences |
| Page Attention | KV-cache optimized inference with continuous batching |
| Prefill Page Attention | Separate prefill phase handling |
| Ragged Page Attention v2/v3 | Variable-length paged attention with attention sinks |
| Ragged Decode Attention | Efficient variable-length decoding |
| Block Sparse Attention | Configurable local+global sparse patterns |
| GLA | Gated Linear Attention with dual gating mechanism |
| Lightning Attention | Layer-adaptive attention with EMA decay |
| Kernel Delta Attention | Delta-rule linear attention with decay control |
| MLA | Multi-head Latent Attention with compressed KV |
| Unified Attention | vLLM-style paged attention |
| Scaled Dot-Product Attention | Standard reference implementation |
State Space Models (2)
| Operation | Description |
|---|---|
| State Space v1 | Mamba1-style SSM with 2D A matrix |
| State Space v2 | Mamba2-style SSM with per-head scalar A, optional gated RMSNorm |
Other Operations (5)
| Operation | Description |
|---|---|
| Grouped MatMul | Efficient batched matrix ops for MoE |
| Grouped MatMul v2 | Enhanced with shard_map support |
| Mean Pooling | Variable-length sequence aggregation |
| Recurrent | Optimized RNN/LSTM/GRU operations |
| Native Sparse | Block-sparse matrix computations |
Platform Support
| Backend | Hardware | Key Operations |
|---|---|---|
| Triton | NVIDIA/AMD GPU | Flash, Ring, Page, GLA, Lightning, MLA |
| Pallas | Google TPU | Flash, Ring, Page, Ragged, Block Sparse |
| XLA | GPU/TPU/CPU | All operations (universal fallback) |
Requirements
- Python 3.11-3.13
- JAX >= 0.8.0
- Triton == 3.4.0 (GPU)
Installation
pip install ejkernel
# GPU support
pip install ejkernel[gpu]
# TPU support
pip install ejkernel[tpu]Quick Example
from ejkernel.modules import flash_attention
output = flash_attention(
query, key, value,
causal=True,
sliding_window=128,
logits_soft_cap=30.0,
)Documentation
Full documentation: https://ejkernel.readthedocs.io/
Full Changelog: https://github.com/erfanzar/ejkernel/commits/v0.0.31