Skip to content

Releases: erfanzar/ejkernel

CUDA and CuTE Have Arrived

09 Feb 20:56

Choose a tag to compare

Highlights

  • Implemented missing non-inference backward paths:
  • Triton rwkv6 backward (fixed-length + cu_seqlens varlen)
  • Triton rwkv7 backward (fixed-length + cu_seqlens varlen), with rwkv7_mul backprop flowing through RWKV7
  • CUDA blocksparse_attention backward via CUDA-side dense analytical fallback (removed NotImplementedError)
  • XLA recurrent varlen (cu_seqlens) backward with packed-varlen/state normalization
  • Added forward_autotune_only context manager in ejkernel.ops / ejkernel.ops.config to disable backward autotune validation while keeping forward autotune active.

Quantized Matmul / Backend Work

  • Added TPU Pallas hybrid QMM path:
  • Packed / predecode / XLA fallback dispatch
  • Shared forward + dX kernel family and custom VJP dX path
  • TPU legality gates and memory-aware heuristics
  • Expanded QMM backend normalization across Triton/CUDA/CuTe/XLA/Pallas.
  • CUDA QMM improvements:
  • cuBLASLt/CUTLASS GEMM backend options (EJKERNEL_QMM_CUDA_GEMM)
  • NF4 exact table lookup path
  • Expanded affine group-size coverage.

Runtime & Compatibility Fixes

  • Fixed Triton QMM tracer leak by removing global decode-table state.
  • Added JAX compatibility shim for pl.ANY movement.
  • Added single-device XLA fallback path for ring attention (axis_name=None).
  • Fixed recurrent varlen forward unpacking bug.
  • Flash attention FlashBias typing improvements.

Developer Experience / Quality

  • Refactored many kernel interfaces into split *_impl_fwd.py / *_impl_bwd.py.
  • Added/expanded docstrings across kernels/callib/ops.
  • Added structured unsupported-feature errors via EjkernelRuntimeError.
  • Minor style cleanup (consolidated multiline errors, sorted exports).

Test Coverage Added/Expanded

  • JIT grad/VJP + numerical sanity checks for:
  • RWKV6/RWKV7/RWKV7_MUL backward paths
  • CUDA blocksparse attention backward
  • XLA recurrent varlen backward
  • Expanded QMM routing/cache/memory-capacity/grad parity coverage across backends.

What's Changed

  • Add native CUDA kernels, quantized matmul, and kernel error handling by @erfanzar in #3

Full Changelog: v0.0.50...v0.0.55

ejkernel v0.0.50

01 Jan 16:56

Choose a tag to compare

Breaking Changes

  • Renamed attention_sink to softmax_aux across ragged_page_attention_v3 and unified_attention (modules + kernels + tests).
  • Renamed mla_attention to flash_mla in ejkernel.modules / ejkernel.modules.operations.

Added

  • New paged-attention ops with Triton + XLA backends:
    • chunked_prefill_paged_decode (updates block-tabled KV cache + runs unified paged attention).
    • decode_attention (paged decode attention returning output + LSE).
  • XLA backend support for flash_mla.

Changed

  • make_dummy_rpa_inputs now supports padded page tables via total_num_pages for smaller physical caches.
  • Attention-module call signatures now place cfg at the end (chunked_prefill_paged_decode, decode_attention, unified_attention).

Testing

  • Added/expanded Triton↔XLA equivalence tests for chunked_prefill_paged_decode and decode_attention, including shard_map coverage.

Full Changelog: v0.0.47...v0.0.50

Release 0.0.47

31 Dec 13:04

Choose a tag to compare

Release v0.0.47

  • Fixes false-positive Signature mismatch warnings for ragged_decode_attention during import-time kernel registry validation.
  • Improves type-annotation normalization so equivalent JAX array spellings (Array / jax.Array / jaxlib._jax.Array) don’t trip signature checks when some backends use postponed annotations (from __future__ import annotations).
  • No API changes; this release cleans up diagnostics and makes cross-backend signature validation consistent.

What's Changed

  • Adjust typing a bit by @dlwh in #2

New Contributors

  • @dlwh made their first contribution in #2

Full Changelog: 0.0.45...v0.0.47

Release 0.0.45

29 Dec 21:51

Choose a tag to compare

Added

  • RWKV recurrence ops: RWKV-4, RWKV-6 (multi-head + variable-length packing + reverse), RWKV-7 DPLR, and RWKV-7 Mul (multiplicative reparameterization)
  • Triton + XLA kernel implementations for rwkv4, rwkv6, rwkv7, rwkv7_mul
    Unit tests for RWKV kernels (output parity vs XLA + gradient shape checks)

Changed

  • XLA block-sparse attention: delegates to dense attention when softmax_aux=None, improves KV-head replication for GQA/MQA, and tightens logits masking/soft-cap handling

