Skip to content

Commit 870e2fb

Browse files
committed
Add RWKV-6 and RWKV-7 recurrence operation modules with tests
- Implement RWKV-6 recurrence kernel with multi-head support and variable-length sequence packing. - Implement RWKV-7 Diagonal + Low-Rank (DPLR) recurrence kernel and its multiplicative parameterization. - Add unit tests for RWKV-6 and RWKV-7 to ensure output consistency with XLA implementations. - Include gradient shape tests for RWKV-6 and RWKV-7 to validate backward compatibility. - Ensure compatibility with Triton and XLA backends for both RWKV-6 and RWKV-7 operations.
1 parent 39f1f88 commit 870e2fb

30 files changed

Lines changed: 2989 additions & 8 deletions

README.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ ejKernel is a production-grade kernel library for JAX that provides highly optim
3434

3535
### State-of-the-Art Operations
3636

37-
- **20+ Deep Learning Operations**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, State Space Models (Mamba), and more
37+
- **25+ Deep Learning Operations**: Flash Attention v2, Ring Attention, Page Attention, Block Sparse, GLA, Lightning, State Space Models (Mamba), RWKV (v4/v6/v7), and more
3838
- **Memory Efficiency**: Custom VJP implementations with O(N) memory complexity for attention
3939
- **Distributed Support**: Full shard_map integration for model and data parallelism
4040
- **Mixed Precision**: Comprehensive dtype support with automatic gradient conversion
@@ -266,6 +266,15 @@ kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
266266
| **Prefill Page Attention** | Page attention prefill phase | O(N) | Separate prefill handling |
267267
| **Scaled Dot-Product Attention** | Standard attention | O(N²) | Basic reference implementation |
268268

269+
### Recurrent Linear Attention (RWKV)
270+
271+
| Operation | Description | Key Features |
272+
| ------------- | ------------------------------------ | -------------------------------------------------- |
273+
| **RWKV-4** | Time-mix recurrence | Numerically stable (α,β,ε) state, O(N) memory |
274+
| **RWKV-6** | Multi-head linear attention | Variable-length packing, reverse mode, O(N) memory |
275+
| **RWKV-7** | DPLR (Diagonal + Low-Rank) recurrence| (a,b) parameterization, state-space inspired |
276+
| **RWKV-7 Mul**| Multiplicative RWKV-7 variant | (kk,a) reparameterization for optimized kernels |
277+
269278
### Other Operations
270279

271280
| Operation | Description | Use Case |
@@ -307,6 +316,10 @@ kernel_with_custom_grad.defvjp(kernel_fwd, kernel_bwd)
307316
| Prefill Page Attention | - |||
308317
| State Space v1 | - | - ||
309318
| State Space v2 | - | - ||
319+
| RWKV-4 || - ||
320+
| RWKV-6 || - ||
321+
| RWKV-7 || - ||
322+
| RWKV-7 Mul || - ||
310323

311324
✅ = Production ready | 🚧 = Under development | - = Not available
312325

ejkernel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import os as _os
4242

4343
_os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
44-
__version__ = "0.0.40"
44+
__version__ = "0.0.45"
4545

4646
from . import kernels, modules, types, utils, xla_utils
4747
from .kernels import Backend, Platform, kernel_registry

ejkernel/kernels/_triton/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from .ragged_page_attention_v3 import ragged_page_attention_v3
4646
from .recurrent import recurrent
4747
from .ring_attention import ring_attention
48+
from .rwkv4 import rwkv4
49+
from .rwkv6 import rwkv6
50+
from .rwkv7 import rwkv7, rwkv7_mul
4851
from .unified_attention import unified_attention
4952