v0.0.40

26 Dec 16:04

Choose a tag to compare

v0.0.40

Changes since v0.0.31.

Highlights

  • Triton is now optional: base installs no longer require Triton. Install GPU/Triton extras to enable GPU kernels.
  • TPU Flash Attention: sliding window + logits soft-cap support (forward/backward), plus cache layout fixes and stability improvements.
  • Ring Attention: Triton backend now uses a block-sparse inner kernel for correct global position/segment masking; q_position_ids/kv_position_ids added across TPU/Triton/XLA backends.
  • New/expanded kernels: prefill_page_attention for TPU Pallas + XLA; native_sparse_attention extended with per-query-block index/count layouts.

Fixes & Improvements

  • Numerical stability improvements across attention kernels (float32 accumulators, stable soft-cap/sink handling).
  • Ragged decode correctness fixes (in-bounds masking, fully-masked blocks, KV length < block size handling).
  • Triton compatibility updates (kernel API changes, wrapper unwrapping order support, min dot block size enforcement).
  • Correctness fixes in grouped_matmul(v2), mean_pooling, and page_attention.
  • Full XLA cache disabled by default to prevent GPU kernel launch failures.

Upgrade Notes

  • Persistent-cache version bumped for invalidation; previous caches will be rebuilt.
  • softmax_aux handling is unified to Float[Array, "num_sinks"] across attention kernels.
  • Tests were reorganized; module operation tests now live under test/modules/operations.

Install

pip install -U ejkernel

# GPU/Triton kernels
pip install -U "ejkernel[gpu]"
# or
pip install -U "ejkernel[triton]"

Full Changelog: v0.0.31...v0.0.40

v0.0.31

21 Dec 18:56

Choose a tag to compare

v0.0.31 - Initial Release

ejKernel: High-Performance JAX Kernels for Deep Learning

We're excited to announce the first public release of ejKernel — a production-grade kernel library for JAX with automatic multi-backend support, sophisticated autotuning, and seamless execution across GPUs, TPUs, and CPUs.

Highlights

  • Multi-Backend Support: Automatic platform detection with Triton (GPU), Pallas (TPU), and XLA (universal) backends
  • 7-Tier Configuration System: Intelligent kernel selection with persistent caching and autotuning
  • Memory Efficient: Custom VJP implementations reducing gradient memory from O(N²) to O(N)
  • Type Safe: Full jaxtyping annotations with runtime validation via beartype
  • Distributed Ready: Full shard_map integration for model and data parallelism

Attention Mechanisms (14)

Operation Description
Flash Attention v2 Memory-efficient exact attention with sliding window, soft capping
Ring Attention Distributed sequence parallelism for 100K+ token sequences
Page Attention KV-cache optimized inference with continuous batching
Prefill Page Attention Separate prefill phase handling
Ragged Page Attention v2/v3 Variable-length paged attention with attention sinks
Ragged Decode Attention Efficient variable-length decoding
Block Sparse Attention Configurable local+global sparse patterns
GLA Gated Linear Attention with dual gating mechanism
Lightning Attention Layer-adaptive attention with EMA decay
Kernel Delta Attention Delta-rule linear attention with decay control
MLA Multi-head Latent Attention with compressed KV
Unified Attention vLLM-style paged attention
Scaled Dot-Product Attention Standard reference implementation

State Space Models (2)

Operation Description
State Space v1 Mamba1-style SSM with 2D A matrix
State Space v2 Mamba2-style SSM with per-head scalar A, optional gated RMSNorm

Other Operations (5)

Operation Description
Grouped MatMul Efficient batched matrix ops for MoE
Grouped MatMul v2 Enhanced with shard_map support
Mean Pooling Variable-length sequence aggregation
Recurrent Optimized RNN/LSTM/GRU operations
Native Sparse Block-sparse matrix computations

Platform Support

Backend Hardware Key Operations
Triton NVIDIA/AMD GPU Flash, Ring, Page, GLA, Lightning, MLA
Pallas Google TPU Flash, Ring, Page, Ragged, Block Sparse
XLA GPU/TPU/CPU All operations (universal fallback)

Requirements

  • Python 3.11-3.13
  • JAX >= 0.8.0
  • Triton == 3.4.0 (GPU)

Installation

pip install ejkernel

# GPU support
pip install ejkernel[gpu]

# TPU support
pip install ejkernel[tpu]

Quick Example

from ejkernel.modules import flash_attention

output = flash_attention(
    query, key, value,
    causal=True,
    sliding_window=128,
    logits_soft_cap=30.0,
)

Documentation

Full documentation: https://ejkernel.readthedocs.io/

Full Changelog: https://github.com/erfanzar/ejkernel/commits/v0.0.31