5053
__all__ = (
@@ -61,5 +64,9 @@
6164
"recurrent",
6265
"recurrent_gla",
6366
"ring_attention",
67+
"rwkv4",
68+
"rwkv6",
69+
"rwkv7",
70+
"rwkv7_mul",
6471
"unified_attention",
6572
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Triton backend for RWKV-4 time-mix recurrence."""
16+
17+
from ._interface import rwkv4
18+
19+
__all__ = [
20+
"rwkv4",
21+
]
22+
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""RWKV-4 recurrent time-mix kernel (Triton)."""
16+
17+
from __future__ import annotations
18+
19+
from functools import partial
20+
21+
import jax
22+
import jax.numpy as jnp
23+
import jaxtyping
24+
from beartype import beartype
25+
from jaxtyping import Array, Float
26+
27+
from ..._registry import Backend, Platform, kernel_registry
28+
from ..._xla.rwkv4 import rwkv4 as xla_rwkv4
29+
from ._triton_impl_fwd import fwd_triton_impl
30+
31+
32+
def _fwd_call(
33+
w: Float[Array, "chans"],
34+
u: Float[Array, "chans"],
35+
k: Float[Array, "batch seq_len chans"],
36+
v: Float[Array, "batch seq_len chans"],
37+
state: Float[Array, "batch three chans"] | None,
38+
):
39+
state_was_none = state is None
40+
if state is None:
41+
bsz, _, chans = k.shape
42+
alpha0 = jnp.zeros((bsz, chans), dtype=jnp.float32)
43+
beta0 = jnp.zeros((bsz, chans), dtype=jnp.float32)
44+
eps0 = jnp.full((bsz, chans), -1e30, dtype=jnp.float32)
45+
state = jnp.stack([alpha0, beta0, eps0], axis=1)
46+
47+
w_neg = -jnp.exp(w.astype(jnp.float32))
48+
o, final_state = fwd_triton_impl(w_neg, u.astype(jnp.float32), k, v, state.astype(jnp.float32))
49+
residual = (w, u, k, v, state, state_was_none)
50+
return (o, final_state), residual
51+
52+
53+
def _bwd_call(
54+
residual,
55+
grads,
56+
):
57+
(w, u, k, v, state, state_was_none) = residual
58+
do, dstate = grads
59+
60+
def f(w_, u_, k_, v_, state_):
61+
return xla_rwkv4(w_, u_, k_, v_, state_)
62+
63+
(o_ref, state_ref), vjp = jax.vjp(f, w, u, k, v, state)
64+
del o_ref, state_ref
65+
dw, du, dk, dv, dstate_in = vjp((do, dstate))
66+
if state_was_none:
67+
dstate_in = None
68+
return dw, du, dk, dv, dstate_in
69+
70+
71+
@partial(jax.custom_vjp)
72+
def _rwkv4(
73+
w: Float[Array, "chans"],
74+
u: Float[Array, "chans"],
75+
k: Float[Array, "batch seq_len chans"],
76+
v: Float[Array, "batch seq_len chans"],
77+
state: Float[Array, "batch three chans"] | None = None,
78+
) -> tuple[Float[Array, "batch seq_len chans"], Float[Array, "batch three chans"]]:
79+
if state is None:
80+
bsz, _, chans = k.shape
81+
alpha0 = jnp.zeros((bsz, chans), dtype=jnp.float32)
82+
beta0 = jnp.zeros((bsz, chans), dtype=jnp.float32)
83+
eps0 = jnp.full((bsz, chans), -1e30, dtype=jnp.float32)
84+
state = jnp.stack([alpha0, beta0, eps0], axis=1)
85+
86+
w_neg = -jnp.exp(w.astype(jnp.float32))
87+
return fwd_triton_impl(w_neg, u.astype(jnp.float32), k, v, state.astype(jnp.float32))
88+
89+
90+
_rwkv4.defvjp(_fwd_call, _bwd_call)
91+
92+
93+
@kernel_registry.register("rwkv4", Platform.TRITON, Backend.GPU)
94+
@jaxtyping.jaxtyped(typechecker=beartype)
95+
def rwkv4(
96+
w: Float[Array, "chans"],
97+
u: Float[Array, "chans"],
98+
k: Float[Array, "batch seq_len chans"],
99+
v: Float[Array, "batch seq_len chans"],
100+
state: Float[Array, "batch three chans"] | None = None,
101+
) -> tuple[Float[Array, "batch seq_len chans"], Float[Array, "batch three chans"]]:
102+
"""RWKV-4 time-mix recurrence (Triton GPU implementation).
103+
104+
Args:
105+
w: Time-decay parameter in log space `[C]`.
106+
u: Time-mix bias `[C]`.
107+
k: Key tensor `[B, T, C]`.
108+
v: Value tensor `[B, T, C]`.
109+
state: Optional initial state `[B, 3, C]` (alpha, beta, eps).
110+
111+
Returns:
112+
Tuple of (output `[B, T, C]`, final_state `[B, 3, C]`).
113+
"""
114+
return _rwkv4(w, u, k, v, state)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""RWKV-4 forward pass Triton kernel implementation.
16+
17+
This module provides the Triton GPU kernel for RWKV-4 time-mix recurrence.
18+
The kernel processes sequences in a numerically-stable manner using the
19+
(alpha, beta, eps) state formulation.
20+
"""
21+
22+
from __future__ import annotations
23+
24+
import jax
25+
import triton
26+
import triton.language as tl
27+
from jax import numpy as jnp
28+
from jaxtyping import Array, Float
29+
30+
from ejkernel.callib import cdiv, triton_call
31+
32+
33+
@triton.jit
34+
def _rwkv4_fwd_kernel(
35+
w_ptr, # [-exp(w_raw)], [C]
36+
u_ptr, # [C]
37+
k_ptr, # [B, T, C]
38+
v_ptr, # [B, T, C]
39+
state_ptr, # [B, 3, C]
40+
o_ptr, # [B, T, C]
41+
state_out_ptr, # [B, 3, C]
42+
T: tl.constexpr,
43+
C: tl.constexpr,
44+
BLOCK_C: tl.constexpr,
45+
):
46+
b = tl.program_id(0)
47+
c_blk = tl.program_id(1)
48+
49+
cs = c_blk * BLOCK_C + tl.arange(0, BLOCK_C)
50+
cmask = cs < C
51+
52+
w = tl.load(w_ptr + cs, mask=cmask, other=0.0).to(tl.float32)
53+
u = tl.load(u_ptr + cs, mask=cmask, other=0.0).to(tl.float32)
54+
55+
base_state = (b * 3) * C
56+
alpha = tl.load(state_ptr + base_state + 0 * C + cs, mask=cmask, other=0.0).to(tl.float32)
57+
beta = tl.load(state_ptr + base_state + 1 * C + cs, mask=cmask, other=0.0).to(tl.float32)
58+
eps = tl.load(state_ptr + base_state + 2 * C + cs, mask=cmask, other=-1e30).to(tl.float32)
59+
60+
base_seq = b * T * C
61+
for t in range(0, T):
62+
off = base_seq + t * C + cs
63+
kt = tl.load(k_ptr + off, mask=cmask, other=0.0).to(tl.float32)
64+
vt = tl.load(v_ptr + off, mask=cmask, other=0.0).to(tl.float32)
65+
66+
ukt = u + kt
67+
tau = tl.maximum(ukt, eps)
68+
e1a = tl.exp(eps - tau)
69+
e2a = tl.exp(ukt - tau)
70+
wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)
71+
tl.store(o_ptr + off, wkv.to(o_ptr.dtype.element_ty), mask=cmask)
72+
73+
w_eps = w + eps
74+
eps_next = tl.maximum(w_eps, kt)
75+
e1b = tl.exp(w_eps - eps_next)
76+
e2b = tl.exp(kt - eps_next)
77+
alpha = e1b * alpha + e2b * vt
78+
beta = e1b * beta + e2b
79+
eps = eps_next
80+
81+
base_state_out = (b * 3) * C
82+
tl.store(state_out_ptr + base_state_out + 0 * C + cs, alpha.to(tl.float32), mask=cmask)
83+
tl.store(state_out_ptr + base_state_out + 1 * C + cs, beta.to(tl.float32), mask=cmask)
84+
tl.store(state_out_ptr + base_state_out + 2 * C + cs, eps.to(tl.float32), mask=cmask)
85+
86+
87+
def fwd_triton_impl(
88+
w: Float[Array, "chans"],
89+
u: Float[Array, "chans"],
90+
k: Float[Array, "batch seq_len chans"],
91+
v: Float[Array, "batch seq_len chans"],
92+
state: Float[Array, "batch three chans"],
93+
) -> tuple[Float[Array, "batch seq_len chans"], Float[Array, "batch three chans"]]:
94+
"""Execute RWKV-4 forward pass on GPU via Triton.
95+
96+
Args:
97+
w: Negative exponentiated time-decay `[C]` (already `-exp(w_raw)`).
98+
u: Time-mix bias `[C]`.
99+
k: Key tensor `[B, T, C]`.
100+
v: Value tensor `[B, T, C]`.
101+
state: Initial state `[B, 3, C]` (alpha, beta, eps).
102+
103+
Returns:
104+
Tuple of (output `[B, T, C]`, final_state `[B, 3, C]`).
105+
"""
106+
B, T, C = k.shape
107+
out_shape = jax.ShapeDtypeStruct(k.shape, v.dtype)
108+
state_shape = jax.ShapeDtypeStruct((B, 3, C), jnp.float32)
109+
110+
BLOCK_C = 128 if C >= 128 else 64 if C >= 64 else 32
111+
grid = (B, cdiv(C, BLOCK_C))
112+
113+
o, state_out = triton_call(
114+
w,
115+
u,
116+
k,
117+
v,
118+
state,
119+
kernel=_rwkv4_fwd_kernel,
120+
out_shape=[out_shape, state_shape],
121+
name="ejkernel::triton::rwkv4_fwd",
122+
grid=grid,
123+
T=T,
124+
C=C,
125+
BLOCK_C=BLOCK_C,
126+
)
127+
return o, state_out
128+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2025 The EasyDeL/ejKernel Author @erfanzar (Erfan Zare Chavoshi).
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Triton backend for RWKV-6 recurrence."""
16+
17+
from ._interface import rwkv6
18+
19+
__all__ = [
20+
"rwkv6",
21+
]
22+

0 commit comments

Comments
 (0)